Skip to content

Commit 8af8b56

Browse files
committed
fix: persist OAuth client credentials for token refresh
The silent token refresh flow was broken because clientID and clientSecret were not stored with the OAuthToken. When getValidToken attempted to refresh an expired token, it passed empty strings for both credentials, causing OAuth servers to reject the request. - Add ClientID and ClientSecret fields to OAuthToken struct - Store client credentials when exchanging authorization codes for tokens - Preserve client credentials through token refresh cycles - Pass stored credentials in getValidToken silent refresh path - Add tests covering credential persistence and refresh flow Assisted-By: docker-agent
1 parent 8af0911 commit 8af8b56

4 files changed

Lines changed: 278 additions & 1 deletion

File tree

pkg/tools/mcp/oauth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
237237
return nil
238238
}
239239

240-
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, "", "")
240+
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, token.ClientID, token.ClientSecret)
241241
if err != nil {
242242
slog.Debug("Token refresh failed, will require interactive auth", "error", err)
243243
return nil
@@ -389,6 +389,9 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer,
389389
return fmt.Errorf("failed to exchange code for token: %w", err)
390390
}
391391

392+
token.ClientID = clientID
393+
token.ClientSecret = clientSecret
394+
392395
if err := t.tokenStore.StoreToken(t.baseURL, token); err != nil {
393396
return fmt.Errorf("failed to store token: %w", err)
394397
}

pkg/tools/mcp/oauth_helpers.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ func ExchangeCodeForToken(ctx context.Context, tokenEndpoint, code, codeVerifier
7979
token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)
8080
}
8181

82+
token.ClientID = clientID
83+
token.ClientSecret = clientSecret
84+
8285
return &token, nil
8386
}
8487

@@ -202,5 +205,9 @@ func RefreshAccessToken(ctx context.Context, tokenEndpoint, refreshToken, client
202205
token.RefreshToken = refreshToken
203206
}
204207

208+
// Preserve client credentials so subsequent refreshes work
209+
token.ClientID = clientID
210+
token.ClientSecret = clientSecret
211+
205212
return &token, nil
206213
}

