Skip to content

Commit 8d9a4c5

Browse files
authored
Merge pull request #2242 from dgageot/refactor-compact
Refactor compaction
2 parents 338f317 + 644680c commit 8d9a4c5

7 files changed

Lines changed: 330 additions & 288 deletions

File tree

pkg/compaction/compaction.go

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,7 @@
1-
// Package compaction provides conversation compaction (summarization) for
2-
// chat sessions that approach their model's context window limit.
3-
//
4-
// It is designed as a standalone component that can be used independently of
5-
// the runtime loop. The package exposes:
6-
//
7-
// - [BuildPrompt]: prepares a conversation for summarization by appending
8-
// the compaction prompt and sanitizing message costs.
9-
// - [ShouldCompact]: decides whether a session needs compaction based on
10-
// token usage and context window limits.
11-
// - [EstimateMessageTokens]: a fast heuristic for estimating the token
12-
// count of a single chat message.
13-
// - [HasConversationMessages]: checks whether a message list contains any
14-
// non-system messages worth summarizing.
151
package compaction
162

173
import (
184
_ "embed"
19-
"time"
205

216
"github.com/docker/docker-agent/pkg/chat"
227
)
@@ -26,68 +11,23 @@ var (
2611
SystemPrompt string
2712

2813
//go:embed prompts/compaction-user.txt
29-
userPrompt string
14+
UserPrompt string
3015
)
3116

3217
// contextThreshold is the fraction of the context window at which compaction
3318
// is triggered. When the estimated token usage exceeds this fraction of the
3419
// context limit, compaction is recommended.
3520
const contextThreshold = 0.9
3621

37-
// Result holds the outcome of a compaction operation.
38-
type Result struct {
39-
// Summary is the generated summary text.
40-
Summary string
41-
42-
// InputTokens is the token count reported by the summarization model,
43-
// used as an approximation of the new context size after compaction.
44-
InputTokens int64
45-
46-
// Cost is the cost of the summarization request in dollars.
47-
Cost float64
48-
}
49-
50-
// BuildPrompt prepares the messages for a summarization request.
51-
// It clones the conversation (zeroing per-message costs so they aren't
52-
// double-counted), then appends a user message containing the compaction
53-
// prompt. If additionalPrompt is non-empty it is included as extra
54-
// instructions.
55-
//
56-
// Callers should first check [HasConversationMessages] to avoid sending
57-
// an empty conversation to the model.
58-
func BuildPrompt(messages []chat.Message, additionalPrompt string) []chat.Message {
59-
prompt := userPrompt
60-
if additionalPrompt != "" {
61-
prompt += "\n\nAdditional instructions from user: " + additionalPrompt
62-
}
63-
64-
out := make([]chat.Message, len(messages), len(messages)+1)
65-
for i, msg := range messages {
66-
cloned := msg
67-
cloned.Cost = 0
68-
cloned.CacheControl = false
69-
out[i] = cloned
70-
}
71-
out = append(out, chat.Message{
72-
Role: chat.MessageRoleUser,
73-
Content: prompt,
74-
CreatedAt: time.Now().Format(time.RFC3339),
75-
})
76-
77-
return out
78-
}
79-
8022
// ShouldCompact reports whether a session's context usage has crossed the
81-
// compaction threshold. It returns true when the estimated total token count
23+
// compaction threshold. It returns true when the total token count
8224
// (input + output + addedTokens) exceeds [contextThreshold] (90%) of
83-
// contextLimit. A non-positive contextLimit is treated as unlimited and
84-
// always returns false.
25+
// contextLimit.
8526
func ShouldCompact(inputTokens, outputTokens, addedTokens, contextLimit int64) bool {
8627
if contextLimit <= 0 {
8728
return false
8829
}
89-
estimated := inputTokens + outputTokens + addedTokens
90-
return estimated > int64(float64(contextLimit)*contextThreshold)
30+
return (inputTokens + outputTokens + addedTokens) > int64(float64(contextLimit)*contextThreshold)
9131
}
9232

