Skip to content

Commit 53a35dc

Browse files
authored
Merge pull request #2186 from dgageot/ws
Add WebSocket transport for OpenAI Responses API streaming
2 parents b729670 + 556f27e commit 53a35dc

10 files changed

Lines changed: 1315 additions & 16 deletions

File tree

agent-schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@
547547
},
548548
"provider_opts": {
549549
"type": "object",
550-
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
550+
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
551551
"additionalProperties": true
552552
},
553553
"track_usage": {

docs/providers/openai/index.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,33 @@ models:
7777
model: gpt-4o
7878
base_url: https://your-proxy.example.com/v1
7979
```
80+
81+
## WebSocket Transport
82+
83+
For OpenAI Responses API models (gpt-4.1+, o-series, gpt-5), you can use WebSocket streaming instead of the default SSE (Server-Sent Events):
84+
85+
```yaml
86+
models:
87+
fast-gpt:
88+
provider: openai
89+
model: gpt-4.1
90+
provider_opts:
91+
transport: websocket # Use WebSocket instead of SSE
92+
```
93+
94+
### Benefits
95+
96+
- **~40% faster** for workflows with 20+ tool calls
97+
- **Persistent connection** reduces per-turn overhead
98+
- **Server-side caching** of connection state
99+
- **Automatic fallback** to SSE if WebSocket fails
100+
101+
### Requirements
102+
103+
- Only works with Responses API models: `gpt-4.1+`, `o1`, `o3`, `o4`, `gpt-5`
104+
- NOT compatible with `--gateway` flag (automatically falls back to SSE)
105+
- Requires `OPENAI_API_KEY` environment variable
106+
107+
### Example
108+
109+
See [`examples/websocket_transport.yaml`]({{ '/examples/websocket_transport/' | relative_url }}) for a complete example.

examples/websocket_transport.yaml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env docker agent run
2+
3+
# Example: WebSocket Transport for OpenAI Responses API
4+
#
5+
# This example demonstrates how to use WebSocket streaming instead of
6+
# Server-Sent Events (SSE) for the OpenAI Responses API.
7+
#
8+
# WebSocket transport maintains a persistent connection across tool-call
9+
# rounds, reducing per-turn overhead and improving end-to-end latency
10+
# for agentic workflows with many tool calls.
11+
#
12+
# Benefits of WebSocket over SSE:
13+
# - ~40% faster end-to-end execution for workflows with 20+ tool calls
14+
# - Persistent connection reduces per-turn continuation overhead
15+
# - Connection-local state caching on the server
16+
# - Falls back to SSE automatically if WebSocket connection fails
17+
#
18+
# Requirements:
19+
# - Works only with OpenAI Responses API models (gpt-4.1+, o-series, gpt-5)
20+
# - Requires OPENAI_API_KEY environment variable (or use token_key)
21+
# - NOT compatible with --gateway flag (automatically falls back to SSE)
22+
#
23+
# Run with:
24+
# docker agent run websocket_transport.yaml
25+
26+
models:
27+
gpt-ws:
28+
provider: openai
29+
model: gpt-4.1
30+
provider_opts:
31+
transport: websocket # Use WebSocket instead of SSE
32+
33+
agents:
34+
root:
35+
model: gpt-ws
36+
description: Assistant using WebSocket streaming
37+
instruction: |
38+
You are a helpful assistant. Answer questions concisely.
39+
toolsets:
40+
- type: shell # Real toolset for demonstrating multi-turn tool calls
41+
commands:
42+
demo: "List the files in the current directory, then count how many are YAML files"

pkg/model/provider/openai/client.go

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"net/http"
1011
"net/url"
1112
"strings"
1213

@@ -29,12 +30,16 @@ import (
2930
"github.com/docker/docker-agent/pkg/tools"
3031
)
3132

32-
// Client represents an OpenAI client wrapper
33-
// It implements the provider.Provider interface
33+
// Client represents an OpenAI client wrapper.
34+
// It implements the provider.Provider interface.
3435
type Client struct {
3536
base.Config
3637

3738
clientFn func(context.Context) (*openai.Client, error)
39+
40+
// wsPool is initialized in NewClient when transport=websocket is configured.
41+
// It maintains a persistent WebSocket connection across requests.
42+
wsPool *wsPool
3843
}
3944

4045
// NewClient creates a new OpenAI client from the provided configuration
@@ -140,14 +145,32 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
140145

141146
slog.Debug("OpenAI client created successfully", "model", cfg.Model)
142147

143-
return &Client{
148+
client := &Client{
144149
Config: base.Config{
145150
ModelConfig: *cfg,
146151
ModelOptions: globalOptions,
147152
Env: env,
148153
},
149154
clientFn: clientFn,
150-
}, nil
155+
}
156+
157+
// Pre-create the WebSocket pool when the transport is configured.
158+
// The pool is cheap (no connections opened until the first Stream call)
159+
// and eager init avoids a data race on the lazy path.
160+
if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" {
161+
baseURL := cmp.Or(cfg.BaseURL, "https://api.openai.com/v1")
162+
client.wsPool = newWSPool(httpToWSURL(baseURL), client.buildWSHeaderFn())
163+
}
164+
165+
return client, nil
166+
}
167+
168+
// Close releases resources held by the client, including any pooled WebSocket
169+
// connections. It is safe to call Close multiple times.
170+
func (c *Client) Close() {
171+
if c.wsPool != nil {
172+
c.wsPool.Close()
173+
}
151174
}
152175

153176
// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion
@@ -307,12 +330,6 @@ func (c *Client) CreateResponseStream(
307330
return nil, errors.New("at least one message is required")
308331
}
309332

310-
client, err := c.clientFn(ctx)
311-
if err != nil {
312-
slog.Error("Failed to create OpenAI client", "error", err)
313-
return nil, err
314-
}
315-
316333
input := convertMessagesToResponseInput(messages)
317334

318335
params := responses.ResponseNewParams{
@@ -398,10 +415,85 @@ func (c *Client) CreateResponseStream(
398415
slog.Error("Failed to marshal OpenAI responses request to JSON", "error", err)
399416
}
400417

418+
// Choose transport: WebSocket or SSE (default).
419+
// WebSocket is disabled when using a Gateway since most gateways don't support it.
420+
transport := getTransport(&c.ModelConfig)
421+
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
422+
423+
if transport == "websocket" && c.ModelOptions.Gateway() == "" {
424+
stream, err := c.createWebSocketStream(ctx, params)
425+
if err != nil {
426+
slog.Warn("WebSocket stream failed, falling back to SSE", "error", err)
427+
// Fall through to SSE below.
428+
} else {
429+
slog.Debug("OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model)
430+
return newResponseStreamAdapter(stream, trackUsage), nil
431+
}
432+
} else if transport == "websocket" {
433+
slog.Debug("WebSocket transport requested but Gateway is configured, using SSE",
434+
"model", c.ModelConfig.Model,
435+
"gateway", c.ModelOptions.Gateway())
436+
}
437+
438+
client, err := c.clientFn(ctx)
439+
if err != nil {
440+
slog.Error("Failed to create OpenAI client", "error", err)
441+
return nil, err
442+
}
401443
stream := client.Responses.NewStreaming(ctx, params)
402444

403445
slog.Debug("OpenAI responses stream created successfully", "model", c.ModelConfig.Model)
404-
return newResponseStreamAdapter(stream, c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage), nil
446+
return newResponseStreamAdapter(stream, trackUsage), nil
447+
}
448+
449+
// createWebSocketStream sends a request over the pre-initialized WebSocket
450+
// pool, returning a responseEventStream.
451+
func (c *Client) createWebSocketStream(
452+
ctx context.Context,
453+
params responses.ResponseNewParams,
454+
) (responseEventStream, error) {
455+
if c.wsPool == nil {
456+
return nil, errors.New("websocket pool not initialized")
457+
}
458+
459+
return c.wsPool.Stream(ctx, params)
460+
}
461+
462+
// buildWSHeaderFn returns a function that produces the HTTP headers needed
463+
// for the WebSocket handshake, including the Authorization header.
464+
func (c *Client) buildWSHeaderFn() func(ctx context.Context) (http.Header, error) {
465+
return func(ctx context.Context) (http.Header, error) {
466+
h := http.Header{}
467+
468+
// Resolve the API key using the same logic as the HTTP client.
469+
var apiKey string
470+
if c.ModelConfig.TokenKey != "" {
471+
apiKey, _ = c.Env.Get(ctx, c.ModelConfig.TokenKey)
472+
}
473+
if apiKey == "" {
474+
// Fall back to the standard OPENAI_API_KEY env var via the
475+
// environment provider so that secret resolution is
476+
// consistent with the HTTP client path.
477+
apiKey, _ = c.Env.Get(ctx, "OPENAI_API_KEY")
478+
}
479+
if apiKey != "" {
480+
h.Set("Authorization", "Bearer "+apiKey)
481+
}
482+
483+
return h, nil
484+
}
485+
}
486+
487+
// getTransport returns the streaming transport preference from ProviderOpts.
488+
// Valid values are "sse" (default) and "websocket".
489+
func getTransport(cfg *latest.ModelConfig) string {
490+
if cfg == nil || cfg.ProviderOpts == nil {
491+
return "sse"
492+
}
493+
if t, ok := cfg.ProviderOpts["transport"].(string); ok {
494+
return strings.ToLower(t)
495+
}
496+
return "sse"
405497
}
406498

407499
func convertMessagesToResponseInput(messages []chat.Message) []responses.ResponseInputItemUnionParam {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package openai
2+
3+
import "github.com/openai/openai-go/v3/responses"
4+
5+
// responseEventStream abstracts over SSE and WebSocket transports for
6+
// streaming Responses API events.
7+
//
8+
// The ssestream.Stream[responses.ResponseStreamEventUnion] type already
9+
// satisfies this interface, so it can be used directly.
10+
type responseEventStream interface {
11+
// Next advances the stream to the next event.
12+
// Returns false when the stream is exhausted or an error occurred.
13+
Next() bool
14+
15+
// Current returns the most recently decoded event.
16+
Current() responses.ResponseStreamEventUnion
17+
18+
// Err returns the first non-EOF error encountered by the stream.
19+
Err() error
20+
21+
// Close releases resources held by the stream.
22+
Close() error
23+
}

pkg/model/provider/openai/response_stream.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@ import (
1313
"github.com/docker/docker-agent/pkg/tools"
1414
)
1515

16-
// ResponseStreamAdapter adapts the OpenAI responses stream to our interface
16+
// Compile-time check: ssestream.Stream satisfies responseEventStream.
17+
var _ responseEventStream = (*ssestream.Stream[responses.ResponseStreamEventUnion])(nil)
18+
19+
// ResponseStreamAdapter adapts the OpenAI responses stream to our interface.
20+
// It works with any responseEventStream implementation (SSE or WebSocket).
1721
type ResponseStreamAdapter struct {
18-
stream *ssestream.Stream[responses.ResponseStreamEventUnion]
22+
stream responseEventStream
1923
trackUsage bool
2024
itemCallIDMap map[string]string
2125
itemHasContent map[string]bool
2226
}
2327

24-
func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamEventUnion], trackUsage bool) *ResponseStreamAdapter {
28+
func newResponseStreamAdapter(stream responseEventStream, trackUsage bool) *ResponseStreamAdapter {
2529
return &ResponseStreamAdapter{
2630
stream: stream,
2731
trackUsage: trackUsage,
@@ -254,5 +258,5 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
254258

255259
// Close closes the stream
256260
func (a *ResponseStreamAdapter) Close() {
257-
a.stream.Close()
261+
_ = a.stream.Close()
258262
}

0 commit comments

Comments
 (0)