|
7 | 7 | "net/http/httptest" |
8 | 8 | "slices" |
9 | 9 | "sort" |
| 10 | + "strings" |
10 | 11 | "testing" |
11 | 12 |
|
12 | 13 | ghcontext "github.com/github/github-mcp-server/pkg/context" |
@@ -660,3 +661,160 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo |
660 | 661 | ctx := context.Background() |
661 | 662 | return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx) |
662 | 663 | } |
| 664 | + |
| 665 | +func TestSetCorsHeaders(t *testing.T) { |
| 666 | + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { |
| 667 | + w.WriteHeader(http.StatusOK) |
| 668 | + }) |
| 669 | + handler := SetCorsHeaders(inner) |
| 670 | + |
| 671 | + t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) { |
| 672 | + req := httptest.NewRequest(http.MethodOptions, "/", nil) |
| 673 | + req.Header.Set("Origin", "http://localhost:6274") |
| 674 | + rr := httptest.NewRecorder() |
| 675 | + handler.ServeHTTP(rr, req) |
| 676 | + |
| 677 | + assert.Equal(t, http.StatusOK, rr.Code) |
| 678 | + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) |
| 679 | + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST") |
| 680 | + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization") |
| 681 | + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type") |
| 682 | + assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id") |
| 683 | + assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id") |
| 684 | + }) |
| 685 | + |
| 686 | + t.Run("POST request includes CORS headers", func(t *testing.T) { |
| 687 | + req := httptest.NewRequest(http.MethodPost, "/", nil) |
| 688 | + req.Header.Set("Origin", "http://localhost:6274") |
| 689 | + rr := httptest.NewRecorder() |
| 690 | + handler.ServeHTTP(rr, req) |
| 691 | + |
| 692 | + assert.Equal(t, http.StatusOK, rr.Code) |
| 693 | + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) |
| 694 | + }) |
| 695 | +} |
| 696 | + |
| 697 | +func TestCrossOriginProtection(t *testing.T) { |
| 698 | + jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}` |
| 699 | + |
| 700 | + newHandler := func(t *testing.T, crossOriginProtection *http.CrossOriginProtection) http.Handler { |
| 701 | + t.Helper() |
| 702 | + |
| 703 | + apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com") |
| 704 | + require.NoError(t, err) |
| 705 | + |
| 706 | + handler := NewHTTPMcpHandler( |
| 707 | + context.Background(), |
| 708 | + &ServerConfig{ |
| 709 | + Version: "test", |
| 710 | + CrossOriginProtection: crossOriginProtection, |
| 711 | + }, |
| 712 | + nil, |
| 713 | + translations.NullTranslationHelper, |
| 714 | + slog.Default(), |
| 715 | + apiHost, |
| 716 | + WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) { |
| 717 | + return inventory.NewBuilder().Build() |
| 718 | + }), |
| 719 | + WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) { |
| 720 | + return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil |
| 721 | + }), |
| 722 | + WithScopeFetcher(allScopesFetcher{}), |
| 723 | + ) |
| 724 | + |
| 725 | + r := chi.NewRouter() |
| 726 | + handler.RegisterMiddleware(r) |
| 727 | + handler.RegisterRoutes(r) |
| 728 | + return r |
| 729 | + } |
| 730 | + |
| 731 | + tests := []struct { |
| 732 | + name string |
| 733 | + crossOriginProtection *http.CrossOriginProtection |
| 734 | + secFetchSite string |
| 735 | + origin string |
| 736 | + expectedStatusCode int |
| 737 | + }{ |
| 738 | + { |
| 739 | + name: "SDK default rejects cross-site when no bypass configured", |
| 740 | + secFetchSite: "cross-site", |
| 741 | + origin: "https://evil.example.com", |
| 742 | + expectedStatusCode: http.StatusForbidden, |
| 743 | + }, |
| 744 | + { |
| 745 | + name: "SDK default allows same-origin request", |
| 746 | + secFetchSite: "same-origin", |
| 747 | + expectedStatusCode: http.StatusOK, |
| 748 | + }, |
| 749 | + { |
| 750 | + name: "SDK default allows request without Sec-Fetch-Site (native client)", |
| 751 | + secFetchSite: "", |
| 752 | + expectedStatusCode: http.StatusOK, |
| 753 | + }, |
| 754 | + { |
| 755 | + name: "bypass protection allows cross-site request", |
| 756 | + crossOriginProtection: func() *http.CrossOriginProtection { |
| 757 | + p := http.NewCrossOriginProtection() |
| 758 | + p.AddInsecureBypassPattern("/") |
| 759 | + return p |
| 760 | + }(), |
| 761 | + secFetchSite: "cross-site", |
| 762 | + origin: "https://example.com", |
| 763 | + expectedStatusCode: http.StatusOK, |
| 764 | + }, |
| 765 | + { |
| 766 | + name: "bypass protection still allows same-origin request", |
| 767 | + crossOriginProtection: func() *http.CrossOriginProtection { |
| 768 | + p := http.NewCrossOriginProtection() |
| 769 | + p.AddInsecureBypassPattern("/") |
| 770 | + return p |
| 771 | + }(), |
| 772 | + secFetchSite: "same-origin", |
| 773 | + expectedStatusCode: http.StatusOK, |
| 774 | + }, |
| 775 | + { |
| 776 | + name: "bypass protection allows request without Sec-Fetch-Site (native client)", |
| 777 | + crossOriginProtection: func() *http.CrossOriginProtection { |
| 778 | + p := http.NewCrossOriginProtection() |
| 779 | + p.AddInsecureBypassPattern("/") |
| 780 | + return p |
| 781 | + }(), |
| 782 | + secFetchSite: "", |
| 783 | + expectedStatusCode: http.StatusOK, |
| 784 | + }, |
| 785 | + { |
| 786 | + // Mirrors RunHTTPServer's auto-bypass: nil config → create bypass. |
| 787 | + name: "server default allows cross-site request (nil triggers auto-bypass)", |
| 788 | + crossOriginProtection: func() *http.CrossOriginProtection { |
| 789 | + p := http.NewCrossOriginProtection() |
| 790 | + p.AddInsecureBypassPattern("/") |
| 791 | + return p |
| 792 | + }(), |
| 793 | + secFetchSite: "cross-site", |
| 794 | + origin: "https://example.com", |
| 795 | + expectedStatusCode: http.StatusOK, |
| 796 | + }, |
| 797 | + } |
| 798 | + |
| 799 | + for _, tt := range tests { |
| 800 | + t.Run(tt.name, func(t *testing.T) { |
| 801 | + h := newHandler(t, tt.crossOriginProtection) |
| 802 | + |
| 803 | + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody)) |
| 804 | + req.Header.Set("Content-Type", "application/json") |
| 805 | + req.Header.Set("Accept", "application/json, text/event-stream") |
| 806 | + req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz") |
| 807 | + if tt.secFetchSite != "" { |
| 808 | + req.Header.Set("Sec-Fetch-Site", tt.secFetchSite) |
| 809 | + } |
| 810 | + if tt.origin != "" { |
| 811 | + req.Header.Set("Origin", tt.origin) |
| 812 | + } |
| 813 | + |
| 814 | + rr := httptest.NewRecorder() |
| 815 | + h.ServeHTTP(rr, req) |
| 816 | + |
| 817 | + assert.Equal(t, tt.expectedStatusCode, rr.Code, "unexpected status code; body: %s", rr.Body.String()) |
| 818 | + }) |
| 819 | + } |
| 820 | +} |
0 commit comments