Skip to content

Commit 2a0a2ff

Browse files
authored
Merge pull request #791 from docker/modelpack-support
Modelpack support
2 parents e81c289 + 75a4f80 commit 2a0a2ff

19 files changed

Lines changed: 1065 additions & 810 deletions

File tree

pkg/distribution/builder/builder_test.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/docker/model-runner/pkg/distribution/builder"
13-
"github.com/docker/model-runner/pkg/distribution/oci"
13+
"github.com/docker/model-runner/pkg/distribution/internal/testutil"
1414
"github.com/docker/model-runner/pkg/distribution/types"
1515
)
1616

@@ -398,8 +398,7 @@ func TestFromModelWithAdditionalLayers(t *testing.T) {
398398

399399
// TestFromModelErrorHandling tests that FromModel properly handles and surfaces errors from mdl.Layers()
400400
func TestFromModelErrorHandling(t *testing.T) {
401-
// Create a mock model that fails when Layers() is called
402-
mockModel := &mockFailingModel{}
401+
mockModel := testutil.WithLayersError(testutil.NewGGUFArtifact(t, filepath.Join("..", "assets", "dummy.gguf")), fmt.Errorf("simulated layers error"))
403402

404403
// Attempt to create a builder from the failing model
405404
_, err := builder.FromModel(mockModel)
@@ -424,12 +423,3 @@ func (ft *fakeTarget) Write(ctx context.Context, artifact types.ModelArtifact, w
424423
ft.artifact = artifact
425424
return nil
426425
}
427-
428-
// mockFailingModel is a mock that fails when Layers() is called
429-
type mockFailingModel struct {
430-
types.ModelArtifact
431-
}
432-
433-
func (m *mockFailingModel) Layers() ([]oci.Layer, error) {
434-
return nil, fmt.Errorf("simulated layers error")
435-
}

pkg/distribution/distribution/bundle_test.go

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ import (
66
"path/filepath"
77
"testing"
88

9-
"github.com/docker/model-runner/pkg/distribution/builder"
10-
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
11-
"github.com/docker/model-runner/pkg/distribution/internal/partial"
9+
"github.com/docker/model-runner/pkg/distribution/internal/testutil"
1210
"github.com/docker/model-runner/pkg/distribution/types"
1311
)
1412

@@ -22,12 +20,7 @@ func TestBundle(t *testing.T) {
2220
t.Fatalf("Failed to create client: %v", err)
2321
}
2422

25-
// Load dummy model from assets directory
26-
b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf"))
27-
if err != nil {
28-
t.Fatalf("Failed to create model: %v", err)
29-
}
30-
mdl := b.Model()
23+
mdl := testutil.NewGGUFArtifact(t, filepath.Join("..", "assets", "dummy.gguf"))
3124
singleGGUFID, err := mdl.ID()
3225
if err != nil {
3326
t.Fatalf("Failed to get model ID: %v", err)
@@ -36,12 +29,11 @@ func TestBundle(t *testing.T) {
3629
t.Fatalf("Failed to write model to store: %v", err)
3730
}
3831

39-
// Load model with multi-modal projector file
40-
mmprojLayer, err := partial.NewLayer(filepath.Join("..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector)
41-
if err != nil {
42-
t.Fatalf("Failed to create mmproj layer: %v", err)
43-
}
44-
mmprojMdl := mutate.AppendLayers(mdl, mmprojLayer)
32+
mmprojMdl := testutil.NewGGUFArtifact(
33+
t,
34+
filepath.Join("..", "assets", "dummy.gguf"),
35+
testutil.Layer(filepath.Join("..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector),
36+
)
4537
mmprojMdlID, err := mmprojMdl.ID()
4638
if err != nil {
4739
t.Fatalf("Failed to get model ID: %v", err)
@@ -50,12 +42,11 @@ func TestBundle(t *testing.T) {
5042
t.Fatalf("Failed to write model to store: %v", err)
5143
}
5244

53-
// Load model with template file
54-
templateLayer, err := partial.NewLayer(filepath.Join("..", "assets", "template.jinja"), types.MediaTypeChatTemplate)
55-
if err != nil {
56-
t.Fatalf("Failed to create chat template layer: %v", err)
57-
}
58-
templateMdl := mutate.AppendLayers(mdl, templateLayer)
45+
templateMdl := testutil.NewGGUFArtifact(
46+
t,
47+
filepath.Join("..", "assets", "dummy.gguf"),
48+
testutil.Layer(filepath.Join("..", "assets", "template.jinja"), types.MediaTypeChatTemplate),
49+
)
5950
templateMdlID, err := templateMdl.ID()
6051
if err != nil {
6152
t.Fatalf("Failed to get model ID: %v", err)
@@ -64,12 +55,12 @@ func TestBundle(t *testing.T) {
6455
t.Fatalf("Failed to write model to store: %v", err)
6556
}
6657

67-
// Load sharded dummy model from asset directory
68-
shardedB, err := builder.FromPath(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf"))
69-
if err != nil {
70-
t.Fatalf("Failed to create model: %v", err)
71-
}
72-
shardedMdl := shardedB.Model()
58+
shardedMdl := testutil.NewDockerArtifact(
59+
t,
60+
types.Config{Format: types.FormatGGUF},
61+
testutil.Layer(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf"), types.MediaTypeGGUF),
62+
testutil.Layer(filepath.Join("..", "assets", "dummy-00002-of-00002.gguf"), types.MediaTypeGGUF),
63+
)
7364
shardedGGUFID, err := shardedMdl.ID()
7465
if err != nil {
7566
t.Fatalf("Failed to get model ID: %v", err)

pkg/distribution/distribution/client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
1616
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1717
"github.com/docker/model-runner/pkg/distribution/internal/store"
18+
"github.com/docker/model-runner/pkg/distribution/modelpack"
1819
"github.com/docker/model-runner/pkg/distribution/oci"
1920
"github.com/docker/model-runner/pkg/distribution/oci/authn"
2021
"github.com/docker/model-runner/pkg/distribution/oci/remote"
@@ -777,7 +778,9 @@ func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string,
777778
if err != nil {
778779
return err
779780
}
780-
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 && manifest.Config.MediaType != types.MediaTypeModelConfigV02 {
781+
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 &&
782+
manifest.Config.MediaType != types.MediaTypeModelConfigV02 &&
783+
manifest.Config.MediaType != modelpack.MediaTypeModelConfigV1 {
781784
return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType)
782785
}
783786

0 commit comments

Comments
 (0)