Skip to content

Commit 29bffa4

Browse files
authored
Merge pull request #2365 from dgageot/debug-auth
Debug oauth
2 parents 8af0911 + 46731c1 commit 29bffa4

8 files changed

Lines changed: 840 additions & 8 deletions

File tree

cmd/root/debug.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func newDebugCmd() *cobra.Command {
5555
addRuntimeConfigFlags(cmd, &flags.runConfig)
5656

5757
cmd.AddCommand(newDebugAuthCmd())
58+
cmd.AddCommand(newDebugOAuthCmd())
5859

5960
return cmd
6061
}

cmd/root/debug_oauth.go

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
package root
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"time"
8+
9+
"github.com/spf13/cobra"
10+
11+
"github.com/docker/docker-agent/pkg/config"
12+
"github.com/docker/docker-agent/pkg/config/latest"
13+
"github.com/docker/docker-agent/pkg/telemetry"
14+
"github.com/docker/docker-agent/pkg/tools/mcp"
15+
)
16+
17+
func newDebugOAuthCmd() *cobra.Command {
18+
cmd := &cobra.Command{
19+
Use: "oauth",
20+
Short: "OAuth token management",
21+
}
22+
23+
cmd.AddCommand(newDebugOAuthListCmd())
24+
cmd.AddCommand(newDebugOAuthRemoveCmd())
25+
cmd.AddCommand(newDebugOAuthLoginCmd())
26+
27+
return cmd
28+
}
29+
30+
func newDebugOAuthListCmd() *cobra.Command {
31+
var jsonOutput bool
32+
33+
cmd := &cobra.Command{
34+
Use: "list",
35+
Short: "List all stored OAuth tokens",
36+
Args: cobra.NoArgs,
37+
RunE: func(cmd *cobra.Command, _ []string) (commandErr error) {
38+
ctx := cmd.Context()
39+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "list"})
40+
defer func() {
41+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "list"}, commandErr)
42+
}()
43+
44+
w := cmd.OutOrStdout()
45+
46+
entries, err := mcp.ListOAuthTokens()
47+
if err != nil {
48+
return fmt.Errorf("failed to list OAuth tokens: %w", err)
49+
}
50+
51+
if len(entries) == 0 {
52+
if jsonOutput {
53+
return json.NewEncoder(w).Encode([]any{})
54+
}
55+
fmt.Fprintln(w, "No OAuth tokens stored.")
56+
return nil
57+
}
58+
59+
if jsonOutput {
60+
return printOAuthListJSON(w, entries)
61+
}
62+
63+
printOAuthListText(w, entries)
64+
return nil
65+
},
66+
}
67+
68+
cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output in JSON format")
69+
70+
return cmd
71+
}
72+
73+
type oauthListEntry struct {
74+
ResourceURL string `json:"resource_url"`
75+
TokenType string `json:"token_type,omitempty"`
76+
Scope string `json:"scope,omitempty"`
77+
ExpiresAt time.Time `json:"expires_at,omitzero"`
78+
Expired bool `json:"expired"`
79+
AccessToken string `json:"access_token"`
80+
RefreshToken bool `json:"has_refresh_token"`
81+
}
82+
83+
func printOAuthListJSON(w io.Writer, entries []mcp.OAuthTokenEntry) error {
84+
var out []oauthListEntry
85+
for _, e := range entries {
86+
out = append(out, oauthListEntry{
87+
ResourceURL: e.ResourceURL,
88+
TokenType: e.Token.TokenType,
89+
Scope: e.Token.Scope,
90+
ExpiresAt: e.Token.ExpiresAt,
91+
Expired: e.Token.IsExpired(),
92+
AccessToken: truncateToken(e.Token.AccessToken),
93+
RefreshToken: e.Token.RefreshToken != "",
94+
})
95+
}
96+
enc := json.NewEncoder(w)
97+
enc.SetIndent("", " ")
98+
return enc.Encode(out)
99+
}
100+
101+
func printOAuthListText(w io.Writer, entries []mcp.OAuthTokenEntry) {
102+
for i, e := range entries {
103+
if i > 0 {
104+
fmt.Fprintln(w)
105+
}
106+
fmt.Fprintf(w, "Resource: %s\n", e.ResourceURL)
107+
if e.Token.TokenType != "" {
108+
fmt.Fprintf(w, "Token Type: %s\n", e.Token.TokenType)
109+
}
110+
if e.Token.Scope != "" {
111+
fmt.Fprintf(w, "Scope: %s\n", e.Token.Scope)
112+
}
113+
fmt.Fprintf(w, "Access Token: %s\n", truncateToken(e.Token.AccessToken))
114+
fmt.Fprintf(w, "Refresh Token: %v\n", e.Token.RefreshToken != "")
115+
if !e.Token.ExpiresAt.IsZero() {
116+
fmt.Fprintf(w, "Expires at: %s\n", e.Token.ExpiresAt.Local().Format(time.RFC3339))
117+
}
118+
if e.Token.IsExpired() {
119+
fmt.Fprintln(w, "Status: ❌ Expired")
120+
} else {
121+
fmt.Fprintln(w, "Status: ✅ Valid")
122+
}
123+
}
124+
}
125+
126+
func truncateToken(token string) string {
127+
const previewLen = 10
128+
if len(token) <= previewLen*2 {
129+
return token
130+
}
131+
return token[:previewLen] + "..." + token[len(token)-previewLen:]
132+
}
133+
134+
func newDebugOAuthRemoveCmd() *cobra.Command {
135+
return &cobra.Command{
136+
Use: "remove <resource-url>",
137+
Short: "Remove a stored OAuth token",
138+
Args: cobra.ExactArgs(1),
139+
RunE: func(cmd *cobra.Command, args []string) (commandErr error) {
140+
ctx := cmd.Context()
141+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "remove"})
142+
defer func() {
143+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "remove"}, commandErr)
144+
}()
145+
146+
if err := mcp.RemoveOAuthToken(args[0]); err != nil {
147+
return err
148+
}
149+
150+
fmt.Fprintf(cmd.OutOrStdout(), "Removed OAuth token for %s\n", args[0])
151+
return nil
152+
},
153+
}
154+
}
155+
156+
func newDebugOAuthLoginCmd() *cobra.Command {
157+
var flags debugFlags
158+
159+
cmd := &cobra.Command{
160+
Use: "login <agent-file> <mcp-name>",
161+
Short: "Perform OAuth login for a remote MCP server",
162+
Args: cobra.ExactArgs(2),
163+
RunE: func(cmd *cobra.Command, args []string) (commandErr error) {
164+
ctx := cmd.Context()
165+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "login"})
166+
defer func() {
167+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "login"}, commandErr)
168+
}()
169+
170+
agentFile := args[0]
171+
mcpName := args[1]
172+
173+
// Load the agent config to find the MCP server URL.
174+
agentSource, err := config.Resolve(agentFile, flags.runConfig.EnvProvider())
175+
if err != nil {
176+
return err
177+
}
178+
179+
cfg, err := config.Load(ctx, agentSource)
180+
if err != nil {
181+
return err
182+
}
183+
184+
serverURL, err := findMCPRemoteURL(cfg, mcpName)
185+
if err != nil {
186+
return err
187+
}
188+
189+
w := cmd.OutOrStdout()
190+
fmt.Fprintf(w, "Starting OAuth login for %s (%s)...\n", mcpName, serverURL)
191+
192+
if err := mcp.PerformOAuthLogin(ctx, serverURL); err != nil {
193+
return fmt.Errorf("OAuth login failed: %w", err)
194+
}
195+
196+
fmt.Fprintf(w, "✅ OAuth login successful for %s\n", serverURL)
197+
return nil
198+
},
199+
}
200+
201+
addRuntimeConfigFlags(cmd, &flags.runConfig)
202+
203+
return cmd
204+
}
205+
206+
// findMCPRemoteURL looks up the remote URL for the named MCP server in the config.
207+
// It matches by name (top-level mcps key or toolset name), by URL substring,
208+
// or returns the only remote MCP if there is exactly one.
209+
func findMCPRemoteURL(cfg *latest.Config, name string) (string, error) {
210+
// Collect all remote MCP URLs with their identifiers.
211+
type mcpEntry struct {
212+
label string
213+
url string
214+
}
215+
var all []mcpEntry
216+
217+
for k, m := range cfg.MCPs {
218+
if m.Remote.URL != "" {
219+
all = append(all, mcpEntry{label: k, url: m.Remote.URL})
220+
}
221+
}
222+
for _, agent := range cfg.Agents {
223+
for _, ts := range agent.Toolsets {
224+
if ts.Type == "mcp" && ts.Remote.URL != "" {
225+
label := ts.Name
226+
if label == "" {
227+
label = ts.Remote.URL
228+
}
229+
all = append(all, mcpEntry{label: label, url: ts.Remote.URL})
230+
}
231+
}
232+
}
233+
234+
// Exact match by name/label.
235+
for _, e := range all {
236+
if e.label == name {
237+
return e.url, nil
238+
}
239+
}
240+
241+
// Exact match by URL.
242+
for _, e := range all {
243+
if e.url == name {
244+
return e.url, nil
245+
}
246+
}
247+
248+
// Build helpful error.
249+
var labels []string
250+
for _, e := range all {
251+
labels = append(labels, e.label)
252+
}
253+
if len(labels) > 0 {
254+
return "", fmt.Errorf("MCP %q not found; available: %v", name, labels)
255+
}
256+
return "", fmt.Errorf("MCP %q not found; no remote MCPs found in config", name)
257+
}

