Skip to content

Commit 46731c1

Browse files
committed
Add debug oauth commands: list, remove, and login
Add 'docker agent debug oauth list' to show all stored OAuth tokens. Add 'docker agent debug oauth remove' to delete a token by resource URL. Add 'docker agent debug oauth login' to trigger an OAuth flow for a remote MCP server. The keyring token store now maintains an index key to minimize keychain prompts when listing tokens, with a fallback to Keys() for pre-existing tokens. Assisted-By: docker-agent
1 parent 8af8b56 commit 46731c1

4 files changed

Lines changed: 562 additions & 7 deletions

File tree

cmd/root/debug.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func newDebugCmd() *cobra.Command {
5555
addRuntimeConfigFlags(cmd, &flags.runConfig)
5656

5757
cmd.AddCommand(newDebugAuthCmd())
58+
cmd.AddCommand(newDebugOAuthCmd())
5859

5960
return cmd
6061
}

cmd/root/debug_oauth.go

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
package root
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"time"
8+
9+
"github.com/spf13/cobra"
10+
11+
"github.com/docker/docker-agent/pkg/config"
12+
"github.com/docker/docker-agent/pkg/config/latest"
13+
"github.com/docker/docker-agent/pkg/telemetry"
14+
"github.com/docker/docker-agent/pkg/tools/mcp"
15+
)
16+
17+
func newDebugOAuthCmd() *cobra.Command {
18+
cmd := &cobra.Command{
19+
Use: "oauth",
20+
Short: "OAuth token management",
21+
}
22+
23+
cmd.AddCommand(newDebugOAuthListCmd())
24+
cmd.AddCommand(newDebugOAuthRemoveCmd())
25+
cmd.AddCommand(newDebugOAuthLoginCmd())
26+
27+
return cmd
28+
}
29+
30+
func newDebugOAuthListCmd() *cobra.Command {
31+
var jsonOutput bool
32+
33+
cmd := &cobra.Command{
34+
Use: "list",
35+
Short: "List all stored OAuth tokens",
36+
Args: cobra.NoArgs,
37+
RunE: func(cmd *cobra.Command, _ []string) (commandErr error) {
38+
ctx := cmd.Context()
39+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "list"})
40+
defer func() {
41+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "list"}, commandErr)
42+
}()
43+
44+
w := cmd.OutOrStdout()
45+
46+
entries, err := mcp.ListOAuthTokens()
47+
if err != nil {
48+
return fmt.Errorf("failed to list OAuth tokens: %w", err)
49+
}
50+
51+
if len(entries) == 0 {
52+
if jsonOutput {
53+
return json.NewEncoder(w).Encode([]any{})
54+
}
55+
fmt.Fprintln(w, "No OAuth tokens stored.")
56+
return nil
57+
}
58+
59+
if jsonOutput {
60+
return printOAuthListJSON(w, entries)
61+
}
62+
63+
printOAuthListText(w, entries)
64+
return nil
65+
},
66+
}
67+
68+
cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output in JSON format")
69+
70+
return cmd
71+
}
72+
73+
type oauthListEntry struct {
74+
ResourceURL string `json:"resource_url"`
75+
TokenType string `json:"token_type,omitempty"`
76+
Scope string `json:"scope,omitempty"`
77+
ExpiresAt time.Time `json:"expires_at,omitzero"`
78+
Expired bool `json:"expired"`
79+
AccessToken string `json:"access_token"`
80+
RefreshToken bool `json:"has_refresh_token"`
81+
}
82+
83+
func printOAuthListJSON(w io.Writer, entries []mcp.OAuthTokenEntry) error {
84+
var out []oauthListEntry
85+
for _, e := range entries {
86+
out = append(out, oauthListEntry{
87+
ResourceURL: e.ResourceURL,
88+
TokenType: e.Token.TokenType,
89+
Scope: e.Token.Scope,
90+
ExpiresAt: e.Token.ExpiresAt,
91+
Expired: e.Token.IsExpired(),
92+
AccessToken: truncateToken(e.Token.AccessToken),
93+
RefreshToken: e.Token.RefreshToken != "",
94+
})
95+
}
96+
enc := json.NewEncoder(w)
97+
enc.SetIndent("", " ")
98+
return enc.Encode(out)
99+
}
100+
101+
func printOAuthListText(w io.Writer, entries []mcp.OAuthTokenEntry) {
102+
for i, e := range entries {
103+
if i > 0 {
104+
fmt.Fprintln(w)
105+
}
106+
fmt.Fprintf(w, "Resource: %s\n", e.ResourceURL)
107+
if e.Token.TokenType != "" {
108+
fmt.Fprintf(w, "Token Type: %s\n", e.Token.TokenType)
109+
}
110+
if e.Token.Scope != "" {
111+
fmt.Fprintf(w, "Scope: %s\n", e.Token.Scope)
112+
}
113+
fmt.Fprintf(w, "Access Token: %s\n", truncateToken(e.Token.AccessToken))
114+
fmt.Fprintf(w, "Refresh Token: %v\n", e.Token.RefreshToken != "")
115+
if !e.Token.ExpiresAt.IsZero() {
116+
fmt.Fprintf(w, "Expires at: %s\n", e.Token.ExpiresAt.Local().Format(time.RFC3339))
117+
}
118+
if e.Token.IsExpired() {
119+
fmt.Fprintln(w, "Status: ❌ Expired")
120+
} else {
121+
fmt.Fprintln(w, "Status: ✅ Valid")
122+
}
123+
}
124+
}
125+
126+
func truncateToken(token string) string {
127+
const previewLen = 10
128+
if len(token) <= previewLen*2 {
129+
return token
130+
}
131+
return token[:previewLen] + "..." + token[len(token)-previewLen:]
132+
}
133+
134+
func newDebugOAuthRemoveCmd() *cobra.Command {
135+
return &cobra.Command{
136+
Use: "remove <resource-url>",
137+
Short: "Remove a stored OAuth token",
138+
Args: cobra.ExactArgs(1),
139+
RunE: func(cmd *cobra.Command, args []string) (commandErr error) {
140+
ctx := cmd.Context()
141+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "remove"})
142+
defer func() {
143+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "remove"}, commandErr)
144+
}()
145+
146+
if err := mcp.RemoveOAuthToken(args[0]); err != nil {
147+
return err
148+
}
149+
150+
fmt.Fprintf(cmd.OutOrStdout(), "Removed OAuth token for %s\n", args[0])
151+
return nil
152+
},
153+
}
154+
}
155+
156+
func newDebugOAuthLoginCmd() *cobra.Command {
157+
var flags debugFlags
158+
159+
cmd := &cobra.Command{
160+
Use: "login <agent-file> <mcp-name>",
161+
Short: "Perform OAuth login for a remote MCP server",
162+
Args: cobra.ExactArgs(2),
163+
RunE: func(cmd *cobra.Command, args []string) (commandErr error) {
164+
ctx := cmd.Context()
165+
telemetry.TrackCommand(ctx, "debug", []string{"oauth", "login"})
166+
defer func() {
167+
telemetry.TrackCommandError(ctx, "debug", []string{"oauth", "login"}, commandErr)
168+
}()
169+
170+
agentFile := args[0]
171+
mcpName := args[1]
172+
173+
// Load the agent config to find the MCP server URL.
174+
agentSource, err := config.Resolve(agentFile, flags.runConfig.EnvProvider())
175+
if err != nil {
176+
return err
177+
}
178+
179+
cfg, err := config.Load(ctx, agentSource)
180+
if err != nil {
181+
return err
182+
}
183+
184+
serverURL, err := findMCPRemoteURL(cfg, mcpName)
185+
if err != nil {
186+
return err
187+
}
188+
189+
w := cmd.OutOrStdout()
190+
fmt.Fprintf(w, "Starting OAuth login for %s (%s)...\n", mcpName, serverURL)
191+
192+
if err := mcp.PerformOAuthLogin(ctx, serverURL); err != nil {
193+
return fmt.Errorf("OAuth login failed: %w", err)
194+
}
195+
196+
fmt.Fprintf(w, "✅ OAuth login successful for %s\n", serverURL)
197+
return nil
198+
},
199+
}
200+
201+
addRuntimeConfigFlags(cmd, &flags.runConfig)
202+
203+
return cmd
204+
}
205+
206+
// findMCPRemoteURL looks up the remote URL for the named MCP server in the config.
207+
// It matches by name (top-level mcps key or toolset name), by URL substring,
208+
// or returns the only remote MCP if there is exactly one.
209+
func findMCPRemoteURL(cfg *latest.Config, name string) (string, error) {
210+
// Collect all remote MCP URLs with their identifiers.
211+
type mcpEntry struct {
212+
label string
213+
url string
214+
}
215+
var all []mcpEntry
216+
217+
for k, m := range cfg.MCPs {
218+
if m.Remote.URL != "" {
219+
all = append(all, mcpEntry{label: k, url: m.Remote.URL})
220+
}
221+
}
222+
for _, agent := range cfg.Agents {
223+
for _, ts := range agent.Toolsets {
224+
if ts.Type == "mcp" && ts.Remote.URL != "" {
225+
label := ts.Name
226+
if label == "" {
227+
label = ts.Remote.URL
228+
}
229+
all = append(all, mcpEntry{label: label, url: ts.Remote.URL})
230+
}
231+
}
232+
}
233+
234+
// Exact match by name/label.
235+
for _, e := range all {
236+
if e.label == name {
237+
return e.url, nil
238+
}
239+
}
240+
241+
// Exact match by URL.
242+
for _, e := range all {
243+
if e.url == name {
244+
return e.url, nil
245+
}
246+
}
247+
248+
// Build helpful error.
249+
var labels []string
250+
for _, e := range all {
251+
labels = append(labels, e.label)
252+
}
253+
if len(labels) > 0 {
254+
return "", fmt.Errorf("MCP %q not found; available: %v", name, labels)
255+
}
256+
return "", fmt.Errorf("MCP %q not found; no remote MCPs found in config", name)
257+
}

