Skip to content

Commit e98016c

Browse files
committed
Add CORS headers and cross-origin protection tests
1 parent 6b28af5 commit e98016c

1 file changed

Lines changed: 158 additions & 0 deletions

File tree

pkg/http/handler_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http/httptest"
88
"slices"
99
"sort"
10+
"strings"
1011
"testing"
1112

1213
ghcontext "github.com/github/github-mcp-server/pkg/context"
@@ -660,3 +661,160 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo
660661
ctx := context.Background()
661662
return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx)
662663
}
664+
665+
func TestSetCorsHeaders(t *testing.T) {
666+
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
667+
w.WriteHeader(http.StatusOK)
668+
})
669+
handler := SetCorsHeaders(inner)
670+
671+
t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
672+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
673+
req.Header.Set("Origin", "http://localhost:6274")
674+
rr := httptest.NewRecorder()
675+
handler.ServeHTTP(rr, req)
676+
677+
assert.Equal(t, http.StatusOK, rr.Code)
678+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
679+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
680+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
681+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
682+
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
683+
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
684+
})
685+
686+
t.Run("POST request includes CORS headers", func(t *testing.T) {
687+
req := httptest.NewRequest(http.MethodPost, "/", nil)
688+
req.Header.Set("Origin", "http://localhost:6274")
689+
rr := httptest.NewRecorder()
690+
handler.ServeHTTP(rr, req)
691+
692+
assert.Equal(t, http.StatusOK, rr.Code)
693+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
694+
})
695+
}
696+
697+
func TestCrossOriginProtection(t *testing.T) {
698+
jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}`
699+
700+
newHandler := func(t *testing.T, crossOriginProtection *http.CrossOriginProtection) http.Handler {
701+
t.Helper()
702+
703+
apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com")
704+
require.NoError(t, err)
705+
706+
handler := NewHTTPMcpHandler(
707+
context.Background(),
708+
&ServerConfig{
709+
Version: "test",
710+
CrossOriginProtection: crossOriginProtection,
711+
},
712+
nil,
713+
translations.NullTranslationHelper,
714+
slog.Default(),
715+
apiHost,
716+
WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) {
717+
return inventory.NewBuilder().Build()
718+
}),
719+
WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) {
720+
return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil
721+
}),
722+
WithScopeFetcher(allScopesFetcher{}),
723+
)
724+
725+
r := chi.NewRouter()
726+
handler.RegisterMiddleware(r)
727+
handler.RegisterRoutes(r)
728+
return r
729+
}
730+
731+
tests := []struct {
732+
name string
733+
crossOriginProtection *http.CrossOriginProtection
734+
secFetchSite string
735+
origin string
736+
expectedStatusCode int
737+
}{
738+
{
739+
name: "SDK default rejects cross-site when no bypass configured",
740+
secFetchSite: "cross-site",
741+
origin: "https://evil.example.com",
742+
expectedStatusCode: http.StatusForbidden,
743+
},
744+
{
745+
name: "SDK default allows same-origin request",
746+
secFetchSite: "same-origin",
747+
expectedStatusCode: http.StatusOK,
748+
},
749+
{
750+
name: "SDK default allows request without Sec-Fetch-Site (native client)",
751+
secFetchSite: "",
752+
expectedStatusCode: http.StatusOK,
753+
},
754+
{
755+
name: "bypass protection allows cross-site request",
756+
crossOriginProtection: func() *http.CrossOriginProtection {
757+
p := http.NewCrossOriginProtection()
758+
p.AddInsecureBypassPattern("/")
759+
return p
760+
}(),
761+
secFetchSite: "cross-site",
762+
origin: "https://example.com",
763+
expectedStatusCode: http.StatusOK,
764+
},
765+
{
766+
name: "bypass protection still allows same-origin request",
767+
crossOriginProtection: func() *http.CrossOriginProtection {
768+
p := http.NewCrossOriginProtection()
769+
p.AddInsecureBypassPattern("/")
770+
return p
771+
}(),
772+
secFetchSite: "same-origin",
773+
expectedStatusCode: http.StatusOK,
774+
},
775+
{
776+
name: "bypass protection allows request without Sec-Fetch-Site (native client)",
777+
crossOriginProtection: func() *http.CrossOriginProtection {
778+
p := http.NewCrossOriginProtection()
779+
p.AddInsecureBypassPattern("/")
780+
return p
781+
}(),
782+
secFetchSite: "",
783+
expectedStatusCode: http.StatusOK,
784+
},
785+
{
786+
// Mirrors RunHTTPServer's auto-bypass: nil config → create bypass.
787+
name: "server default allows cross-site request (nil triggers auto-bypass)",
788+
crossOriginProtection: func() *http.CrossOriginProtection {
789+
p := http.NewCrossOriginProtection()
790+
p.AddInsecureBypassPattern("/")
791+
return p
792+
}(),
793+
secFetchSite: "cross-site",
794+
origin: "https://example.com",
795+
expectedStatusCode: http.StatusOK,
796+
},
797+
}
798+
799+
for _, tt := range tests {
800+
t.Run(tt.name, func(t *testing.T) {
801+
h := newHandler(t, tt.crossOriginProtection)
802+
803+
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody))
804+
req.Header.Set("Content-Type", "application/json")
805+
req.Header.Set("Accept", "application/json, text/event-stream")
806+
req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz")
807+
if tt.secFetchSite != "" {
808+
req.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
809+
}
810+
if tt.origin != "" {
811+
req.Header.Set("Origin", tt.origin)
812+
}
813+
814+
rr := httptest.NewRecorder()
815+
h.ServeHTTP(rr, req)
816+
817+
assert.Equal(t, tt.expectedStatusCode, rr.Code, "unexpected status code; body: %s", rr.Body.String())
818+
})
819+
}
820+
}

0 commit comments

Comments
 (0)