@@ -2,7 +2,6 @@ package openai
22
33import (
44 "context"
5- "encoding/json"
65 "fmt"
76 "log/slog"
87 "net/http"
@@ -17,23 +16,13 @@ const (
1716 // wsMaxConnectionAge is the maximum lifetime of a WebSocket connection.
1817 // OpenAI enforces a 60-minute limit; we reconnect slightly earlier.
1918 wsMaxConnectionAge = 55 * time .Minute
20-
21- // wsMaxReconnectAttempts is the maximum number of times a broken
22- // connection will be replaced with a fresh one within a single
23- // Stream call before the error is propagated to the caller.
24- wsMaxReconnectAttempts = 1
2519)
2620
2721// wsConnection holds a WebSocket connection together with bookkeeping
2822// metadata for the connection pool.
2923type wsConnection struct {
3024 conn * websocket.Conn
3125 createdAt time.Time
32-
33- // lastResponseID is the ID of the most recent response completed on
34- // this connection. It can be passed as previous_response_id in subsequent
35- // requests to enable server-side context caching.
36- lastResponseID string
3726}
3827
3928// isExpired returns true when the connection has been open longer than
@@ -50,6 +39,12 @@ type wsPool struct {
5039 mu sync.Mutex
5140 conn * wsConnection
5241
42+ // lastResponseID is the ID of the most recent response completed on
43+ // this pool. It can be passed as previous_response_id in subsequent
44+ // requests to enable server-side context caching.
45+ // It lives on the pool (not wsConnection) so it survives reconnections.
46+ lastResponseID string
47+
5348 // wsURL is the WebSocket endpoint (e.g. wss://api.openai.com/v1/responses).
5449 wsURL string
5550
@@ -76,99 +71,89 @@ func (p *wsPool) Stream(
7671 p .mu .Lock ()
7772 defer p .mu .Unlock ()
7873
79- // Close stale connections, preserving the last response ID.
80- var prevResponseID string
74+ // Close stale connections.
8175 if p .conn != nil && p .conn .isExpired () {
8276 slog .Debug ("Closing expired WebSocket connection" ,
8377 "age" , time .Since (p .conn .createdAt ))
84- prevResponseID = p .conn .lastResponseID
85- _ = p .conn .conn .Close ()
86- p .conn = nil
78+ p .closeLocked ()
8779 }
8880
8981 // Establish a new connection if needed.
9082 if p .conn == nil {
91- headers , err := p .headerFn (ctx )
92- if err != nil {
93- return nil , fmt .Errorf ("websocket pool: headers: %w" , err )
94- }
95-
96- stream , err := dialWebSocket (ctx , p .wsURL , headers , params )
97- if err != nil {
98- return nil , err
99- }
100-
101- p .conn = & wsConnection {
102- conn : stream .conn ,
103- createdAt : time .Now (),
104- lastResponseID : prevResponseID ,
105- }
106-
107- return & pooledStream {pool : p , inner : stream }, nil
83+ return p .dialLocked (ctx , params )
10884 }
10985
11086 // Reuse existing connection: send a new response.create.
11187 stream , err := sendOnExisting (p .conn .conn , params )
11288 if err != nil {
113- // Connection is broken; tear down and retry with a fresh one.
114- // We only attempt wsMaxReconnectAttempts reconnections to avoid
115- // unbounded loops if the server keeps rejecting connections.
89+ // Connection is broken; tear down and reconnect once.
11690 slog .Warn ("Existing WebSocket connection failed, reconnecting" , "error" , err )
117- prevResponseID := p .conn .lastResponseID
118- _ = p .conn .conn .Close ()
119- p .conn = nil
120-
121- var lastErr error
122- for attempt := range wsMaxReconnectAttempts {
123- headers , err2 := p .headerFn (ctx )
124- if err2 != nil {
125- lastErr = fmt .Errorf ("websocket pool: headers on reconnect (attempt %d/%d): %w" , attempt + 1 , wsMaxReconnectAttempts , err2 )
126- continue
127- }
128- stream , err2 = dialWebSocket (ctx , p .wsURL , headers , params )
129- if err2 != nil {
130- lastErr = fmt .Errorf ("websocket pool: reconnect (attempt %d/%d): %w" , attempt + 1 , wsMaxReconnectAttempts , err2 )
131- continue
132- }
133- p .conn = & wsConnection {
134- conn : stream .conn ,
135- createdAt : time .Now (),
136- lastResponseID : prevResponseID ,
137- }
138- return & pooledStream {pool : p , inner : stream }, nil
139- }
140- return nil , lastErr
91+ p .closeLocked ()
92+ return p .dialLocked (ctx , params )
93+ }
94+
95+ return & pooledStream {pool : p , inner : stream }, nil
96+ }
97+
98+ // dialLocked opens a fresh WebSocket connection and stores it in the pool.
99+ // Caller must hold p.mu.
100+ func (p * wsPool ) dialLocked (
101+ ctx context.Context ,
102+ params responses.ResponseNewParams ,
103+ ) (* pooledStream , error ) {
104+ headers , err := p .headerFn (ctx )
105+ if err != nil {
106+ return nil , fmt .Errorf ("websocket pool: headers: %w" , err )
107+ }
108+
109+ stream , err := dialWebSocket (ctx , p .wsURL , headers , params )
110+ if err != nil {
111+ return nil , err
112+ }
113+
114+ p .conn = & wsConnection {
115+ conn : stream .conn ,
116+ createdAt : time .Now (),
141117 }
142118
143119 return & pooledStream {pool : p , inner : stream }, nil
144120}
145121
122+ // closeLocked closes the current connection. lastResponseID is preserved
123+ // on the pool so it survives reconnections. Caller must hold p.mu.
124+ func (p * wsPool ) closeLocked () {
125+ if p .conn == nil {
126+ return
127+ }
128+ _ = p .conn .conn .Close ()
129+ p .conn = nil
130+ }
131+
132+ // invalidateConn tears down the pooled connection if it matches conn.
133+ // Called by pooledStream.Close when the stream encountered an error,
134+ // so the pool does not hand out a broken connection.
135+ func (p * wsPool ) invalidateConn (conn * websocket.Conn ) {
136+ p .mu .Lock ()
137+ defer p .mu .Unlock ()
138+
139+ if p .conn != nil && p .conn .conn == conn {
140+ p .closeLocked ()
141+ }
142+ }
143+
146144// Close shuts down the pooled connection.
147145func (p * wsPool ) Close () {
148146 p .mu .Lock ()
149147 defer p .mu .Unlock ()
150148
151- if p .conn != nil {
152- _ = p .conn .conn .Close ()
153- p .conn = nil
154- }
149+ p .closeLocked ()
155150}
156151
157152// sendOnExisting sends a response.create on an already-open connection and
158153// returns a wsStream that reads events from it.
159154func sendOnExisting (conn * websocket.Conn , params responses.ResponseNewParams ) (* wsStream , error ) {
160- paramsJSON , err := json .Marshal (params )
161- if err != nil {
162- return nil , fmt .Errorf ("websocket: marshal params: %w" , err )
163- }
164-
165- msg := wsCreateMessage {
166- Type : "response.create" ,
167- Params : paramsJSON ,
168- }
169-
170- if err := conn .WriteJSON (msg ); err != nil {
171- return nil , fmt .Errorf ("websocket: write response.create: %w" , err )
155+ if err := sendResponseCreate (conn , params ); err != nil {
156+ return nil , err
172157 }
173158
174159 slog .Debug ("WebSocket response.create sent (reused connection)" )
@@ -196,9 +181,7 @@ func (s *pooledStream) Next() bool {
196181 event := s .inner .Current ()
197182 if isTerminalEvent (event .Type ) && event .Response .ID != "" {
198183 s .pool .mu .Lock ()
199- if s .pool .conn != nil {
200- s .pool .conn .lastResponseID = event .Response .ID
201- }
184+ s .pool .lastResponseID = event .Response .ID
202185 s .pool .mu .Unlock ()
203186 }
204187
@@ -220,16 +203,8 @@ func (s *pooledStream) Close() error {
220203 s .inner .done = true
221204
222205 if s .inner .Err () != nil {
223- // Connection is likely broken; tear it down so the pool
224- // doesn't hand out a dead socket.
225- s .pool .mu .Lock ()
226- if s .pool .conn != nil && s .pool .conn .conn == s .inner .conn {
227- _ = s .pool .conn .conn .Close ()
228- s .pool .conn = nil
229- }
230- s .pool .mu .Unlock ()
206+ s .pool .invalidateConn (s .inner .conn )
231207 }
232208
233- // Do NOT close the WebSocket connection when healthy—it stays in the pool.
234209 return nil
235210}
0 commit comments