Skip to content

Commit 556f27e

Browse files
committed
Inject lastResponseID as previous_response_id in WebSocket requests
The wsPool already tracked lastResponseID from completed responses but never forwarded it to subsequent requests. Now, wsPool.Stream() injects it as previous_response_id when the caller hasn't already set one, enabling server-side context caching across multi-turn exchanges. Add tests covering automatic injection, caller override preservation, and survival across reconnections. Assisted-By: docker-agent
1 parent 75b96c2 commit 556f27e

2 files changed

Lines changed: 242 additions & 0 deletions

File tree

pkg/model/provider/openai/ws_pool.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/gorilla/websocket"
12+
"github.com/openai/openai-go/v3/packages/param"
1213
"github.com/openai/openai-go/v3/responses"
1314
)
1415

@@ -71,6 +72,13 @@ func (p *wsPool) Stream(
7172
p.mu.Lock()
7273
defer p.mu.Unlock()
7374

75+
// Inject previous_response_id for server-side context caching when the
76+
// caller hasn't already set one and we have a response from an earlier
77+
// exchange on this pool.
78+
if p.lastResponseID != "" && !params.PreviousResponseID.Valid() {
79+
params.PreviousResponseID = param.NewOpt(p.lastResponseID)
80+
}
81+
7482
// Close stale connections.
7583
if p.conn != nil && p.conn.isExpired() {
7684
slog.Debug("Closing expired WebSocket connection",
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"strings"
9+
"testing"
10+
11+
"github.com/gorilla/websocket"
12+
"github.com/openai/openai-go/v3/packages/param"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
// testWSServerCapture starts a test WebSocket server that captures each
18+
// response.create message into the returned slice and replies with the
19+
// given canned events.
20+
func testWSServerCapture(t *testing.T, events []map[string]any) (*httptest.Server, *[]map[string]json.RawMessage) {
21+
t.Helper()
22+
23+
var captured []map[string]json.RawMessage
24+
25+
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
26+
27+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28+
conn, err := upgrader.Upgrade(w, r, nil)
29+
if err != nil {
30+
t.Errorf("WebSocket upgrade failed: %v", err)
31+
return
32+
}
33+
defer conn.Close()
34+
35+
for {
36+
// Read a response.create message.
37+
_, data, err := conn.ReadMessage()
38+
if err != nil {
39+
return
40+
}
41+
42+
var createMsg map[string]json.RawMessage
43+
if err := json.Unmarshal(data, &createMsg); err != nil {
44+
t.Errorf("Failed to unmarshal response.create: %v", err)
45+
return
46+
}
47+
captured = append(captured, createMsg)
48+
49+
// Send events.
50+
for _, event := range events {
51+
eventData, _ := json.Marshal(event)
52+
if err := conn.WriteMessage(websocket.TextMessage, eventData); err != nil {
53+
return
54+
}
55+
}
56+
}
57+
}))
58+
59+
return srv, &captured
60+
}
61+
62+
func completedEvent(responseID string) map[string]any {
63+
return map[string]any{
64+
"type": "response.completed",
65+
"response": map[string]any{
66+
"id": responseID,
67+
"output": []any{},
68+
"usage": map[string]any{
69+
"input_tokens": 5,
70+
"output_tokens": 1,
71+
"total_tokens": 6,
72+
"input_tokens_details": map[string]any{
73+
"cached_tokens": 0,
74+
},
75+
"output_tokens_details": map[string]any{
76+
"reasoning_tokens": 0,
77+
},
78+
},
79+
},
80+
}
81+
}
82+
83+
func TestWSPool_InjectsPreviousResponseID(t *testing.T) {
84+
t.Parallel()
85+
86+
events := []map[string]any{
87+
completedEvent("resp_first"),
88+
}
89+
90+
srv, captured := testWSServerCapture(t, events)
91+
defer srv.Close()
92+
93+
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
94+
pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) {
95+
return http.Header{}, nil
96+
})
97+
defer pool.Close()
98+
99+
ctx := t.Context()
100+
101+
// --- First request: no previous_response_id should be set.
102+
stream1, err := pool.Stream(ctx, defaultTestParams())
103+
require.NoError(t, err)
104+
drainStream(t, stream1)
105+
106+
// After draining, the pool should have captured the response ID.
107+
assert.Equal(t, "resp_first", pool.lastResponseID)
108+
109+
// --- Second request: the pool should inject previous_response_id automatically.
110+
// Change events for the second request to return a different ID.
111+
// (The server always sends the same events we initialized, so we verify
112+
// the injection from the captured request.)
113+
stream2, err := pool.Stream(ctx, defaultTestParams())
114+
require.NoError(t, err)
115+
drainStream(t, stream2)
116+
117+
// Verify captured messages.
118+
require.Len(t, *captured, 2)
119+
120+
// First request: no previous_response_id.
121+
assertPreviousResponseID(t, (*captured)[0], "")
122+
123+
// Second request: pool injects the ID from the first response.
124+
assertPreviousResponseID(t, (*captured)[1], "resp_first")
125+
}
126+
127+
func TestWSPool_CallerPreviousResponseIDNotOverwritten(t *testing.T) {
128+
t.Parallel()
129+
130+
events := []map[string]any{
131+
completedEvent("resp_pool"),
132+
}
133+
134+
srv, captured := testWSServerCapture(t, events)
135+
defer srv.Close()
136+
137+
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
138+
pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) {
139+
return http.Header{}, nil
140+
})
141+
defer pool.Close()
142+
143+
ctx := t.Context()
144+
145+
// First request — populate lastResponseID.
146+
stream1, err := pool.Stream(ctx, defaultTestParams())
147+
require.NoError(t, err)
148+
drainStream(t, stream1)
149+
150+
assert.Equal(t, "resp_pool", pool.lastResponseID)
151+
152+
// Second request with caller-provided previous_response_id.
153+
params := defaultTestParams()
154+
params.PreviousResponseID = param.NewOpt("caller_resp_999")
155+
156+
stream2, err := pool.Stream(ctx, params)
157+
require.NoError(t, err)
158+
drainStream(t, stream2)
159+
160+
require.Len(t, *captured, 2)
161+
162+
// The caller's ID must NOT be overwritten by the pool.
163+
assertPreviousResponseID(t, (*captured)[1], "caller_resp_999")
164+
}
165+
166+
func TestWSPool_LastResponseIDSurvivesReconnect(t *testing.T) {
167+
t.Parallel()
168+
169+
events := []map[string]any{
170+
completedEvent("resp_survive"),
171+
}
172+
173+
srv, captured := testWSServerCapture(t, events)
174+
defer srv.Close()
175+
176+
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
177+
pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) {
178+
return http.Header{}, nil
179+
})
180+
defer pool.Close()
181+
182+
ctx := t.Context()
183+
184+
// First request.
185+
stream1, err := pool.Stream(ctx, defaultTestParams())
186+
require.NoError(t, err)
187+
drainStream(t, stream1)
188+
189+
assert.Equal(t, "resp_survive", pool.lastResponseID)
190+
191+
// Force a reconnect by closing the pooled connection.
192+
pool.Close()
193+
194+
// Second request after reconnection.
195+
stream2, err := pool.Stream(ctx, defaultTestParams())
196+
require.NoError(t, err)
197+
drainStream(t, stream2)
198+
199+
require.Len(t, *captured, 2)
200+
201+
// The lastResponseID should survive the reconnect.
202+
assertPreviousResponseID(t, (*captured)[1], "resp_survive")
203+
}
204+
205+
// drainStream reads all events from a responseEventStream until exhausted.
206+
func drainStream(t *testing.T, stream responseEventStream) {
207+
t.Helper()
208+
for stream.Next() {
209+
// consume
210+
}
211+
require.NoError(t, stream.Err())
212+
require.NoError(t, stream.Close())
213+
}
214+
215+
// assertPreviousResponseID checks that the captured response.create message
216+
// contains (or omits) the expected previous_response_id.
217+
func assertPreviousResponseID(t *testing.T, msg map[string]json.RawMessage, expected string) {
218+
t.Helper()
219+
220+
raw, ok := msg["previous_response_id"]
221+
if expected == "" {
222+
// Either absent or null.
223+
if ok {
224+
assert.JSONEq(t, "null", string(raw),
225+
"expected previous_response_id to be absent or null")
226+
}
227+
return
228+
}
229+
230+
require.True(t, ok, "expected previous_response_id in request")
231+
var got string
232+
require.NoError(t, json.Unmarshal(raw, &got))
233+
assert.Equal(t, expected, got)
234+
}

0 commit comments

Comments
 (0)