Skip to content

Commit dc6878c

Browse files
committed
feat: forward sampling provider_opts (top_k, repetition_penalty, etc.) to provider APIs
Add support for passing sampling parameters via provider_opts to all provider backends. This enables custom OpenAI-compatible providers (vLLM, Ollama, llama.cpp) to receive parameters like top_k, repetition_penalty, min_p, and seed that they support but the native OpenAI API does not. Provider support: - OpenAI/custom: top_k, repetition_penalty, min_p, typical_p via SetExtraFields; seed via native field - Anthropic: top_k via native TopK field - Gemini: top_k via native TopK field - Bedrock: top_k via AdditionalModelRequestFields Also refactors Bedrock buildAdditionalModelRequestFields to avoid early returns that would discard top_k when thinking budget is invalid, and extracts isThinkingEnabled() to decouple thinking detection from additional fields presence. Assisted-By: docker-agent
1 parent 88ef828 commit dc6878c

11 files changed

Lines changed: 370 additions & 36 deletions

File tree

agent-schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@
552552
},
553553
"provider_opts": {
554554
"type": "object",
555-
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
555+
"description": "Provider-specific options. Sampling parameters: top_k (integer, supported by anthropic, google, amazon-bedrock, and custom OpenAI-compatible providers like vLLM/Ollama), repetition_penalty (float, forwarded to custom OpenAI-compatible providers), min_p (float, forwarded to custom providers), seed (integer, forwarded to OpenAI). Infrastructure options: dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
556556
"additionalProperties": true
557557
},
558558
"track_usage": {

examples/sampling-opts.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env docker agent run
2+
3+
# This example shows how to use provider_opts to pass sampling parameters
4+
# like top_k and repetition_penalty to different providers.
5+
6+
agents:
7+
root:
8+
model: gpt
9+
description: "Assistant with custom sampling parameters"
10+
instruction: |
11+
You are a helpful assistant running on a local model with tuned sampling parameters.
12+
13+
models:
14+
gpt:
15+
provider: openai
16+
model: gpt-4o
17+
temperature: 0.7
18+
top_p: 0.9
19+
provider_opts:
20+
top_k: 40
21+
repetition_penalty: 1.15
22+
min_p: 0.05

pkg/model/provider/anthropic/beta_client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
1414

1515
"github.com/docker/docker-agent/pkg/chat"
16+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
1617
"github.com/docker/docker-agent/pkg/rag/prompts"
1718
"github.com/docker/docker-agent/pkg/rag/types"
1819
"github.com/docker/docker-agent/pkg/tools"
@@ -115,6 +116,12 @@ func (c *Client) createBetaStream(
115116
"max_tokens", maxTokens,
116117
"message_count", len(params.Messages))
117118

119+
// Forward top_k from provider_opts (Anthropic natively supports it)
120+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
121+
params.TopK = param.NewOpt(topK)
122+
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
123+
}
124+
118125
stream := client.Beta.Messages.NewStreaming(ctx, params)
119126
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
120127
ad := c.newBetaStreamAdapter(stream, trackUsage)
@@ -293,6 +300,12 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
293300
params.TopP = param.NewOpt(*c.ModelConfig.TopP)
294301
}
295302

303+
// Forward top_k from provider_opts (Anthropic natively supports it)
304+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
305+
params.TopK = param.NewOpt(topK)
306+
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
307+
}
308+
296309
// Use streaming API to avoid timeout errors for operations that may take longer than 10 minutes
297310
stream := client.Beta.Messages.NewStreaming(ctx, params)
298311

pkg/model/provider/anthropic/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/docker-agent/pkg/httpclient"
2525
"github.com/docker/docker-agent/pkg/model/provider/base"
2626
"github.com/docker/docker-agent/pkg/model/provider/options"
27+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2728
"github.com/docker/docker-agent/pkg/tools"
2829
)
2930

@@ -337,6 +338,12 @@ func (c *Client) CreateChatCompletionStream(
337338
slog.Debug("Anthropic extended thinking enabled, ignoring temperature/top_p settings")
338339
}
339340

341+
// Forward top_k from provider_opts (Anthropic natively supports it)
342+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
343+
params.TopK = param.NewOpt(topK)
344+
slog.Debug("Anthropic provider_opts: set top_k", "value", topK)
345+
}
346+
340347
if len(requestTools) > 0 {
341348
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))
342349
}

pkg/model/provider/bedrock/client.go

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/docker/docker-agent/pkg/environment"
2121
"github.com/docker/docker-agent/pkg/model/provider/base"
2222
"github.com/docker/docker-agent/pkg/model/provider/options"
23+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2324
"github.com/docker/docker-agent/pkg/modelsdev"
2425
"github.com/docker/docker-agent/pkg/tools"
2526
)
@@ -244,7 +245,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
244245
}
245246