pkg/tools/mcp/oauth_test.go

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
package mcp
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"testing"
9+
"time"
10+
)
11+
12+
// TestExchangeCodeForToken_PreservesClientCredentials verifies that
13+
// ExchangeCodeForToken stores the client_id and client_secret on the
14+
// returned OAuthToken so they are available for subsequent refresh calls.
15+
func TestExchangeCodeForToken_PreservesClientCredentials(t *testing.T) {
16+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17+
if err := r.ParseForm(); err != nil {
18+
t.Fatal(err)
19+
}
20+
if got := r.FormValue("client_id"); got != "my-client" {
21+
t.Errorf("client_id = %q, want %q", got, "my-client")
22+
}
23+
if got := r.FormValue("client_secret"); got != "my-secret" {
24+
t.Errorf("client_secret = %q, want %q", got, "my-secret")
25+
}
26+
27+
w.Header().Set("Content-Type", "application/json")
28+
_ = json.NewEncoder(w).Encode(map[string]any{
29+
"access_token": "at-new",
30+
"token_type": "Bearer",
31+
"expires_in": 3600,
32+
"refresh_token": "rt-new",
33+
})
34+
}))
35+
defer srv.Close()
36+
37+
token, err := ExchangeCodeForToken(t.Context(), srv.URL, "code", "verifier", "my-client", "my-secret", "http://localhost/callback")
38+
if err != nil {
39+
t.Fatalf("ExchangeCodeForToken: %v", err)
40+
}
41+
42+
if token.ClientID != "my-client" {
43+
t.Errorf("ClientID = %q, want %q", token.ClientID, "my-client")
44+
}
45+
if token.ClientSecret != "my-secret" {
46+
t.Errorf("ClientSecret = %q, want %q", token.ClientSecret, "my-secret")
47+
}
48+
}
49+
50+
// TestRefreshAccessToken_PreservesClientCredentials verifies that
51+
// RefreshAccessToken carries the client credentials through to the new token.
52+
func TestRefreshAccessToken_PreservesClientCredentials(t *testing.T) {
53+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54+
if err := r.ParseForm(); err != nil {
55+
t.Fatal(err)
56+
}
57+
if got := r.FormValue("client_id"); got != "cid" {
58+
t.Errorf("client_id = %q, want %q", got, "cid")
59+
}
60+
if got := r.FormValue("client_secret"); got != "csec" {
61+
t.Errorf("client_secret = %q, want %q", got, "csec")
62+
}
63+
64+
w.Header().Set("Content-Type", "application/json")
65+
_ = json.NewEncoder(w).Encode(map[string]any{
66+
"access_token": "at-refreshed",
67+
"token_type": "Bearer",
68+
"expires_in": 7200,
69+
// Server does NOT return a new refresh_token – old one should be preserved.
70+
})
71+
}))
72+
defer srv.Close()
73+
74+
token, err := RefreshAccessToken(t.Context(), srv.URL, "old-rt", "cid", "csec")
75+
if err != nil {
76+
t.Fatalf("RefreshAccessToken: %v", err)
77+
}
78+
79+
if token.AccessToken != "at-refreshed" {
80+
t.Errorf("AccessToken = %q, want %q", token.AccessToken, "at-refreshed")
81+
}
82+
if token.RefreshToken != "old-rt" {
83+
t.Errorf("RefreshToken = %q, want %q (should be preserved)", token.RefreshToken, "old-rt")
84+
}
85+
if token.ClientID != "cid" {
86+
t.Errorf("ClientID = %q, want %q", token.ClientID, "cid")
87+
}
88+
if token.ClientSecret != "csec" {
89+
t.Errorf("ClientSecret = %q, want %q", token.ClientSecret, "csec")
90+
}
91+
}
92+
93+
// TestGetValidToken_UsesStoredCredentialsForRefresh verifies that the
94+
// oauthTransport.getValidToken method sends the stored client credentials
95+
// when silently refreshing an expired token.
96+
func TestGetValidToken_UsesStoredCredentialsForRefresh(t *testing.T) {
97+
var receivedClientID, receivedClientSecret string
98+
99+
// Use a mux so we can reference srv.URL in closures (srv is assigned before handlers run).
100+
mux := http.NewServeMux()
101+
srv := httptest.NewServer(mux)
102+
defer srv.Close()
103+
104+
mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, _ *http.Request) {
105+
w.Header().Set("Content-Type", "application/json")
106+
_ = json.NewEncoder(w).Encode(map[string]any{
107+
"issuer": srv.URL,
108+
"token_endpoint": srv.URL + "/token",
109+
"authorization_endpoint": srv.URL + "/authorize",
110+
})
111+
})
112+
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
113+
if err := r.ParseForm(); err != nil {
114+
t.Fatal(err)
115+
}
116+
receivedClientID = r.FormValue("client_id")
117+
receivedClientSecret = r.FormValue("client_secret")
118+
119+
w.Header().Set("Content-Type", "application/json")
120+
_ = json.NewEncoder(w).Encode(map[string]any{
121+
"access_token": "fresh-at",
122+
"token_type": "Bearer",
123+
"expires_in": 3600,
124+
"refresh_token": "fresh-rt",
125+
})
126+
})
127+
128+
// Pre-populate an expired token with stored client credentials.
129+
store := NewInMemoryTokenStore()
130+
expiredToken := &OAuthToken{
131+
AccessToken: "old-at",
132+
TokenType: "Bearer",
133+
RefreshToken: "old-rt",
134+
ExpiresAt: time.Now().Add(-1 * time.Hour), // expired
135+
ClientID: "stored-cid",
136+
ClientSecret: "stored-csec",
137+
}
138+
if err := store.StoreToken(srv.URL, expiredToken); err != nil {
139+
t.Fatal(err)
140+
}
141+
142+
transport := &oauthTransport{
143+
base: http.DefaultTransport,
144+
tokenStore: store,
145+
baseURL: srv.URL,
146+
}
147+
148+
got := transport.getValidToken(t.Context())
149+
if got == nil {
150+
t.Fatal("getValidToken returned nil, expected refreshed token")
151+
}
152+
if got.AccessToken != "fresh-at" {
153+
t.Errorf("AccessToken = %q, want %q", got.AccessToken, "fresh-at")
154+
}
155+
if receivedClientID != "stored-cid" {
156+
t.Errorf("token endpoint received client_id = %q, want %q", receivedClientID, "stored-cid")
157+
}
158+
if receivedClientSecret != "stored-csec" {
159+
t.Errorf("token endpoint received client_secret = %q, want %q", receivedClientSecret, "stored-csec")
160+
}
161+
162+
// Verify the refreshed token also carries the credentials forward.
163+
updated, err := store.GetToken(srv.URL)
164+
if err != nil {
165+
t.Fatalf("GetToken after refresh: %v", err)
166+
}
167+
if updated.ClientID != "stored-cid" {
168+
t.Errorf("stored ClientID = %q, want %q", updated.ClientID, "stored-cid")
169+
}
170+
if updated.ClientSecret != "stored-csec" {
171+
t.Errorf("stored ClientSecret = %q, want %q", updated.ClientSecret, "stored-csec")
172+
}
173+
}
174+
175+
// TestOAuthTokenClientCredentials_JSONRoundTrip verifies that ClientID and
176+
// ClientSecret survive JSON serialization (important for keyring storage).
177+
func TestOAuthTokenClientCredentials_JSONRoundTrip(t *testing.T) {
178+
token := &OAuthToken{
179+
AccessToken: "at",
180+
TokenType: "Bearer",
181+
RefreshToken: "rt",
182+
ExpiresIn: 3600,
183+
ExpiresAt: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
184+
ClientID: "cid",
185+
ClientSecret: "csec",
186+
}
187+
188+
data, err := json.Marshal(token)
189+
if err != nil {
190+
t.Fatalf("Marshal: %v", err)
191+
}
192+
193+
var got OAuthToken
194+
if err := json.Unmarshal(data, &got); err != nil {
195+
t.Fatalf("Unmarshal: %v", err)
196+
}
197+
198+
if got.ClientID != "cid" {
199+
t.Errorf("ClientID = %q, want %q", got.ClientID, "cid")
200+
}
201+
if got.ClientSecret != "csec" {
202+
t.Errorf("ClientSecret = %q, want %q", got.ClientSecret, "csec")
203+
}
204+
}
205+
206+
// TestOAuthTokenClientCredentials_OmittedWhenEmpty verifies the omitempty
207+
// tag works so tokens without client credentials don't leak empty fields.
208+
func TestOAuthTokenClientCredentials_OmittedWhenEmpty(t *testing.T) {
209+
token := &OAuthToken{
210+
AccessToken: "at",
211+
TokenType: "Bearer",
212+
}
213+
214+
data, err := json.Marshal(token)
215+
if err != nil {
216+
t.Fatalf("Marshal: %v", err)
217+
}
218+
219+
var raw map[string]any
220+
if err := json.Unmarshal(data, &raw); err != nil {
221+
t.Fatal(err)
222+
}
223+
224+
if _, ok := raw["client_id"]; ok {
225+
t.Error("client_id should be omitted when empty")
226+
}
227+
if _, ok := raw["client_secret"]; ok {
228+
t.Error("client_secret should be omitted when empty")
229+
}
230+
}
231+
232+
// TestRefreshAccessToken_SendsEmptyClientIDWhenNotStored ensures that when
233+
// no client credentials were stored (legacy tokens), the refresh still
234+
// sends whatever was provided (empty string), matching the old behavior.
235+
func TestRefreshAccessToken_SendsEmptyClientIDWhenNotStored(t *testing.T) {
236+
var receivedForm url.Values
237+
238+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
239+
if err := r.ParseForm(); err != nil {
240+
t.Fatal(err)
241+
}
242+
receivedForm = r.Form
243+
244+
w.Header().Set("Content-Type", "application/json")
245+
_ = json.NewEncoder(w).Encode(map[string]any{
246+
"access_token": "new-at",
247+
"token_type": "Bearer",
248+
})
249+
}))
250+
defer srv.Close()
251+
252+
_, err := RefreshAccessToken(t.Context(), srv.URL, "rt", "", "")
253+
if err != nil {
254+
t.Fatal(err)
255+
}
256+
257+
// client_id is always sent (even empty) per the current implementation.
258+
if got := receivedForm.Get("client_id"); got != "" {
259+
t.Errorf("client_id = %q, want empty", got)
260+
}
261+
// client_secret should NOT be sent when empty.
262+
if receivedForm.Has("client_secret") {
263+
t.Error("client_secret should not be sent when empty")
264+
}
265+
}

pkg/tools/mcp/tokenstore.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type OAuthToken struct {
2424
RefreshToken string `json:"refresh_token,omitempty"`
2525
Scope string `json:"scope,omitempty"`
2626
ExpiresAt time.Time `json:"expires_at"`
27+
ClientID string `json:"client_id,omitempty"`
28+
ClientSecret string `json:"client_secret,omitempty"`
2729
}
2830

2931
// IsExpired checks if the token is expired

0 commit comments

Comments
 (0)