Skip to content

Commit fdf6028

Browse files
committed
Simplify RAG event forwarding and clean up RAGTool
- Remove RAGInitializer interface and StartBackgroundRAGInit indirection. RAG event callbacks are now wired in configureToolsetHandlers alongside other handler setup, using the same pattern as Elicitable/OAuthCapable. - Remove deprecated NewManagers wrapper (no callers after toolset refactor). - Clean up RAGTool: unexport internal types (QueryRAGArgs, QueryResult), inline sortResults, remove verbose debug logging from Tools(), simplify handleQueryRAG. Assisted-By: docker-agent
1 parent 10dcb3c commit fdf6028

6 files changed

Lines changed: 61 additions & 165 deletions

File tree

pkg/app/app.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ import (
3030
"github.com/docker/docker-agent/pkg/tui/messages"
3131
)
3232

33-
// RAGInitializer is implemented by runtimes that support background RAG initialization.
34-
// Local runtimes use this to start indexing early; remote runtimes typically do not.
35-
type RAGInitializer interface {
36-
StartBackgroundRAGInit(ctx context.Context, sendEvent func(runtime.Event))
37-
}
38-
3933
type App struct {
4034
runtime runtime.Runtime
4135
session *session.Session
@@ -122,18 +116,6 @@ func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts ..
122116
}
123117
}()
124118

125-
// If the runtime supports background RAG initialization, start it
126-
// and forward events to the TUI. Remote runtimes typically handle RAG server-side
127-
// and won't implement this optional interface.
128-
if ragRuntime, ok := rt.(RAGInitializer); ok {
129-
go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) {
130-
select {
131-
case app.events <- event:
132-
case <-ctx.Done():
133-
}
134-
})
135-
}
136-
137119
// Subscribe to tool list changes so the sidebar updates immediately
138120
// when an MCP server adds or removes tools (outside of a RunStream).
139121
if tcs, ok := rt.(runtime.ToolsChangeSubscriber); ok {

pkg/rag/builder.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,6 @@ type ManagersBuildConfig struct {
2424
Models map[string]latest.ModelConfig // Model configurations from config
2525
}
2626

27-
// NewManagers constructs all RAG managers defined in the config.
28-
//
29-
// Deprecated: Use NewManager for per-toolset creation instead.
30-
func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) {
31-
if len(cfg.RAG) == 0 {
32-
return nil, nil
33-
}
34-
35-
var managers []*Manager
36-
37-
for ragName, ragToolset := range cfg.RAG {
38-
if ragToolset.RAGConfig == nil {
39-
continue
40-
}
41-
mgr, err := NewManager(ctx, ragName, ragToolset.RAGConfig, buildCfg)
42-
if err != nil {
43-
return nil, err
44-
}
45-
managers = append(managers, mgr)
46-
}
47-
48-
return managers, nil
49-
}
50-
5127
// NewManager constructs a single RAG manager from a RAGConfig.
5228
func NewManager(
5329
ctx context.Context,

pkg/runtime/loop.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
9696
// Emit team information
9797
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())
9898

99-
// Initialize RAG and forward events
100-
r.StartBackgroundRAGInit(ctx, func(event Event) {
101-
events <- event
102-
})
103-
10499
r.emitAgentWarnings(a, chanSend(events))
105100
r.configureToolsetHandlers(a, events)
106101

@@ -534,6 +529,11 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even
534529
func() { events <- Authorization(tools.ElicitationActionAccept, a.Name()) },
535530
r.managedOAuth,
536531
)
532+
533+
// Wire RAG event forwarding so the TUI shows indexing progress.
534+
if ragTool, ok := tools.As[*builtin.RAGTool](toolset); ok {
535+
ragTool.SetEventCallback(ragEventForwarder(ragTool.Name(), r, chanSend(events)))
536+
}
537537
}
538538
}
539539

pkg/runtime/rag.go

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,15 @@
11
package runtime
22

33
import (
4-
"context"
54
"fmt"
65
"log/slog"
76

87
ragtypes "github.com/docker/docker-agent/pkg/rag/types"
9-
"github.com/docker/docker-agent/pkg/tools"
108
"github.com/docker/docker-agent/pkg/tools/builtin"
119
)
1210

13-
// StartBackgroundRAGInit discovers RAG toolsets from agents and wires up event
14-
// forwarding so the TUI can display indexing progress. Actual initialization
15-
// happens lazily when the tool is first used (via tools.Startable).
16-
func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) {
17-
for _, name := range r.team.AgentNames() {
18-
a, err := r.team.Agent(name)
19-
if err != nil {
20-
continue
21-
}
22-
for _, ts := range a.ToolSets() {
23-
ragTool, ok := tools.As[*builtin.RAGTool](ts)
24-
if !ok {
25-
continue
26-
}
27-
ragTool.SetEventCallback(ragEventForwarder(ctx, ragTool.Name(), r, sendEvent))
28-
}
29-
}
30-
}
31-
3211
// ragEventForwarder returns a callback that converts RAG manager events to runtime events.
33-
func ragEventForwarder(ctx context.Context, ragName string, r *LocalRuntime, sendEvent func(Event)) builtin.RAGEventCallback {
12+
func ragEventForwarder(ragName string, r *LocalRuntime, sendEvent func(Event)) builtin.RAGEventCallback {
3413
return func(ragEvent ragtypes.Event) {
3514
agentName := r.CurrentAgentName()
3615
slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName)
@@ -57,7 +36,5 @@ func ragEventForwarder(ctx context.Context, ragName string, r *LocalRuntime, sen
5736
default:
5837
slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName)
5938
}
60-
61-
_ = ctx // available for future use
6239
}
6340
}