9333
// EstimateMessageTokens returns a rough token-count estimate for a single
@@ -121,15 +61,3 @@ func EstimateMessageTokens(msg *chat.Message) int64 {
12161
}
12262
return int64(chars/charsPerToken) + perMessageOverhead
12363
}
124-
125-
// HasConversationMessages reports whether messages contains at least one
126-
// non-system message. A session with only system prompts has no conversation
127-
// to summarize.
128-
func HasConversationMessages(messages []chat.Message) bool {
129-
for _, msg := range messages {
130-
if msg.Role != chat.MessageRoleSystem {
131-
return true
132-
}
133-
}
134-
return false
135-
}

pkg/compaction/compaction_test.go

Lines changed: 0 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"testing"
55

66
"github.com/stretchr/testify/assert"
7-
"github.com/stretchr/testify/require"
87

98
"github.com/docker/docker-agent/pkg/chat"
109
"github.com/docker/docker-agent/pkg/tools"
@@ -164,137 +163,3 @@ func TestShouldCompact(t *testing.T) {
164163
})
165164
}
166165
}
167-
168-
func TestHasConversationMessages(t *testing.T) {
169-
t.Parallel()
170-
171-
tests := []struct {
172-
name string
173-
messages []chat.Message
174-
want bool
175-
}{
176-
{
177-
name: "empty",
178-
messages: nil,
179-
want: false,
180-
},
181-
{
182-
name: "system only",
183-
messages: []chat.Message{
184-
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
185-
},
186-
want: false,
187-
},
188-
{
189-
name: "system and user",
190-
messages: []chat.Message{
191-
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
192-
{Role: chat.MessageRoleUser, Content: "Hello"},
193-
},
194-
want: true,
195-
},
196-
{
197-
name: "only user",
198-
messages: []chat.Message{
199-
{Role: chat.MessageRoleUser, Content: "Hello"},
200-
},
201-
want: true,
202-
},
203-
{
204-
name: "assistant message",
205-
messages: []chat.Message{
206-
{Role: chat.MessageRoleAssistant, Content: "Hi there"},
207-
},
208-
want: true,
209-
},
210-
}
211-
212-
for _, tt := range tests {
213-
t.Run(tt.name, func(t *testing.T) {
214-
t.Parallel()
215-
got := HasConversationMessages(tt.messages)
216-
assert.Equal(t, tt.want, got)
217-
})
218-
}
219-
}
220-
221-
func TestBuildPrompt(t *testing.T) {
222-
t.Parallel()
223-
224-
messages := []chat.Message{
225-
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
226-
{Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05},
227-
{Role: chat.MessageRoleAssistant, Content: "Hi!", Cost: 0.10},
228-
}
229-
230-
t.Run("basic", func(t *testing.T) {
231-
t.Parallel()
232-
233-
out := BuildPrompt(messages, "")
234-
235-
// Original messages + appended summarization prompt.
236-
require.Len(t, out, 4)
237-
238-
// Costs are zeroed.
239-
for _, msg := range out[:3] {
240-
assert.Zero(t, msg.Cost, "cost should be zeroed for %q", msg.Content)
241-
}
242-
243-
// Last message is the summarization prompt.
244-
last := out[len(out)-1]
245-
assert.Equal(t, chat.MessageRoleUser, last.Role)
246-
assert.Contains(t, last.Content, "summary")
247-
assert.NotEmpty(t, last.CreatedAt)
248-
})
249-
250-
t.Run("with additional prompt", func(t *testing.T) {
251-
t.Parallel()
252-
253-
out := BuildPrompt(messages, "focus on code changes")
254-
255-
last := out[len(out)-1]
256-
assert.Contains(t, last.Content, "Additional instructions from user: focus on code changes")
257-
})
258-
259-
t.Run("does not modify original messages", func(t *testing.T) {
260-
t.Parallel()
261-
262-
original := []chat.Message{
263-
{Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05},
264-
}
265-
266-
_ = BuildPrompt(original, "")
267-
268-
assert.InDelta(t, 0.05, original[0].Cost, 1e-9)
269-
assert.Len(t, original, 1)
270-
})
271-
272-
t.Run("strips CacheControl from cloned messages", func(t *testing.T) {
273-
t.Parallel()
274-
275-
input := []chat.Message{
276-
{Role: chat.MessageRoleSystem, Content: "system", CacheControl: true},
277-
{Role: chat.MessageRoleSystem, Content: "context", CacheControl: true},
278-
{Role: chat.MessageRoleUser, Content: "hello"},
279-
}
280-
281-
out := BuildPrompt(input, "")
282-
283-
// All cloned messages should have CacheControl=false
284-
for i, msg := range out {
285-
assert.False(t, msg.CacheControl, "message %d should have CacheControl stripped", i)
286-
}
287-
// Original should be unchanged
288-
assert.True(t, input[0].CacheControl)
289-
assert.True(t, input[1].CacheControl)
290-
})
291-
}
292-
293-
func TestPromptsAreEmbedded(t *testing.T) {
294-
t.Parallel()
295-
296-
assert.NotEmpty(t, SystemPrompt, "compaction system prompt should be embedded")
297-
assert.NotEmpty(t, userPrompt, "compaction user prompt should be embedded")
298-
assert.Contains(t, SystemPrompt, "summary")
299-
assert.Contains(t, userPrompt, "summary")
300-
}

