diff --git a/pkg/http/handler.go b/pkg/http/handler.go index d55d7c53d..3b182f46b 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -3,11 +3,13 @@ package http import ( "context" "errors" + "fmt" "log/slog" "net/http" ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" @@ -226,7 +228,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server { return ghServer }, &mcp.StreamableHTTPOptions{ - Stateless: true, + Stateless: true, + CrossOriginProtection: h.config.CrossOriginProtection, }) mcpHandler.ServeHTTP(w, r) @@ -412,3 +415,31 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche return b } + +// SetCorsHeaders is middleware that sets CORS headers to allow browser-based +// MCP clients to connect from any origin. This is safe because the server +// authenticates via bearer tokens (not cookies), so cross-origin requests +// cannot exploit ambient credentials. +func SetCorsHeaders(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Max-Age", "86400") + w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id") + w.Header().Set("Access-Control-Allow-Headers", fmt.Sprintf( + "Content-Type, Authorization, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, %s, %s, %s, %s, %s, %s", + headers.MCPReadOnlyHeader, + headers.MCPToolsetsHeader, + headers.MCPToolsHeader, + headers.MCPExcludeToolsHeader, + headers.MCPFeaturesHeader, + headers.AuthorizationHeader, + )) + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + h.ServeHTTP(w, r) + }) +} diff --git a/pkg/http/server.go b/pkg/http/server.go index d1e8192ba..703115e38 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -86,6 +86,10 @@ type ServerConfig struct { // InsidersMode indicates if we should enable experimental features. InsidersMode bool + + // CrossOriginProtection configures the SDK's cross-origin request protection. + // If nil, the SDK default (reject cross-origin POSTs) is used. + CrossOriginProtection *http.CrossOriginProtection } func RunHTTPServer(cfg ServerConfig) error { @@ -159,6 +163,14 @@ func RunHTTPServer(cfg ServerConfig) error { serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher)) } + // Bypass cross-origin protection: this server uses bearer tokens, not + // cookies, so CSRF checks are unnecessary. + if cfg.CrossOriginProtection == nil { + p := http.NewCrossOriginProtection() + p.AddInsecureBypassPattern("/") + cfg.CrossOriginProtection = p + } + r := chi.NewRouter() handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost) @@ -167,6 +179,8 @@ func RunHTTPServer(cfg ServerConfig) error { } r.Group(func(r chi.Router) { + r.Use(SetCorsHeaders) + // Register Middleware First, needs to be before route registration handler.RegisterMiddleware(r)