Skip to content

Commit a27c84f

Browse files
authored
fix: improve error handling and retry logic in pull operations (#795)
* fix: improve error handling and retry logic in pull operations * fix: refactor error handling for unexpected non-JSON progress data * fix: enhance error reporting for pull and push operations by improving response body handling * fix: refine retry logic for push operations to handle specific gateway errors * fix: enhance error handling for unsupported media types in pull operations
1 parent 8996da2 commit a27c84f

7 files changed

Lines changed: 270 additions & 45 deletions

File tree

cmd/cli/commands/utils.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/docker/model-runner/cmd/cli/desktop"
1212
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
13+
"github.com/docker/model-runner/pkg/distribution/distribution"
1314
"github.com/docker/model-runner/pkg/distribution/oci/reference"
1415
"github.com/docker/model-runner/pkg/inference/backends/vllm"
1516
"github.com/moby/term"
@@ -53,6 +54,13 @@ func handleClientError(err error, message string) error {
5354
var buf bytes.Buffer
5455
printNextSteps(&buf, []string{enableVLLM})
5556
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
57+
} else if errors.Is(err, distribution.ErrUnsupportedMediaType) {
58+
// The model uses a newer config format than this client supports.
59+
var buf bytes.Buffer
60+
printNextSteps(&buf, []string{
61+
"Upgrade Docker Model Runner to the latest version to support this model",
62+
})
63+
return fmt.Errorf("%s: %w\n%s", message, err, strings.TrimRight(buf.String(), "\n"))
5664
}
5765
return fmt.Errorf("%s: %w", message, err)
5866
}

cmd/cli/desktop/desktop.go

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,29 @@ func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, b
144144
defer resp.Body.Close()
145145

146146
if resp.StatusCode != http.StatusOK {
147-
body, _ := io.ReadAll(resp.Body)
148-
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
149-
// Only retry on server errors (5xx), not client errors (4xx)
150-
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
147+
body, readErr := io.ReadAll(resp.Body)
148+
var bodyStr string
149+
if readErr != nil {
150+
bodyStr = fmt.Sprintf("failed to read response body: %v", readErr)
151+
} else {
152+
bodyStr = strings.TrimSpace(string(body))
153+
}
154+
var err error
155+
if resp.StatusCode == http.StatusUnprocessableEntity {
156+
// 422 means the model uses a config type this client does not
157+
// support. Reattach the sentinel so callers can use errors.Is.
158+
err = fmt.Errorf("pulling %s failed with status %s: %w: %s",
159+
model, resp.Status, distribution.ErrUnsupportedMediaType, bodyStr)
160+
} else {
161+
err = fmt.Errorf("pulling %s failed with status %s: %s",
162+
model, resp.Status, bodyStr)
163+
}
164+
// Only retry on gateway/proxy errors (502, 503, 504).
165+
// Do not retry 500 (usually deterministic server errors) or
166+
// 4xx (client errors including 422 for unsupported media type).
167+
shouldRetry := resp.StatusCode == http.StatusBadGateway ||
168+
resp.StatusCode == http.StatusServiceUnavailable ||
169+
resp.StatusCode == http.StatusGatewayTimeout
151170
return "", false, err, shouldRetry
152171
}
153172

@@ -235,7 +254,7 @@ func (c *Client) withRetries(
235254
}
236255
}
237256

238-
return "", progressShown, fmt.Errorf("failed to %s after %d retries: %w", operationName, maxRetries, lastErr)
257+
return "", progressShown, fmt.Errorf("%s failed after %d retries: %w", operationName, maxRetries, lastErr)
239258
}
240259

241260
func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
@@ -272,10 +291,19 @@ func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, b
272291
defer resp.Body.Close()
273292