246247
// Set inference configuration (temp/topP are suppressed when thinking is on).
247-
input.InferenceConfig = c.buildInferenceConfig(additionalFields != nil)
248+
input.InferenceConfig = c.buildInferenceConfig(c.isThinkingEnabled())
248249

249250
// Convert and set tools
250251
if len(requestTools) > 0 {
@@ -281,56 +282,80 @@ func (c *Client) interleavedThinkingEnabled() bool {
281282
return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking")
282283
}
283284

284-
func (c *Client) promptCachingEnabled() bool {
285-
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
286-
return false
287-
}
288-
return c.cachingSupported
289-
}
290-
291-
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
292-
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
285+
// isThinkingEnabled returns true if a valid thinking budget is configured.
286+
// It mirrors the validation in buildAdditionalModelRequestFields but without
287+
// side effects (no logging), so it can safely be used to gate inference config.
288+
func (c *Client) isThinkingEnabled() bool {
293289
if c.ModelConfig.ThinkingBudget == nil {
294-
return nil
290+
return false
295291
}
296292
tokens := c.ModelConfig.ThinkingBudget.Tokens
297293
if t, ok := c.ModelConfig.ThinkingBudget.EffortTokens(); ok {
298294
tokens = t
299295
}
300-
if tokens <= 0 {
301-
return nil
302-
}
303-
304-
// Validate minimum (Claude requires at least 1024 tokens for thinking)
305296
if tokens < 1024 {
306-
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring",
307-
"tokens", tokens)
308-
return nil
297+
return false
309298
}
310-
311-
// Validate against max_tokens
312299
if c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
313-
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
314-
"thinking_budget", tokens,
315-
"max_tokens", *c.ModelConfig.MaxTokens)
316-
return nil
300+
return false
317301
}
302+
return true
303+
}
318304

319-
slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
305+
func (c *Client) promptCachingEnabled() bool {
306+
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
307+
return false
308+
}
309+
return c.cachingSupported
310+
}
320311

321-
fields := map[string]any{
322-
"thinking": map[string]any{
323-
"type": "enabled",
324-
"budget_tokens": tokens,
325-
},
312+
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode
313+
// and forwards supported sampling parameters from provider_opts (e.g. top_k).
314+
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
315+
fields := map[string]any{}
316+
317+
// Forward top_k from provider_opts (Anthropic on Bedrock supports it)
318+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
319+
fields["top_k"] = topK
320+
slog.Debug("Bedrock provider_opts: set top_k", "value", topK)
326321
}
327322

328-
// Add anthropic_beta field for interleaved thinking
329-
if c.interleavedThinkingEnabled() {
330-
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
331-
slog.Debug("Bedrock request using interleaved thinking beta")
323+
// Configure thinking budget if present and valid
324+
if budget := c.ModelConfig.ThinkingBudget; budget != nil {
325+
tokens := budget.Tokens
326+
if t, ok := budget.EffortTokens(); ok {
327+
tokens = t
328+
}
329+
330+
valid := tokens > 0
331+
if valid && tokens < 1024 {
332+
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring", "tokens", tokens)
333+
valid = false
334+
}
335+
if valid && c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
336+
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
337+
"thinking_budget", tokens,
338+
"max_tokens", *c.ModelConfig.MaxTokens)
339+
valid = false
340+
}
341+
342+
if valid {
343+
slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
344+
fields["thinking"] = map[string]any{
345+
"type": "enabled",
346+
"budget_tokens": tokens,
347+
}
348+
349+
if c.interleavedThinkingEnabled() {
350+
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
351+
slog.Debug("Bedrock request using interleaved thinking beta")
352+
}
353+
}
332354
}
333355

356+
if len(fields) == 0 {
357+
return nil
358+
}
334359
return document.NewLazyDocument(fields)
335360
}
336361

pkg/model/provider/gemini/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/docker/docker-agent/pkg/httpclient"
2222
"github.com/docker/docker-agent/pkg/model/provider/base"
2323
"github.com/docker/docker-agent/pkg/model/provider/options"
24+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2425
"github.com/docker/docker-agent/pkg/rag/prompts"
2526
"github.com/docker/docker-agent/pkg/rag/types"
2627
"github.com/docker/docker-agent/pkg/tools"
@@ -352,6 +353,12 @@ func (c *Client) buildConfig() *genai.GenerateContentConfig {
352353
config.PresencePenalty = new(float32(*c.ModelConfig.PresencePenalty))
353354
}
354355

