Skip to content

Commit 32ebaf5

Browse files
committed
Use Desktop proxy for all HTTP downloads
Updates fetch, openapi, api tools, gateway catalog, skills cache, and models.dev store to use remote.NewTransport for Desktop proxy support when downloading external content. Assisted-By: docker-agent Signed-off-by: Guillaume Tardif <guillaume.tardif@gmail.com>
1 parent 878630a commit 32ebaf5

9 files changed

Lines changed: 52 additions & 42 deletions

File tree

pkg/gateway/catalog.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/docker/docker-agent/pkg/paths"
16+
"github.com/docker/docker-agent/pkg/remote"
1617
)
1718

1819
const (
@@ -166,10 +167,6 @@ func saveToDisk(path string, catalog Catalog, etag string) {
166167
}
167168
}
168169

169-
// catalogClient is a dedicated HTTP client for catalog fetches, isolated from
170-
// http.DefaultClient so that other parts of the process cannot interfere.
171-
var catalogClient = &http.Client{}
172-
173170
// fetchFromNetwork fetches the catalog, using the ETag for conditional requests.
174171
// It returns (nil, "", nil) when the server responds with 304 Not Modified.
175172
func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) {
@@ -185,6 +182,7 @@ func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error)
185182
req.Header.Set("If-None-Match", etag)
186183
}
187184

185+
catalogClient := &http.Client{Transport: remote.NewTransport(ctx)}
188186
resp, err := catalogClient.Do(req)
189187
if err != nil {
190188
return nil, "", err

pkg/modelsdev/store.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"strings"
1414
"sync"
1515
"time"
16+
17+
"github.com/docker/docker-agent/pkg/remote"
1618
)
1719

