Skip to content

Commit 8e0348b

Browse files
committed
fix: serialize concurrent RunSession calls to prevent tool_use/tool_result mismatch
When two HTTP requests target the same session concurrently, the second can inject user messages while the first is mid-tool-call, producing a tool_use without a matching tool_result that causes Anthropic API errors. Add a per-session streaming mutex to activeRuntimes. RunSession uses TryLock to fail fast with ErrSessionBusy (HTTP 409) when the session is already streaming. Message addition is deferred until after the lock is acquired, so a rejected request never mutates the session. TryLock is called on the calling goroutine inside RunSession; Unlock is deferred in the background goroutine after RunStream completes. The lock is held continuously from before message addition through the entire stream including all tool-call processing. Assisted-By: docker-agent
1 parent b355a29 commit 8e0348b

3 files changed

Lines changed: 263 additions & 15 deletions

File tree

pkg/server/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"cmp"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"log/slog"
910
"net"
@@ -285,6 +286,9 @@ func (s *Server) runAgent(c echo.Context) error {
285286

286287
streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages)
287288
if err != nil {
289+
if errors.Is(err, ErrSessionBusy) {
290+
return echo.NewHTTPError(http.StatusConflict, err.Error())
291+
}
288292
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err))
289293
}
290294

pkg/server/session_manager.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type activeRuntimes struct {
2727
cancel context.CancelFunc
2828
session *session.Session // The actual session object used by the runtime
2929
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
30+
31+
streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests
3032
}
3133

3234
// SessionManager manages sessions for HTTP and Connect-RPC servers.
@@ -134,6 +136,9 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e
134136
return nil
135137
}
136138

139+
// ErrSessionBusy is returned when a session is already processing a request.
140+
var ErrSessionBusy = errors.New("session is already processing a request")
141+
137142
// RunSession runs a session with the given messages.
138143
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
139144
sm.mux.Lock()
@@ -146,19 +151,6 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
146151
rc := sm.runConfig.Clone()
147152
rc.WorkingDir = sess.WorkingDir
148153

149-
// Collect user messages for potential title generation
150-
var userMessages []string
151-
for _, msg := range messages {
152-
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
153-
if msg.Content != "" {
154-
userMessages = append(userMessages, msg.Content)
155-
}
156-
}
157-
158-
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
159-
return nil, err
160-
}
161-
162154
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
163155
streamCtx, cancel := context.WithCancel(ctx)
164156
var titleGen *sessiontitle.Generator
@@ -177,17 +169,45 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
177169
}
178170
sm.runtimeSessions.Store(sessionID, runtimeSession)
179171
} else {
180-
// Update the session pointer in case it was reloaded
181-
runtimeSession.session = sess
182172
titleGen = runtimeSession.titleGen
183173
}
184174

175+
// Reject the request immediately if the session is already streaming.
176+
// This prevents interleaving user messages while a tool call is in
177+
// progress, which would produce a tool_use without a matching
178+
// tool_result and cause provider errors.
179+
if !runtimeSession.streaming.TryLock() {
180+
cancel()
181+
return nil, ErrSessionBusy
182+
}
183+
184+
// Now that we hold the streaming lock, it is safe to mutate the session.
185+
// Collect user messages for potential title generation
186+
var userMessages []string
187+
for _, msg := range messages {
188+
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
189+
if msg.Content != "" {
190+
userMessages = append(userMessages, msg.Content)
191+
}
192+
}
193+
194+
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
195+
runtimeSession.streaming.Unlock()
196+
cancel()
197+
return nil, err
198+
}
199+
200+
// Update the session pointer so the runtime sees the latest messages.
201+
runtimeSession.session = sess
202+
185203
streamChan := make(chan runtime.Event)
186204

187205
// Check if we need to generate a title
188206
needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil
189207

