@@ -172,6 +172,80 @@ func TestGetValidToken_UsesStoredCredentialsForRefresh(t *testing.T) {
172172 }
173173}
174174
175+ // TestGetValidToken_UsesStoredAuthServerForRefresh verifies that silent
176+ // refresh uses the discovered auth server rather than assuming the MCP
177+ // server URL also hosts the OAuth metadata.
178+ func TestGetValidToken_UsesStoredAuthServerForRefresh (t * testing.T ) {
179+ var refreshRequests int
180+
181+ authMux := http .NewServeMux ()
182+ authSrv := httptest .NewServer (authMux )
183+ defer authSrv .Close ()
184+
185+ authMux .HandleFunc ("/.well-known/oauth-authorization-server" , func (w http.ResponseWriter , _ * http.Request ) {
186+ w .Header ().Set ("Content-Type" , "application/json" )
187+ _ = json .NewEncoder (w ).Encode (map [string ]any {
188+ "issuer" : authSrv .URL ,
189+ "token_endpoint" : authSrv .URL + "/token" ,
190+ "authorization_endpoint" : authSrv .URL + "/authorize" ,
191+ })
192+ })
193+ authMux .HandleFunc ("/token" , func (w http.ResponseWriter , r * http.Request ) {
194+ refreshRequests ++
195+ if err := r .ParseForm (); err != nil {
196+ t .Fatal (err )
197+ }
198+ if got := r .FormValue ("refresh_token" ); got != "old-rt" {
199+ t .Fatalf ("refresh_token = %q, want %q" , got , "old-rt" )
200+ }
201+
202+ w .Header ().Set ("Content-Type" , "application/json" )
203+ _ = json .NewEncoder (w ).Encode (map [string ]any {
204+ "access_token" : "fresh-at" ,
205+ "token_type" : "Bearer" ,
206+ "expires_in" : 3600 ,
207+ "refresh_token" : "fresh-rt" ,
208+ })
209+ })
210+
211+ mcpSrv := httptest .NewServer (http .NotFoundHandler ())
212+ defer mcpSrv .Close ()
213+
214+ store := NewInMemoryTokenStore ()
215+ expiredToken := & OAuthToken {
216+ AccessToken : "old-at" ,
217+ TokenType : "Bearer" ,
218+ RefreshToken : "old-rt" ,
219+ ExpiresAt : time .Now ().Add (- 1 * time .Hour ),
220+ ClientID : "stored-cid" ,
221+ ClientSecret : "stored-csec" ,
222+ AuthServer : authSrv .URL ,
223+ }
224+ if err := store .StoreToken (mcpSrv .URL , expiredToken ); err != nil {
225+ t .Fatal (err )
226+ }
227+
228+ transport := & oauthTransport {
229+ base : http .DefaultTransport ,
230+ tokenStore : store ,
231+ baseURL : mcpSrv .URL ,
232+ }
233+
234+ got := transport .getValidToken (t .Context ())
235+ if got == nil {
236+ t .Fatal ("getValidToken returned nil, expected refreshed token" )
237+ }
238+ if got .AccessToken != "fresh-at" {
239+ t .Fatalf ("AccessToken = %q, want %q" , got .AccessToken , "fresh-at" )
240+ }
241+ if got .AuthServer != authSrv .URL {
242+ t .Fatalf ("AuthServer = %q, want %q" , got .AuthServer , authSrv .URL )
243+ }
244+ if refreshRequests != 1 {
245+ t .Fatalf ("refreshRequests = %d, want 1" , refreshRequests )
246+ }
247+ }
248+
175249// TestOAuthTokenClientCredentials_JSONRoundTrip verifies that ClientID and
176250// ClientSecret survive JSON serialization (important for keyring storage).
177251func TestOAuthTokenClientCredentials_JSONRoundTrip (t * testing.T ) {
@@ -183,6 +257,7 @@ func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
183257 ExpiresAt : time .Date (2025 , 6 , 1 , 0 , 0 , 0 , 0 , time .UTC ),
184258 ClientID : "cid" ,
185259 ClientSecret : "csec" ,
260+ AuthServer : "https://auth.example.com" ,
186261 }
187262
188263 data , err := json .Marshal (token )
@@ -201,6 +276,9 @@ func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
201276 if got .ClientSecret != "csec" {
202277 t .Errorf ("ClientSecret = %q, want %q" , got .ClientSecret , "csec" )
203278 }
279+ if got .AuthServer != "https://auth.example.com" {
280+ t .Errorf ("AuthServer = %q, want %q" , got .AuthServer , "https://auth.example.com" )
281+ }
204282}
205283
206284// TestOAuthTokenClientCredentials_OmittedWhenEmpty verifies the omitempty
0 commit comments