Skip to content

Commit 12ade90

Browse files
authored
Merge pull request #2205 from dgageot/refacto-rag
Simplify the runtime related RAG code a bit
2 parents 4b3a638 + 4265af5 commit 12ade90

9 files changed

Lines changed: 136 additions & 167 deletions

File tree

pkg/app/app.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ import (
3030
"github.com/docker/docker-agent/pkg/tui/messages"
3131
)
3232

33+
// RAGInitializer is implemented by runtimes that support background RAG initialization.
34+
// Local runtimes use this to start indexing early; remote runtimes typically do not.
35+
type RAGInitializer interface {
36+
StartBackgroundRAGInit(ctx context.Context, sendEvent func(runtime.Event))
37+
}
38+
3339
type App struct {
3440
runtime runtime.Runtime
3541
session *session.Session
@@ -119,7 +125,7 @@ func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts ..
119125
// If the runtime supports background RAG initialization, start it
120126
// and forward events to the TUI. Remote runtimes typically handle RAG server-side
121127
// and won't implement this optional interface.
122-
if ragRuntime, ok := rt.(runtime.RAGInitializer); ok {
128+
if ragRuntime, ok := rt.(RAGInitializer); ok {
123129
go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) {
124130
select {
125131
case app.events <- event:

pkg/rag/builder.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,13 @@ type ManagersBuildConfig struct {
2525
}
2626

2727
// NewManagers constructs all RAG managers defined in the config.
28-
func NewManagers(
29-
ctx context.Context,
30-
cfg *latest.Config,
31-
buildCfg ManagersBuildConfig,
32-
) (map[string]*Manager, error) {
33-
managers := make(map[string]*Manager)
34-
28+
func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) {
3529
if len(cfg.RAG) == 0 {
36-
return managers, nil
30+
return nil, nil
3731
}
3832

33+
var managers []*Manager
34+
3935
for ragName, ragCfg := range cfg.RAG {
4036
// Validate that we have at least one strategy
4137
if len(ragCfg.Strategies) == 0 {
@@ -69,7 +65,7 @@ func NewManagers(
6965
return nil, fmt.Errorf("failed to create RAG manager %q: %w", ragName, err)
7066
}
7167

72-
managers[ragName] = manager
68+
managers = append(managers, manager)
7369

7470
strategyNames := make([]string, len(strategyConfigs))
7571
for i, sc := range strategyConfigs {

pkg/runtime/event.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -558,12 +558,12 @@ type RAGIndexingStartedEvent struct {
558558
StrategyName string `json:"strategy_name"`
559559
}
560560

561-
func RAGIndexingStarted(ragName, strategyName, agentName string) Event {
561+
func RAGIndexingStarted(ragName, strategyName string) Event {
562562
return &RAGIndexingStartedEvent{
563563
Type: "rag_indexing_started",
564564
RAGName: ragName,
565565
StrategyName: strategyName,
566-
AgentContext: newAgentContext(agentName),
566+
AgentContext: newAgentContext(""),
567567
}
568568
}
569569

@@ -596,12 +596,12 @@ type RAGIndexingCompletedEvent struct {
596596
StrategyName string `json:"strategy_name"`
597597
}
598598

599-
func RAGIndexingCompleted(ragName, strategyName, agentName string) Event {
599+
func RAGIndexingCompleted(ragName, strategyName string) Event {
600600
return &RAGIndexingCompletedEvent{
601601
Type: "rag_indexing_completed",
602602
RAGName: ragName,
603603
StrategyName: strategyName,
604-
AgentContext: newAgentContext(agentName),
604+
AgentContext: newAgentContext(""),
605605
}
606606
}
607607

@@ -635,7 +635,6 @@ type MessageAddedEvent struct {
635635
Message *session.Message `json:"-"`
636636
}
637637

638-
func (e *MessageAddedEvent) GetAgentName() string { return e.AgentName }
639638
func (e *MessageAddedEvent) GetSessionID() string { return e.SessionID }
640639

641640
func MessageAdded(sessionID string, msg *session.Message, agentName string) Event {
@@ -657,8 +656,6 @@ type SubSessionCompletedEvent struct {
657656
SubSession any `json:"sub_session"` // *session.Session
658657
}
659658

660-
func (e *SubSessionCompletedEvent) GetAgentName() string { return e.AgentName }
661-
662659
func SubSessionCompleted(parentSessionID string, subSession any, agentName string) Event {
663660
return &SubSessionCompletedEvent{
664661
Type: "sub_session_completed",

pkg/runtime/loop.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
9797
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())
9898

9999
// Initialize RAG and forward events
100-
r.InitializeRAG(ctx, events)
100+
r.StartBackgroundRAGInit(ctx, func(event Event) {
101+
events <- event
102+
})
101103

102104
r.emitAgentWarnings(a, chanSend(events))
103105
r.configureToolsetHandlers(a, events)

pkg/runtime/rag.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
8+
"github.com/docker/docker-agent/pkg/rag"
9+
"github.com/docker/docker-agent/pkg/rag/types"
10+
)
11+
12+
// StartBackgroundRAGInit initializes RAG in background and forwards events
13+
// Should be called early (e.g., by App) to start indexing before RunStream
14+
func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) {
15+
if r.ragInitialized.Swap(true) {
16+
return
17+
}
18+
19+
ragManagers := r.team.RAGManagers()
20+
if len(ragManagers) == 0 {
21+
return
22+
}
23+
24+
// Set up event forwarding BEFORE starting initialization
25+
r.forwardRAGEvents(ctx, ragManagers, sendEvent)
26+
initializeRAG(ctx, ragManagers)
27+
startRAGFileWatchers(ctx, ragManagers)
28+
}
29+
30+
// forwardRAGEvents forwards RAG manager events to the given callback
31+
// Consolidates duplicated event forwarding logic
32+
func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers []*rag.Manager, sendEvent func(Event)) {
33+
for _, mgr := range ragManagers {
34+
go func() {
35+
ragName := mgr.Name()
36+
slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName)
37+
for {
38+
select {
39+
case <-ctx.Done():
40+
slog.Debug("RAG event forwarder stopped", "rag", ragName)
41+
return
42+
case ragEvent, ok := <-mgr.Events():
43+
if !ok {
44+
slog.Debug("RAG events channel closed", "rag", ragName)
45+
return
46+
}
47+
48+
agentName := r.CurrentAgentName()
49+
slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName)
50+
51+
switch ragEvent.Type {
52+
case types.EventTypeIndexingStarted:
53+
sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName))
54+
case types.EventTypeIndexingProgress:
55+
if ragEvent.Progress != nil {
56+
sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName))
57+
}
58+
case types.EventTypeIndexingComplete:
59+
sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName))
60+
case types.EventTypeUsage:
61+
// Convert RAG usage to TokenUsageEvent so TUI displays it
62+
sendEvent(NewTokenUsageEvent("", agentName, &Usage{
63+
InputTokens: ragEvent.TotalTokens,
64+
ContextLength: ragEvent.TotalTokens,
65+
Cost: ragEvent.Cost,
66+
}))
67+
case types.EventTypeError:
68+
if ragEvent.Error != nil {
69+
sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error)))
70+
}
71+
default:
72+
// Log unhandled events for debugging
73+
slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName)
74+
}
75+
}
76+
}
77+
}()
78+
}
79+
}
80+
81+
// InitializeRAG initializes all RAG managers in the background
82+
func initializeRAG(ctx context.Context, ragManagers []*rag.Manager) {
83+
for _, mgr := range ragManagers {
84+
go func() {
85+
slog.Debug("Starting RAG manager initialization goroutine", "rag", mgr.Name())
86+
if err := mgr.Initialize(ctx); err != nil {
87+
slog.Error("Failed to initialize RAG manager", "rag", mgr.Name(), "error", err)
88+
} else {
89+
slog.Info("RAG manager initialized successfully", "rag", mgr.Name())
90+
}
91+
}()
92+
}
93+
}
94+
95+
// StartRAGFileWatchers starts file watchers for all RAG managers
96+
func startRAGFileWatchers(ctx context.Context, ragManagers []*rag.Manager) {
97+
for _, mgr := range ragManagers {
98+
go func() {
99+
slog.Debug("Starting RAG file watcher goroutine", "rag", mgr.Name())
100+
if err := mgr.StartFileWatcher(ctx); err != nil {
101+
slog.Error("Failed to start RAG file watcher", "rag", mgr.Name(), "error", err)
102+
}
103+
}()
104+
}
105+
}

