@@ -6,14 +6,22 @@ import (
66 "errors"
77 "io"
88 "net/http"
9+ "strings"
910 "testing"
1011
1112 mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
13+ "github.com/docker/model-runner/pkg/distribution/distribution"
1214 "github.com/stretchr/testify/assert"
1315 "github.com/stretchr/testify/require"
1416 "go.uber.org/mock/gomock"
1517)
1618
19+ // errorReadCloser is an io.ReadCloser whose Read always returns an error.
20+ type errorReadCloser struct { err error }
21+
22+ func (e * errorReadCloser ) Read (_ []byte ) (int , error ) { return 0 , e .err }
23+ func (e * errorReadCloser ) Close () error { return nil }
24+
1725func TestPullRetryOnNetworkError (t * testing.T ) {
1826 ctrl := gomock .NewController (t )
1927 defer ctrl .Finish ()
@@ -59,7 +67,7 @@ func TestPullNoRetryOn4xxError(t *testing.T) {
5967 assert .Contains (t , err .Error (), "Model not found" )
6068}
6169
62- func TestPullRetryOn5xxError (t * testing.T ) {
70+ func TestPullNoRetryOn500Error (t * testing.T ) {
6371 ctrl := gomock .NewController (t )
6472 defer ctrl .Finish ()
6573
@@ -68,21 +76,83 @@ func TestPullRetryOn5xxError(t *testing.T) {
6876 mockContext := NewContextForMock (mockClient )
6977 client := New (mockContext )
7078
71- // First attempt fails with 500, second succeeds
72- gomock .InOrder (
73- mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
74- StatusCode : http .StatusInternalServerError ,
75- Body : io .NopCloser (bytes .NewBufferString ("Internal server error" )),
76- }, nil ),
77- mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
78- StatusCode : http .StatusOK ,
79- Body : io .NopCloser (bytes .NewBufferString (`{"type":"success","message":"Model pulled successfully"}` )),
80- }, nil ),
81- )
79+ // 500 is not retried (deterministic server error), so only 1 call.
80+ mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
81+ StatusCode : http .StatusInternalServerError ,
82+ Body : io .NopCloser (bytes .NewBufferString ("Internal server error" )),
83+ }, nil ).Times (1 )
8284
8385 printer := NewSimplePrinter (func (s string ) {})
8486 _ , _ , err := client .Pull (modelName , printer )
85- assert .NoError (t , err )
87+ assert .Error (t , err )
88+ assert .Contains (t , err .Error (), "Internal server error" )
89+ }
90+
91+ func TestPullNoRetryOn422Error (t * testing.T ) {
92+ ctrl := gomock .NewController (t )
93+ defer ctrl .Finish ()
94+
95+ modelName := "test-model"
96+ mockClient := mockdesktop .NewMockDockerHttpClient (ctrl )
97+ mockContext := NewContextForMock (mockClient )
98+ client := New (mockContext )
99+
100+ // 422 (unsupported media type) must not be retried.
101+ unsupportedMsg := `error while pulling model: config type "v0.3" is not supported` +
102+ ` - try upgrading`
103+ mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
104+ StatusCode : http .StatusUnprocessableEntity ,
105+ Body : io .NopCloser (bytes .NewBufferString (unsupportedMsg )),
106+ }, nil ).Times (1 )
107+
108+ printer := NewSimplePrinter (func (s string ) {})
109+ _ , _ , err := client .Pull (modelName , printer )
110+ require .Error (t , err )
111+ // The sentinel must be preserved so callers can use errors.Is.
112+ assert .True (t , errors .Is (err , distribution .ErrUnsupportedMediaType ))
113+ }
114+
115+ func TestPullRetriesOnTransientGatewayErrors (t * testing.T ) {
116+ // 502 and 504 are transient gateway/proxy errors and should be retried.
117+ // Note: 503 is intercepted by doRequestWithAuthContext as ErrServiceUnavailable
118+ // and is covered separately by TestPullRetryOnServiceUnavailable.
119+ transientCodes := []struct {
120+ code int
121+ name string
122+ body string
123+ }{
124+ {http .StatusBadGateway , "502 Bad Gateway" , "Bad Gateway" },
125+ {http .StatusGatewayTimeout , "504 Gateway Timeout" , "Gateway Timeout" },
126+ }
127+
128+ for _ , tc := range transientCodes {
129+ t .Run (tc .name , func (t * testing.T ) {
130+ ctrl := gomock .NewController (t )
131+ defer ctrl .Finish ()
132+
133+ mockClient := mockdesktop .NewMockDockerHttpClient (ctrl )
134+ mockContext := NewContextForMock (mockClient )
135+ client := New (mockContext )
136+
137+ // First attempt fails with the transient error, second succeeds.
138+ gomock .InOrder (
139+ mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
140+ StatusCode : tc .code ,
141+ Body : io .NopCloser (bytes .NewBufferString (tc .body )),
142+ }, nil ),
143+ mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
144+ StatusCode : http .StatusOK ,
145+ Body : io .NopCloser (bytes .NewBufferString (
146+ `{"type":"success","message":"Model pulled successfully"}` ,
147+ )),
148+ }, nil ),
149+ )
150+
151+ printer := NewSimplePrinter (func (s string ) {})
152+ _ , _ , err := client .Pull ("test-model" , printer )
153+ assert .NoError (t , err )
154+ })
155+ }
86156}
87157
88158func TestPullRetryOnServiceUnavailable (t * testing.T ) {
@@ -127,7 +197,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
127197 printer := NewSimplePrinter (func (s string ) {})
128198 _ , _ , err := client .Pull (modelName , printer )
129199 assert .Error (t , err )
130- assert .Contains (t , err .Error (), "failed to download after 3 retries" )
200+ assert .Contains (t , err .Error (), "download failed after 3 retries" )
131201}
132202
133203func TestPushRetryOnNetworkError (t * testing.T ) {
@@ -341,3 +411,45 @@ func TestIsTemplateIncompatibleError(t *testing.T) {
341411 })
342412 }
343413}
414+
415+ func TestPullBodyReadFailure (t * testing.T ) {
416+ ctrl := gomock .NewController (t )
417+ defer ctrl .Finish ()
418+
419+ mockClient := mockdesktop .NewMockDockerHttpClient (ctrl )
420+ mockContext := NewContextForMock (mockClient )
421+ client := New (mockContext )
422+
423+ // Response body read fails. Use a non-retryable 500 status so the test
424+ // completes in a single attempt.
425+ mockClient .EXPECT ().Do (gomock .Any ()).Return (& http.Response {
426+ StatusCode : http .StatusInternalServerError ,
427+ Body : & errorReadCloser {err : errors .New ("connection reset" )},
428+ }, nil ).Times (1 )
429+
430+ printer := NewSimplePrinter (func (s string ) {})
431+ _ , _ , err := client .Pull ("test-model" , printer )
432+ require .Error (t , err )
433+ assert .Contains (t , err .Error (), "failed to read response body" )
434+ }
435+
436+ func TestDisplayProgressNonJSONLines (t * testing.T ) {
437+ // Simulate a proxy returning an HTML error page instead of a progress stream.
438+ htmlBody := "<html><body><h1>502 Bad Gateway</h1></body></html>\n "
439+ printer := NewSimplePrinter (func (string ) {})
440+ _ , _ , err := DisplayProgress (strings .NewReader (htmlBody ), printer )
441+ require .Error (t , err )
442+ assert .Contains (t , err .Error (), "unexpected response from server" )
443+ assert .Contains (t , err .Error (), "502 Bad Gateway" )
444+ }
445+
446+ func TestDisplayProgressMixedContent (t * testing.T ) {
447+ // Valid progress followed by some unparseable lines: the valid progress
448+ // should be honoured and no error returned for the stray lines.
449+ body := `{"type":"success","message":"Model pulled successfully"}` + "\n " +
450+ "<html>some extra garbage</html>\n "
451+ printer := NewSimplePrinter (func (string ) {})
452+ msg , _ , err := DisplayProgress (strings .NewReader (body ), printer )
453+ require .NoError (t , err )
454+ assert .Equal (t , "Model pulled successfully" , msg )
455+ }
0 commit comments