Skip to content

Commit afd1dfc

Browse files
authored
Merge pull request #2180 from dgageot/simplify-models
Simplify models
2 parents 09719ae + 73c7ebd commit afd1dfc

4 files changed

Lines changed: 37 additions & 128 deletions

File tree

pkg/config/model_alias_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ func TestResolveModelAliases(t *testing.T) {
1616
mockData := &modelsdev.Database{
1717
Providers: map[string]modelsdev.Provider{
1818
"anthropic": {
19-
ID: "anthropic",
20-
Name: "Anthropic",
2119
Models: map[string]modelsdev.Model{
2220
"claude-sonnet-4-5": {Name: "Claude Sonnet 4.5 (latest)"},
2321
"claude-sonnet-4-5-20250929": {Name: "Claude Sonnet 4.5"},

pkg/modelsdev/store.go

Lines changed: 32 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ type Store struct {
3131
cacheFile string
3232
mu sync.Mutex
3333
db *Database
34-
etag string // ETag from last successful fetch, used for conditional requests
3534
}
3635

37-
// singleton holds the process-wide Store instance. It is initialised lazily
38-
// on the first call to NewStore. All subsequent calls return the same value.
39-
var singleton = sync.OnceValues(func() (*Store, error) {
36+
// NewStore returns the process-wide singleton Store.
37+
//
38+
// The database is loaded lazily on the first call to GetDatabase and
39+
// then cached in memory so that every caller shares one copy.
40+
// The first call creates the cache directory if it does not exist.
41+
var NewStore = sync.OnceValues(func() (*Store, error) {
4042
homeDir, err := os.UserHomeDir()
4143
if err != nil {
4244
return nil, fmt.Errorf("failed to get user home directory: %w", err)
@@ -52,15 +54,6 @@ var singleton = sync.OnceValues(func() (*Store, error) {
5254
}, nil
5355
})
5456

55-
// NewStore returns the process-wide singleton Store.
56-
//
57-
// The database is loaded lazily on the first call to GetDatabase and
58-
// then cached in memory so that every caller shares one copy.
59-
// The first call creates the cache directory if it does not exist.
60-
func NewStore() (*Store, error) {
61-
return singleton()
62-
}
63-
6457
// NewDatabaseStore creates a Store pre-populated with the given database.
6558
// The returned store serves data entirely from memory and never fetches
6659
// from the network or touches the filesystem, making it suitable for
@@ -78,18 +71,17 @@ func (s *Store) GetDatabase(ctx context.Context) (*Database, error) {
7871
return s.db, nil
7972
}
8073

81-
db, etag, err := loadDatabase(ctx, s.cacheFile)
74+
db, err := loadDatabase(ctx, s.cacheFile)
8275
if err != nil {
8376
return nil, err
8477
}
8578

8679
s.db = db
87-
s.etag = etag
8880
return db, nil
8981
}
9082

91-
// GetProvider returns a specific provider by ID.
92-
func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) {
83+
// getProvider returns a specific provider by ID.
84+
func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider, error) {
9385
db, err := s.GetDatabase(ctx)
9486
if err != nil {
9587
return nil, err
@@ -112,30 +104,23 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
112104
providerID := parts[0]
113105
modelID := parts[1]
114106

115-
provider, err := s.GetProvider(ctx, providerID)
107+
provider, err := s.getProvider(ctx, providerID)
116108
if err != nil {
117109
return nil, err
118110
}
119111

120112
model, exists := provider.Models[modelID]
121-
if !exists {
122-
// For amazon-bedrock, try stripping region/inference profile prefixes
123-
// Bedrock uses prefixes for cross-region inference profiles,
124-
// but models.dev stores models without these prefixes.
125-
//
126-
// Strip known region prefixes and retry lookup.
127-
if providerID == "amazon-bedrock" {
128-
if before, after, ok := strings.Cut(modelID, "."); ok {
129-
possibleRegionPrefix := before
130-
if isBedrockRegionPrefix(possibleRegionPrefix) {
131-
normalizedModelID := after
132-
model, exists = provider.Models[normalizedModelID]
133-
if exists {
134-
return &model, nil
135-
}
136-
}
137-
}
113+
114+
// For amazon-bedrock, try stripping region/inference profile prefixes.
115+
// Bedrock uses prefixes for cross-region inference profiles,
116+
// but models.dev stores models without these prefixes.
117+
if !exists && providerID == "amazon-bedrock" {
118+
if prefix, after, ok := strings.Cut(modelID, "."); ok && bedrockRegionPrefixes[prefix] {
119+
model, exists = provider.Models[after]
138120
}
121+
}
122+
123+
if !exists {
139124
return nil, fmt.Errorf("model %q not found in provider %q", modelID, providerID)
140125
}
141126

@@ -144,12 +129,11 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
144129

145130
// loadDatabase loads the database from the local cache file or
146131
// falls back to fetching from the models.dev API.
147-
// It returns the database and the ETag associated with the data.
148-
func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, error) {
132+
func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) {
149133
// Try to load from cache first
150134
cached, err := loadFromCache(cacheFile)
151135
if err == nil && time.Since(cached.LastRefresh) < refreshInterval {
152-
return &cached.Database, cached.ETag, nil
136+
return &cached.Database, nil
153137
}
154138

155139
// Cache is stale or doesn't exist — try a conditional fetch with the ETag.
@@ -163,9 +147,9 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err
163147
// If API fetch fails but we have cached data, use it regardless of age.
164148
if cached != nil {
165149
slog.Debug("API fetch failed, using stale cache", "error", fetchErr)
166-
return &cached.Database, cached.ETag, nil
150+
return &cached.Database, nil
167151
}
168-
return nil, "", fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr)
152+
return nil, fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr)
169153
}
170154

171155
// database is nil when the server returned 304 Not Modified.
@@ -175,15 +159,15 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err
175159
if saveErr := saveToCache(cacheFile, &cached.Database, cached.ETag); saveErr != nil {
176160
slog.Warn("Failed to update cache timestamp", "error", saveErr)
177161
}
178-
return &cached.Database, cached.ETag, nil
162+
return &cached.Database, nil
179163
}
180164

181165
// Save the fresh data to cache.
182166
if saveErr := saveToCache(cacheFile, database, newETag); saveErr != nil {
183167
slog.Warn("Failed to save to cache", "error", saveErr)
184168
}
185169

186-
return database, newETag, nil
170+
return database, nil
187171
}
188172

189173
// fetchFromAPI fetches the models.dev database.
@@ -230,7 +214,6 @@ func fetchFromAPI(ctx context.Context, etag string) (*Database, string, error) {
230214

231215
return &Database{
232216
Providers: providers,
233-
UpdatedAt: time.Now(),
234217
}, newETag, nil
235218
}
236219

@@ -249,11 +232,9 @@ func loadFromCache(cacheFile string) (*CachedData, error) {
249232
}
250233

251234
func saveToCache(cacheFile string, database *Database, etag string) error {
252-
now := time.Now()
253235
cached := CachedData{
254236
Database: *database,
255-
CachedAt: now,
256-
LastRefresh: now,
237+
LastRefresh: time.Now(),
257238
ETag: etag,
258239
}
259240

@@ -286,8 +267,7 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str
286267
return modelName
287268
}
288269

289-
// Get the provider from the database
290-
provider, err := s.GetProvider(ctx, providerID)
270+
provider, err := s.getProvider(ctx, providerID)
291271
if err != nil {
292272
return modelName
293273
}
@@ -319,46 +299,8 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str
319299
// stores models without regional prefixes. AWS uses these for cross-region inference profiles.
320300
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html
321301
var bedrockRegionPrefixes = map[string]bool{
322-
"us": true, // US region inference profile
323-
"eu": true, // EU region inference profile
324-
"apac": true, // Asia Pacific region inference profile
325-
"global": true, // Global inference profile (routes to any available region)
326-
}
327-
328-
// isBedrockRegionPrefix returns true if the prefix is a known Bedrock regional/inference profile prefix.
329-
func isBedrockRegionPrefix(prefix string) bool {
330-
return bedrockRegionPrefixes[prefix]
331-
}
332-
333-
// ModelSupportsReasoning checks if the given model ID supports reasoning/thinking.
334-
//
335-
// This function implements fail-open semantics:
336-
// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
337-
// - If models.dev lookup fails for any reason, returns true (fail-open)
338-
// - If lookup succeeds, returns the model's Reasoning field value
339-
func ModelSupportsReasoning(ctx context.Context, modelID string) bool {
340-
// Fail-open for empty model ID
341-
if modelID == "" {
342-
return true
343-
}
344-
345-
// Fail-open if not in provider/model format
346-
if !strings.Contains(modelID, "/") {
347-
slog.Debug("Model ID not in provider/model format, assuming reasoning supported to allow user choice", "model_id", modelID)
348-
return true
349-
}
350-
351-
store, err := NewStore()
352-
if err != nil {
353-
slog.Debug("Failed to create modelsdev store, assuming reasoning supported to allow user choice", "error", err)
354-
return true
355-
}
356-
357-
model, err := store.GetModel(ctx, modelID)
358-
if err != nil {
359-
slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err)
360-
return true
361-
}
362-
363-
return model.Reasoning
302+
"us": true,
303+
"eu": true,
304+
"apac": true,
305+
"global": true,
364306
}

pkg/modelsdev/types.go

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,20 @@ import "time"
55
// Database represents the complete models.dev database
66
type Database struct {
77
Providers map[string]Provider `json:"providers"`
8-
UpdatedAt time.Time `json:"updated_at"`
98
}
109

1110
// Provider represents an AI model provider
1211
type Provider struct {
13-
ID string `json:"id"`
14-
Name string `json:"name"`
15-
Doc string `json:"doc,omitempty"`
16-
API string `json:"api,omitempty"`
17-
NPM string `json:"npm,omitempty"`
18-
Env []string `json:"env,omitempty"`
1912
Models map[string]Model `json:"models"`
2013
}
2114

2215
// Model represents an AI model with its specifications and capabilities
2316
type Model struct {
24-
ID string `json:"id"`
25-
Name string `json:"name"`
26-
Family string `json:"family,omitempty"`
27-
Attachment bool `json:"attachment"`
28-
Reasoning bool `json:"reasoning"`
29-
Temperature bool `json:"temperature"`
30-
ToolCall bool `json:"tool_call"`
31-
Knowledge string `json:"knowledge,omitempty"`
32-
ReleaseDate string `json:"release_date"`
33-
LastUpdated string `json:"last_updated"`
34-
OpenWeights bool `json:"open_weights"`
35-
Cost *Cost `json:"cost,omitempty"`
36-
Limit Limit `json:"limit"`
37-
Modalities Modalities `json:"modalities"`
17+
Name string `json:"name"`
18+
Family string `json:"family,omitempty"`
19+
Cost *Cost `json:"cost,omitempty"`
20+
Limit Limit `json:"limit"`
21+
Modalities Modalities `json:"modalities"`
3822
}
3923

4024
// Cost represents the pricing information for a model
@@ -60,7 +44,6 @@ type Modalities struct {
6044
// CachedData represents the cached models.dev data with metadata
6145
type CachedData struct {
6246
Database Database `json:"database"`
63-
CachedAt time.Time `json:"cached_at"`
6447
LastRefresh time.Time `json:"last_refresh"`
6548
ETag string `json:"etag,omitempty"`
6649
}

pkg/runtime/model_switcher_test.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,25 +244,20 @@ func TestBuildCatalogChoices(t *testing.T) {
244244
db := &modelsdev.Database{
245245
Providers: map[string]modelsdev.Provider{
246246
"openai": {
247-
ID: "openai",
248-
Name: "OpenAI",
249247
Models: map[string]modelsdev.Model{
250248
"gpt-4o": {
251-
ID: "gpt-4o",
252249
Name: "GPT-4o",
253250
Modalities: modelsdev.Modalities{
254251
Output: []string{"text"},
255252
},
256253
},
257254
"dall-e-3": {
258-
ID: "dall-e-3",
259255
Name: "DALL-E 3",
260256
Modalities: modelsdev.Modalities{
261257
Output: []string{"image"}, // Not a text model
262258
},
263259
},
264260
"text-embedding-3-large": {
265-
ID: "text-embedding-3-large",
266261
Name: "Text Embedding 3 Large",
267262
Family: "text-embedding",
268263
Modalities: modelsdev.Modalities{
@@ -272,11 +267,8 @@ func TestBuildCatalogChoices(t *testing.T) {
272267
},
273268
},
274269
"anthropic": {
275-
ID: "anthropic",
276-
Name: "Anthropic",
277270
Models: map[string]modelsdev.Model{
278271
"claude-sonnet-4-0": {
279-
ID: "claude-sonnet-4-0",
280272
Name: "Claude Sonnet 4",
281273
Modalities: modelsdev.Modalities{
282274
Output: []string{"text"},
@@ -285,11 +277,8 @@ func TestBuildCatalogChoices(t *testing.T) {
285277
},
286278
},
287279
"unsupported": {
288-
ID: "unsupported",
289-
Name: "Unsupported Provider",
290280
Models: map[string]modelsdev.Model{
291281
"some-model": {
292-
ID: "some-model",
293282
Name: "Some Model",
294283
Modalities: modelsdev.Modalities{
295284
Output: []string{"text"},
@@ -348,11 +337,8 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) {
348337
db := &modelsdev.Database{
349338
Providers: map[string]modelsdev.Provider{
350339
"openai": {
351-
ID: "openai",
352-
Name: "OpenAI",
353340
Models: map[string]modelsdev.Model{
354341
"gpt-4o": {
355-
ID: "gpt-4o",
356342
Name: "GPT-4o",
357343
Modalities: modelsdev.Modalities{
358344
Output: []string{"text"},

0 commit comments

Comments
 (0)