1820
const (
@@ -183,7 +185,7 @@ func fetchFromAPI(ctx context.Context, etag string) (*Database, string, error) {
183185
req.Header.Set("If-None-Match", etag)
184186
}
185187

186-
resp, err := (&http.Client{Timeout: 30 * time.Second}).Do(req)
188+
resp, err := (&http.Client{Timeout: 30 * time.Second, Transport: remote.NewTransport(ctx)}).Do(req)
187189
if err != nil {
188190
return nil, "", fmt.Errorf("failed to fetch from API: %w", err)
189191
}

pkg/skills/cache.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package skills
22

33
import (
4+
"context"
45
"crypto/sha256"
56
"encoding/hex"
67
"encoding/json"
@@ -13,11 +14,12 @@ import (
1314
"strconv"
1415
"strings"
1516
"time"
17+
18+
"github.com/docker/docker-agent/pkg/remote"
1619
)
1720

1821
type diskCache struct {
19-
baseDir string
20-
httpClient *http.Client
22+
baseDir string
2123
}
2224

2325
type cacheMetadata struct {
@@ -29,9 +31,6 @@ type cacheMetadata struct {
2931
func newDiskCache(baseDir string) *diskCache {
3032
return &diskCache{
3133
baseDir: baseDir,
32-
httpClient: &http.Client{
33-
Timeout: 30 * time.Second,
34-
},
3534
}
3635
}
3736

@@ -68,10 +67,14 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) {
6867

6968
// FetchAndStore downloads a file from the given URL and stores it in the cache.
7069
// It respects Cache-Control headers to determine expiry.
71-
func (c *diskCache) FetchAndStore(baseURL, skillName, filePath, fileURL string) (string, error) {
70+
func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) {
7271
slog.Debug("Fetching remote skill file", "url", fileURL)
7372

74-
resp, err := c.httpClient.Get(fileURL)
73+
httpClient := &http.Client{
74+
Timeout: 30 * time.Second,
75+
Transport: remote.NewTransport(ctx),
76+
}
77+
resp, err := httpClient.Get(fileURL)
7578
if err != nil {
7679
return "", fmt.Errorf("fetching %s: %w", fileURL, err)
7780
}

pkg/skills/cache_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestDiskCache_FetchAndStore(t *testing.T) {
2222

2323
cache := newDiskCache(t.TempDir())
2424

25-
content, err := cache.FetchAndStore("https://example.com", "my-skill", "SKILL.md", srv.URL+"/SKILL.md")
25+
content, err := cache.FetchAndStore(t.Context(), "https://example.com", "my-skill", "SKILL.md", srv.URL+"/SKILL.md")
2626
require.NoError(t, err)
2727
assert.Equal(t, "file content", content)
2828

@@ -54,7 +54,7 @@ func TestDiskCache_Get_Cached(t *testing.T) {
5454

5555
cache := newDiskCache(t.TempDir())
5656

57-
_, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md")
57+
_, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md")
5858
require.NoError(t, err)
5959

6060
content, ok := cache.Get("https://example.com", "skill", "SKILL.md")
@@ -71,7 +71,7 @@ func TestDiskCache_Get_Expired(t *testing.T) {
7171

7272
cache := newDiskCache(t.TempDir())
7373

74-
_, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md")
74+
_, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md")
7575
require.NoError(t, err)
7676

7777
// The max-age=0 should make it immediately expired
@@ -87,7 +87,7 @@ func TestDiskCache_NestedFiles(t *testing.T) {
8787

8888
cache := newDiskCache(t.TempDir())
8989

90-
content, err := cache.FetchAndStore("https://example.com", "my-skill", "references/FORMS.md", srv.URL+"/file")
90+
content, err := cache.FetchAndStore(t.Context(), "https://example.com", "my-skill", "references/FORMS.md", srv.URL+"/file")
9191
require.NoError(t, err)
9292
assert.Equal(t, "nested file content", content)
9393

@@ -152,7 +152,7 @@ func TestDiskCache_HTTPError(t *testing.T) {
152152

153153
cache := newDiskCache(t.TempDir())
154154

155-
_, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/notfound")
155+
_, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/notfound")
156156
require.Error(t, err)
157157
assert.Contains(t, err.Error(), "HTTP 404")
158158
}

pkg/skills/remote.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package skills
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io"
@@ -11,6 +12,7 @@ import (
1112
"time"
1213

1314
"github.com/docker/docker-agent/pkg/paths"
15+
"github.com/docker/docker-agent/pkg/remote"
1416
)
1517

1618
// remoteIndex represents the index.json served at /.well-known/skills/index.json
@@ -24,10 +26,6 @@ type remoteSkillEntry struct {
2426
Files []string `json:"files"`
2527
}
2628

27-
var defaultHTTPClient = &http.Client{
28-
Timeout: 30 * time.Second,
29-
}
30-
3129
func defaultCache() *diskCache {
3230
return newDiskCache(filepath.Join(paths.GetCacheDir(), "skills"))
3331
}
@@ -37,16 +35,20 @@ func defaultCache() *diskCache {
3735
// into a disk cache so the agent can read them without network requests during
3836
// task execution.
3937
func loadRemoteSkills(baseURL string) []Skill {
40-
return loadRemoteSkillsWithCache(baseURL, defaultCache())
38+
return loadRemoteSkillsWithCache(context.Background(), baseURL, defaultCache())
4139
}
4240

43-
func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill {
41+
func loadRemoteSkillsWithCache(ctx context.Context, baseURL string, cache *diskCache) []Skill {
4442
baseURL = strings.TrimRight(baseURL, "/")
4543
indexURL := baseURL + "/.well-known/skills/index.json"
4644

4745
slog.Debug("Fetching remote skills index", "url", indexURL)
4846

49-
resp, err := defaultHTTPClient.Get(indexURL)
47+
httpClient := &http.Client{
48+
Timeout: 30 * time.Second,
49+
Transport: remote.NewTransport(ctx),
50+
}
51+
resp, err := httpClient.Get(indexURL)
5052
if err != nil {
5153
slog.Warn("Failed to fetch remote skills index", "url", indexURL, "error", err)
5254
return nil
@@ -77,7 +79,7 @@ func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill {
7779
}
7880

7981
cacheDir := cache.cacheDir(baseURL, entry.Name)
80-
prefetchFiles(cache, baseURL, entry.Name, entry.Files)
82+
prefetchFiles(ctx, cache, baseURL, entry.Name, entry.Files)
8183

8284
skill := Skill{
8385
Name: entry.Name,
@@ -96,7 +98,7 @@ func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill {
9698
// prefetchFiles downloads all files listed in the index for a skill,
9799
// storing them in the disk cache. Files already in cache (and not expired)
98100
// are skipped.
99-
func prefetchFiles(cache *diskCache, baseURL, skillName string, files []string) {
101+
func prefetchFiles(ctx context.Context, cache *diskCache, baseURL, skillName string, files []string) {
100102
for _, file := range files {
101103
if !isValidFilePath(file) {
102104
slog.Debug("Skipping invalid file path in skill", "skill", skillName, "file", file)
@@ -108,7 +110,7 @@ func prefetchFiles(cache *diskCache, baseURL, skillName string, files []string)
108110
}
109111

110112
fileURL := fmt.Sprintf("%s/.well-known/skills/%s/%s", baseURL, skillName, file)
111-
if _, err := cache.FetchAndStore(baseURL, skillName, file, fileURL); err != nil {
113+
if _, err := cache.FetchAndStore(ctx, baseURL, skillName, file, fileURL); err != nil {
112114
slog.Warn("Failed to prefetch skill file", "skill", skillName, "file", file, "error", err)
113115
}
114116
}

pkg/skills/remote_test.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestLoadRemoteSkills(t *testing.T) {
4747

4848
cacheDir := t.TempDir()
4949
cache := newDiskCache(cacheDir)
50-
skills := loadRemoteSkillsWithCache(srv.URL, cache)
50+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache)
5151

5252
require.Len(t, skills, 2)
5353

@@ -84,7 +84,7 @@ func TestLoadRemoteSkills(t *testing.T) {
8484
defer srv.Close()
8585

8686
cache := newDiskCache(t.TempDir())
87-
skills := loadRemoteSkillsWithCache(srv.URL+"/", cache)
87+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL+"/", cache)
8888
require.Len(t, skills, 1)
8989

9090
content, err := os.ReadFile(skills[0].FilePath)
@@ -99,7 +99,7 @@ func TestLoadRemoteSkills(t *testing.T) {
9999
}))
100100
defer srv.Close()
101101

102-
skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir()))
102+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir()))
103103
assert.Empty(t, skills)
104104
})
105105

@@ -110,7 +110,7 @@ func TestLoadRemoteSkills(t *testing.T) {
110110
}))
111111
defer srv.Close()
112112

113-
skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir()))
113+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir()))
114114
assert.Empty(t, skills)
115115
})
116116

@@ -121,15 +121,15 @@ func TestLoadRemoteSkills(t *testing.T) {
121121
}))
122122
defer srv.Close()
123123

