Skip to content

Commit e4f454c

Browse files
committed
Simplify WebSocket pool code structure
- Promote lastResponseID from wsConnection to wsPool so it naturally survives all connection transitions without manual threading - Extract closeLocked(), dialLocked(), invalidateConn() helpers to eliminate duplicated connection lifecycle logic in Stream() - Replace loop-of-one reconnect with a single dialLocked() call - Extract sendResponseCreate() to deduplicate marshal+send between dialWebSocket() and sendOnExisting() - Remove wsMaxReconnectAttempts constant (was always 1) - Simplify wsConnection struct to just conn + createdAt Net result: -18 lines, fewer code paths, same behavior. Assisted-By: docker-agent
1 parent 6013188 commit e4f454c

2 files changed

Lines changed: 85 additions & 103 deletions

File tree

pkg/model/provider/openai/ws_pool.go

Lines changed: 63 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package openai
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"log/slog"
87
"net/http"
@@ -17,23 +16,13 @@ const (
1716
// wsMaxConnectionAge is the maximum lifetime of a WebSocket connection.
1817
// OpenAI enforces a 60-minute limit; we reconnect slightly earlier.
1918
wsMaxConnectionAge = 55 * time.Minute
20-
21-
// wsMaxReconnectAttempts is the maximum number of times a broken
22-
// connection will be replaced with a fresh one within a single
23-
// Stream call before the error is propagated to the caller.
24-
wsMaxReconnectAttempts = 1
2519
)
2620

2721
// wsConnection holds a WebSocket connection together with bookkeeping
2822
// metadata for the connection pool.
2923
type wsConnection struct {
3024
conn *websocket.Conn
3125
createdAt time.Time
32-
33-
// lastResponseID is the ID of the most recent response completed on
34-
// this connection. It can be passed as previous_response_id in subsequent
35-
// requests to enable server-side context caching.
36-
lastResponseID string
3726
}
3827

