Skip to content

Commit 6013188

Browse files
committed
Address PR #2186 review feedback for WebSocket pool
- Add Client.Close() to release pooled WebSocket connections - Invalidate broken connections in pooledStream.Close() instead of returning dead sockets to the pool - Preserve lastResponseID across reconnections (expired + broken) so server-side context caching survives connection resets - Add wsMaxReconnectAttempts constant with bounded retry loop to prevent unbounded reconnection attempts - Replace os.Getenv("OPENAI_API_KEY") with c.Env.Get() for consistent secret resolution via the environment provider - Treat websocket.CloseNoStatusReceived as a normal close condition Assisted-By: docker-agent
1 parent b22ccaf commit 6013188

3 files changed

Lines changed: 62 additions & 21 deletions

File tree

pkg/model/provider/openai/client.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"log/slog"
1010
"net/http"
1111
"net/url"
12-
"os"
1312
"strings"
1413

1514
"github.com/openai/openai-go/v3"
@@ -156,6 +155,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
156155
}, nil
157156
}
158157

158+
// Close releases resources held by the client, including any pooled WebSocket
159+
// connections. It is safe to call Close multiple times.
160+
func (c *Client) Close() {
161+
if c.wsPool != nil {
162+
c.wsPool.Close()
163+
}
164+
}
165+
159166
// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion
160167
// using the shared oaistream implementation.
161168
func convertMessages(messages []chat.Message) []openai.ChatCompletionMessageParamUnion {
@@ -459,8 +466,10 @@ func (c *Client) buildWSHeaderFn() func(ctx context.Context) (http.Header, error
459466
apiKey, _ = c.Env.Get(ctx, c.ModelConfig.TokenKey)
460467
}
461468
if apiKey == "" {
462-
// Fall back to the standard OPENAI_API_KEY env var.
463-
apiKey = os.Getenv("OPENAI_API_KEY")
469+
// Fall back to the standard OPENAI_API_KEY env var via the
470+
// environment provider so that secret resolution is
471+
// consistent with the HTTP client path.
472+
apiKey, _ = c.Env.Get(ctx, "OPENAI_API_KEY")
464473
}
465474
if apiKey != "" {
466475
h.Set("Authorization", "Bearer "+apiKey)

pkg/model/provider/openai/ws_pool.go

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ const (
1717
// wsMaxConnectionAge is the maximum lifetime of a WebSocket connection.
1818
// OpenAI enforces a 60-minute limit; we reconnect slightly earlier.
1919
wsMaxConnectionAge = 55 * time.Minute
20+
21+
// wsMaxReconnectAttempts is the maximum number of times a broken
22+
// connection will be replaced with a fresh one within a single
23+
// Stream call before the error is propagated to the caller.
24+
wsMaxReconnectAttempts = 1
2025
)
2126

2227
// wsConnection holds a WebSocket connection together with bookkeeping
@@ -71,10 +76,12 @@ func (p *wsPool) Stream(
7176
p.mu.Lock()
7277
defer p.mu.Unlock()
7378

74-
// Close stale connections.
79+
// Close stale connections, preserving the last response ID.
80+
var prevResponseID string
7581
if p.conn != nil && p.conn.isExpired() {
7682
slog.Debug("Closing expired WebSocket connection",
7783
"age", time.Since(p.conn.createdAt))
84+
prevResponseID = p.conn.lastResponseID
7885
_ = p.conn.conn.Close()
7986
p.conn = nil
8087
}
@@ -92,8 +99,9 @@ func (p *wsPool) Stream(
9299
}
93100

94101
p.conn = &wsConnection{
95-
conn: stream.conn,
96-
createdAt: time.Now(),
102+
conn: stream.conn,
103+
createdAt: time.Now(),
104+
lastResponseID: prevResponseID,
97105
}
98106

99107
return &pooledStream{pool: p, inner: stream}, nil
@@ -103,23 +111,33 @@ func (p *wsPool) Stream(
103111
stream, err := sendOnExisting(p.conn.conn, params)
104112
if err != nil {
105113
// Connection is broken; tear down and retry with a fresh one.
114+
// We only attempt wsMaxReconnectAttempts reconnections to avoid
115+
// unbounded loops if the server keeps rejecting connections.
106116
slog.Warn("Existing WebSocket connection failed, reconnecting", "error", err)
117+
prevResponseID := p.conn.lastResponseID
107118
_ = p.conn.conn.Close()
108119
p.conn = nil
109120

110-
headers, err2 := p.headerFn(ctx)
111-
if err2 != nil {
112-
return nil, fmt.Errorf("websocket pool: headers on reconnect: %w", err2)
113-
}
114-
stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params)
115-
if err2 != nil {
116-
return nil, fmt.Errorf("websocket pool: reconnect: %w", err2)
117-
}
118-
p.conn = &wsConnection{
119-
conn: stream.conn,
120-
createdAt: time.Now(),
121+
var lastErr error
122+
for attempt := range wsMaxReconnectAttempts {
123+
headers, err2 := p.headerFn(ctx)
124+
if err2 != nil {
125+
lastErr = fmt.Errorf("websocket pool: headers on reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2)
126+
continue
127+
}
128+
stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params)
129+
if err2 != nil {
130+
lastErr = fmt.Errorf("websocket pool: reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2)
131+
continue
132+
}
133+
p.conn = &wsConnection{
134+
conn: stream.conn,
135+
createdAt: time.Now(),
136+
lastResponseID: prevResponseID,
137+
}
138+
return &pooledStream{pool: p, inner: stream}, nil
121139
}
122-
return &pooledStream{pool: p, inner: stream}, nil
140+
return nil, lastErr
123141
}
124142

125143
return &pooledStream{pool: p, inner: stream}, nil
@@ -195,10 +213,23 @@ func (s *pooledStream) Err() error {
195213
return s.inner.Err()
196214
}
197215

198-
// Close releases the stream but keeps the underlying connection alive in
199-
// the pool for reuse.
216+
// Close releases the stream. If the stream encountered an error, the
217+
// underlying connection is invalidated so that the pool opens a fresh one
218+
// on the next request. Otherwise the connection stays in the pool for reuse.
200219
func (s *pooledStream) Close() error {
201220
s.inner.done = true
202-
// Do NOT close the WebSocket connection—it stays in the pool.
221+
222+
if s.inner.Err() != nil {
223+
// Connection is likely broken; tear it down so the pool
224+
// doesn't hand out a dead socket.
225+
s.pool.mu.Lock()
226+
if s.pool.conn != nil && s.pool.conn.conn == s.inner.conn {
227+
_ = s.pool.conn.conn.Close()
228+
s.pool.conn = nil
229+
}
230+
s.pool.mu.Unlock()
231+
}
232+
233+
// Do NOT close the WebSocket connection when healthy—it stays in the pool.
203234
return nil
204235
}

pkg/model/provider/openai/ws_stream.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func (s *wsStream) Next() bool {
119119
if websocket.IsCloseError(err,
120120
websocket.CloseNormalClosure,
121121
websocket.CloseGoingAway,
122+
websocket.CloseNoStatusReceived,
122123
) {
123124
s.done = true
124125
return false

0 commit comments

Comments
 (0)