@@ -3,8 +3,13 @@ package mcp
33import (
44 "context"
55 "fmt"
6+ "io"
7+ "iter"
68 "net"
79 "net/http"
10+ "os"
11+ "os/exec"
12+ "sync"
813 "sync/atomic"
914 "testing"
1015 "time"
@@ -238,3 +243,250 @@ func TestRemoteReconnectRefreshesTools(t *testing.T) {
238243 assert .Contains (t , toolNames , "ns_shared" )
239244 assert .NotContains (t , toolNames , "ns_alpha" , "stale tool from old server should not be present" )
240245}
246+
247+ // failingInitClient is a mock mcpClient whose Initialize method returns a
248+ // configurable error for the first N calls, then succeeds.
249+ type failingInitClient struct {
250+ mu sync.Mutex
251+ initErr error // error to return from Initialize
252+ failsLeft int // how many more times Initialize should fail
253+ initCalls int // total Initialize calls
254+ waitCh chan struct {}
255+ toolsToList []* gomcp.Tool
256+ }
257+
258+ func (m * failingInitClient ) Initialize (_ context.Context , _ * gomcp.InitializeRequest ) (* gomcp.InitializeResult , error ) {
259+ m .mu .Lock ()
260+ defer m .mu .Unlock ()
261+ m .initCalls ++
262+ if m .failsLeft > 0 {
263+ m .failsLeft --
264+ return nil , m .initErr
265+ }
266+ if m .waitCh != nil {
267+ m .waitCh = make (chan struct {})
268+ }
269+ return & gomcp.InitializeResult {}, nil
270+ }
271+
272+ func (m * failingInitClient ) ListTools (_ context.Context , _ * gomcp.ListToolsParams ) iter.Seq2 [* gomcp.Tool , error ] {
273+ m .mu .Lock ()
274+ t := m .toolsToList
275+ m .mu .Unlock ()
276+ return func (yield func (* gomcp.Tool , error ) bool ) {
277+ for _ , tool := range t {
278+ if ! yield (tool , nil ) {
279+ return
280+ }
281+ }
282+ }
283+ }
284+
285+ func (m * failingInitClient ) CallTool (context.Context , * gomcp.CallToolParams ) (* gomcp.CallToolResult , error ) {
286+ return & gomcp.CallToolResult {Content : []gomcp.Content {& gomcp.TextContent {Text : "ok" }}}, nil
287+ }
288+
289+ func (m * failingInitClient ) ListPrompts (context.Context , * gomcp.ListPromptsParams ) iter.Seq2 [* gomcp.Prompt , error ] {
290+ return func (func (* gomcp.Prompt , error ) bool ) {}
291+ }
292+
293+ func (m * failingInitClient ) GetPrompt (context.Context , * gomcp.GetPromptParams ) (* gomcp.GetPromptResult , error ) {
294+ return & gomcp.GetPromptResult {}, nil
295+ }
296+
297+ func (m * failingInitClient ) SetElicitationHandler (tools.ElicitationHandler ) {}
298+ func (m * failingInitClient ) SetOAuthSuccessHandler (func ()) {}
299+ func (m * failingInitClient ) SetManagedOAuth (bool ) {}
300+ func (m * failingInitClient ) SetToolListChangedHandler (func ()) {}
301+ func (m * failingInitClient ) SetPromptListChangedHandler (func ()) {}
302+
303+ func (m * failingInitClient ) Wait () error {
304+ m .mu .Lock ()
305+ ch := m .waitCh
306+ m .mu .Unlock ()
307+ if ch == nil {
308+ select {}
309+ }
310+ <- ch
311+ return nil
312+ }
313+
314+ func (m * failingInitClient ) Close (context.Context ) error {
315+ m .mu .Lock ()
316+ if m .waitCh != nil {
317+ select {
318+ case <- m .waitCh :
319+ default :
320+ close (m .waitCh )
321+ }
322+ }
323+ m .mu .Unlock ()
324+ return nil
325+ }
326+
327+ // TestStdioStartReturnsErrorWhenServerUnavailable verifies that a stdio toolset
328+ // propagates errServerUnavailable when Initialize returns io.EOF, and that
329+ // started remains false so the runtime can retry.
330+ func TestStdioStartReturnsErrorWhenServerUnavailable (t * testing.T ) {
331+ t .Parallel ()
332+
333+ mock := & failingInitClient {
334+ initErr : io .EOF ,
335+ failsLeft : 1 ,
336+ }
337+
338+ ts := & Toolset {
339+ name : "test-stdio" ,
340+ mcpClient : mock ,
341+ logID : "test-cmd" ,
342+ }
343+
344+ err := ts .Start (t .Context ())
345+ require .Error (t , err )
346+ require .ErrorIs (t , err , errServerUnavailable )
347+
348+ ts .mu .Lock ()
349+ started := ts .started
350+ ts .mu .Unlock ()
351+ assert .False (t , started , "stdio toolset must not be marked as started when server is unavailable" )
352+ }
353+
354+ // TestStdioStartReturnsErrorWhenBinaryNotFound verifies that exec.ErrNotFound
355+ // from Initialize is treated the same as io.EOF for stdio toolsets.
356+ func TestStdioStartReturnsErrorWhenBinaryNotFound (t * testing.T ) {
357+ t .Parallel ()
358+
359+ mock := & failingInitClient {
360+ initErr : fmt .Errorf ("start command: %w" , exec .ErrNotFound ),
361+ failsLeft : 1 ,
362+ }
363+
364+ ts := & Toolset {
365+ name : "test-stdio" ,
366+ mcpClient : mock ,
367+ logID : "missing-binary" ,
368+ }
369+
370+ err := ts .Start (t .Context ())
371+ require .Error (t , err )
372+ require .ErrorIs (t , err , errServerUnavailable )
373+
374+ ts .mu .Lock ()
375+ started := ts .started
376+ ts .mu .Unlock ()
377+ assert .False (t , started , "stdio toolset must not be marked as started when binary is not found" )
378+ }
379+
380+ // TestStdioLazyRetrySucceedsWhenBinaryAppears verifies the end-to-end retry
381+ // scenario: turn 1 fails with EOF (binary not yet available), turn 2 succeeds
382+ // once the binary "appears" (mock stops failing).
383+ func TestStdioLazyRetrySucceedsWhenBinaryAppears (t * testing.T ) {
384+ t .Parallel ()
385+
386+ pingTool := & gomcp.Tool {Name : "ping" }
387+ mock := & failingInitClient {
388+ initErr : io .EOF ,
389+ failsLeft : 1 ,
390+ toolsToList : []* gomcp.Tool {pingTool },
391+ waitCh : make (chan struct {}),
392+ }
393+
394+ ts := & Toolset {
395+ name : "test-stdio" ,
396+ mcpClient : mock ,
397+ logID : "lazy-binary" ,
398+ }
399+
400+ // Turn 1: Start fails — binary not available yet.
401+ err := ts .Start (t .Context ())
402+ require .Error (t , err )
403+ require .ErrorIs (t , err , errServerUnavailable )
404+
405+ // Turn 2: Binary has "appeared" (mock will succeed).
406+ err = ts .Start (t .Context ())
407+ require .NoError (t , err )
408+
409+ ts .mu .Lock ()
410+ started := ts .started
411+ ts .mu .Unlock ()
412+ assert .True (t , started , "stdio toolset must be started after successful retry" )
413+
414+ toolList , err := ts .Tools (t .Context ())
415+ require .NoError (t , err )
416+ require .Len (t , toolList , 1 )
417+ assert .Equal (t , "test-stdio_ping" , toolList [0 ].Name )
418+
419+ _ = ts .Stop (t .Context ())
420+ }
421+
422+ // TestRemoteStartRetriesWhenUnavailable verifies that a remote toolset also
423+ // returns an error and stays un-started when the server is unavailable (EOF),
424+ // confirming retry-on-next-turn applies to all toolset types.
425+ func TestRemoteStartRetriesWhenUnavailable (t * testing.T ) {
426+ t .Parallel ()
427+
428+ mock := & failingInitClient {
429+ initErr : io .EOF ,
430+ failsLeft : 1 ,
431+ }
432+
433+ ts := & Toolset {
434+ name : "test-remote" ,
435+ mcpClient : mock ,
436+ logID : "remote-server" ,
437+ }
438+
439+ err := ts .Start (t .Context ())
440+ require .Error (t , err )
441+ require .ErrorIs (t , err , errServerUnavailable )
442+
443+ ts .mu .Lock ()
444+ started := ts .started
445+ ts .mu .Unlock ()
446+ assert .False (t , started , "remote toolset must not be marked as started when server is unavailable" )
447+ }
448+
449+ // TestStartableToolSetRetryAcrossTurns is a full integration test using
450+ // tools.NewStartable to wrap an MCP Toolset. It verifies that when a stdio
451+ // toolset fails N turns, the StartableToolSet keeps retrying and succeeds
452+ // on turn N+1.
453+ func TestStartableToolSetRetryAcrossTurns (t * testing.T ) {
454+ t .Parallel ()
455+
456+ const failTurns = 3
457+
458+ pingTool := & gomcp.Tool {Name : "ping" }
459+ mock := & failingInitClient {
460+ initErr : fmt .Errorf ("command not found: %w" , os .ErrNotExist ),
461+ failsLeft : failTurns ,
462+ toolsToList : []* gomcp.Tool {pingTool },
463+ waitCh : make (chan struct {}),
464+ }
465+
466+ mcpToolset := & Toolset {
467+ name : "retry-test" ,
468+ mcpClient : mock ,
469+ logID : "retry-binary" ,
470+ }
471+
472+ startable := tools .NewStartable (mcpToolset )
473+
474+ // Turns 1..N: Start fails, IsStarted stays false.
475+ for turn := 1 ; turn <= failTurns ; turn ++ {
476+ err := startable .Start (t .Context ())
477+ require .Error (t , err , "turn %d should fail" , turn )
478+ assert .False (t , startable .IsStarted (), "turn %d: should not be started" , turn )
479+ }
480+
481+ // Turn N+1: binary is now available, Start succeeds.
482+ err := startable .Start (t .Context ())
483+ require .NoError (t , err )
484+ assert .True (t , startable .IsStarted ())
485+
486+ toolList , err := mcpToolset .Tools (t .Context ())
487+ require .NoError (t , err )
488+ require .Len (t , toolList , 1 )
489+ assert .Equal (t , "retry-test_ping" , toolList [0 ].Name )
490+
491+ _ = startable .Stop (t .Context ())
492+ }
0 commit comments