124-
skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir()))
124+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir()))
125125
assert.Empty(t, skills)
126126
})
127127

128128
t.Run("server returns 404", func(t *testing.T) {
129129
srv := httptest.NewServer(http.NotFoundHandler())
130130
defer srv.Close()
131131

132-
skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir()))
132+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir()))
133133
assert.Empty(t, skills)
134134
})
135135

@@ -139,12 +139,12 @@ func TestLoadRemoteSkills(t *testing.T) {
139139
}))
140140
defer srv.Close()
141141

142-
skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir()))
142+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir()))
143143
assert.Empty(t, skills)
144144
})
145145

146146
t.Run("unreachable server", func(t *testing.T) {
147-
skills := loadRemoteSkillsWithCache("http://127.0.0.1:1", newDiskCache(t.TempDir()))
147+
skills := loadRemoteSkillsWithCache(t.Context(), "http://127.0.0.1:1", newDiskCache(t.TempDir()))
148148
assert.Empty(t, skills)
149149
})
150150

@@ -168,12 +168,12 @@ func TestLoadRemoteSkills(t *testing.T) {
168168
cache := newDiskCache(t.TempDir())
169169

170170
// First load
171-
skills1 := loadRemoteSkillsWithCache(srv.URL, cache)
171+
skills1 := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache)
172172
require.Len(t, skills1, 1)
173173
assert.Equal(t, 2, fetchCount) // index.json + SKILL.md
174174

