Skip to content

Commit 3e7a791

Browse files
fix: check local store before pulling HuggingFace models (#617)
* fix: check local store before pulling HuggingFace models Fixes #616 Previously, HuggingFace models were always downloaded from the Hub even if the same model already existed in the local store. This caused unnecessary bandwidth usage and slower pull times. This change adds a cache check before calling pullNativeHuggingFace(), similar to how OCI models are handled. If the model is found in the local store, it returns immediately without downloading. Also includes hf.co to huggingface.co URL normalization to ensure consistent cache lookups regardless of which URL format is used. * refactor: apply Gemini code review suggestions - Fix error handling: only treat ErrModelNotFound as cache miss, propagate other errors - Consolidate test functions into table-driven test for better maintainability
1 parent 2e6fc88 commit 3e7a791

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pkg/distribution/distribution/client.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,24 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
249249
// HuggingFace references always use native pull (download raw files from HF Hub)
250250
if isHuggingFaceReference(originalReference) {
251251
c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference))
252+
253+
// Check if model already exists in local store (reference is already normalized)
254+
localModel, err := c.store.Read(reference)
255+
if err == nil {
256+
c.log.Infoln("HuggingFace model found in local store:", utils.SanitizeForLog(reference))
257+
cfg, err := localModel.Config()
258+
if err != nil {
259+
return fmt.Errorf("getting cached model config: %w", err)
260+
}
261+
if err := progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()), oci.ModePull); err != nil {
262+
c.log.Warnf("Writing progress: %v", err)
263+
}
264+
return nil
265+
}
266+
if !errors.Is(err, ErrModelNotFound) {
267+
return fmt.Errorf("checking for cached HuggingFace model: %w", err)
268+
}
269+
252270
// Pass original reference to preserve case-sensitivity for HuggingFace API
253271
return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token)
254272
}

pkg/distribution/distribution/client_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,3 +1140,56 @@ func randomFile(size int64) (string, error) {
11401140

11411141
return f.Name(), nil
11421142
}
1143+
1144+
func TestPullHuggingFaceModelFromCache(t *testing.T) {
1145+
testCases := []struct {
1146+
name string
1147+
pullRef string
1148+
}{
1149+
{
1150+
name: "full URL",
1151+
pullRef: "huggingface.co/testorg/testmodel:latest",
1152+
},
1153+
{
1154+
name: "short URL",
1155+
pullRef: "hf.co/testorg/testmodel:latest",
1156+
},
1157+
}
1158+
1159+
for _, tc := range testCases {
1160+
t.Run(tc.name, func(t *testing.T) {
1161+
tempDir := t.TempDir()
1162+
1163+
// Create client
1164+
client, err := newTestClient(tempDir)
1165+
if err != nil {
1166+
t.Fatalf("Failed to create client: %v", err)
1167+
}
1168+
1169+
// Create a test model and write it to the store with a normalized HuggingFace tag
1170+
model, err := gguf.NewModel(testGGUFFile)
1171+
if err != nil {
1172+
t.Fatalf("Failed to create model: %v", err)
1173+
}
1174+
1175+
// Store with normalized tag (huggingface.co)
1176+
hfTag := "huggingface.co/testorg/testmodel:latest"
1177+
if err := client.store.Write(model, []string{hfTag}, nil); err != nil {
1178+
t.Fatalf("Failed to write model to store: %v", err)
1179+
}
1180+
1181+
// Now try to pull using the test case's reference - it should use the cache
1182+
var progressBuffer bytes.Buffer
1183+
err = client.PullModel(t.Context(), tc.pullRef, &progressBuffer)
1184+
if err != nil {
1185+
t.Fatalf("Failed to pull model from cache: %v", err)
1186+
}
1187+
1188+
// Verify that progress shows it was cached
1189+
progressOutput := progressBuffer.String()
1190+
if !strings.Contains(progressOutput, "Using cached model") {
1191+
t.Errorf("Expected progress to indicate cached model, got: %s", progressOutput)
1192+
}
1193+
})
1194+
}
1195+
}

0 commit comments

Comments
 (0)