Skip to content

Commit f92ffe7

Browse files
authored
Merge pull request #2226 from dgageot/board/docker-agent-issue-2220-feasibility-chec-c088f140
feat: forward sampling provider_opts (top_k, repetition_penalty) to provider APIs
2 parents 3336e23 + 4520285 commit f92ffe7

12 files changed

Lines changed: 405 additions & 38 deletions

File tree

agent-schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@
552552
},
553553
"provider_opts": {
554554
"type": "object",
555-
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
555+
"description": "Provider-specific options. Sampling parameters: top_k (integer, supported by anthropic, google, amazon-bedrock, and custom OpenAI-compatible providers like vLLM/Ollama), repetition_penalty (float, forwarded to custom OpenAI-compatible providers), min_p (float, forwarded to custom providers), seed (integer, forwarded to OpenAI). Infrastructure options: dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
556556
"additionalProperties": true
557557
},
558558
"track_usage": {

examples/sampling-opts.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env docker agent run
2+
3+
# This example shows how to use provider_opts to pass sampling parameters
4+
# like top_k and repetition_penalty to different providers.
5+
6+
agents:
7+
root:
8+
model: gpt
9+
description: "Assistant with custom sampling parameters"
10+
instruction: |
11+
You are a helpful assistant running on a local model with tuned sampling parameters.
12+
13+
models:
14+
gpt:
15+
provider: openai
16+
model: gpt-4o
17+
temperature: 0.7
18+
top_p: 0.9
19+
provider_opts:
20+
top_k: 40
21+
repetition_penalty: 1.15
22+
min_p: 0.05

pkg/model/provider/anthropic/beta_client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
1414

1515
"github.com/docker/docker-agent/pkg/chat"
16+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
1617
"github.com/docker/docker-agent/pkg/rag/prompts"
1718
"github.com/docker/docker-agent/pkg/rag/types"
1819
"github.com/docker/docker-agent/pkg/tools"
@@ -115,6 +116,12 @@ func (c *Client) createBetaStream(
115116
"max_tokens", maxTokens,
116117
"message_count", len(params.Messages))
117118

119+
// Forward top_k from provider_opts (Anthropic natively supports it)
120+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
121+
params.TopK = param.NewOpt(topK)
122+
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
123+
}
124+
118125
stream := client.Beta.Messages.NewStreaming(ctx, params)
119126
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
120127
ad := c.newBetaStreamAdapter(stream, trackUsage)
@@ -293,6 +300,12 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
293300
params.TopP = param.NewOpt(*c.ModelConfig.TopP)
294301
}
295302

303+
// Forward top_k from provider_opts (Anthropic natively supports it)
304+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
305+
params.TopK = param.NewOpt(topK)
306+
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
307+
}
308+
296309
// Use streaming API to avoid timeout errors for operations that may take longer than 10 minutes
297310
stream := client.Beta.Messages.NewStreaming(ctx, params)
298311

pkg/model/provider/anthropic/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/docker-agent/pkg/httpclient"
2525
"github.com/docker/docker-agent/pkg/model/provider/base"
2626
"github.com/docker/docker-agent/pkg/model/provider/options"
27+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2728
"github.com/docker/docker-agent/pkg/tools"
2829
)
2930

@@ -337,6 +338,12 @@ func (c *Client) CreateChatCompletionStream(
337338
slog.Debug("Anthropic extended thinking enabled, ignoring temperature/top_p settings")
338339
}
339340

341+
// Forward top_k from provider_opts (Anthropic natively supports it)
342+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
343+
params.TopK = param.NewOpt(topK)
344+
slog.Debug("Anthropic provider_opts: set top_k", "value", topK)
345+
}
346+
340347
if len(requestTools) > 0 {
341348
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))
342349
}

pkg/model/provider/bedrock/client.go

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/docker/docker-agent/pkg/environment"
2121
"github.com/docker/docker-agent/pkg/model/provider/base"
2222
"github.com/docker/docker-agent/pkg/model/provider/options"
23+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2324
"github.com/docker/docker-agent/pkg/modelsdev"
2425
"github.com/docker/docker-agent/pkg/tools"
2526
)
@@ -244,7 +245,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
244245
}
245246

246247
// Set inference configuration (temp/topP are suppressed when thinking is on).
247-
input.InferenceConfig = c.buildInferenceConfig(additionalFields != nil)
248+
input.InferenceConfig = c.buildInferenceConfig(c.isThinkingEnabled())
248249

