Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions pkg/distribution/huggingface/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,
_ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "", "pull")
}

model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir, createdTime)
model, err := buildModelFromFiles(
result.LocalPaths, weightFiles, configFiles, mmprojFile, tempDir, createdTime,
)
if err != nil {
return nil, fmt.Errorf("build model: %w", err)
}
Expand All @@ -103,14 +105,20 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,
// which preserves directory structure and adds each file as an individual layer with
// filepath annotations. For GGUF models, it uses the V0.1 packaging (FromPaths)
// for backward compatibility.
func buildModelFromFiles(localPaths map[string]string, weightFiles, configFiles []RepoFile, tempDir string, createdTime *time.Time) (types.ModelArtifact, error) {
func buildModelFromFiles(
localPaths map[string]string,
weightFiles, configFiles []RepoFile,
mmprojFile *RepoFile,
tempDir string,
createdTime *time.Time,
) (types.ModelArtifact, error) {
// Check if this is a safetensors model - use V0.2 packaging
if isSafetensorsModel(weightFiles) {
return buildSafetensorsModelV02(tempDir, createdTime)
}

// For GGUF models, use V0.1 packaging (backward compatible)
return buildGGUFModelV01(localPaths, weightFiles, configFiles, createdTime)
return buildGGUFModelV01(localPaths, weightFiles, configFiles, mmprojFile, createdTime)
}

// buildSafetensorsModelV02 builds a safetensors model using V0.2 layer-per-file packaging.
Expand All @@ -133,7 +141,12 @@ func buildSafetensorsModelV02(tempDir string, createdTime *time.Time) (types.Mod
}

// buildGGUFModelV01 builds a GGUF model using V0.1 packaging (backward compatible).
func buildGGUFModelV01(localPaths map[string]string, weightFiles, configFiles []RepoFile, createdTime *time.Time) (types.ModelArtifact, error) {
func buildGGUFModelV01(
localPaths map[string]string,
weightFiles, configFiles []RepoFile,
mmprojFile *RepoFile,
createdTime *time.Time,
) (types.ModelArtifact, error) {
// Collect weight file paths (sorted for reproducibility)
var weightPaths []string
for _, f := range weightFiles {
Expand All @@ -157,7 +170,19 @@ func buildGGUFModelV01(localPaths map[string]string, weightFiles, configFiles []
return nil, fmt.Errorf("create builder: %w", err)
}

// Check for chat template and add it
// Add multimodal projector if present (F16 preferred, selected upstream).
if mmprojFile != nil {
localPath, ok := localPaths[mmprojFile.Path]
if !ok {
return nil, fmt.Errorf("missing local path for mmproj %s", mmprojFile.Path)
}
b, err = b.WithMultimodalProjector(localPath)
if err != nil {
return nil, fmt.Errorf("add mmproj: %w", err)
}
}

// Check for chat template and add it.
for _, f := range configFiles {
if isChatTemplate(f.Path) {
localPath := localPaths[f.Path]
Expand Down
83 changes: 83 additions & 0 deletions pkg/distribution/huggingface/model_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package huggingface

import (
"path/filepath"
"testing"
"time"

"github.com/docker/model-runner/pkg/distribution/types"
)

// TestBuildGGUFModelV01WithMMProj verifies that buildGGUFModelV01 includes
// the multimodal projector as a MediaTypeMultimodalProjector layer when an
// mmprojFile is provided.
func TestBuildGGUFModelV01WithMMProj(t *testing.T) {
assetsDir := filepath.Join("..", "assets")
ggufPath := filepath.Join(assetsDir, "dummy.gguf")
mmprojPath := filepath.Join(assetsDir, "dummy.mmproj")

weightFiles := []RepoFile{
{Type: "file", Path: "dummy.gguf"},
}
mmprojFile := &RepoFile{Type: "file", Path: "mmproj-model-f16.gguf"}
localPaths := map[string]string{
"dummy.gguf": ggufPath,
"mmproj-model-f16.gguf": mmprojPath,
}

artifact, err := buildGGUFModelV01(localPaths, weightFiles, nil, mmprojFile, nil)
if err != nil {
t.Fatalf("buildGGUFModelV01 failed: %v", err)
}

// Retrieve the manifest and look for the mmproj layer.
manifest, err := artifact.Manifest()
if err != nil {
t.Fatalf("get manifest: %v", err)
}

found := false
for _, layer := range manifest.Layers {
if layer.MediaType == types.MediaTypeMultimodalProjector {
found = true
break
}
}
if !found {
t.Errorf("expected manifest to contain a %s layer, but none was found",
types.MediaTypeMultimodalProjector)
}
}

// TestBuildGGUFModelV01WithoutMMProj verifies that buildGGUFModelV01 succeeds
// and produces no MediaTypeMultimodalProjector layer when no mmprojFile is
// provided.
func TestBuildGGUFModelV01WithoutMMProj(t *testing.T) {
assetsDir := filepath.Join("..", "assets")
ggufPath := filepath.Join(assetsDir, "dummy.gguf")

weightFiles := []RepoFile{
{Type: "file", Path: "dummy.gguf"},
}
localPaths := map[string]string{
"dummy.gguf": ggufPath,
}
createdTime := time.Now()

artifact, err := buildGGUFModelV01(localPaths, weightFiles, nil, nil, &createdTime)
if err != nil {
t.Fatalf("buildGGUFModelV01 failed: %v", err)
}

manifest, err := artifact.Manifest()
if err != nil {
t.Fatalf("get manifest: %v", err)
}

for _, layer := range manifest.Layers {
if layer.MediaType == types.MediaTypeMultimodalProjector {
t.Errorf("expected no %s layer, but one was found",
types.MediaTypeMultimodalProjector)
}
}
}
Loading