175175
// Second load — SKILL.md should be cached
176-
skills2 := loadRemoteSkillsWithCache(srv.URL, cache)
176+
skills2 := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache)
177177
require.Len(t, skills2, 1)
178178
assert.Equal(t, 3, fetchCount) // only index.json re-fetched, SKILL.md from cache
179179
})
@@ -193,7 +193,7 @@ func TestLoadRemoteSkills(t *testing.T) {
193193
defer srv.Close()
194194

195195
cache := newDiskCache(t.TempDir())
196-
skills := loadRemoteSkillsWithCache(srv.URL, cache)
196+
skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache)
197197
require.Len(t, skills, 1)
198198
// Only SKILL.md should have been fetched, not the malicious paths
199199
})

pkg/tools/builtin/api.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/docker/docker-agent/pkg/config/latest"
1616
"github.com/docker/docker-agent/pkg/js"
17+
"github.com/docker/docker-agent/pkg/remote"
1718
"github.com/docker/docker-agent/pkg/tools"
1819
)
1920

@@ -30,7 +31,8 @@ var (
3031

3132
func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
3233
client := &http.Client{
33-
Timeout: 30 * time.Second,
34+
Timeout: 30 * time.Second,
35+
Transport: remote.NewTransport(ctx),
3436
}
3537

3638
endpoint := t.config.Endpoint

pkg/tools/builtin/fetch.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/k3a/html2text"
1515
"github.com/temoto/robotstxt"
1616

17+
"github.com/docker/docker-agent/pkg/remote"
1718
"github.com/docker/docker-agent/pkg/tools"
1819
"github.com/docker/docker-agent/pkg/useragent"
1920
)
@@ -49,7 +50,8 @@ func (h *fetchHandler) CallTool(ctx context.Context, params FetchToolArgs) (*too
4950

5051
// Set timeout if specified
5152
client := &http.Client{
52-
Timeout: h.timeout,
53+
Timeout: h.timeout,
54+
Transport: remote.NewTransport(ctx),
5355
}
5456
if params.Timeout > 0 {
5557
client.Timeout = time.Duration(params.Timeout) * time.Second

pkg/tools/builtin/openapi.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"github.com/getkin/kin-openapi/openapi3"
1717

18+
"github.com/docker/docker-agent/pkg/remote"
1819
"github.com/docker/docker-agent/pkg/tools"
1920
"github.com/docker/docker-agent/pkg/upstream"
2021
"github.com/docker/docker-agent/pkg/useragent"
@@ -70,7 +71,7 @@ func (t *OpenAPITool) fetchSpec(ctx context.Context) (*openapi3.T, error) {
7071
req.Header.Set("Accept", "application/json")
7172
setHeaders(req, t.headers)
7273

73-
resp, err := (&http.Client{Timeout: httpTimeout}).Do(req)
74+
resp, err := (&http.Client{Timeout: httpTimeout, Transport: remote.NewTransport(ctx)}).Do(req)
7475
if err != nil {
7576
return nil, fmt.Errorf("request failed: %w", err)
7677
}
@@ -398,7 +399,7 @@ func (h *openAPIHandler) callTool(ctx context.Context, params openAPICallArgs) (
398399
req.Header.Set("Accept", "application/json")
399400
setHeaders(req, h.headers)
400401

401-
resp, err := (&http.Client{Timeout: httpTimeout}).Do(req)
402+
resp, err := (&http.Client{Timeout: httpTimeout, Transport: remote.NewTransport(ctx)}).Do(req)
402403
if err != nil {
403404
return nil, fmt.Errorf("request failed: %w", err)
404405
}

0 commit comments

Comments
 (0)