3928
// isExpired returns true when the connection has been open longer than
@@ -50,6 +39,12 @@ type wsPool struct {
5039
mu sync.Mutex
5140
conn *wsConnection
5241

42+
// lastResponseID is the ID of the most recent response completed on
43+
// this pool. It can be passed as previous_response_id in subsequent
44+
// requests to enable server-side context caching.
45+
// It lives on the pool (not wsConnection) so it survives reconnections.
46+
lastResponseID string
47+
5348
// wsURL is the WebSocket endpoint (e.g. wss://api.openai.com/v1/responses).
5449
wsURL string
5550

@@ -76,99 +71,89 @@ func (p *wsPool) Stream(
7671
p.mu.Lock()
7772
defer p.mu.Unlock()
7873

79-
// Close stale connections, preserving the last response ID.
80-
var prevResponseID string
74+
// Close stale connections.
8175
if p.conn != nil && p.conn.isExpired() {
8276
slog.Debug("Closing expired WebSocket connection",
8377
"age", time.Since(p.conn.createdAt))
84-
prevResponseID = p.conn.lastResponseID
85-
_ = p.conn.conn.Close()
86-
p.conn = nil
78+
p.closeLocked()
8779
}
8880

8981
// Establish a new connection if needed.
9082
if p.conn == nil {
91-
headers, err := p.headerFn(ctx)
92-
if err != nil {
93-
return nil, fmt.Errorf("websocket pool: headers: %w", err)
94-
}
95-
96-
stream, err := dialWebSocket(ctx, p.wsURL, headers, params)
97-
if err != nil {
98-
return nil, err
99-
}
100-
101-
p.conn = &wsConnection{
102-
conn: stream.conn,
103-
createdAt: time.Now(),
104-
lastResponseID: prevResponseID,
105-
}
106-
107-
return &pooledStream{pool: p, inner: stream}, nil
83+
return p.dialLocked(ctx, params)
10884
}
10985

11086
// Reuse existing connection: send a new response.create.
11187
stream, err := sendOnExisting(p.conn.conn, params)
11288
if err != nil {
113-
// Connection is broken; tear down and retry with a fresh one.
114-
// We only attempt wsMaxReconnectAttempts reconnections to avoid
115-
// unbounded loops if the server keeps rejecting connections.
89+
// Connection is broken; tear down and reconnect once.
11690
slog.Warn("Existing WebSocket connection failed, reconnecting", "error", err)
117-
prevResponseID := p.conn.lastResponseID
118-
_ = p.conn.conn.Close()
119-
p.conn = nil
120-
121-
var lastErr error
122-
for attempt := range wsMaxReconnectAttempts {
123-
headers, err2 := p.headerFn(ctx)
124-
if err2 != nil {
125-
lastErr = fmt.Errorf("websocket pool: headers on reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2)
126-
continue
127-
}
128-
stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params)
129-
if err2 != nil {
130-
lastErr = fmt.Errorf("websocket pool: reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2)
131-
continue
132-
}
133-
p.conn = &wsConnection{
134-
conn: stream.conn,
135-
createdAt: time.Now(),
136-
lastResponseID: prevResponseID,
137-
}
138-
return &pooledStream{pool: p, inner: stream}, nil
139-
}
140-
return nil, lastErr
91+
p.closeLocked()
92+
return p.dialLocked(ctx, params)
93+
}
94+
95+
return &pooledStream{pool: p, inner: stream}, nil
96+
}
97+
98+
// dialLocked opens a fresh WebSocket connection and stores it in the pool.
99+
// Caller must hold p.mu.
100+
func (p *wsPool) dialLocked(
101+
ctx context.Context,
102+
params responses.ResponseNewParams,
103+
) (*pooledStream, error) {
104+
headers, err := p.headerFn(ctx)
105+
if err != nil {
106+
return nil, fmt.Errorf("websocket pool: headers: %w", err)
107+
}
108+
109+
stream, err := dialWebSocket(ctx, p.wsURL, headers, params)
110+
if err != nil {
111+
return nil, err
112+
}
113+
114+
p.conn = &wsConnection{
115+
conn: stream.conn,
116+
createdAt: time.Now(),
141117
}
142118

143119
return &pooledStream{pool: p, inner: stream}, nil
144120
}
145121

122+
// closeLocked closes the current connection. lastResponseID is preserved
123+
// on the pool so it survives reconnections. Caller must hold p.mu.
124+
func (p *wsPool) closeLocked() {
125+
if p.conn == nil {
126+
return
127+
}
128+
_ = p.conn.conn.Close()
129+
p.conn = nil
130+
}
131+
132+
// invalidateConn tears down the pooled connection if it matches conn.
133+
// Called by pooledStream.Close when the stream encountered an error,
134+
// so the pool does not hand out a broken connection.
135+
func (p *wsPool) invalidateConn(conn *websocket.Conn) {
136+
p.mu.Lock()
137+
defer p.mu.Unlock()
138+
139+
if p.conn != nil && p.conn.conn == conn {
140+
p.closeLocked()
141+
}
142+
}
143+
146144
// Close shuts down the pooled connection.
147145
func (p *wsPool) Close() {
148146
p.mu.Lock()
149147
defer p.mu.Unlock()
150148

151-
if p.conn != nil {
152-
_ = p.conn.conn.Close()
153-
p.conn = nil
154-
}
149+
p.closeLocked()
155150
}
156151

157152
// sendOnExisting sends a response.create on an already-open connection and
158153
// returns a wsStream that reads events from it.
159154
func sendOnExisting(conn *websocket.Conn, params responses.ResponseNewParams) (*wsStream, error) {
160-
paramsJSON, err := json.Marshal(params)
161-
if err != nil {
162-
return nil, fmt.Errorf("websocket: marshal params: %w", err)
163-
}
164-
165-
msg := wsCreateMessage{
166-
Type: "response.create",
167-
Params: paramsJSON,
168-
}
169-
170-
if err := conn.WriteJSON(msg); err != nil {
171-
return nil, fmt.Errorf("websocket: write response.create: %w", err)
155+
if err := sendResponseCreate(conn, params); err != nil {
156+
return nil, err
172157
}
173158

174159
slog.Debug("WebSocket response.create sent (reused connection)")
@@ -196,9 +181,7 @@ func (s *pooledStream) Next() bool {
196181
event := s.inner.Current()
197182
if isTerminalEvent(event.Type) && event.Response.ID != "" {
198183
s.pool.mu.Lock()
199-
if s.pool.conn != nil {
200-
s.pool.conn.lastResponseID = event.Response.ID
201-
}
184+
s.pool.lastResponseID = event.Response.ID
202185
s.pool.mu.Unlock()
203186
}
204187

@@ -220,16 +203,8 @@ func (s *pooledStream) Close() error {
220203
s.inner.done = true
221204

222205
if s.inner.Err() != nil {
223-
// Connection is likely broken; tear it down so the pool
224-
// doesn't hand out a dead socket.
225-
s.pool.mu.Lock()
226-
if s.pool.conn != nil && s.pool.conn.conn == s.inner.conn {
227-
_ = s.pool.conn.conn.Close()
228-
s.pool.conn = nil
229-
}
230-
s.pool.mu.Unlock()
206+
s.pool.invalidateConn(s.inner.conn)
231207
}
232208

233-
// Do NOT close the WebSocket connection when healthy—it stays in the pool.
234209
return nil
235210
}

pkg/model/provider/openai/ws_stream.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ type wsStream struct {
6060
// Compile-time check: wsStream satisfies responseEventStream.
6161
var _ responseEventStream = (*wsStream)(nil)
6262

63+
// sendResponseCreate marshals params and writes a response.create message
64+
// on the given WebSocket connection.
65+
func sendResponseCreate(conn *websocket.Conn, params responses.ResponseNewParams) error {
66+
paramsJSON, err := json.Marshal(params)
67+
if err != nil {
68+
return fmt.Errorf("websocket: marshal params: %w", err)
69+
}
70+
71+
msg := wsCreateMessage{
72+
Type: "response.create",
73+
Params: paramsJSON,
74+
}
75+
76+
if err := conn.WriteJSON(msg); err != nil {
77+
return fmt.Errorf("websocket: write response.create: %w", err)
78+
}
79+
80+
return nil
81+
}
82+
6383
// dialWebSocket opens a WebSocket connection, sends the response.create
6484
// message, and returns a stream that yields server events.
6585
func dialWebSocket(
@@ -84,22 +104,9 @@ func dialWebSocket(
84104
return nil, fmt.Errorf("websocket dial %s: %w", wsURL, err)
85105
}
86106

87-
// Marshal the params using the SDK's MarshalJSON so all field
88-
// encodings (omitzero, unions, etc.) are handled correctly.
89-
paramsJSON, err := json.Marshal(params)
90-
if err != nil {
91-
conn.Close()
92-
return nil, fmt.Errorf("websocket: marshal params: %w", err)
93-
}
94-
95-
msg := wsCreateMessage{
96-
Type: "response.create",
97-
Params: paramsJSON,
98-
}
99-
100-
if err := conn.WriteJSON(msg); err != nil {
107+
if err := sendResponseCreate(conn, params); err != nil {
101108
conn.Close()
102-
return nil, fmt.Errorf("websocket: write response.create: %w", err)
109+
return nil, err
103110
}
104111

105112
slog.Debug("WebSocket response.create sent", "url", wsURL)

0 commit comments

Comments
 (0)