Skip to content

Commit 0c2bf5d

Browse files
authored
Merge pull request #2212 from dgageot/mcp-reconnect
fix: recover from ErrSessionMissing when remote MCP server restarts
2 parents e2610a1 + 02c372e commit 0c2bf5d

3 files changed

Lines changed: 263 additions & 0 deletions

File tree

pkg/tools/mcp/mcp.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ type Toolset struct {
5858
// toolsChangedHandler is called after the tool cache is refreshed
5959
// following a ToolListChanged notification from the server.
6060
toolsChangedHandler func()
61+
62+
// restarted is closed and replaced whenever the connection is
63+
// successfully restarted by watchConnection, allowing callers
64+
// waiting on a reconnect to be unblocked.
65+
restarted chan struct{}
6166
}
6267

6368
// invalidateCache clears the cached tools and prompts and bumps the
@@ -68,6 +73,10 @@ func (ts *Toolset) invalidateCache() {
6873
ts.cacheGen++
6974
}
7075

76+
// sessionMissingRetryTimeout is the maximum time to wait for watchConnection
77+
// to restart the MCP server after an ErrSessionMissing error.
78+
const sessionMissingRetryTimeout = 35 * time.Second
79+
7180
var (
7281
_ tools.ToolSet = (*Toolset)(nil)
7382
_ tools.Describer = (*Toolset)(nil)
@@ -145,6 +154,8 @@ func (ts *Toolset) Start(ctx context.Context) error {
145154
return nil
146155
}
147156

157+
ts.restarted = make(chan struct{})
158+
148159
if err := ts.doStart(ctx); err != nil {
149160
if errors.Is(err, errServerUnavailable) {
150161
// The server is unreachable but the error is non-fatal.
@@ -307,6 +318,9 @@ func (ts *Toolset) tryRestart(ctx context.Context) bool {
307318
}
308319

309320
ts.started = true
321+
// Signal anyone waiting for a reconnect.
322+
close(ts.restarted)
323+
ts.restarted = make(chan struct{})
310324
ts.mu.Unlock()
311325

312326
slog.Info("MCP server restarted successfully", "server", ts.logID)
@@ -438,6 +452,16 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
438452
request.Arguments = args
439453

440454
resp, err := ts.mcpClient.CallTool(ctx, request)
455+
456+
// If the server lost our session (e.g. it restarted), force a
457+
// reconnection and retry the call once.
458+
if errors.Is(err, mcp.ErrSessionMissing) {
459+
slog.Warn("MCP session missing, forcing reconnect and retrying", "tool", toolCall.Function.Name, "server", ts.logID)
460+
if waitErr := ts.forceReconnectAndWait(ctx); waitErr != nil {
461+
return nil, fmt.Errorf("failed to reconnect after session loss: %w", waitErr)
462+
}
463+
resp, err = ts.mcpClient.CallTool(ctx, request)
464+
}
441465
if err != nil {
442466
if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
443467
slog.Debug("CallTool canceled by context", "tool", toolCall.Function.Name)
@@ -453,6 +477,33 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
453477
return result, nil
454478
}
455479

480+
// forceReconnectAndWait closes the current session to trigger watchConnection's
481+
// restart logic, then waits for the reconnection to complete.
482+
func (ts *Toolset) forceReconnectAndWait(ctx context.Context) error {
483+
ts.mu.Lock()
484+
restartCh := ts.restarted
485+
alreadyRestarting := !ts.started
486+
ts.mu.Unlock()
487+
488+
if !alreadyRestarting {
489+
// Force-close the session so that Wait() returns and watchConnection
490+
// kicks in with its restart loop. Skip this if watchConnection has
491+
// already detected the disconnect (started==false) to avoid killing
492+
// a connection that tryRestart may be establishing concurrently.
493+
_ = ts.mcpClient.Close(context.WithoutCancel(ctx))
494+
}
495+
496+
// Wait for watchConnection to complete a successful restart.
497+
select {
498+
case <-restartCh:
499+
return nil
500+
case <-ctx.Done():
501+
return ctx.Err()
502+
case <-time.After(sessionMissingRetryTimeout):
503+
return errors.New("timed out waiting for MCP server reconnection")
504+
}
505+
}
506+
456507
func (ts *Toolset) Stop(ctx context.Context) error {
457508
slog.Debug("Stopping MCP toolset", "server", ts.logID)
458509

pkg/tools/mcp/mcp_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package mcp
22

33
import (
44
"context"
5+
"fmt"
56
"iter"
7+
"sync"
8+
"sync/atomic"
69
"testing"
710

811
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -51,6 +54,47 @@ func (m *mockMCPClient) Wait() error { return nil }
5154

5255
func (m *mockMCPClient) Close(context.Context) error { return nil }
5356

57+
// reconnectableMockClient extends mockMCPClient with reconnect simulation.
58+
type reconnectableMockClient struct {
59+
mockMCPClient
60+
61+
mu sync.Mutex
62+
waitCh chan struct{} // closed when Close is called, unblocking Wait
63+
}
64+
65+
func newReconnectableMock() *reconnectableMockClient {
66+
return &reconnectableMockClient{
67+
waitCh: make(chan struct{}),
68+
}
69+
}
70+
71+
func (m *reconnectableMockClient) Initialize(context.Context, *mcp.InitializeRequest) (*mcp.InitializeResult, error) {
72+
m.mu.Lock()
73+
m.waitCh = make(chan struct{}) // fresh channel for each session
74+
m.mu.Unlock()
75+
return &mcp.InitializeResult{}, nil
76+
}
77+
78+
func (m *reconnectableMockClient) Wait() error {
79+
m.mu.Lock()
80+
ch := m.waitCh
81+
m.mu.Unlock()
82+
<-ch
83+
return nil
84+
}
85+
86+
func (m *reconnectableMockClient) Close(context.Context) error {
87+
m.mu.Lock()
88+
// Close the wait channel to unblock Wait().
89+
select {
90+
case <-m.waitCh:
91+
default:
92+
close(m.waitCh)
93+
}
94+
m.mu.Unlock()
95+
return nil
96+
}
97+
5498
func TestCallToolStripsNullArguments(t *testing.T) {
5599
t.Parallel()
56100

@@ -251,3 +295,46 @@ func TestProcessMCPContent(t *testing.T) {
251295
func callToolResult(content ...mcp.Content) *mcp.CallToolResult {
252296
return &mcp.CallToolResult{Content: content}
253297
}
298+
299+
func TestCallToolRecoversFromErrSessionMissing(t *testing.T) {
300+
t.Parallel()
301+
302+
var callCount atomic.Int32
303+
304+
mock := newReconnectableMock()
305+
mock.callToolFn = func(_ context.Context, _ *mcp.CallToolParams) (*mcp.CallToolResult, error) {
306+
n := callCount.Add(1)
307+
if n == 1 {
308+
// First call: simulate server restart by returning ErrSessionMissing.
309+
return nil, fmt.Errorf("tools/call: %w", mcp.ErrSessionMissing)
310+
}
311+
// Second call (after reconnect): succeed.
312+
return &mcp.CallToolResult{
313+
Content: []mcp.Content{&mcp.TextContent{Text: "recovered"}},
314+
}, nil
315+
}
316+
317+
ts := &Toolset{
318+
started: true,
319+
mcpClient: mock,
320+
logID: "test-server",
321+
restarted: make(chan struct{}),
322+
}
323+
324+
// Start the watchConnection goroutine as Start() would.
325+
go ts.watchConnection(t.Context())
326+
327+
result, err := ts.callTool(t.Context(), tools.ToolCall{
328+
Function: tools.FunctionCall{
329+
Name: "test_tool",
330+
Arguments: `{"key": "value"}`,
331+
},
332+
})
333+
334+
require.NoError(t, err)
335+
assert.Equal(t, "recovered", result.Output)
336+
assert.Equal(t, int32(2), callCount.Load(), "expected exactly 2 CallTool invocations (1 failed + 1 retry)")
337+
338+
// Clean up: stop the watcher.
339+
_ = ts.Stop(t.Context())
340+
}

pkg/tools/mcp/reconnect_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"sync/atomic"
9+
"testing"
10+
"time"
11+
12+
"github.com/google/jsonschema-go/jsonschema"
13+
gomcp "github.com/modelcontextprotocol/go-sdk/mcp"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
17+
"github.com/docker/docker-agent/pkg/tools"
18+
)
19+
20+
// TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a
21+
// real remote (streamable-HTTP) MCP server transparently recovers when the
22+
// server is restarted.
23+
//
24+
// The scenario:
25+
// 1. Start a minimal MCP server with a "ping" tool.
26+
// 2. Connect a Toolset, call "ping" — succeeds.
27+
// 3. Shut down the server (simulates crash / restart).
28+
// 4. Start a **new** server on the same address.
29+
// 5. Call "ping" again — this must succeed after automatic reconnection.
30+
//
31+
// Without the ErrSessionMissing recovery logic the second call would fail
32+
// because the new server does not know the old session ID.
33+
func TestRemoteReconnectAfterServerRestart(t *testing.T) {
34+
t.Parallel()
35+
36+
// Use a fixed listener address so we can restart on the same port.
37+
ln, err := net.Listen("tcp", "127.0.0.1:0")
38+
require.NoError(t, err)
39+
addr := ln.Addr().String()
40+
ln.Close() // We only needed the address; close so startServer can bind it.
41+
42+
var callCount atomic.Int32
43+
44+
// startServer creates a minimal MCP server on addr with a "ping" tool
45+
// and returns a function to shut it down.
46+
startServer := func(t *testing.T) (shutdown func()) {
47+
t.Helper()
48+
49+
s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil)
50+
s.AddTool(&gomcp.Tool{
51+
Name: "ping",
52+
InputSchema: &jsonschema.Schema{Type: "object"},
53+
}, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) {
54+
n := callCount.Add(1)
55+
return &gomcp.CallToolResult{
56+
Content: []gomcp.Content{&gomcp.TextContent{Text: fmt.Sprintf("pong-%d", n)}},
57+
}, nil
58+
})
59+
60+
// Retry Listen until the port is available (e.g. after a server shutdown).
61+
var srvLn net.Listener
62+
require.Eventually(t, func() bool {
63+
var listenErr error
64+
srvLn, listenErr = net.Listen("tcp", addr)
65+
return listenErr == nil
66+
}, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr)
67+
68+
srv := &http.Server{
69+
Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil),
70+
}
71+
go func() { _ = srv.Serve(srvLn) }()
72+
73+
return func() { _ = srv.Close() }
74+
}
75+
76+
callPing := func(t *testing.T, ts *Toolset) string {
77+
t.Helper()
78+
result, callErr := ts.callTool(t.Context(), tools.ToolCall{
79+
Function: tools.FunctionCall{Name: "ping", Arguments: "{}"},
80+
})
81+
require.NoError(t, callErr)
82+
return result.Output
83+
}
84+
85+
// --- Step 1–2: Start first server, connect toolset ---
86+
shutdown1 := startServer(t)
87+
88+
ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil)
89+
require.NoError(t, ts.Start(t.Context()))
90+
91+
toolList, err := ts.Tools(t.Context())
92+
require.NoError(t, err)
93+
require.Len(t, toolList, 1)
94+
assert.Equal(t, "test_ping", toolList[0].Name)
95+
96+
// --- Step 3: Call succeeds on original server ---
97+
assert.Equal(t, "pong-1", callPing(t, ts))
98+
99+
// --- Step 4: Shut down the server ---
100+
shutdown1()
101+
102+
// Capture the current restarted channel before the reconnect
103+
ts.mu.Lock()
104+
restartedCh := ts.restarted
105+
ts.mu.Unlock()
106+
107+
// --- Step 5–6: Start a fresh server, call again ---
108+
shutdown2 := startServer(t)
109+
t.Cleanup(func() {
110+
_ = ts.Stop(t.Context())
111+
shutdown2()
112+
})
113+
114+
// This call triggers ErrSessionMissing recovery and must succeed transparently.
115+
assert.Equal(t, "pong-2", callPing(t, ts))
116+
117+
// Verify that watchConnection actually restarted the connection by checking
118+
// that the restarted channel was closed (signaling reconnect completion).
119+
select {
120+
case <-restartedCh:
121+
// Success: the channel was closed, meaning reconnect happened
122+
case <-time.After(100 * time.Millisecond):
123+
t.Fatal("reconnect did not complete: restarted channel was not closed")
124+
}
125+
}

0 commit comments

Comments
 (0)