Skip to content

Commit 45be5d1

Browse files
authored
Merge pull request #2448 from pandego/fix/2417-mcp-oauth-refresh
fix(mcp): reuse discovered auth server for token refresh
2 parents cb465bf + 602c0d8 commit 45be5d1

3 files changed

Lines changed: 85 additions & 2 deletions

File tree

pkg/tools/mcp/oauth.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,10 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
259259
slog.Debug("Attempting silent token refresh", "url", t.baseURL)
260260

261261
o := &oauth{metadataClient: &http.Client{Timeout: 5 * time.Second}}
262-
metadata, err := o.getAuthorizationServerMetadata(ctx, t.baseURL)
262+
authServer := cmp.Or(token.AuthServer, t.baseURL)
263+
metadata, err := o.getAuthorizationServerMetadata(ctx, authServer)
263264
if err != nil {
264-
slog.Debug("Failed to fetch auth server metadata for refresh", "error", err)
265+
slog.Debug("Failed to fetch auth server metadata for refresh", "auth_server", authServer, "error", err)
265266
return nil
266267
}
267268

@@ -273,6 +274,7 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
273274
t.mu.Unlock()
274275
return nil
275276
}
277+
newToken.AuthServer = authServer
276278

277279
t.mu.Lock()
278280
t.refreshFailedAt = time.Time{} // reset on success
@@ -443,6 +445,7 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer,
443445

444446
token.ClientID = clientID
445447
token.ClientSecret = clientSecret
448+
token.AuthServer = resourceMetadata.AuthorizationServers[0]
446449

447450
if err := t.tokenStore.StoreToken(t.baseURL, token); err != nil {
448451
return fmt.Errorf("failed to store token: %w", err)
@@ -539,6 +542,7 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe
539542
token.ExpiresIn = int(expiresIn)
540543
token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)
541544
}
545+
token.AuthServer = resourceMetadata.AuthorizationServers[0]
542546

543547
if refreshToken, ok := tokenData["refresh_token"].(string); ok {
544548
token.RefreshToken = refreshToken

pkg/tools/mcp/oauth_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,80 @@ func TestGetValidToken_UsesStoredCredentialsForRefresh(t *testing.T) {
172172
}
173173
}
174174

175+
// TestGetValidToken_UsesStoredAuthServerForRefresh verifies that silent
176+
// refresh uses the discovered auth server rather than assuming the MCP
177+
// server URL also hosts the OAuth metadata.
178+
func TestGetValidToken_UsesStoredAuthServerForRefresh(t *testing.T) {
179+
var refreshRequests int
180+
181+
authMux := http.NewServeMux()
182+
authSrv := httptest.NewServer(authMux)
183+
defer authSrv.Close()
184+
185+
authMux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) {
186+
w.Header().Set("Content-Type", "application/json")
187+
_ = json.NewEncoder(w).Encode(map[string]any{
188+
"issuer": authSrv.URL,
189+
"token_endpoint": authSrv.URL + "/token",
190+
"authorization_endpoint": authSrv.URL + "/authorize",
191+
})
192+
})
193+
authMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
194+
refreshRequests++
195+
if err := r.ParseForm(); err != nil {
196+
t.Fatal(err)
197+
}
198+
if got := r.FormValue("refresh_token"); got != "old-rt" {
199+
t.Fatalf("refresh_token = %q, want %q", got, "old-rt")
200+
}
201+
202+
w.Header().Set("Content-Type", "application/json")
203+
_ = json.NewEncoder(w).Encode(map[string]any{
204+
"access_token": "fresh-at",
205+
"token_type": "Bearer",
206+
"expires_in": 3600,
207+
"refresh_token": "fresh-rt",
208+
})
209+
})
210+
211+
mcpSrv := httptest.NewServer(http.NotFoundHandler())
212+
defer mcpSrv.Close()
213+
214+
store := NewInMemoryTokenStore()
215+
expiredToken := &OAuthToken{
216+
AccessToken: "old-at",
217+
TokenType: "Bearer",
218+
RefreshToken: "old-rt",
219+
ExpiresAt: time.Now().Add(-1 * time.Hour),
220+
ClientID: "stored-cid",
221+
ClientSecret: "stored-csec",
222+
AuthServer: authSrv.URL,
223+
}
224+
if err := store.StoreToken(mcpSrv.URL, expiredToken); err != nil {
225+
t.Fatal(err)
226+
}
227+
228+
transport := &oauthTransport{
229+
base: http.DefaultTransport,
230+
tokenStore: store,
231+
baseURL: mcpSrv.URL,
232+
}
233+
234+
got := transport.getValidToken(t.Context())
235+
if got == nil {
236+
t.Fatal("getValidToken returned nil, expected refreshed token")
237+
}
238+
if got.AccessToken != "fresh-at" {
239+
t.Fatalf("AccessToken = %q, want %q", got.AccessToken, "fresh-at")
240+
}
241+
if got.AuthServer != authSrv.URL {
242+
t.Fatalf("AuthServer = %q, want %q", got.AuthServer, authSrv.URL)
243+
}
244+
if refreshRequests != 1 {
245+
t.Fatalf("refreshRequests = %d, want 1", refreshRequests)
246+
}
247+
}
248+
175249
// TestOAuthTokenClientCredentials_JSONRoundTrip verifies that ClientID and
176250
// ClientSecret survive JSON serialization (important for keyring storage).
177251
func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
@@ -183,6 +257,7 @@ func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
183257
ExpiresAt: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
184258
ClientID: "cid",
185259
ClientSecret: "csec",
260+
AuthServer: "https://auth.example.com",
186261
}
187262

188263
data, err := json.Marshal(token)
@@ -201,6 +276,9 @@ func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
201276
if got.ClientSecret != "csec" {
202277
t.Errorf("ClientSecret = %q, want %q", got.ClientSecret, "csec")
203278
}
279+
if got.AuthServer != "https://auth.example.com" {
280+
t.Errorf("AuthServer = %q, want %q", got.AuthServer, "https://auth.example.com")
281+
}
204282
}
205283

206284
// TestOAuthTokenClientCredentials_OmittedWhenEmpty verifies the omitempty

pkg/tools/mcp/tokenstore.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type OAuthToken struct {
2626
ExpiresAt time.Time `json:"expires_at"`
2727
ClientID string `json:"client_id,omitempty"`
2828
ClientSecret string `json:"client_secret,omitempty"`
29+
AuthServer string `json:"auth_server,omitempty"`
2930
}
3031

3132
// IsExpired checks if the token is expired

0 commit comments

Comments
 (0)