249250
// Convert and set tools
250251
if len(requestTools) > 0 {
@@ -278,59 +279,100 @@ func (c *Client) buildInferenceConfig(thinkingEnabled bool) *types.InferenceConf
278279
}
279280

280281
func (c *Client) interleavedThinkingEnabled() bool {
281-
return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking")
282-
}
283-
284-
func (c *Client) promptCachingEnabled() bool {
285-
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
286-
return false
282+
// Default to true, matching the documented schema behavior.
283+
v, ok := c.ModelConfig.ProviderOpts["interleaved_thinking"]
284+
if !ok {
285+
return true
287286
}
288-
return c.cachingSupported
287+
b, ok := v.(bool)
288+
if !ok {
289+
slog.Warn("Bedrock provider_opts type mismatch",
290+
"key", "interleaved_thinking",
291+
"expected_type", "bool",
292+
"actual_type", fmt.Sprintf("%T", v),
293+
"value", v)
294+
return true
295+
}
296+
return b
289297
}
290298

291-
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
292-
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
299+
// isThinkingEnabled returns true if a valid thinking budget is configured.
300+
// It mirrors the validation in buildAdditionalModelRequestFields but without
301+
// side effects (no logging), so it can safely be used to gate inference config.
302+
func (c *Client) isThinkingEnabled() bool {
293303
if c.ModelConfig.ThinkingBudget == nil {
294-
return nil
304+
return false
295305
}
296306
tokens := c.ModelConfig.ThinkingBudget.Tokens
297307
if t, ok := c.ModelConfig.ThinkingBudget.EffortTokens(); ok {
298308
tokens = t
299309
}
300-
if tokens <= 0 {
301-
return nil
302-
}
303-
304-
// Validate minimum (Claude requires at least 1024 tokens for thinking)
305310
if tokens < 1024 {
306-
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring",
307-
"tokens", tokens)
308-
return nil
311+
return false
309312
}
310-
311-
// Validate against max_tokens
312313
if c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
313-
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
314-
"thinking_budget", tokens,
315-
"max_tokens", *c.ModelConfig.MaxTokens)
316-
return nil
314+
return false
317315
}
316+
return true
317+
}
318318

319-
slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
319+
func (c *Client) promptCachingEnabled() bool {
320+
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
321+
return false
322+
}
323+
return c.cachingSupported
324+
}
320325

321-
fields := map[string]any{
322-
"thinking": map[string]any{
323-
"type": "enabled",
324-
"budget_tokens": tokens,
325-
},
326+
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode
327+
// and forwards supported sampling parameters from provider_opts (e.g. top_k).
328+
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
329+
fields := map[string]any{}
330+
331+
// Forward top_k from provider_opts (Anthropic on Bedrock supports it)
332+
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
333+
fields["top_k"] = topK
334+
slog.Debug("Bedrock provider_opts: set top_k", "value", topK)
326335
}
327336

328-
// Add anthropic_beta field for interleaved thinking
329-
if c.interleavedThinkingEnabled() {
330-
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
331-
slog.Debug("Bedrock request using interleaved thinking beta")
337+
// Configure thinking budget if present and valid
338+
if budget := c.ModelConfig.ThinkingBudget; budget != nil {
339+
tokens := budget.Tokens
340+
if t, ok := budget.EffortTokens(); ok {
341+
tokens = t
342+
}
343+
344+
valid := tokens > 0
345+
if valid && tokens < 1024 {
346+
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring", "tokens", tokens)
347+
valid = false
348+
}
349+
if valid && c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
350+
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
351+
"thinking_budget", tokens,
352+
"max_tokens", *c.ModelConfig.MaxTokens)
353+
valid = false
354+
}
355+
356+
if valid {
357+
slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
358+
fields["thinking"] = map[string]any{
359+
"type": "enabled",
360+
"budget_tokens": tokens,
361+
}
362+
363+
if c.interleavedThinkingEnabled() {
364+
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
365+
slog.Debug("Bedrock request using interleaved thinking beta")
366+
} else {
367+
slog.Warn("Bedrock thinking_budget is set but interleaved_thinking is explicitly disabled; " +
368+
"the anthropic_beta header will not be sent, which may cause the thinking budget to be ignored")
369+
}
370+
}
332371
}
333372

373+
if len(fields) == 0 {
374+
return nil
375+
}
334376
return document.NewLazyDocument(fields)
335377
}
336378