pkg/runtime/runtime.go

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ import (
1818
"github.com/docker/docker-agent/pkg/config/types"
1919
"github.com/docker/docker-agent/pkg/hooks"
2020
"github.com/docker/docker-agent/pkg/modelsdev"
21-
"github.com/docker/docker-agent/pkg/rag"
22-
ragtypes "github.com/docker/docker-agent/pkg/rag/types"
2321
"github.com/docker/docker-agent/pkg/session"
2422
"github.com/docker/docker-agent/pkg/sessiontitle"
2523
"github.com/docker/docker-agent/pkg/team"
@@ -163,12 +161,6 @@ type ModelStore interface {
163161
GetDatabase(ctx context.Context) (*modelsdev.Database, error)
164162
}
165163

166-
// RAGInitializer is implemented by runtimes that support background RAG initialization.
167-
// Local runtimes use this to start indexing early; remote runtimes typically do not.
168-
type RAGInitializer interface {
169-
StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event))
170-
}
171-
172164
// ToolsChangeSubscriber is implemented by runtimes that can notify when
173165
// toolsets report a change in their tool list (e.g. after an MCP
174166
// ToolListChanged notification). The provided callback is invoked
@@ -340,107 +332,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
340332
return r, nil
341333
}
342334

343-
// StartBackgroundRAGInit initializes RAG in background and forwards events
344-
// Should be called early (e.g., by App) to start indexing before RunStream
345-
func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) {
346-
if r.ragInitialized.Swap(true) {
347-
return
348-
}
349-
350-
ragManagers := r.team.RAGManagers()
351-
if len(ragManagers) == 0 {
352-
return
353-
}
354-
355-
slog.Debug("Starting background RAG initialization with event forwarding", "manager_count", len(ragManagers))
356-
357-
// Set up event forwarding BEFORE starting initialization
358-
// This ensures all events are captured
359-
r.forwardRAGEvents(ctx, ragManagers, sendEvent)
360-
361-
// Now start initialization (events will be forwarded)
362-
r.team.InitializeRAG(ctx)
363-
r.team.StartRAGFileWatchers(ctx)
364-
}
365-
366-
// forwardRAGEvents forwards RAG manager events to the given callback
367-
// Consolidates duplicated event forwarding logic
368-
func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers map[string]*rag.Manager, sendEvent func(Event)) {
369-
for _, mgr := range ragManagers {
370-
go func(mgr *rag.Manager) {
371-
ragName := mgr.Name()
372-
slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName)
373-
for {
374-
select {
375-
case <-ctx.Done():
376-
slog.Debug("RAG event forwarder stopped", "rag", ragName)
377-
return
378-
case ragEvent, ok := <-mgr.Events():
379-
if !ok {
380-
slog.Debug("RAG events channel closed", "rag", ragName)
381-
return
382-
}
383-
384-
agentName := r.CurrentAgentName()
385-
slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName)
386-
387-
switch ragEvent.Type {
388-
case ragtypes.EventTypeIndexingStarted:
389-
sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName, agentName))
390-
case ragtypes.EventTypeIndexingProgress:
391-
if ragEvent.Progress != nil {
392-
sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName))
393-
}
394-
case ragtypes.EventTypeIndexingComplete:
395-
sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName, agentName))
396-
case ragtypes.EventTypeUsage:
397-
// Convert RAG usage to TokenUsageEvent so TUI displays it
398-
sendEvent(NewTokenUsageEvent("", agentName, &Usage{
399-
InputTokens: ragEvent.TotalTokens,
400-
ContextLength: ragEvent.TotalTokens,
401-
Cost: ragEvent.Cost,
402-
}))
403-
case ragtypes.EventTypeError:
404-
if ragEvent.Error != nil {
405-
sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error)))
406-
}
407-
default:
408-
// Log unhandled events for debugging
409-
slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName)
410-
}
411-
}
412-
}
413-
}(mgr)
414-
}
415-
}
416-
417-
// InitializeRAG is called within RunStream as a fallback when background init wasn't used
418-
// (e.g., for exec command or API mode where there's no App)
419-
func (r *LocalRuntime) InitializeRAG(ctx context.Context, events chan Event) {
420-
// If already initialized via StartBackgroundRAGInit, skip entirely
421-
// Event forwarding was already set up there
422-
if r.ragInitialized.Swap(true) {
423-
slog.Debug("RAG already initialized, event forwarding already active", "manager_count", len(r.team.RAGManagers()))
424-
return
425-
}
426-
427-
ragManagers := r.team.RAGManagers()
428-
if len(ragManagers) == 0 {
429-
return
430-
}
431-
432-
slog.Debug("Setting up RAG initialization (fallback path for non-TUI)", "manager_count", len(ragManagers))
433-
434-
// Set up event forwarding BEFORE starting initialization
435-
r.forwardRAGEvents(ctx, ragManagers, func(event Event) {
436-
events <- event
437-
})
438-
439-
// Start initialization and file watchers
440-
r.team.InitializeRAG(ctx)
441-
r.team.StartRAGFileWatchers(ctx)
442-
}
443-
444335
func (r *LocalRuntime) CurrentAgentName() string {
445336
r.currentAgentMu.RLock()
446337
defer r.currentAgentMu.RUnlock()

pkg/runtime/runtime_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,8 @@ func TestStartBackgroundRAGInit_StopsForwardingAfterContextCancel(t *testing.T)
550550
_ = mgr.Close()
551551
}()
552552

553-
tm := team.New(team.WithRAGManagers(map[string]*rag.Manager{
554-
"default": mgr,
555-
}))
556-
557553
rt := &LocalRuntime{
558-
team: tm,
554+
team: team.New(team.WithRAGManagers([]*rag.Manager{mgr})),
559555
currentAgent: "root",
560556
}
561557

0 commit comments

Comments
 (0)