|
| 1 | +package mcp |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "net" |
| 7 | + "net/http" |
| 8 | + "sync/atomic" |
| 9 | + "testing" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/google/jsonschema-go/jsonschema" |
| 13 | + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" |
| 14 | + "github.com/stretchr/testify/assert" |
| 15 | + "github.com/stretchr/testify/require" |
| 16 | + |
| 17 | + "github.com/docker/docker-agent/pkg/tools" |
| 18 | +) |
| 19 | + |
| 20 | +// TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a |
| 21 | +// real remote (streamable-HTTP) MCP server transparently recovers when the |
| 22 | +// server is restarted. |
| 23 | +// |
| 24 | +// The scenario: |
| 25 | +// 1. Start a minimal MCP server with a "ping" tool. |
| 26 | +// 2. Connect a Toolset, call "ping" — succeeds. |
| 27 | +// 3. Shut down the server (simulates crash / restart). |
| 28 | +// 4. Start a **new** server on the same address. |
| 29 | +// 5. Call "ping" again — this must succeed after automatic reconnection. |
| 30 | +// |
| 31 | +// Without the ErrSessionMissing recovery logic the second call would fail |
| 32 | +// because the new server does not know the old session ID. |
| 33 | +func TestRemoteReconnectAfterServerRestart(t *testing.T) { |
| 34 | + t.Parallel() |
| 35 | + |
| 36 | + // Use a fixed listener address so we can restart on the same port. |
| 37 | + ln, err := net.Listen("tcp", "127.0.0.1:0") |
| 38 | + require.NoError(t, err) |
| 39 | + addr := ln.Addr().String() |
| 40 | + ln.Close() // We only needed the address; close so startServer can bind it. |
| 41 | + |
| 42 | + var callCount atomic.Int32 |
| 43 | + |
| 44 | + // startServer creates a minimal MCP server on addr with a "ping" tool |
| 45 | + // and returns a function to shut it down. |
| 46 | + startServer := func(t *testing.T) (shutdown func()) { |
| 47 | + t.Helper() |
| 48 | + |
| 49 | + s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil) |
| 50 | + s.AddTool(&gomcp.Tool{ |
| 51 | + Name: "ping", |
| 52 | + InputSchema: &jsonschema.Schema{Type: "object"}, |
| 53 | + }, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) { |
| 54 | + n := callCount.Add(1) |
| 55 | + return &gomcp.CallToolResult{ |
| 56 | + Content: []gomcp.Content{&gomcp.TextContent{Text: fmt.Sprintf("pong-%d", n)}}, |
| 57 | + }, nil |
| 58 | + }) |
| 59 | + |
| 60 | + // Retry Listen until the port is available (e.g. after a server shutdown). |
| 61 | + var srvLn net.Listener |
| 62 | + require.Eventually(t, func() bool { |
| 63 | + var listenErr error |
| 64 | + srvLn, listenErr = net.Listen("tcp", addr) |
| 65 | + return listenErr == nil |
| 66 | + }, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr) |
| 67 | + |
| 68 | + srv := &http.Server{ |
| 69 | + Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil), |
| 70 | + } |
| 71 | + go func() { _ = srv.Serve(srvLn) }() |
| 72 | + |
| 73 | + return func() { _ = srv.Close() } |
| 74 | + } |
| 75 | + |
| 76 | + callPing := func(t *testing.T, ts *Toolset) string { |
| 77 | + t.Helper() |
| 78 | + result, callErr := ts.callTool(t.Context(), tools.ToolCall{ |
| 79 | + Function: tools.FunctionCall{Name: "ping", Arguments: "{}"}, |
| 80 | + }) |
| 81 | + require.NoError(t, callErr) |
| 82 | + return result.Output |
| 83 | + } |
| 84 | + |
| 85 | + // --- Step 1–2: Start first server, connect toolset --- |
| 86 | + shutdown1 := startServer(t) |
| 87 | + |
| 88 | + ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) |
| 89 | + require.NoError(t, ts.Start(t.Context())) |
| 90 | + |
| 91 | + toolList, err := ts.Tools(t.Context()) |
| 92 | + require.NoError(t, err) |
| 93 | + require.Len(t, toolList, 1) |
| 94 | + assert.Equal(t, "test_ping", toolList[0].Name) |
| 95 | + |
| 96 | + // --- Step 3: Call succeeds on original server --- |
| 97 | + assert.Equal(t, "pong-1", callPing(t, ts)) |
| 98 | + |
| 99 | + // --- Step 4: Shut down the server --- |
| 100 | + shutdown1() |
| 101 | + |
| 102 | + // Capture the current restarted channel before the reconnect |
| 103 | + ts.mu.Lock() |
| 104 | + restartedCh := ts.restarted |
| 105 | + ts.mu.Unlock() |
| 106 | + |
| 107 | + // --- Step 5–6: Start a fresh server, call again --- |
| 108 | + shutdown2 := startServer(t) |
| 109 | + t.Cleanup(func() { |
| 110 | + _ = ts.Stop(t.Context()) |
| 111 | + shutdown2() |
| 112 | + }) |
| 113 | + |
| 114 | + // This call triggers ErrSessionMissing recovery and must succeed transparently. |
| 115 | + assert.Equal(t, "pong-2", callPing(t, ts)) |
| 116 | + |
| 117 | + // Verify that watchConnection actually restarted the connection by checking |
| 118 | + // that the restarted channel was closed (signaling reconnect completion). |
| 119 | + select { |
| 120 | + case <-restartedCh: |
| 121 | + // Success: the channel was closed, meaning reconnect happened |
| 122 | + case <-time.After(100 * time.Millisecond): |
| 123 | + t.Fatal("reconnect did not complete: restarted channel was not closed") |
| 124 | + } |
| 125 | +} |
0 commit comments