Skip to content

Commit 491326d

Browse files
committed
feat(unpack): improve runtime config handling and streamline layer unpacking logic
1 parent f383f26 commit 491326d

3 files changed

Lines changed: 29 additions & 48 deletions

File tree

pkg/distribution/builder/from_directory.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"github.com/docker/model-runner/pkg/distribution/types"
1515
)
1616

17+
const rootFSType = "rootfs"
18+
1719
// DirectoryOptions configures the behavior of FromDirectory.
1820
type DirectoryOptions struct {
1921
// Exclusions is a list of patterns to exclude from packaging.
@@ -130,7 +132,7 @@ func FromDirectory(dirPath string, opts ...DirectoryOption) (*Builder, error) {
130132
fileType := files.Classify(path)
131133
mediaType := fileTypeToMediaType(fileType)
132134

133-
// Track format from weight files
135+
// Track format from weight files (only weight file types affect format detection)
134136
switch fileType {
135137
case files.FileTypeSafetensors:
136138
if detectedFormat == "" {
@@ -147,10 +149,8 @@ func FromDirectory(dirPath string, opts ...DirectoryOption) (*Builder, error) {
147149
detectedFormat = types.FormatDiffusers
148150
}
149151
weightFiles = append(weightFiles, path)
150-
case files.FileTypeUnknown:
151-
case files.FileTypeConfig:
152-
case files.FileTypeLicense:
153-
case files.FileTypeChatTemplate:
152+
case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate:
153+
// Non-weight files don't affect format detection
154154
}
155155

156156
// Create layer with relative path annotation
@@ -182,7 +182,7 @@ func FromDirectory(dirPath string, opts ...DirectoryOption) (*Builder, error) {
182182
return nil, fmt.Errorf("no weight files (safetensors, GGUF, or DDUF) found in directory: %s", dirPath)
183183
}
184184

185-
// Build config - use the first weight file for metadata extraction
185+
// Build config
186186
config := types.Config{
187187
Format: detectedFormat,
188188
}
@@ -199,12 +199,12 @@ func FromDirectory(dirPath string, opts ...DirectoryOption) (*Builder, error) {
199199
Created: &created,
200200
},
201201
RootFS: oci.RootFS{
202-
Type: "rootfs",
202+
Type: rootFSType,
203203
DiffIDs: diffIDs,
204204
},
205205
},
206206
LayerList: layers,
207-
ConfigMediaType: types.MediaTypeModelConfigV02, // V0.2: layer-per-file with filepath annotations
207+
ConfigMediaType: types.MediaTypeModelConfigV02,
208208
}
209209

210210
return &Builder{
@@ -290,10 +290,11 @@ func fileTypeToMediaType(ft files.FileType) oci.MediaType {
290290
return types.MediaTypeLicense
291291
case files.FileTypeChatTemplate:
292292
return types.MediaTypeChatTemplate
293+
case files.FileTypeUnknown:
294+
return types.MediaTypeModelFile
293295
case files.FileTypeConfig:
294296
return types.MediaTypeModelFile
295297
default:
296-
// For unknown files, use the generic model file type
297298
return types.MediaTypeModelFile
298299
}
299300
}

pkg/distribution/internal/bundle/unpack.go

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,13 @@ func hasLayerWithMediaType(model types.Model, targetMediaType oci.MediaType) boo
175175
}
176176
}
177177