274293
if resp.StatusCode != http.StatusOK {
275-
body, _ := io.ReadAll(resp.Body)
276-
err := fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
277-
// Only retry on server errors (5xx), not client errors (4xx)
278-
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
294+
body, readErr := io.ReadAll(resp.Body)
295+
var bodyStr string
296+
if readErr != nil {
297+
bodyStr = fmt.Sprintf("(failed to read response body: %v)", readErr)
298+
} else {
299+
bodyStr = strings.TrimSpace(string(body))
300+
}
301+
err := fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, bodyStr)
302+
// Only retry on gateway/proxy errors. Do not retry plain 500
303+
// (usually deterministic server errors) or 4xx (client errors).
304+
shouldRetry := resp.StatusCode == http.StatusBadGateway ||
305+
resp.StatusCode == http.StatusServiceUnavailable ||
306+
resp.StatusCode == http.StatusGatewayTimeout
279307
return "", false, err, shouldRetry
280308
}
281309

cmd/cli/desktop/desktop_test.go

Lines changed: 126 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,22 @@ import (
66
"errors"
77
"io"
88
"net/http"
9+
"strings"
910
"testing"
1011

1112
mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
13+
"github.com/docker/model-runner/pkg/distribution/distribution"
1214
"github.com/stretchr/testify/assert"
1315
"github.com/stretchr/testify/require"
1416
"go.uber.org/mock/gomock"
1517
)
1618

19+
// errorReadCloser is an io.ReadCloser whose Read always returns an error.
20+
type errorReadCloser struct{ err error }
21+
22+
func (e *errorReadCloser) Read(_ []byte) (int, error) { return 0, e.err }
23+
func (e *errorReadCloser) Close() error { return nil }
24+
1725
func TestPullRetryOnNetworkError(t *testing.T) {
1826
ctrl := gomock.NewController(t)
1927
defer ctrl.Finish()
@@ -59,7 +67,7 @@ func TestPullNoRetryOn4xxError(t *testing.T) {
5967
assert.Contains(t, err.Error(), "Model not found")
6068
}
6169

62-
func TestPullRetryOn5xxError(t *testing.T) {
70+
func TestPullNoRetryOn500Error(t *testing.T) {
6371
ctrl := gomock.NewController(t)
6472
defer ctrl.Finish()
6573

@@ -68,21 +76,83 @@ func TestPullRetryOn5xxError(t *testing.T) {
6876
mockContext := NewContextForMock(mockClient)
6977
client := New(mockContext)
7078

71-
// First attempt fails with 500, second succeeds
72-
gomock.InOrder(
73-
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
74-
StatusCode: http.StatusInternalServerError,
75-
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
76-
}, nil),
77-
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
78-
StatusCode: http.StatusOK,
79-
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
80-
}, nil),
81-
)
79+
// 500 is not retried (deterministic server error), so only 1 call.
80+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
81+
StatusCode: http.StatusInternalServerError,
82+
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
83+
}, nil).Times(1)
8284

8385
printer := NewSimplePrinter(func(s string) {})
8486
_, _, err := client.Pull(modelName, printer)
85-
assert.NoError(t, err)
87+
assert.Error(t, err)
88+
assert.Contains(t, err.Error(), "Internal server error")
89+
}
90+
91+
func TestPullNoRetryOn422Error(t *testing.T) {
92+
ctrl := gomock.NewController(t)
93+
defer ctrl.Finish()
94+
95+
modelName := "test-model"
96+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
97+
mockContext := NewContextForMock(mockClient)
98+
client := New(mockContext)
99+
100+
// 422 (unsupported media type) must not be retried.
101+
unsupportedMsg := `error while pulling model: config type "v0.3" is not supported` +
102+
` - try upgrading`
103+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
104+
StatusCode: http.StatusUnprocessableEntity,
105+
Body: io.NopCloser(bytes.NewBufferString(unsupportedMsg)),
106+
}, nil).Times(1)
107+
108+
printer := NewSimplePrinter(func(s string) {})
109+
_, _, err := client.Pull(modelName, printer)
110+
require.Error(t, err)
111+
// The sentinel must be preserved so callers can use errors.Is.
112+
assert.True(t, errors.Is(err, distribution.ErrUnsupportedMediaType))
113+
}
114+
115+
func TestPullRetriesOnTransientGatewayErrors(t *testing.T) {
116+
// 502 and 504 are transient gateway/proxy errors and should be retried.
117+
// Note: 503 is intercepted by doRequestWithAuthContext as ErrServiceUnavailable
118+
// and is covered separately by TestPullRetryOnServiceUnavailable.
119+
transientCodes := []struct {
120+
code int
121+
name string
122+
body string
123+
}{
124+
{http.StatusBadGateway, "502 Bad Gateway", "Bad Gateway"},
125+
{http.StatusGatewayTimeout, "504 Gateway Timeout", "Gateway Timeout"},
126+
}
127+
128+
for _, tc := range transientCodes {
129+
t.Run(tc.name, func(t *testing.T) {
130+
ctrl := gomock.NewController(t)
131+
defer ctrl.Finish()
132+
133+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
134+
mockContext := NewContextForMock(mockClient)
135+
client := New(mockContext)
136+
137+
// First attempt fails with the transient error, second succeeds.
138+
gomock.InOrder(
139+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
140+
StatusCode: tc.code,
141+
Body: io.NopCloser(bytes.NewBufferString(tc.body)),
142+
}, nil),
143+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
144+
StatusCode: http.StatusOK,
145+
Body: io.NopCloser(bytes.NewBufferString(
146+
`{"type":"success","message":"Model pulled successfully"}`,
147+
)),
148+
}, nil),
149+
)
150+
151+
printer := NewSimplePrinter(func(s string) {})
152+
_, _, err := client.Pull("test-model", printer)
153+
assert.NoError(t, err)
154+
})
155+
}
86156
}
87157