190208
go func() {
209+
defer runtimeSession.streaming.Unlock()
210+
191211
// Start title generation in parallel if needed
192212
if needsTitle {
193213
go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan)

pkg/server/session_manager_test.go

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"sync"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/docker/docker-agent/pkg/api"
14+
"github.com/docker/docker-agent/pkg/concurrent"
15+
"github.com/docker/docker-agent/pkg/config"
16+
"github.com/docker/docker-agent/pkg/runtime"
17+
"github.com/docker/docker-agent/pkg/session"
18+
"github.com/docker/docker-agent/pkg/sessiontitle"
19+
"github.com/docker/docker-agent/pkg/tools"
20+
)
21+
22+
// fakeRuntime is a minimal Runtime that records concurrent RunStream calls.
23+
type fakeRuntime struct {
24+
runtime.Runtime
25+
26+
concurrentStreams atomic.Int32
27+
maxConcurrent atomic.Int32
28+
streamDelay time.Duration
29+
}
30+
31+
func (f *fakeRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
32+
cur := f.concurrentStreams.Add(1)
33+
for {
34+
old := f.maxConcurrent.Load()
35+
if cur <= old || f.maxConcurrent.CompareAndSwap(old, cur) {
36+
break
37+
}
38+
}
39+
40+
ch := make(chan runtime.Event)
41+
go func() {
42+
time.Sleep(f.streamDelay)
43+
f.concurrentStreams.Add(-1)
44+
close(ch)
45+
}()
46+
return ch
47+
}
48+
49+
func (f *fakeRuntime) Resume(_ context.Context, _ runtime.ResumeRequest) {}
50+
51+
func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAction, _ map[string]any) error {
52+
return nil
53+
}
54+
55+
func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager {
56+
t.Helper()
57+
58+
ctx := t.Context()
59+
store := session.NewInMemorySessionStore()
60+
require.NoError(t, store.AddSession(ctx, sess))
61+
62+
sm := &SessionManager{
63+
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
64+
sessionStore: store,
65+
Sources: config.Sources{},
66+
runConfig: &config.RuntimeConfig{},
67+
}
68+
69+
// Pre-register a runtime for this session so RunSession skips agent loading.
70+
sm.runtimeSessions.Store(sess.ID, &activeRuntimes{
71+
runtime: fake,
72+
session: sess,
73+
titleGen: (*sessiontitle.Generator)(nil),
74+
})
75+
76+
return sm
77+
}
78+
79+
// TestRunSession_ConcurrentRequestReturnsErrSessionBusy verifies that a
80+
// second RunSession call on a session that is already streaming returns
81+
// ErrSessionBusy instead of silently interleaving messages.
82+
func TestRunSession_ConcurrentRequestReturnsErrSessionBusy(t *testing.T) {
83+
t.Parallel()
84+
85+
ctx := t.Context()
86+
sess := session.New()
87+
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
88+
sm := newTestSessionManager(t, sess, fake)
89+
90+
// Start the first stream.
91+
ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
92+
{Content: "first"},
93+
})
94+
require.NoError(t, err)
95+
96+
// Give the goroutine a moment to acquire the streaming lock.
97+
time.Sleep(50 * time.Millisecond)
98+
99+
// The second request should fail immediately with ErrSessionBusy.
100+
_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
101+
{Content: "second"},
102+
})
103+
require.ErrorIs(t, err, ErrSessionBusy)
104+
105+
// Drain first stream to let it complete.
106+
for range ch1 {
107+
}
108+
109+
// After the first stream finishes, a new request should succeed.
110+
ch3, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
111+
{Content: "third"},
112+
})
113+
require.NoError(t, err)
114+
for range ch3 {
115+
}
116+
}
117+
118+
// TestRunSession_MessagesNotAddedWhenBusy verifies that when a session
119+
// is busy, the rejected request does not mutate the session's messages.
120+
func TestRunSession_MessagesNotAddedWhenBusy(t *testing.T) {
121+
t.Parallel()
122+
123+
ctx := t.Context()
124+
sess := session.New()
125+
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
126+
sm := newTestSessionManager(t, sess, fake)
127+
128+
ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
129+
{Content: "first"},
130+
})
131+
require.NoError(t, err)
132+
133+
time.Sleep(50 * time.Millisecond)
134+
135+
msgCountBefore := len(sess.GetAllMessages())
136+
137+
_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
138+
{Content: "should not be added"},
139+
})
140+
require.ErrorIs(t, err, ErrSessionBusy)
141+
142+
// Messages should not have been added.
143+
assert.Len(t, sess.GetAllMessages(), msgCountBefore)
144+
145+
for range ch1 {
146+
}
147+
}
148+
149+
// TestRunSession_SequentialRequestsSucceed verifies that sequential
150+
// (non-overlapping) requests on the same session work normally.
151+
func TestRunSession_SequentialRequestsSucceed(t *testing.T) {
152+
t.Parallel()
153+
154+
ctx := t.Context()
155+
sess := session.New()
156+
fake := &fakeRuntime{streamDelay: 10 * time.Millisecond}
157+
sm := newTestSessionManager(t, sess, fake)
158+
159+
for range 3 {
160+
ch, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
161+
{Content: "hello"},
162+
})
163+
require.NoError(t, err)
164+
for range ch {
165+
}
166+
}
167+
168+
assert.Equal(t, int32(1), fake.maxConcurrent.Load())
169+
}
170+
171+
// TestRunSession_DifferentSessionsConcurrently verifies that concurrent
172+
// requests on *different* sessions are not blocked by each other.
173+
func TestRunSession_DifferentSessionsConcurrently(t *testing.T) {
174+
t.Parallel()
175+
176+
ctx := t.Context()
177+
store := session.NewInMemorySessionStore()
178+
fake1 := &fakeRuntime{streamDelay: 200 * time.Millisecond}
179+
fake2 := &fakeRuntime{streamDelay: 200 * time.Millisecond}
180+
181+
sess1 := session.New()
182+
sess2 := session.New()
183+
require.NoError(t, store.AddSession(ctx, sess1))
184+
require.NoError(t, store.AddSession(ctx, sess2))
185+
186+
sm := &SessionManager{
187+
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
188+
sessionStore: store,
189+
Sources: config.Sources{},
190+
runConfig: &config.RuntimeConfig{},
191+
}
192+
193+
sm.runtimeSessions.Store(sess1.ID, &activeRuntimes{
194+
runtime: fake1, session: sess1, titleGen: (*sessiontitle.Generator)(nil),
195+
})
196+
sm.runtimeSessions.Store(sess2.ID, &activeRuntimes{
197+
runtime: fake2, session: sess2, titleGen: (*sessiontitle.Generator)(nil),
198+
})
199+
200+
var wg sync.WaitGroup
201+
wg.Add(2)
202+
203+
go func() {
204+
defer wg.Done()
205+
ch, err := sm.RunSession(ctx, sess1.ID, "agent", "root", []api.Message{{Content: "a"}})
206+
assert.NoError(t, err)
207+
for range ch {
208+
}
209+
}()
210+
211+
go func() {
212+
defer wg.Done()
213+
ch, err := sm.RunSession(ctx, sess2.ID, "agent", "root", []api.Message{{Content: "b"}})
214+
assert.NoError(t, err)
215+
for range ch {
216+
}
217+
}()
218+
219+
wg.Wait()
220+
221+
// Both sessions should have streamed (1 each).
222+
assert.Equal(t, int32(1), fake1.maxConcurrent.Load())
223+
assert.Equal(t, int32(1), fake2.maxConcurrent.Load())
224+
}

0 commit comments

Comments
 (0)