pkg/tools/mcp/oauth_login.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"log/slog"
9+
"net/http"
10+
"net/url"
11+
"time"
12+
13+
"golang.org/x/oauth2"
14+
)
15+
16+
// PerformOAuthLogin performs a standalone OAuth flow for the given MCP server URL.
17+
// It discovers the authorization server metadata, performs dynamic client registration,
18+
// opens the browser for user authorization, and stores the resulting token in the keyring.
19+
func PerformOAuthLogin(ctx context.Context, serverURL string) error {
20+
tokenStore := NewKeyringTokenStore()
21+
22+
o := &oauth{metadataClient: &http.Client{Timeout: 5 * time.Second}}
23+
24+
// Derive the base origin (scheme + host) from the server URL.
25+
// The well-known endpoints live at the origin, not under the SSE/path.
26+
parsed, err := url.Parse(serverURL)
27+
if err != nil {
28+
return fmt.Errorf("invalid server URL: %w", err)
29+
}
30+
baseURL := parsed.Scheme + "://" + parsed.Host
31+
32+
// Discover protected resource metadata.
33+
resourceURL := baseURL + "/.well-known/oauth-protected-resource"
34+
resp, err := http.Get(resourceURL) //nolint:gosec // URL is user-provided
35+
if err != nil {
36+
return fmt.Errorf("failed to fetch protected resource metadata: %w", err)
37+
}
38+
defer resp.Body.Close()
39+
40+
authServer := baseURL
41+
if resp.StatusCode == http.StatusOK {
42+
var resourceMetadata protectedResourceMetadata
43+
if decErr := json.NewDecoder(resp.Body).Decode(&resourceMetadata); decErr == nil && len(resourceMetadata.AuthorizationServers) > 0 {
44+
authServer = resourceMetadata.AuthorizationServers[0]
45+
}
46+
}
47+
48+
// Discover authorization server metadata.
49+
authServerMetadata, err := o.getAuthorizationServerMetadata(ctx, authServer)
50+
if err != nil {
51+
return fmt.Errorf("failed to fetch authorization server metadata: %w", err)
52+
}
53+
54+
// Set up the callback server for the redirect.
55+
callbackServer, err := NewCallbackServer()
56+
if err != nil {
57+
return fmt.Errorf("failed to create callback server: %w", err)
58+
}
59+
defer func() {
60+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
61+
defer cancel()
62+
if err := callbackServer.Shutdown(shutdownCtx); err != nil {
63+
slog.Error("Failed to shutdown callback server", "error", err)
64+
}
65+
}()
66+
67+
if err := callbackServer.Start(); err != nil {
68+
return fmt.Errorf("failed to start callback server: %w", err)
69+
}
70+
71+
redirectURI := callbackServer.GetRedirectURI()
72+
73+
// Dynamic client registration.
74+
var clientID, clientSecret string
75+
if authServerMetadata.RegistrationEndpoint != "" {
76+
clientID, clientSecret, err = RegisterClient(ctx, authServerMetadata, redirectURI, nil)
77+
if err != nil {
78+
return fmt.Errorf("dynamic client registration failed: %w", err)
79+
}
80+
} else {
81+
return errors.New("authorization server does not support dynamic client registration")
82+
}
83+
84+
// Generate PKCE and state.
85+
state, err := GenerateState()
86+
if err != nil {
87+
return fmt.Errorf("failed to generate state: %w", err)
88+
}
89+
callbackServer.SetExpectedState(state)
90+
verifier := GeneratePKCEVerifier()
91+
92+
authURL := BuildAuthorizationURL(
93+
authServerMetadata.AuthorizationEndpoint,
94+
clientID,
95+
redirectURI,
96+
state,
97+
oauth2.S256ChallengeFromVerifier(verifier),
98+
serverURL,
99+
)
100+
101+
// Open the browser and wait for the callback.
102+
code, receivedState, err := RequestAuthorizationCode(ctx, authURL, callbackServer, state)
103+
if err != nil {
104+
return fmt.Errorf("failed to get authorization code: %w", err)
105+
}
106+
107+
if receivedState != state {
108+
return errors.New("state mismatch in authorization response")
109+
}
110+
111+
// Exchange the code for a token.
112+
token, err := ExchangeCodeForToken(ctx, authServerMetadata.TokenEndpoint, code, verifier, clientID, clientSecret, redirectURI)
113+
if err != nil {
114+
return fmt.Errorf("failed to exchange code for token: %w", err)
115+
}
116+
117+
token.ClientID = clientID
118+
token.ClientSecret = clientSecret
119+
120+
if err := tokenStore.StoreToken(serverURL, token); err != nil {
121+
return fmt.Errorf("failed to store token: %w", err)
122+
}
123+
124+
return nil
125+
}

0 commit comments

Comments
 (0)