pkg/model/provider/bedrock/client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ func TestInterleavedThinkingEnabled_NotSet(t *testing.T) {
854854
},
855855
}
856856

857-
assert.False(t, client.interleavedThinkingEnabled())
857+
assert.True(t, client.interleavedThinkingEnabled())
858858
}
859859

860860
func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) {
@@ -870,7 +870,7 @@ func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) {
870870
},
871871
}
872872

873-
assert.False(t, client.interleavedThinkingEnabled())
873+
assert.True(t, client.interleavedThinkingEnabled())
874874
}
875875

876876
func TestBuildAdditionalModelRequestFields_WithInterleavedThinking(t *testing.T) {

pkg/model/provider/gemini/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/docker/docker-agent/pkg/httpclient"
2222
"github.com/docker/docker-agent/pkg/model/provider/base"
2323
"github.com/docker/docker-agent/pkg/model/provider/options"
24+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
2425
"github.com/docker/docker-agent/pkg/rag/prompts"
2526
"github.com/docker/docker-agent/pkg/rag/types"
2627
"github.com/docker/docker-agent/pkg/tools"
@@ -352,6 +353,12 @@ func (c *Client) buildConfig() *genai.GenerateContentConfig {
352353
config.PresencePenalty = new(float32(*c.ModelConfig.PresencePenalty))
353354
}
354355

356+
// Forward top_k from provider_opts (Gemini natively supports it)
357+
if topK, ok := providerutil.GetProviderOptFloat64(c.ModelConfig.ProviderOpts, "top_k"); ok {
358+
config.TopK = new(float32(topK))
359+
slog.Debug("Gemini provider_opts: set top_k", "value", topK)
360+
}
361+
355362
// Apply thinking configuration for Gemini models.
356363
// See https://ai.google.dev/gemini-api/docs/thinking
357364
if c.ModelOptions.NoThinking() {

pkg/model/provider/openai/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ func (c *Client) CreateChatCompletionStream(
312312
return nil, err
313313
}
314314

315+
// Forward sampling-related provider_opts as extra body fields.
316+
// This allows custom/OpenAI-compatible providers (vLLM, Ollama, etc.)
317+
// to receive parameters like top_k, repetition_penalty, etc.
318+
applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)
319+
315320
stream := client.Chat.Completions.NewStreaming(ctx, params)
316321

317322
slog.Debug("OpenAI chat completion stream created successfully", "model", c.ModelConfig.Model)
@@ -842,6 +847,8 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
842847
},
843848
}
844849

850+
applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)
851+
845852
resp, err := client.Chat.Completions.New(ctx, params)
846853
if err != nil {
847854
slog.Error("OpenAI rerank request failed", "error", err)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package openai
2+
3+
import (
4+
"log/slog"
5+
6+
oai "github.com/openai/openai-go/v3"
7+
8+
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
9+
)
10+
11+
// applySamplingProviderOpts forwards sampling-related provider_opts as extra
12+
// body fields on the OpenAI ChatCompletionNewParams. This enables custom
13+
// OpenAI-compatible providers (vLLM, Ollama, llama.cpp, etc.) to receive
14+
// parameters like top_k, repetition_penalty, min_p, etc. that the native
15+
// OpenAI API does not support but these backends do.
16+
func applySamplingProviderOpts(params *oai.ChatCompletionNewParams, opts map[string]any) {
17+
if len(opts) == 0 {
18+
return
19+
}
20+
21+
extras := make(map[string]any)
22+
23+
for _, key := range providerutil.SamplingProviderOptsKeys() {
24+
if key == "seed" {
25+
// seed is a native ChatCompletionNewParams field (int64),
26+
// so set it directly rather than as an extra field.
27+
if v, ok := providerutil.GetProviderOptInt64(opts, key); ok {
28+
params.Seed = oai.Int(v)
29+
slog.Debug("OpenAI provider_opts: set seed", "value", v)
30+
}
31+
continue
32+
}
33+
34+
if v, ok := providerutil.GetProviderOptFloat64(opts, key); ok {
35+
extras[key] = v
36+
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", v)
37+
} else if vi, ok := providerutil.GetProviderOptInt64(opts, key); ok {
38+
extras[key] = vi
39+
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", vi)
40+
}
41+
}
42+
43+
if len(extras) > 0 {
44+
params.SetExtraFields(extras)
45+
}
46+
}

0 commit comments

Comments
 (0)