pkg/tools/builtin/rag.go

Lines changed: 48 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@ import (
1717
// RAGEventCallback is called to forward RAG manager events during initialization.
1818
type RAGEventCallback func(event ragtypes.Event)
1919

20-
// RAGTool provides document querying capabilities for a single RAG source
20+
// RAGTool provides document querying capabilities for a single RAG source.
2121
type RAGTool struct {
2222
manager *rag.Manager
2323
toolName string
2424
eventCallback RAGEventCallback
2525
}
2626

27-
// Verify interface compliance
27+
// Verify interface compliance.
2828
var (
2929
_ tools.ToolSet = (*RAGTool)(nil)
3030
_ tools.Instructable = (*RAGTool)(nil)
3131
_ tools.Startable = (*RAGTool)(nil)
3232
)
3333

34-
// NewRAGTool creates a new RAG tool for a single RAG manager
35-
// toolName is the name to use for the tool (typically from config or manager name)
34+
// NewRAGTool creates a new RAG tool for a single RAG manager.
3635
func NewRAGTool(manager *rag.Manager, toolName string) *RAGTool {
3736
return &RAGTool{
3837
manager: manager,
@@ -45,39 +44,28 @@ func (t *RAGTool) Name() string {
4544
return t.toolName
4645
}
4746

48-
type QueryRAGArgs struct {
49-
Query string `json:"query" jsonschema:"Search query"`
50-
}
51-
52-
type QueryResult struct {
53-
SourcePath string `json:"source_path" jsonschema:"Path to the source document"`
54-
Content string `json:"content" jsonschema:"Relevant document chunk content"`
55-
Similarity float64 `json:"similarity" jsonschema:"Similarity score (0-1)"`
56-
ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"`
57-
}
58-
59-
// SetEventCallback sets a callback to receive RAG manager events during initialization.
60-
// This must be called before Start() to receive indexing progress events.
47+
// SetEventCallback sets a callback to receive RAG manager events during
48+
// initialization. Must be called before Start().
6149
func (t *RAGTool) SetEventCallback(cb RAGEventCallback) {
6250
t.eventCallback = cb
6351
}
6452

65-
// Start initializes the RAG manager (indexes documents).
53+
// Start initializes the RAG manager (indexes documents) and starts a
54+
// file watcher for incremental updates.
6655
func (t *RAGTool) Start(ctx context.Context) error {
6756
if t.manager == nil {
6857
return nil
6958
}
70-
slog.Debug("Starting RAG tool initialization", "tool", t.toolName)
7159

72-
// Forward RAG manager events if a callback is set
60+
// Forward RAG manager events if a callback is set.
7361
if t.eventCallback != nil {
7462
go t.forwardEvents(ctx)
7563
}
7664

7765
if err := t.manager.Initialize(ctx); err != nil {
7866
return fmt.Errorf("failed to initialize RAG manager %q: %w", t.toolName, err)
7967
}
80-
// Start file watcher in background
68+
8169
go func() {
8270
if err := t.manager.StartFileWatcher(ctx); err != nil {
8371
slog.Error("Failed to start RAG file watcher", "tool", t.toolName, "error", err)
@@ -86,6 +74,14 @@ func (t *RAGTool) Start(ctx context.Context) error {
8674
return nil
8775
}
8876

77+
// Stop closes the RAG manager and releases resources.
78+
func (t *RAGTool) Stop(_ context.Context) error {
79+
if t.manager == nil {
80+
return nil
81+
}
82+
return t.manager.Close()
83+
}
84+
8985
// forwardEvents reads events from the RAG manager and forwards them via the callback.
9086
func (t *RAGTool) forwardEvents(ctx context.Context) {
9187
for {
@@ -101,27 +97,27 @@ func (t *RAGTool) forwardEvents(ctx context.Context) {
10197
}
10298
}
10399

104-
// Stop closes the RAG manager and releases resources.
105-
func (t *RAGTool) Stop(_ context.Context) error {
106-
if t.manager == nil {
107-
return nil
108-
}
109-
return t.manager.Close()
110-
}
111-
112100
func (t *RAGTool) Instructions() string {
113101
if t.manager != nil {
114-
instruction := t.manager.ToolInstruction()
115-
if instruction != "" {
102+
if instruction := t.manager.ToolInstruction(); instruction != "" {
116103
return instruction
117104
}
118105
}
119-
120-
// Default instruction if none provided
121106
return fmt.Sprintf("Search documents in %s to find relevant code or documentation. "+
122107
"Provide a clear search query describing what you need.", t.toolName)
123108
}
124109

110+
type queryRAGArgs struct {
111+
Query string `json:"query" jsonschema:"Search query"`
112+
}
113+
114+
type queryResult struct {
115+
SourcePath string `json:"source_path" jsonschema:"Path to the source document"`
116+
Content string `json:"content" jsonschema:"Relevant document chunk content"`
117+
Similarity float64 `json:"similarity" jsonschema:"Similarity score (0-1)"`
118+
ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"`
119+
}
120+
125121
func (t *RAGTool) Tools(context.Context) ([]tools.Tool, error) {
126122
var description string
127123
if t.manager != nil {
@@ -131,83 +127,53 @@ func (t *RAGTool) Tools(context.Context) ([]tools.Tool, error) {
131127
"Provide a natural language query describing what you need. "+
132128
"Returns the most relevant document chunks with file paths.", t.toolName))
133129

134-
paramsSchema := tools.MustSchemaFor[QueryRAGArgs]()
135-
outputSchema := tools.MustSchemaFor[[]QueryResult]()
136-
137-
// Log schemas for debugging
138-
if paramsJSON, err := json.Marshal(paramsSchema); err == nil {
139-
slog.Debug("RAG tool parameters schema",
140-
"tool_name", t.toolName,
141-
"schema", string(paramsJSON))
142-
}
143-
if outputJSON, err := json.Marshal(outputSchema); err == nil {
144-
slog.Debug("RAG tool output schema",
145-
"tool_name", t.toolName,
146-
"schema", string(outputJSON))
147-
}
148-
149-
tool := tools.Tool{
130+
return []tools.Tool{{
150131
Name: t.toolName,
151132
Category: "knowledge",
152133
Description: description,
153-
Parameters: paramsSchema,
154-
OutputSchema: outputSchema,
134+
Parameters: tools.MustSchemaFor[queryRAGArgs](),
135+
OutputSchema: tools.MustSchemaFor[[]queryResult](),
155136
Handler: tools.NewHandler(t.handleQueryRAG),
156137
Annotations: tools.ToolAnnotations{
157138
ReadOnlyHint: true,
158139
Title: "Query " + t.toolName,
159140
},
160-
}
161-
162-
slog.Debug("RAG tool registered",
163-
"tool_name", tool.Name,
164-
"category", tool.Category,
165-
"description", description,
166-
"title", tool.Annotations.Title,
167-
"read_only", tool.Annotations.ReadOnlyHint)
168-
169-
return []tools.Tool{tool}, nil
141+
}}, nil
170142
}
171143

172-
func (t *RAGTool) handleQueryRAG(ctx context.Context, args QueryRAGArgs) (*tools.ToolCallResult, error) {
144+
func (t *RAGTool) handleQueryRAG(ctx context.Context, args queryRAGArgs) (*tools.ToolCallResult, error) {
173145
if args.Query == "" {
174146
return nil, errors.New("query cannot be empty")
175147
}
176148

177149
results, err := t.manager.Query(ctx, args.Query)
178150
if err != nil {
179-
slog.Error("RAG query failed", "rag", t.manager.Name(), "error", err)
180151
return nil, fmt.Errorf("RAG query failed: %w", err)
181152
}
182153

183-
allResults := make([]QueryResult, 0, len(results))
184-
for _, result := range results {
185-
allResults = append(allResults, QueryResult{
186-
SourcePath: result.Document.SourcePath,
187-
Content: result.Document.Content,
188-
Similarity: result.Similarity,
189-
ChunkIndex: result.Document.ChunkIndex,
154+
out := make([]queryResult, 0, len(results))
155+
for _, r := range results {
156+
out = append(out, queryResult{
157+
SourcePath: r.Document.SourcePath,
158+
Content: r.Document.Content,
159+
Similarity: r.Similarity,
160+
ChunkIndex: r.Document.ChunkIndex,
190161
})
191162
}
192163

193-
sortResults(allResults)
164+
slices.SortFunc(out, func(a, b queryResult) int {
165+
return cmp.Compare(b.Similarity, a.Similarity)
166+
})
194167

195-
maxResults := 10
196-
if len(allResults) > maxResults {
197-
allResults = allResults[:maxResults]
168+
const maxResults = 10
169+
if len(out) > maxResults {
170+
out = out[:maxResults]
198171
}
199172

200-
resultJSON, err := json.Marshal(allResults)
173+
resultJSON, err := json.Marshal(out)
201174
if err != nil {
202175
return nil, fmt.Errorf("failed to marshal results: %w", err)
203176
}
204177

205178
return tools.ResultSuccess(string(resultJSON)), nil
206179
}
207-
208-
// sortResults sorts query results by similarity in descending order
209-
func sortResults(results []QueryResult) {
210-
slices.SortFunc(results, func(a, b QueryResult) int {
211-
return cmp.Compare(b.Similarity, a.Similarity) // Descending order
212-
})
213-
}

0 commit comments

Comments
 (0)