178-
func unpackRuntimeConfig(bundle *Bundle, mdl types.Model) error {
178+
// configProvider is an interface for types that provide a Config() method.
179+
// Both types.Model and types.ModelArtifact satisfy this interface.
180+
type configProvider interface {
181+
Config() (types.ModelConfig, error)
182+
}
183+
184+
func unpackRuntimeConfig(bundle *Bundle, mdl configProvider) error {
179185
cfg, err := mdl.Config()
180186
if err != nil {
181187
return err
@@ -317,11 +323,8 @@ func unpackSafetensorsWithAnnotations(bundle *Bundle, modelDir string, safetenso
317323
digestHex := filepath.Base(srcPath)
318324

319325
var destRelPath string
320-
if annotatedPath, ok := blobToFilepath[digestHex]; ok && strings.Contains(annotatedPath, "/") {
321-
// Use the annotated path (contains subdirectory)
322-
destRelPath = annotatedPath
323-
} else if annotatedPath, ok := blobToFilepath[digestHex]; ok {
324-
// Annotation exists but is just a filename - use it
326+
if annotatedPath, ok := blobToFilepath[digestHex]; ok {
327+
// Use the annotated path
325328
destRelPath = annotatedPath
326329
} else {
327330
// No annotation found - use legacy naming
@@ -422,7 +425,6 @@ func unpackDirTarArchives(bundle *Bundle, mdl types.Model) error {
422425
for _, layer := range layers {
423426
mediaType, err := layer.MediaType()
424427
if err != nil {
425-
fmt.Printf("Warning: failed to get media type for layer: %v", err)
426428
continue
427429
}
428430

@@ -641,21 +643,18 @@ func UnpackFromLayers(dir string, model types.ModelArtifact) (*Bundle, error) {
641643
for _, layer := range layers {
642644
mediaType, err := layer.MediaType()
643645
if err != nil {
644-
fmt.Printf("Warning: error getting media type: %v\n", err)
645646
continue
646647
}
647648

648649
// Get the filepath annotation
649650
dp, ok := layer.(descriptorProvider)
650651
if !ok {
651-
fmt.Printf("Warning: layer is not a descriptorProvider\n")
652652
continue
653653
}
654654

655655
desc := dp.GetDescriptor()
656656
relPath, exists := desc.Annotations[types.AnnotationFilePath]
657657
if !exists || relPath == "" {
658-
fmt.Printf("Warning: layer missing filepath annotation\n")
659658
continue
660659
}
661660

@@ -687,28 +686,16 @@ func UnpackFromLayers(dir string, model types.ModelArtifact) (*Bundle, error) {
687686
updateBundleFieldsFromLayer(bundle, mediaType, relPath)
688687
}
689688

690-
// Create the runtime config from the model
691-
cfg, err := model.Config()
692-
if err != nil {
693-
return nil, fmt.Errorf("get model config: %w", err)
694-
}
695-
696-
// Write runtime config to bundle root
697-
f, err := os.Create(filepath.Join(bundle.dir, "config.json"))
698-
if err != nil {
699-
return nil, fmt.Errorf("create runtime config file: %w", err)
700-
}
701-
defer f.Close()
702-
if err := json.NewEncoder(f).Encode(cfg); err != nil {
703-
return nil, fmt.Errorf("encode runtime config: %w", err)
689+
// Create the runtime config
690+
if err := unpackRuntimeConfig(bundle, model); err != nil {
691+
return nil, fmt.Errorf("add config.json to runtime bundle: %w", err)
704692
}
705-
bundle.runtimeConfig = cfg
706693

707694
return bundle, nil
708695
}
709696

710-
// unpackLayerToFile unpacks a single layer to the destination path.
711-
// It tries to use hard linking for local layers, falling back to copying for remote layers.
697+
// unpackLayerToFile unpacks a single layer to the destination path using hard linking.
698+
// It requires the layer to be a local layer with a file path (pathProvider interface).
712699
func unpackLayerToFile(destPath string, layer oci.Layer) error {
713700
// Try to get the layer's local path for hard linking
714701
type pathProvider interface {
@@ -724,6 +711,7 @@ func unpackLayerToFile(destPath string, layer oci.Layer) error {
724711

725712
// updateBundleFieldsFromLayer updates the bundle tracking fields based on the unpacked layer.
726713
func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPath string) {
714+
//nolint:exhaustive // only tracking specific model-related media types
727715
switch mediaType {
728716
case types.MediaTypeGGUF:
729717
if bundle.ggufFile == "" {

pkg/distribution/internal/bundle/unpack_test.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ import (
88

99
func TestValidatePathWithinDirectory(t *testing.T) {
1010
// Create a temporary directory for testing
11-
baseDir, err := os.MkdirTemp("", "unpack-test-*")
12-
if err != nil {
13-
t.Fatalf("Failed to create temp dir: %v", err)
14-
}
15-
defer os.RemoveAll(baseDir)
11+
baseDir := t.TempDir()
1612

1713
tests := []struct {
1814
name string
@@ -98,7 +94,7 @@ func TestValidatePathWithinDirectory(t *testing.T) {
9894

9995
// Tricky paths that might bypass naive checks
10096
{
101-
name: "dot dot in middle",
97+
name: ".. in middle",
10298
targetPath: "foo/../bar/model.safetensors",
10399
expectError: false,
104100
description: "Path with .. that stays within directory should be valid",
@@ -127,11 +123,7 @@ func TestValidatePathWithinDirectory(t *testing.T) {
127123

128124
func TestValidatePathWithinDirectory_RealFilesystem(t *testing.T) {
129125
// Create a temporary directory structure
130-
baseDir, err := os.MkdirTemp("", "unpack-realfs-test-*")
131-
if err != nil {
132-
t.Fatalf("Failed to create temp dir: %v", err)
133-
}
134-
defer os.RemoveAll(baseDir)
126+
baseDir := t.TempDir()
135127

136128
// Create a sibling directory that attacker might try to access
137129
siblingDir := filepath.Join(filepath.Dir(baseDir), "sibling-secret")
@@ -148,7 +140,7 @@ func TestValidatePathWithinDirectory_RealFilesystem(t *testing.T) {
148140

149141
// Try to escape to the sibling directory
150142
escapePath := "../sibling-secret/secret.txt"
151-
err = validatePathWithinDirectory(baseDir, escapePath)
143+
err := validatePathWithinDirectory(baseDir, escapePath)
152144
if err == nil {
153145
t.Errorf("Expected error when attempting to escape to sibling directory, but validation passed")
154146
}

0 commit comments

Comments
 (0)