Skip to content

Commit 2372105

Browse files
ekcaseydoringeman
authored andcommitted
Optionally set context-size in artifact (docker#92)
Signed-off-by: Emily Casey <emily.casey@docker.com>
1 parent 6a01954 commit 2372105

6 files changed

Lines changed: 50 additions & 0 deletions

File tree

pkg/distribution/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ make build
3131
# Package a model with license files and push to a registry
3232
./bin/model-distribution-tool package --licenses license1.txt --licenses license2.txt ./model.gguf registry.example.com/models/llama:v1.0
3333

34+
# Package a model with a default context size and push to a registry
35+
./bin/model-distribution-tool ./model.gguf --context-size 2048 registry.example.com/models/llama:v1.0
36+
3437
# Push a model from the content store to the registry
3538
./bin/model-distribution-tool push registry.example.com/models/llama:v1.0
3639

pkg/distribution/builder/builder.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ func (b *Builder) WithLicense(path string) (*Builder, error) {
3838
}, nil
3939
}
4040

41+
func (b *Builder) WithContextSize(size uint64) *Builder {
42+
return &Builder{
43+
model: mutate.ContextSize(b.model, size),
44+
}
45+
}
46+
4147
// Target represents a build target
4248
type Target interface {
4349
Write(context.Context, types.ModelArtifact, io.Writer) error

pkg/distribution/internal/mutate/model.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type model struct {
1616
base types.ModelArtifact
1717
appended []v1.Layer
1818
configMediaType ggcr.MediaType
19+
contextSize *uint64
1920
}
2021

2122
func (m *model) Descriptor() (types.Descriptor, error) {
@@ -123,6 +124,9 @@ func (m *model) RawConfigFile() ([]byte, error) {
123124
}
124125
cf.RootFS.DiffIDs = append(cf.RootFS.DiffIDs, diffID)
125126
}
127+
if m.contextSize != nil {
128+
cf.Config.ContextSize = m.contextSize
129+
}
126130
raw, err := json.Marshal(cf)
127131
if err != nil {
128132
return nil, err

pkg/distribution/internal/mutate/mutate.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArti
2020
configMediaType: mt,
2121
}
2222
}
23+
24+
func ContextSize(mdl types.ModelArtifact, cs uint64) types.ModelArtifact {
25+
return &model{
26+
base: mdl,
27+
contextSize: &cs,
28+
}
29+
}

pkg/distribution/internal/mutate/mutate_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,32 @@ func TestConfigMediaTypes(t *testing.T) {
8383
t.Fatalf("Expected media type %s, got %s", newMediaType, manifest2.Config.MediaType)
8484
}
8585
}
86+
87+
func TestContextSize(t *testing.T) {
88+
mdl1, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf"))
89+
if err != nil {
90+
t.Fatalf("Failed to create model: %v", err)
91+
}
92+
cfg, err := mdl1.Config()
93+
if err != nil {
94+
t.Fatalf("Failed to get config file: %v", err)
95+
}
96+
if cfg.ContextSize != nil {
97+
t.Fatalf("Epected nil context size got %d", cfg.ContextSize)
98+
}
99+
100+
// set the context size
101+
mdl2 := mutate.ContextSize(mdl1, 2096)
102+
103+
// check the config
104+
cfg2, err := mdl2.Config()
105+
if err != nil {
106+
t.Fatalf("Failed to get config file: %v", err)
107+
}
108+
if cfg2.ContextSize == nil {
109+
t.Fatal("Expected non-nil context")
110+
}
111+
if *cfg2.ContextSize != uint64(2096) {
112+
t.Fatalf("Expected context size of 2096 got %d", *cfg2.ContextSize)
113+
}
114+
}

pkg/distribution/types/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Config struct {
4444
Architecture string `json:"architecture,omitempty"`
4545
Size string `json:"size,omitempty"`
4646
GGUF map[string]string `json:"gguf,omitempty"`
47+
ContextSize *uint64 `json:"context_size,omitempty"`
4748
}
4849

4950
// Descriptor provides metadata about the provenance of the model.

0 commit comments

Comments
 (0)