88158
func TestPullRetryOnServiceUnavailable(t *testing.T) {
@@ -127,7 +197,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
127197
printer := NewSimplePrinter(func(s string) {})
128198
_, _, err := client.Pull(modelName, printer)
129199
assert.Error(t, err)
130-
assert.Contains(t, err.Error(), "failed to download after 3 retries")
200+
assert.Contains(t, err.Error(), "download failed after 3 retries")
131201
}
132202

133203
func TestPushRetryOnNetworkError(t *testing.T) {
@@ -341,3 +411,45 @@ func TestIsTemplateIncompatibleError(t *testing.T) {
341411
})
342412
}
343413
}
414+
415+
func TestPullBodyReadFailure(t *testing.T) {
416+
ctrl := gomock.NewController(t)
417+
defer ctrl.Finish()
418+
419+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
420+
mockContext := NewContextForMock(mockClient)
421+
client := New(mockContext)
422+
423+
// Response body read fails. Use a non-retryable 500 status so the test
424+
// completes in a single attempt.
425+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
426+
StatusCode: http.StatusInternalServerError,
427+
Body: &errorReadCloser{err: errors.New("connection reset")},
428+
}, nil).Times(1)
429+
430+
printer := NewSimplePrinter(func(s string) {})
431+
_, _, err := client.Pull("test-model", printer)
432+
require.Error(t, err)
433+
assert.Contains(t, err.Error(), "failed to read response body")
434+
}
435+
436+
func TestDisplayProgressNonJSONLines(t *testing.T) {
437+
// Simulate a proxy returning an HTML error page instead of a progress stream.
438+
htmlBody := "<html><body><h1>502 Bad Gateway</h1></body></html>\n"
439+
printer := NewSimplePrinter(func(string) {})
440+
_, _, err := DisplayProgress(strings.NewReader(htmlBody), printer)
441+
require.Error(t, err)
442+
assert.Contains(t, err.Error(), "unexpected response from server")
443+
assert.Contains(t, err.Error(), "502 Bad Gateway")
444+
}
445+
446+
func TestDisplayProgressMixedContent(t *testing.T) {
447+
// Valid progress followed by some unparseable lines: the valid progress
448+
// should be honoured and no error returned for the stray lines.
449+
body := `{"type":"success","message":"Model pulled successfully"}` + "\n" +
450+
"<html>some extra garbage</html>\n"
451+
printer := NewSimplePrinter(func(string) {})
452+
msg, _, err := DisplayProgress(strings.NewReader(body), printer)
453+
require.NoError(t, err)
454+
assert.Equal(t, "Model pulled successfully", msg)
455+
}

0 commit comments

Comments
 (0)