Skip to content

Commit 659ec86

Browse files
authored
Merge pull request #2434 from dgageot/fix7
fix: reject OAuth callback when expected state has not been set (CSRF)
2 parents fb1a459 + 9c3f798 commit 659ec86

2 files changed

Lines changed: 28 additions & 2 deletions

File tree

pkg/tools/mcp/oauth_server.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,14 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request)
116116
return
117117
}
118118

119-
// Verify state parameter for CSRF protection
119+
// Verify state parameter for CSRF protection.
120+
// Reject the callback if the expected state has not been set yet
121+
// (i.e. the callback arrived before SetExpectedState was called).
120122
cs.mu.Lock()
121123
expectedState := cs.expectedState
122124
cs.mu.Unlock()
123125

124-
if expectedState != "" && state != expectedState {
126+
if expectedState == "" || state != expectedState {
125127
cs.errCh <- fmt.Errorf("state mismatch: expected %s, got %s", expectedState, state)
126128
w.WriteHeader(http.StatusBadRequest)
127129
fmt.Fprint(w, "Invalid state parameter")

pkg/tools/mcp/oauth_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,27 @@ func TestRefreshAccessToken_SendsEmptyClientIDWhenNotStored(t *testing.T) {
263263
t.Error("client_secret should not be sent when empty")
264264
}
265265
}
266+
267+
// TestCallbackServer_RejectsCallbackBeforeStateSet verifies that a callback
268+
// arriving before SetExpectedState is called is rejected (CSRF protection).
269+
func TestCallbackServer_RejectsCallbackBeforeStateSet(t *testing.T) {
270+
cs, err := NewCallbackServer()
271+
if err != nil {
272+
t.Fatal(err)
273+
}
274+
if err := cs.Start(); err != nil {
275+
t.Fatal(err)
276+
}
277+
defer func() { _ = cs.Shutdown(t.Context()) }()
278+
279+
// Send a callback before SetExpectedState has been called.
280+
resp, err := http.Get(cs.GetRedirectURI() + "?code=authcode&state=anything")
281+
if err != nil {
282+
t.Fatal(err)
283+
}
284+
defer resp.Body.Close()
285+
286+
if resp.StatusCode != http.StatusBadRequest {
287+
t.Errorf("expected 400, got %d — callback accepted without expected state set", resp.StatusCode)
288+
}
289+
}

0 commit comments

Comments
 (0)