pkg/tools/mcp/oauth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
237237
return nil
238238
}
239239

240-
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, "", "")
240+
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, token.ClientID, token.ClientSecret)
241241
if err != nil {
242242
slog.Debug("Token refresh failed, will require interactive auth", "error", err)
243243
return nil
@@ -389,6 +389,9 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer,
389389
return fmt.Errorf("failed to exchange code for token: %w", err)
390390
}
391391

392+
token.ClientID = clientID
393+
token.ClientSecret = clientSecret
394+
392395
if err := t.tokenStore.StoreToken(t.baseURL, token); err != nil {
393396
return fmt.Errorf("failed to store token: %w", err)
394397
}

pkg/tools/mcp/oauth_helpers.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ func ExchangeCodeForToken(ctx context.Context, tokenEndpoint, code, codeVerifier
7979
token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)
8080
}
8181

82+
token.ClientID = clientID
83+
token.ClientSecret = clientSecret
84+
8285
return &token, nil
8386
}
8487

@@ -202,5 +205,9 @@ func RefreshAccessToken(ctx context.Context, tokenEndpoint, refreshToken, client
202205
token.RefreshToken = refreshToken
203206
}
204207

208+
// Preserve client credentials so subsequent refreshes work
209+
token.ClientID = clientID
210+
token.ClientSecret = clientSecret
211+
205212
return &token, nil
206213
}

0 commit comments

Comments
 (0)