Skip to content

Commit b22ccaf

Browse files
committed
Add WebSocket transport for OpenAI Responses API streaming
Introduce an optional WebSocket transport as an alternative to SSE for the OpenAI Responses API. Users can enable it via provider_opts: provider_opts: transport: websocket Key changes: - Add responseEventStream interface to abstract SSE and WebSocket transports - Refactor ResponseStreamAdapter to accept any responseEventStream - Implement wsStream (WebSocket transport) and wsPool (connection pool with 55-min TTL, auto-reconnect, and lastResponseID tracking) - Integrate WebSocket path in CreateResponseStream with automatic SSE fallback on connection failure - No new dependencies (reuses existing gorilla/websocket) The existing ResponseStreamAdapter.Recv() logic is fully reused since WebSocket events use the same JSON schema as SSE events. Assisted-By: docker-agent
1 parent b729670 commit b22ccaf

9 files changed

Lines changed: 1040 additions & 14 deletions

File tree

agent-schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@
547547
},
548548
"provider_opts": {
549549
"type": "object",
550-
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
550+
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
551551
"additionalProperties": true
552552
},
553553
"track_usage": {

docs/providers/openai/index.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,33 @@ models:
7777
model: gpt-4o
7878
base_url: https://your-proxy.example.com/v1
7979
```
80+
81+
## WebSocket Transport
82+
83+
For OpenAI Responses API models (gpt-4.1+, o-series, gpt-5), you can use WebSocket streaming instead of the default SSE (Server-Sent Events):
84+
85+
```yaml
86+
models:
87+
fast-gpt:
88+
provider: openai
89+
model: gpt-4.1
90+
provider_opts:
91+
transport: websocket # Use WebSocket instead of SSE
92+
```
93+
94+
### Benefits
95+
96+
- **~40% faster** for workflows with 20+ tool calls
97+
- **Persistent connection** reduces per-turn overhead
98+
- **Server-side caching** of connection state
99+
- **Automatic fallback** to SSE if WebSocket fails
100+
101+
### Requirements
102+
103+
- Only works with Responses API models: `gpt-4.1+`, `o1`, `o3`, `o4`, `gpt-5`
104+
- NOT compatible with `--gateway` flag (automatically falls back to SSE)
105+
- Requires `OPENAI_API_KEY` environment variable
106+
107+
### Example
108+
109+
See [`examples/websocket_transport.yaml`]({{ '/examples/websocket_transport/' | relative_url }}) for a complete example.

examples/websocket_transport.yaml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env docker agent run
2+
3+
# Example: WebSocket Transport for OpenAI Responses API
4+
#
5+
# This example demonstrates how to use WebSocket streaming instead of
6+
# Server-Sent Events (SSE) for the OpenAI Responses API.
7+
#
8+
# WebSocket transport maintains a persistent connection across tool-call
9+
# rounds, reducing per-turn overhead and improving end-to-end latency
10+
# for agentic workflows with many tool calls.
11+
#
12+
# Benefits of WebSocket over SSE:
13+
# - ~40% faster end-to-end execution for workflows with 20+ tool calls
14+
# - Persistent connection reduces per-turn continuation overhead
15+
# - Connection-local state caching on the server
16+
# - Falls back to SSE automatically if WebSocket connection fails
17+
#
18+
# Requirements:
19+
# - Works only with OpenAI Responses API models (gpt-4.1+, o-series, gpt-5)
20+
# - Requires OPENAI_API_KEY environment variable (or use token_key)
21+
# - NOT compatible with --gateway flag (automatically falls back to SSE)
22+
#
23+
# Run with:
24+
# docker agent run websocket_transport.yaml
25+
26+
models:
27+
gpt-ws:
28+
provider: openai
29+
model: gpt-4.1
30+
provider_opts:
31+
transport: websocket # Use WebSocket instead of SSE
32+
33+
agents:
34+
root:
35+
model: gpt-ws
36+
description: Assistant using WebSocket streaming
37+
instruction: |
38+
You are a helpful assistant. Answer questions concisely.
39+
toolsets:
40+
- type: shell # Real toolset for demonstrating multi-turn tool calls
41+
commands:
42+
demo: "List the files in the current directory, then count how many are YAML files"

pkg/model/provider/openai/client.go

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"net/http"
1011
"net/url"
12+
"os"
1113
"strings"
1214

1315
"github.com/openai/openai-go/v3"
@@ -29,12 +31,16 @@ import (
2931
"github.com/docker/docker-agent/pkg/tools"
3032
)
3133

32-
// Client represents an OpenAI client wrapper
33-
// It implements the provider.Provider interface
34+
// Client represents an OpenAI client wrapper.
35+
// It implements the provider.Provider interface.
3436
type Client struct {
3537
base.Config
3638

3739
clientFn func(context.Context) (*openai.Client, error)
40+
41+
// wsPool is lazily initialized when transport=websocket is configured.
42+
// It maintains a persistent WebSocket connection across requests.
43+
wsPool *wsPool
3844
}
3945

4046
// NewClient creates a new OpenAI client from the provided configuration
@@ -307,12 +313,6 @@ func (c *Client) CreateResponseStream(
307313
return nil, errors.New("at least one message is required")
308314
}
309315

310-
client, err := c.clientFn(ctx)
311-
if err != nil {
312-
slog.Error("Failed to create OpenAI client", "error", err)
313-
return nil, err
314-
}
315-
316316
input := convertMessagesToResponseInput(messages)
317317

318318
params := responses.ResponseNewParams{
@@ -398,10 +398,88 @@ func (c *Client) CreateResponseStream(
398398
slog.Error("Failed to marshal OpenAI responses request to JSON", "error", err)
399399
}
400400

401+
// Choose transport: WebSocket or SSE (default).
402+
// WebSocket is disabled when using a Gateway since most gateways don't support it.
403+
transport := getTransport(&c.ModelConfig)
404+
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
405+
406+
if transport == "websocket" && c.ModelOptions.Gateway() == "" {
407+
stream, err := c.createWebSocketStream(ctx, params)
408+
if err != nil {
409+
slog.Error("WebSocket stream failed, falling back to SSE", "error", err)
410+
// Fall through to SSE below.
411+
} else {
412+
slog.Debug("OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model)
413+
return newResponseStreamAdapter(stream, trackUsage), nil
414+
}
415+
} else if transport == "websocket" {
416+
slog.Debug("WebSocket transport requested but Gateway is configured, using SSE",
417+
"model", c.ModelConfig.Model,
418+
"gateway", c.ModelOptions.Gateway())
419+
}
420+
421+
client, err := c.clientFn(ctx)
422+
if err != nil {
423+
slog.Error("Failed to create OpenAI client", "error", err)
424+
return nil, err
425+
}
401426
stream := client.Responses.NewStreaming(ctx, params)
402427

403428
slog.Debug("OpenAI responses stream created successfully", "model", c.ModelConfig.Model)
404-
return newResponseStreamAdapter(stream, c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage), nil
429+
return newResponseStreamAdapter(stream, trackUsage), nil
430+
}
431+
432+
// createWebSocketStream initializes (or reuses) a WebSocket connection and
433+
// sends the response.create message, returning a responseEventStream.
434+
func (c *Client) createWebSocketStream(
435+
ctx context.Context,
436+
params responses.ResponseNewParams,
437+
) (responseEventStream, error) {
438+
if c.wsPool == nil {
439+
// Lazy-init the pool on first WebSocket call.
440+
baseURL := cmp.Or(c.ModelConfig.BaseURL, "https://api.openai.com/v1")
441+
wsURL := httpToWSURL(baseURL)
442+
443+
headerFn := c.buildWSHeaderFn()
444+
c.wsPool = newWSPool(wsURL, headerFn)
445+
}
446+
447+
return c.wsPool.Stream(ctx, params)
448+
}
449+
450+
// buildWSHeaderFn returns a function that produces the HTTP headers needed
451+
// for the WebSocket handshake, including the Authorization header.
452+
func (c *Client) buildWSHeaderFn() func(ctx context.Context) (http.Header, error) {
453+
return func(ctx context.Context) (http.Header, error) {
454+
h := http.Header{}
455+
456+
// Resolve the API key using the same logic as the HTTP client.
457+
var apiKey string
458+
if c.ModelConfig.TokenKey != "" {
459+
apiKey, _ = c.Env.Get(ctx, c.ModelConfig.TokenKey)
460+
}
461+
if apiKey == "" {
462+
// Fall back to the standard OPENAI_API_KEY env var.
463+
apiKey = os.Getenv("OPENAI_API_KEY")
464+
}
465+
if apiKey != "" {
466+
h.Set("Authorization", "Bearer "+apiKey)
467+
}
468+
469+
return h, nil
470+
}
471+
}
472+
473+
// getTransport returns the streaming transport preference from ProviderOpts.
474+
// Valid values are "sse" (default) and "websocket".
475+
func getTransport(cfg *latest.ModelConfig) string {
476+
if cfg == nil || cfg.ProviderOpts == nil {
477+
return "sse"
478+
}
479+
if t, ok := cfg.ProviderOpts["transport"].(string); ok {
480+
return strings.ToLower(t)
481+
}
482+
return "sse"
405483
}
406484

407485
func convertMessagesToResponseInput(messages []chat.Message) []responses.ResponseInputItemUnionParam {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package openai
2+
3+
import "github.com/openai/openai-go/v3/responses"
4+
5+
// responseEventStream abstracts over SSE and WebSocket transports for
6+
// streaming Responses API events.
7+
//
8+
// The ssestream.Stream[responses.ResponseStreamEventUnion] type already
9+
// satisfies this interface, so it can be used directly.
10+
type responseEventStream interface {
11+
// Next advances the stream to the next event.
12+
// Returns false when the stream is exhausted or an error occurred.
13+
Next() bool
14+
15+
// Current returns the most recently decoded event.
16+
Current() responses.ResponseStreamEventUnion
17+
18+
// Err returns the first non-EOF error encountered by the stream.
19+
Err() error
20+
21+
// Close releases resources held by the stream.
22+
Close() error
23+
}

pkg/model/provider/openai/response_stream.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@ import (
1313
"github.com/docker/docker-agent/pkg/tools"
1414
)
1515

16-
// ResponseStreamAdapter adapts the OpenAI responses stream to our interface
16+
// Compile-time check: ssestream.Stream satisfies responseEventStream.
17+
var _ responseEventStream = (*ssestream.Stream[responses.ResponseStreamEventUnion])(nil)
18+
19+
// ResponseStreamAdapter adapts the OpenAI responses stream to our interface.
20+
// It works with any responseEventStream implementation (SSE or WebSocket).
1721
type ResponseStreamAdapter struct {
18-
stream *ssestream.Stream[responses.ResponseStreamEventUnion]
22+
stream responseEventStream
1923
trackUsage bool
2024
itemCallIDMap map[string]string
2125
itemHasContent map[string]bool
2226
}
2327

24-
func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamEventUnion], trackUsage bool) *ResponseStreamAdapter {
28+
func newResponseStreamAdapter(stream responseEventStream, trackUsage bool) *ResponseStreamAdapter {
2529
return &ResponseStreamAdapter{
2630
stream: stream,
2731
trackUsage: trackUsage,
@@ -254,5 +258,5 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
254258

255259
// Close closes the stream
256260
func (a *ResponseStreamAdapter) Close() {
257-
a.stream.Close()
261+
_ = a.stream.Close()
258262
}

0 commit comments

Comments
 (0)