pkg/runtime/loop.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
128128
}
129129
loopDetector := newToolLoopDetector(loopThreshold)
130130

131+
// overflowCompactions counts how many consecutive context-overflow
132+
// auto-compactions have been attempted without a successful model
133+
// call in between. This prevents an infinite loop when compaction
134+
// cannot reduce the context below the model's limit.
135+
const maxOverflowCompactions = 1
136+
var overflowCompactions int
137+
131138
// toolModelOverride holds the per-toolset model from the most recent
132139
// tool calls. It applies for one LLM turn, then resets.
133140
var toolModelOverride string
@@ -248,13 +255,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
248255
slog.Debug("Failed to get model definition", "error", err)
249256
}
250257

258+
// We can only compact if we know the limit.
251259
var contextLimit int64
252260
if m != nil {
253261
contextLimit = int64(m.Limit.Context)
254-
}
255262

256-
if m != nil && r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
257-
r.Summarize(ctx, sess, "", events)
263+
if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
264+
r.Summarize(ctx, sess, "", events)
265+
}
258266
}
259267

260268
messages := sess.GetMessages(a)
@@ -280,13 +288,18 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
280288
// Auto-recovery: if the error is a context overflow and
281289
// session compaction is enabled, compact the conversation
282290
// and retry the request instead of surfacing raw errors.
283-
if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction {
291+
// We allow at most maxOverflowCompactions consecutive attempts
292+
// to avoid an infinite loop when compaction cannot reduce
293+
// the context enough.
294+
if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && overflowCompactions < maxOverflowCompactions {
295+
overflowCompactions++
284296
slog.Warn("Context window overflow detected, attempting auto-compaction",
285297
"agent", a.Name(),
286298
"session_id", sess.ID,
287299
"input_tokens", sess.InputTokens,
288300
"output_tokens", sess.OutputTokens,
289301
"context_limit", contextLimit,
302+
"attempt", overflowCompactions,
290303
)
291304
events <- Warning(
292305
"The conversation has exceeded the model's context window. Automatically compacting the conversation history...",
@@ -313,6 +326,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
313326
return
314327
}
315328

329+
// A successful model call resets the overflow compaction counter.
330+
overflowCompactions = 0
331+
316332
if usedModel != nil && usedModel.ID() != model.ID() {
317333
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
318334
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())

0 commit comments

Comments
 (0)