Skip to content

Commit 99f9adb

Browse files
authored
Merge pull request #2429 from dgageot/fix2
fix: add mutex to protect lastSelectedID in rule-based router
2 parents f809e54 + da03166 commit 99f9adb

2 files changed

Lines changed: 54 additions & 2 deletions

File tree

pkg/model/provider/rulebased/client.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"errors"
1212
"fmt"
1313
"log/slog"
14+
"sync"
1415

1516
"github.com/blevesearch/bleve/v2"
1617
"github.com/blevesearch/bleve/v2/mapping"
@@ -45,6 +46,7 @@ type Client struct {
4546
routes []Provider
4647
fallback Provider
4748
index bleve.Index
49+
mu sync.RWMutex
4850
lastSelectedID string // ID of the provider selected by the most recent call
4951
}
5052

@@ -165,10 +167,13 @@ func (c *Client) CreateChatCompletionStream(
165167
return nil, errors.New("no provider available for routing")
166168
}
167169

168-
c.lastSelectedID = provider.ID()
170+
selectedID := provider.ID()
171+
c.mu.Lock()
172+
c.lastSelectedID = selectedID
173+
c.mu.Unlock()
169174
slog.Debug("Rule-based router selected model",
170175
"router", c.ID(),
171-
"selected_model", c.lastSelectedID,
176+
"selected_model", selectedID,
172177
"message_count", len(messages),
173178
)
174179

@@ -179,6 +184,8 @@ func (c *Client) CreateChatCompletionStream(
179184
// recent CreateChatCompletionStream call. This allows callers to display
180185
// the YAML-configured sub-model name for rule-based routing.
181186
func (c *Client) LastSelectedModelID() string {
187+
c.mu.RLock()
188+
defer c.mu.RUnlock()
182189
return c.lastSelectedID
183190
}
184191

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package rulebased
2+
3+
import (
4+
"sync"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/docker/docker-agent/pkg/chat"
10+
"github.com/docker/docker-agent/pkg/config/latest"
11+
)
12+
13+
func TestLastSelectedModelID_Concurrent(t *testing.T) {
14+
t.Parallel()
15+
16+
cfg := &latest.ModelConfig{
17+
Provider: "openai",
18+
Model: "gpt-4o",
19+
Routing: []latest.RoutingRule{
20+
{
21+
Model: "anthropic/claude-3-haiku",
22+
Examples: []string{"hello", "hi there"},
23+
},
24+
},
25+
}
26+
27+
client, err := NewClient(t.Context(), cfg, nil, nil, mockProviderFactory)
28+
require.NoError(t, err)
29+
defer client.Close()
30+
31+
var wg sync.WaitGroup
32+
for range 100 {
33+
wg.Add(2)
34+
go func() {
35+
defer wg.Done()
36+
messages := []chat.Message{{Role: chat.MessageRoleUser, Content: "hello"}}
37+
_, _ = client.CreateChatCompletionStream(t.Context(), messages, nil)
38+
}()
39+
go func() {
40+
defer wg.Done()
41+
_ = client.LastSelectedModelID()
42+
}()
43+
}
44+
wg.Wait()
45+
}

0 commit comments

Comments
 (0)