|
4 | 4 | "encoding/json" |
5 | 5 | "testing" |
6 | 6 |
|
| 7 | + "github.com/docker/model-runner/pkg/distribution/modelpack" |
7 | 8 | "github.com/docker/model-runner/pkg/distribution/types" |
8 | 9 | "github.com/stretchr/testify/assert" |
9 | 10 | "github.com/stretchr/testify/require" |
@@ -478,6 +479,100 @@ func TestToOpenAIList(t *testing.T) { |
478 | 479 | assert.Nil(t, result.Data[1].DMR) |
479 | 480 | } |
480 | 481 |
|
| 482 | +func TestNormalizeConfigModelPack(t *testing.T) { |
| 483 | + // Test that normalizeConfig converts ModelPack config to Docker format. |
| 484 | + // This simulates what happens in ToModel() when the server normalizes |
| 485 | + // CNCF configs before serializing the API response. |
| 486 | + mp := &modelpack.Model{ |
| 487 | + Descriptor: modelpack.ModelDescriptor{ |
| 488 | + Family: "qwen3", |
| 489 | + }, |
| 490 | + Config: modelpack.ModelConfig{ |
| 491 | + Architecture: "qwen3", |
| 492 | + Format: "safetensors", |
| 493 | + ParamSize: "0.6B", |
| 494 | + Quantization: "F16", |
| 495 | + }, |
| 496 | + ModelFS: modelpack.ModelFS{ |
| 497 | + Type: "layers", |
| 498 | + }, |
| 499 | + } |
| 500 | + |
| 501 | + normalized := normalizeConfig(mp) |
| 502 | + require.NotNil(t, normalized) |
| 503 | + |
| 504 | + // Should be converted to *types.Config |
| 505 | + dockerCfg, ok := normalized.(*types.Config) |
| 506 | + require.True(t, ok, "Normalized config should be *types.Config") |
| 507 | + assert.Equal(t, types.FormatSafetensors, dockerCfg.Format) |
| 508 | + assert.Equal(t, "0.6B", dockerCfg.Parameters) |
| 509 | + assert.Equal(t, "F16", dockerCfg.Quantization) |
| 510 | + assert.Equal(t, "qwen3", dockerCfg.Architecture) |
| 511 | + assert.Equal(t, "0.6B", dockerCfg.Size) |
| 512 | +} |
| 513 | + |
| 514 | +func TestNormalizeConfigDocker(t *testing.T) { |
| 515 | + // Docker format configs should pass through unchanged. |
| 516 | + dockerCfg := &types.Config{ |
| 517 | + Format: "gguf", |
| 518 | + Parameters: "7B", |
| 519 | + Quantization: "Q4_K_M", |
| 520 | + Architecture: "llama", |
| 521 | + Size: "7B", |
| 522 | + } |
| 523 | + |
| 524 | + normalized := normalizeConfig(dockerCfg) |
| 525 | + assert.Equal(t, dockerCfg, normalized, "Docker config should pass through unchanged") |
| 526 | +} |
| 527 | + |
| 528 | +func TestNormalizeConfigNil(t *testing.T) { |
| 529 | + assert.Nil(t, normalizeConfig(nil), "nil config should return nil") |
| 530 | +} |
| 531 | + |
| 532 | +func TestToModelWithModelPackConfig(t *testing.T) { |
| 533 | + // Test that ToModel properly normalizes a ModelPack config and |
| 534 | + // the resulting JSON is always in Docker format. |
| 535 | + mp := &modelpack.Model{ |
| 536 | + Config: modelpack.ModelConfig{ |
| 537 | + Architecture: "qwen3", |
| 538 | + Format: "gguf", |
| 539 | + ParamSize: "0.6B", |
| 540 | + Quantization: "Q8_0", |
| 541 | + }, |
| 542 | + } |
| 543 | + |
| 544 | + m := &mockModel{ |
| 545 | + id: "sha256:cncf123456789012", |
| 546 | + tags: []string{"aistaging/qwen3-cncf:0.6B"}, |
| 547 | + config: mp, |
| 548 | + desc: types.Descriptor{}, |
| 549 | + } |
| 550 | + |
| 551 | + apiModel, err := ToModel(m) |
| 552 | + require.NoError(t, err) |
| 553 | + |
| 554 | + // Config should be normalized to *types.Config |
| 555 | + dockerCfg, ok := apiModel.Config.(*types.Config) |
| 556 | + require.True(t, ok, "Config should be normalized to *types.Config") |
| 557 | + assert.Equal(t, types.FormatGGUF, dockerCfg.Format) |
| 558 | + assert.Equal(t, "0.6B", dockerCfg.Parameters) |
| 559 | + assert.Equal(t, "Q8_0", dockerCfg.Quantization) |
| 560 | + assert.Equal(t, "qwen3", dockerCfg.Architecture) |
| 561 | + assert.Equal(t, "0.6B", dockerCfg.Size) |
| 562 | + |
| 563 | + // Verify the JSON output is always Docker format (flat structure) |
| 564 | + jsonData, err := json.Marshal(apiModel) |
| 565 | + require.NoError(t, err) |
| 566 | + |
| 567 | + var unmarshaled Model |
| 568 | + err = json.Unmarshal(jsonData, &unmarshaled) |
| 569 | + require.NoError(t, err) |
| 570 | + |
| 571 | + assert.Equal(t, "0.6B", unmarshaled.Config.GetParameters()) |
| 572 | + assert.Equal(t, "Q8_0", unmarshaled.Config.GetQuantization()) |
| 573 | + assert.Equal(t, "qwen3", unmarshaled.Config.GetArchitecture()) |
| 574 | +} |
| 575 | + |
481 | 576 | // Helper function to create int32 pointers |
482 | 577 | func int32Ptr(i int32) *int32 { |
483 | 578 | return &i |
|
0 commit comments