Skip to content

Commit c850a88

Browse files
committed
Refactor logic into separated CORS middleware and update related tests
1 parent 1284cb4 commit c850a88

File tree

5 files changed

+90
-74
lines changed

5 files changed

+90
-74
lines changed

pkg/http/handler.go

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@ package http
33
import (
44
"context"
55
"errors"
6-
"strings"
76
"log/slog"
87
"net/http"
98

109
ghcontext "github.com/github/github-mcp-server/pkg/context"
1110
"github.com/github/github-mcp-server/pkg/github"
12-
"github.com/github/github-mcp-server/pkg/http/headers"
1311
"github.com/github/github-mcp-server/pkg/http/middleware"
1412
"github.com/github/github-mcp-server/pkg/http/oauth"
1513
"github.com/github/github-mcp-server/pkg/inventory"
@@ -415,39 +413,3 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
415413

416414
return b
417415
}
418-
419-
// corsAllowHeaders is the precomputed Access-Control-Allow-Headers value.
420-
var corsAllowHeaders = strings.Join([]string{
421-
"Content-Type",
422-
"Mcp-Session-Id",
423-
"Mcp-Protocol-Version",
424-
"Last-Event-ID",
425-
headers.AuthorizationHeader,
426-
headers.MCPReadOnlyHeader,
427-
headers.MCPToolsetsHeader,
428-
headers.MCPToolsHeader,
429-
headers.MCPExcludeToolsHeader,
430-
headers.MCPFeaturesHeader,
431-
headers.MCPLockdownHeader,
432-
headers.MCPInsidersHeader,
433-
}, ", ")
434-
435-
// SetCorsHeaders is middleware that sets CORS headers to allow browser-based
436-
// MCP clients to connect from any origin. This is safe because the server
437-
// authenticates via bearer tokens (not cookies), so cross-origin requests
438-
// cannot exploit ambient credentials.
439-
func SetCorsHeaders(h http.Handler) http.Handler {
440-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
441-
w.Header().Set("Access-Control-Allow-Origin", "*")
442-
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
443-
w.Header().Set("Access-Control-Max-Age", "86400")
444-
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate")
445-
w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders)
446-
447-
if r.Method == http.MethodOptions {
448-
w.WriteHeader(http.StatusOK)
449-
return
450-
}
451-
h.ServeHTTP(w, r)
452-
})
453-
}

pkg/http/handler_test.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -662,41 +662,6 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo
662662
return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx)
663663
}
664664

665-
func TestSetCorsHeaders(t *testing.T) {
666-
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
667-
w.WriteHeader(http.StatusOK)
668-
})
669-
handler := SetCorsHeaders(inner)
670-
671-
t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
672-
req := httptest.NewRequest(http.MethodOptions, "/", nil)
673-
req.Header.Set("Origin", "http://localhost:6274")
674-
rr := httptest.NewRecorder()
675-
handler.ServeHTTP(rr, req)
676-
677-
assert.Equal(t, http.StatusOK, rr.Code)
678-
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
679-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
680-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
681-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
682-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
683-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown")
684-
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders")
685-
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
686-
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate")
687-
})
688-
689-
t.Run("POST request includes CORS headers", func(t *testing.T) {
690-
req := httptest.NewRequest(http.MethodPost, "/", nil)
691-
req.Header.Set("Origin", "http://localhost:6274")
692-
rr := httptest.NewRecorder()
693-
handler.ServeHTTP(rr, req)
694-
695-
assert.Equal(t, http.StatusOK, rr.Code)
696-
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
697-
})
698-
}
699-
700665
func TestCrossOriginProtection(t *testing.T) {
701666
jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}`
702667

pkg/http/middleware/cors.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
7+
"github.com/github/github-mcp-server/pkg/http/headers"
8+
)
9+
10+
// SetCorsHeaders is middleware that sets CORS headers to allow browser-based
11+
// MCP clients to connect from any origin. This is safe because the server
12+
// authenticates via bearer tokens (not cookies), so cross-origin requests
13+
// cannot exploit ambient credentials.
14+
func SetCorsHeaders(h http.Handler) http.Handler {
15+
allowHeaders := strings.Join([]string{
16+
"Content-Type",
17+
"Mcp-Session-Id",
18+
"Mcp-Protocol-Version",
19+
"Last-Event-ID",
20+
headers.AuthorizationHeader,
21+
headers.MCPReadOnlyHeader,
22+
headers.MCPToolsetsHeader,
23+
headers.MCPToolsHeader,
24+
headers.MCPExcludeToolsHeader,
25+
headers.MCPFeaturesHeader,
26+
headers.MCPLockdownHeader,
27+
headers.MCPInsidersHeader,
28+
}, ", ")
29+
30+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31+
w.Header().Set("Access-Control-Allow-Origin", "*")
32+
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
33+
w.Header().Set("Access-Control-Max-Age", "86400")
34+
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate")
35+
w.Header().Set("Access-Control-Allow-Headers", allowHeaders)
36+
37+
if r.Method == http.MethodOptions {
38+
w.WriteHeader(http.StatusOK)
39+
return
40+
}
41+
h.ServeHTTP(w, r)
42+
})
43+
}

pkg/http/middleware/cors_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package middleware_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/github/github-mcp-server/pkg/http/middleware"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestSetCorsHeaders(t *testing.T) {
13+
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
14+
w.WriteHeader(http.StatusOK)
15+
})
16+
handler := middleware.SetCorsHeaders(inner)
17+
18+
t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
19+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
20+
req.Header.Set("Origin", "http://localhost:6274")
21+
rr := httptest.NewRecorder()
22+
handler.ServeHTTP(rr, req)
23+
24+
assert.Equal(t, http.StatusOK, rr.Code)
25+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
26+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
27+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
28+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
29+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
30+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown")
31+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders")
32+
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
33+
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate")
34+
})
35+
36+
t.Run("POST request includes CORS headers", func(t *testing.T) {
37+
req := httptest.NewRequest(http.MethodPost, "/", nil)
38+
req.Header.Set("Origin", "http://localhost:6274")
39+
rr := httptest.NewRecorder()
40+
handler.ServeHTTP(rr, req)
41+
42+
assert.Equal(t, http.StatusOK, rr.Code)
43+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
44+
})
45+
}

pkg/http/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
ghcontext "github.com/github/github-mcp-server/pkg/context"
1515
"github.com/github/github-mcp-server/pkg/github"
16+
"github.com/github/github-mcp-server/pkg/http/middleware"
1617
"github.com/github/github-mcp-server/pkg/http/oauth"
1718
"github.com/github/github-mcp-server/pkg/inventory"
1819
"github.com/github/github-mcp-server/pkg/lockdown"
@@ -180,7 +181,7 @@ func RunHTTPServer(cfg ServerConfig) error {
180181
}
181182

182183
r.Group(func(r chi.Router) {
183-
r.Use(SetCorsHeaders)
184+
r.Use(middleware.SetCorsHeaders)
184185

185186
// Register Middleware First, needs to be before route registration
186187
handler.RegisterMiddleware(r)

0 commit comments

Comments
 (0)