356+
// Forward top_k from provider_opts (Gemini natively supports it)
357+
if topK, ok := providerutil.GetProviderOptFloat64(c.ModelConfig.ProviderOpts, "top_k"); ok {
358+
config.TopK = new(float32(topK))
359+
slog.Debug("Gemini provider_opts: set top_k", "value", topK)
360+
}
361+
355362
// Apply thinking configuration for Gemini models.
356363
// See https://ai.google.dev/gemini-api/docs/thinking
357364
if c.ModelOptions.NoThinking() {

pkg/model/provider/openai/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ func (c *Client) CreateChatCompletionStream(
312312
return nil, err
313313
}
314314

315+
// Forward sampling-related provider_opts as extra body fields.
316+
// This allows custom/OpenAI-compatible providers (vLLM, Ollama, etc.)
317+
// to receive parameters like top_k, repetition_penalty, etc.
318+
applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)
319+
315320
stream := client.Chat.Completions.NewStreaming(ctx, params)
316321

317322
slog.Debug("OpenAI chat completion stream created successfully", "model", c.ModelConfig.Model)
@@ -842,6 +847,8 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
842847
},
843848
}
844849

850+
applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)
851+
845852
resp, err := client.Chat.Completions.New(ctx, params)
846853
if err != nil {
847854
slog.Error("OpenAI rerank request failed", "error", err)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package openai
2+
3+
import (
4+
"log/slog"
5+
6+
oai "github.com/openai/openai-go/v3"
7+
8+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
9+
)
10+
11+
// applySamplingProviderOpts forwards sampling-related provider_opts as extra
12+
// body fields on the OpenAI ChatCompletionNewParams. This enables custom
13+
// OpenAI-compatible providers (vLLM, Ollama, llama.cpp, etc.) to receive
14+
// parameters like top_k, repetition_penalty, min_p, etc. that the native
15+
// OpenAI API does not support but these backends do.
16+
func applySamplingProviderOpts(params *oai.ChatCompletionNewParams, opts map[string]any) {
17+
if len(opts) == 0 {
18+
return
19+
}
20+
21+
extras := make(map[string]any)
22+
23+
for _, key := range providerutil.SamplingProviderOptsKeys() {
24+
if key == "seed" {
25+
// seed is a native ChatCompletionNewParams field (int64),
26+
// so set it directly rather than as an extra field.
27+
if v, ok := providerutil.GetProviderOptInt64(opts, key); ok {
28+
params.Seed = oai.Int(v)
29+
slog.Debug("OpenAI provider_opts: set seed", "value", v)
30+
}
31+
continue
32+
}
33+
34+
if v, ok := providerutil.GetProviderOptFloat64(opts, key); ok {
35+
extras[key] = v
36+
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", v)
37+
} else if vi, ok := providerutil.GetProviderOptInt64(opts, key); ok {
38+
extras[key] = vi
39+
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", vi)
40+
}
41+
}
42+
43+
if len(extras) > 0 {
44+
params.SetExtraFields(extras)
45+
}
46+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package openai
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
oai "github.com/openai/openai-go/v3"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestApplySamplingProviderOpts(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
opts map[string]any
16+
wantKeys []string // keys expected in JSON output
17+
}{
18+
{
19+
name: "nil opts",
20+
opts: nil,
21+
},
22+
{
23+
name: "empty opts",
24+
opts: map[string]any{},
25+
},
26+
{
27+
name: "top_k forwarded",
28+
opts: map[string]any{"top_k": 40},
29+
wantKeys: []string{"top_k"},
30+
},
31+
{
32+
name: "repetition_penalty forwarded",
33+
opts: map[string]any{"repetition_penalty": 1.15},
34+
wantKeys: []string{"repetition_penalty"},
35+
},
36+
{
37+
name: "multiple sampling opts",
38+
opts: map[string]any{"top_k": 50, "repetition_penalty": 1.1, "min_p": 0.05},
39+
wantKeys: []string{"top_k", "repetition_penalty", "min_p"},
40+
},
41+
{
42+
name: "non-sampling opts ignored",
43+
opts: map[string]any{"api_type": "openai_chatcompletions", "transport": "websocket"},
44+
},
45+
{
46+
name: "seed set natively",
47+
opts: map[string]any{"seed": 42},
48+
wantKeys: []string{"seed"},
49+
},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
params := oai.ChatCompletionNewParams{
55+
Model: "test-model",
56+
}
57+
applySamplingProviderOpts(&params, tt.opts)
58+
59+
// Marshal to JSON and check for expected keys
60+
data, err := json.Marshal(params)
61+
require.NoError(t, err)
62+
63+
var m map[string]any
64+
require.NoError(t, json.Unmarshal(data, &m))
65+
66+
for _, key := range tt.wantKeys {
67+
assert.Contains(t, m, key, "expected key %q in JSON output", key)
68+
}
69+
70+
// Non-sampling keys should never appear
71+
assert.NotContains(t, m, "api_type")
72+
assert.NotContains(t, m, "transport")
73+
})
74+
}
75+
}

0 commit comments

Comments
 (0)