77 "errors"
88 "fmt"
99 "log/slog"
10+ "net/http"
1011 "net/url"
1112 "strings"
1213
@@ -29,12 +30,16 @@ import (
2930 "github.com/docker/docker-agent/pkg/tools"
3031)
3132
32- // Client represents an OpenAI client wrapper
33- // It implements the provider.Provider interface
33+ // Client represents an OpenAI client wrapper.
34+ // It implements the provider.Provider interface.
3435type Client struct {
3536 base.Config
3637
3738 clientFn func (context.Context ) (* openai.Client , error )
39+
40+ // wsPool is initialized in NewClient when transport=websocket is configured.
41+ // It maintains a persistent WebSocket connection across requests.
42+ wsPool * wsPool
3843}
3944
4045// NewClient creates a new OpenAI client from the provided configuration
@@ -140,14 +145,32 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
140145
141146 slog .Debug ("OpenAI client created successfully" , "model" , cfg .Model )
142147
143- return & Client {
148+ client := & Client {
144149 Config : base.Config {
145150 ModelConfig : * cfg ,
146151 ModelOptions : globalOptions ,
147152 Env : env ,
148153 },
149154 clientFn : clientFn ,
150- }, nil
155+ }
156+
157+ // Pre-create the WebSocket pool when the transport is configured.
158+ // The pool is cheap (no connections opened until the first Stream call)
159+ // and eager init avoids a data race on the lazy path.
160+ if getTransport (cfg ) == "websocket" && globalOptions .Gateway () == "" {
161+ baseURL := cmp .Or (cfg .BaseURL , "https://api.openai.com/v1" )
162+ client .wsPool = newWSPool (httpToWSURL (baseURL ), client .buildWSHeaderFn ())
163+ }
164+
165+ return client , nil
166+ }
167+
168+ // Close releases resources held by the client, including any pooled WebSocket
169+ // connections. It is safe to call Close multiple times.
170+ func (c * Client ) Close () {
171+ if c .wsPool != nil {
172+ c .wsPool .Close ()
173+ }
151174}
152175
153176// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion
@@ -307,12 +330,6 @@ func (c *Client) CreateResponseStream(
307330 return nil , errors .New ("at least one message is required" )
308331 }
309332
310- client , err := c .clientFn (ctx )
311- if err != nil {
312- slog .Error ("Failed to create OpenAI client" , "error" , err )
313- return nil , err
314- }
315-
316333 input := convertMessagesToResponseInput (messages )
317334
318335 params := responses.ResponseNewParams {
@@ -398,10 +415,85 @@ func (c *Client) CreateResponseStream(
398415 slog .Error ("Failed to marshal OpenAI responses request to JSON" , "error" , err )
399416 }
400417
418+ // Choose transport: WebSocket or SSE (default).
419+ // WebSocket is disabled when using a Gateway since most gateways don't support it.
420+ transport := getTransport (& c .ModelConfig )
421+ trackUsage := c .ModelConfig .TrackUsage == nil || * c .ModelConfig .TrackUsage
422+
423+ if transport == "websocket" && c .ModelOptions .Gateway () == "" {
424+ stream , err := c .createWebSocketStream (ctx , params )
425+ if err != nil {
426+ slog .Warn ("WebSocket stream failed, falling back to SSE" , "error" , err )
427+ // Fall through to SSE below.
428+ } else {
429+ slog .Debug ("OpenAI responses WebSocket stream created successfully" , "model" , c .ModelConfig .Model )
430+ return newResponseStreamAdapter (stream , trackUsage ), nil
431+ }
432+ } else if transport == "websocket" {
433+ slog .Debug ("WebSocket transport requested but Gateway is configured, using SSE" ,
434+ "model" , c .ModelConfig .Model ,
435+ "gateway" , c .ModelOptions .Gateway ())
436+ }
437+
438+ client , err := c .clientFn (ctx )
439+ if err != nil {
440+ slog .Error ("Failed to create OpenAI client" , "error" , err )
441+ return nil , err
442+ }
401443 stream := client .Responses .NewStreaming (ctx , params )
402444
403445 slog .Debug ("OpenAI responses stream created successfully" , "model" , c .ModelConfig .Model )
404- return newResponseStreamAdapter (stream , c .ModelConfig .TrackUsage == nil || * c .ModelConfig .TrackUsage ), nil
446+ return newResponseStreamAdapter (stream , trackUsage ), nil
447+ }
448+
449+ // createWebSocketStream sends a request over the pre-initialized WebSocket
450+ // pool, returning a responseEventStream.
451+ func (c * Client ) createWebSocketStream (
452+ ctx context.Context ,
453+ params responses.ResponseNewParams ,
454+ ) (responseEventStream , error ) {
455+ if c .wsPool == nil {
456+ return nil , errors .New ("websocket pool not initialized" )
457+ }
458+
459+ return c .wsPool .Stream (ctx , params )
460+ }
461+
462+ // buildWSHeaderFn returns a function that produces the HTTP headers needed
463+ // for the WebSocket handshake, including the Authorization header.
464+ func (c * Client ) buildWSHeaderFn () func (ctx context.Context ) (http.Header , error ) {
465+ return func (ctx context.Context ) (http.Header , error ) {
466+ h := http.Header {}
467+
468+ // Resolve the API key using the same logic as the HTTP client.
469+ var apiKey string
470+ if c .ModelConfig .TokenKey != "" {
471+ apiKey , _ = c .Env .Get (ctx , c .ModelConfig .TokenKey )
472+ }
473+ if apiKey == "" {
474+ // Fall back to the standard OPENAI_API_KEY env var via the
475+ // environment provider so that secret resolution is
476+ // consistent with the HTTP client path.
477+ apiKey , _ = c .Env .Get (ctx , "OPENAI_API_KEY" )
478+ }
479+ if apiKey != "" {
480+ h .Set ("Authorization" , "Bearer " + apiKey )
481+ }
482+
483+ return h , nil
484+ }
485+ }
486+
487+ // getTransport returns the streaming transport preference from ProviderOpts.
488+ // Valid values are "sse" (default) and "websocket".
489+ func getTransport (cfg * latest.ModelConfig ) string {
490+ if cfg == nil || cfg .ProviderOpts == nil {
491+ return "sse"
492+ }
493+ if t , ok := cfg .ProviderOpts ["transport" ].(string ); ok {
494+ return strings .ToLower (t )
495+ }
496+ return "sse"
405497}
406498
407499func convertMessagesToResponseInput (messages []chat.Message ) []responses.ResponseInputItemUnionParam {
0 commit comments