Skip to content

Commit f010c3f

Browse files
authored
Merge pull request #2287 from dgageot/board/fix-docker-agent-issue-2280-f0d4a711
Add Vertex AI Model Garden support for non-Gemini models
2 parents e753f38 + ca7dbde commit f010c3f

3 files changed

Lines changed: 272 additions & 0 deletions

File tree

pkg/model/provider/provider.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/docker/docker-agent/pkg/model/provider/openai"
2020
"github.com/docker/docker-agent/pkg/model/provider/options"
2121
"github.com/docker/docker-agent/pkg/model/provider/rulebased"
22+
"github.com/docker/docker-agent/pkg/model/provider/vertexai"
2223
"github.com/docker/docker-agent/pkg/rag/types"
2324
"github.com/docker/docker-agent/pkg/tools"
2425
)
@@ -242,6 +243,11 @@ func createDirectProvider(ctx context.Context, cfg *latest.ModelConfig, env envi
242243
case "anthropic":
243244
return anthropic.NewClient(ctx, enhancedCfg, env, opts...)
244245
case "google":
246+
// Route non-Gemini models on Vertex AI (Model Garden) through the
247+
// OpenAI-compatible endpoint instead of the Gemini SDK.
248+
if vertexai.IsModelGardenConfig(enhancedCfg) {
249+
return vertexai.NewClient(ctx, enhancedCfg, env, opts...)
250+
}
245251
return gemini.NewClient(ctx, enhancedCfg, env, opts...)
246252
case "dmr":
247253
return dmr.NewClient(ctx, enhancedCfg, opts...)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
// Package vertexai provides support for non-Gemini models hosted on
2+
// Google Cloud's Vertex AI Model Garden via the OpenAI-compatible endpoint.
3+
//
4+
// Vertex AI Model Garden hosts models from various publishers (Anthropic,
5+
// Meta, Mistral, etc.) and exposes them through an OpenAI-compatible API.
6+
// This package configures the OpenAI provider to talk to that endpoint
7+
// using Google Cloud Application Default Credentials for authentication.
8+
//
9+
// Usage in agent config:
10+
//
11+
// models:
12+
// claude-on-vertex:
13+
// provider: google
14+
// model: claude-sonnet-4-20250514
15+
// provider_opts:
16+
// project: my-gcp-project
17+
// location: us-east5
18+
// publisher: anthropic
19+
package vertexai
20+
21+
import (
22+
"context"
23+
"errors"
24+
"fmt"
25+
"log/slog"
26+
"net/url"
27+
"regexp"
28+
"strings"
29+
"sync"
30+
31+
"golang.org/x/oauth2"
32+
"golang.org/x/oauth2/google"
33+
34+
"github.com/docker/docker-agent/pkg/config/latest"
35+
"github.com/docker/docker-agent/pkg/environment"
36+
"github.com/docker/docker-agent/pkg/model/provider/openai"
37+
"github.com/docker/docker-agent/pkg/model/provider/options"
38+
)
39+
40+
// cloudPlatformScope is the OAuth2 scope required for Vertex AI API access.
41+
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
42+
43+
// validGCPIdentifier matches GCP project IDs and location names.
44+
// Project IDs: 6-30 chars, lowercase letters, digits, hyphens.
45+
// Locations: lowercase letters, digits, hyphens (e.g. us-central1).
46+
var validGCPIdentifier = regexp.MustCompile(`^[a-z][a-z0-9-]{1,29}$`)
47+
48+
// IsModelGardenConfig returns true when the ModelConfig describes a
49+
// non-Gemini model on Vertex AI (i.e. the "publisher" provider_opt is set).
50+
func IsModelGardenConfig(cfg *latest.ModelConfig) bool {
51+
if cfg == nil || cfg.ProviderOpts == nil {
52+
return false
53+
}
54+
publisher, _ := cfg.ProviderOpts["publisher"].(string)
55+
return publisher != "" && !strings.EqualFold(publisher, "google")
56+
}
57+
58+
// NewClient creates an OpenAI-compatible client pointing at the Vertex AI
59+
// Model Garden endpoint. It uses Google Application Default Credentials
60+
// for authentication.
61+
func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (*openai.Client, error) {
62+
project, _ := cfg.ProviderOpts["project"].(string)
63+
location, _ := cfg.ProviderOpts["location"].(string)
64+
publisher, _ := cfg.ProviderOpts["publisher"].(string)
65+
66+
// Expand env vars in project/location.
67+
var err error
68+
project, err = environment.Expand(ctx, project, env)
69+
if err != nil {
70+
return nil, fmt.Errorf("expanding project: %w", err)
71+
}
72+
location, err = environment.Expand(ctx, location, env)
73+
if err != nil {
74+
return nil, fmt.Errorf("expanding location: %w", err)
75+
}
76+
77+
// Fall back to environment variables if not set in provider_opts.
78+
if project == "" {
79+
project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT")
80+
}
81+
if location == "" {
82+
location, _ = env.Get(ctx, "GOOGLE_CLOUD_LOCATION")
83+
}
84+
85+
if project == "" {
86+
return nil, errors.New("vertex AI Model Garden requires a GCP project (set provider_opts.project or GOOGLE_CLOUD_PROJECT)")
87+
}
88+
if location == "" {
89+
return nil, errors.New("vertex AI Model Garden requires a GCP location (set provider_opts.location or GOOGLE_CLOUD_LOCATION)")
90+
}
91+
92+
// Validate project and location to prevent URL path manipulation.
93+
if !validGCPIdentifier.MatchString(project) {
94+
return nil, fmt.Errorf("invalid GCP project ID: %q", project)
95+
}
96+
if !validGCPIdentifier.MatchString(location) {
97+
return nil, fmt.Errorf("invalid GCP location: %q", location)
98+
}
99+
100+
// Build the base URL for the OpenAI-compatible endpoint.
101+
// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#openai_sdk
102+
baseURL := "https://" + location + "-aiplatform.googleapis.com/v1beta1/projects/" +
103+
url.PathEscape(project) + "/locations/" + url.PathEscape(location) + "/endpoints/openapi"
104+
105+
slog.Debug("Creating Vertex AI Model Garden client",
106+
"publisher", publisher,
107+
"project", project,
108+
"location", location,
109+
"model", cfg.Model,
110+
"base_url", baseURL,
111+
)
112+
113+
// Get a GCP access token using Application Default Credentials.
114+
tokenSource, err := google.DefaultTokenSource(ctx, cloudPlatformScope)
115+
if err != nil {
116+
return nil, fmt.Errorf("failed to obtain GCP credentials for Vertex AI: %w (run 'gcloud auth application-default login')", err)
117+
}
118+
token, err := tokenSource.Token()
119+
if err != nil {
120+
return nil, fmt.Errorf("failed to get GCP access token: %w", err)
121+
}
122+
123+
// Build a modified config that the OpenAI provider can use.
124+
// We override the base URL and set the token directly.
125+
oaiCfg := cfg.Clone()
126+
oaiCfg.BaseURL = baseURL
127+
// Use a synthetic token key env var — we'll set it in a wrapper env provider.
128+
const tokenEnvVar = "_VERTEX_AI_ACCESS_TOKEN"
129+
oaiCfg.TokenKey = tokenEnvVar
130+
131+
// Remove provider_opts that are specific to Vertex AI / not relevant for OpenAI.
132+
delete(oaiCfg.ProviderOpts, "project")
133+
delete(oaiCfg.ProviderOpts, "location")
134+
delete(oaiCfg.ProviderOpts, "publisher")
135+
136+
// Force chat completions API type (Vertex AI OpenAI endpoint uses this).
137+
if oaiCfg.ProviderOpts == nil {
138+
oaiCfg.ProviderOpts = map[string]any{}
139+
}
140+
oaiCfg.ProviderOpts["api_type"] = "openai_chatcompletions"
141+
142+
// Wrap the environment provider to inject the GCP access token.
143+
wrappedEnv := &tokenEnv{
144+
Provider: env,
145+
key: tokenEnvVar,
146+
tok: token.AccessToken,
147+
ts: tokenSource,
148+
}
149+
150+
return openai.NewClient(ctx, oaiCfg, wrappedEnv, opts...)
151+
}
152+
153+
// tokenEnv wraps an environment.Provider to inject a GCP access token.
154+
// It refreshes the token on each Get call to handle token expiry.
155+
type tokenEnv struct {
156+
environment.Provider
157+
158+
key string
159+
mu sync.Mutex
160+
tok string
161+
ts oauth2.TokenSource
162+
}
163+
164+
func (e *tokenEnv) Get(ctx context.Context, name string) (string, bool) {
165+
if name == e.key {
166+
e.mu.Lock()
167+
defer e.mu.Unlock()
168+
169+
// Refresh token if needed — TokenSource handles caching.
170+
tok, err := e.ts.Token()
171+
if err != nil {
172+
slog.Warn("Failed to refresh GCP access token, using cached", "error", err)
173+
return e.tok, true
174+
}
175+
e.tok = tok.AccessToken
176+
return e.tok, true
177+
}
178+
return e.Provider.Get(ctx, name)
179+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package vertexai
2+
3+
import (
4+
"testing"
5+
6+
"github.com/docker/docker-agent/pkg/config/latest"
7+
)
8+
9+
func TestIsModelGardenConfig(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
cfg *latest.ModelConfig
13+
want bool
14+
}{
15+
{
16+
name: "nil config",
17+
cfg: nil,
18+
want: false,
19+
},
20+
{
21+
name: "no provider_opts",
22+
cfg: &latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"},
23+
want: false,
24+
},
25+
{
26+
name: "no publisher",
27+
cfg: &latest.ModelConfig{
28+
Provider: "google",
29+
Model: "gemini-2.5-flash",
30+
ProviderOpts: map[string]any{"project": "my-project", "location": "us-central1"},
31+
},
32+
want: false,
33+
},
34+
{
35+
name: "publisher=google",
36+
cfg: &latest.ModelConfig{
37+
Provider: "google",
38+
Model: "gemini-2.5-flash",
39+
ProviderOpts: map[string]any{"project": "my-project", "location": "us-central1", "publisher": "google"},
40+
},
41+
want: false,
42+
},
43+
{
44+
name: "publisher=anthropic",
45+
cfg: &latest.ModelConfig{
46+
Provider: "google",
47+
Model: "claude-sonnet-4-20250514",
48+
ProviderOpts: map[string]any{"project": "my-project", "location": "us-east5", "publisher": "anthropic"},
49+
},
50+
want: true,
51+
},
52+
{
53+
name: "publisher=meta",
54+
cfg: &latest.ModelConfig{
55+
Provider: "google",
56+
Model: "meta/llama-4-maverick-17b-128e-instruct-maas",
57+
ProviderOpts: map[string]any{"project": "my-project", "location": "us-central1", "publisher": "meta"},
58+
},
59+
want: true,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
got := IsModelGardenConfig(tt.cfg)
66+
if got != tt.want {
67+
t.Errorf("IsModelGardenConfig() = %v, want %v", got, tt.want)
68+
}
69+
})
70+
}
71+
}
72+
73+
func TestValidGCPIdentifier(t *testing.T) {
74+
valid := []string{"my-project", "us-central1", "project123", "ab"}
75+
for _, s := range valid {
76+
if !validGCPIdentifier.MatchString(s) {
77+
t.Errorf("expected %q to be valid", s)
78+
}
79+
}
80+
81+
invalid := []string{"", "A", "../foo", "my project", "a", "123abc", "my_project/../../evil"}
82+
for _, s := range invalid {
83+
if validGCPIdentifier.MatchString(s) {
84+
t.Errorf("expected %q to be invalid", s)
85+
}
86+
}
87+
}

0 commit comments

Comments
 (0)