diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 985c88c7..3cad9361 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -62,7 +62,7 @@ func init() { flags.StringVar(&generateConfig.Precision, "precision", "", "specify model precision, such as bf16, fp16, int8, etc") flags.StringVar(&generateConfig.Quantization, "quantization", "", "specify model quantization, such as awq, gptq, etc") flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") - flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") + flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "[deprecated] ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") if err := viper.BindPFlags(flags); err != nil { diff --git a/pkg/backend/attach.go b/pkg/backend/attach.go index 2011a5ff..8dfb1069 100644 --- a/pkg/backend/attach.go +++ b/pkg/backend/attach.go @@ -107,6 +107,7 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac pb.Start() defer pb.Stop() + // TODO: Copy old flags to the new layer. newLayers, err := proc.Process(ctx, builder, ".", processor.WithProgressTracker(pb)) if err != nil { return fmt.Errorf("failed to process layers: %w", err) @@ -266,11 +267,11 @@ func (b *backend) getProcessor(filepath string) processor.Processor { } if modelfile.IsFileType(filepath, modelfile.ModelFilePatterns) { - return processor.NewModelProcessor(b.store, modelspec.MediaTypeModelWeight, []string{filepath}) + return processor.NewModelProcessor(b.store, modelspec.MediaTypeModelWeight, []string{filepath}, make(map[string]map[string]string)) } if modelfile.IsFileType(filepath, modelfile.CodeFilePatterns) { - return processor.NewCodeProcessor(b.store, modelspec.MediaTypeModelCode, []string{filepath}) + return processor.NewCodeProcessor(b.store, modelspec.MediaTypeModelCode, []string{filepath}, make(map[string]map[string]string)) } if modelfile.IsFileType(filepath, modelfile.DocFilePatterns) { diff --git a/pkg/backend/build.go b/pkg/backend/build.go index e1967ae7..7c7e0250 100644 --- a/pkg/backend/build.go +++ b/pkg/backend/build.go @@ -170,7 +170,7 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build if cfg.Raw { mediaType = modelspec.MediaTypeModelWeightRaw } - processors = append(processors, processor.NewModelProcessor(b.store, mediaType, models)) + processors = append(processors, processor.NewModelProcessor(b.store, mediaType, models, modelfile.GetModelFlags())) } if codes := modelfile.GetCodes(); len(codes) > 0 { @@ -178,7 +178,7 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build if cfg.Raw { mediaType = modelspec.MediaTypeModelCodeRaw } - processors = append(processors, processor.NewCodeProcessor(b.store, mediaType, codes)) + processors = append(processors, processor.NewCodeProcessor(b.store, mediaType, codes, modelfile.GetCodeFlags())) } if docs := modelfile.GetDocs(); len(docs) > 0 { diff --git a/pkg/backend/build/builder.go b/pkg/backend/build/builder.go index 08de314c..9cb1ebd5 100644 --- a/pkg/backend/build/builder.go +++ b/pkg/backend/build/builder.go @@ -54,8 +54,8 @@ const ( // Builder is an interface for building artifacts. type Builder interface { - // BuildLayer builds the layer blob from the given file path. - BuildLayer(ctx context.Context, mediaType, workDir, path string, hooks hooks.Hooks) (ocispec.Descriptor, error) + // BuildLayer builds the layer blob from the given file path with optional extra annotations. + BuildLayer(ctx context.Context, mediaType, workDir, path string, annotations map[string]string, hooks hooks.Hooks) (ocispec.Descriptor, error) // BuildConfig builds the config blob of the artifact. BuildConfig(ctx context.Context, layers []ocispec.Descriptor, modelConfig *buildconfig.Model, hooks hooks.Hooks) (ocispec.Descriptor, error) @@ -119,7 +119,7 @@ type abstractBuilder struct { interceptor interceptor.Interceptor } -func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, path string, hooks hooks.Hooks) (ocispec.Descriptor, error) { +func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, path string, annotations map[string]string, hooks hooks.Hooks) (ocispec.Descriptor, error) { info, err := os.Stat(path) if err != nil { return ocispec.Descriptor{}, fmt.Errorf("failed to get file info: %w", err) @@ -223,6 +223,13 @@ func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, p } desc.Annotations[modelspec.AnnotationFileMetadata] = string(metadataStr) + // Add extra annotations if provided + if annotations != nil { + for key, value := range annotations { + desc.Annotations[key] = value + } + } + return desc, nil } diff --git a/pkg/backend/build/builder_test.go b/pkg/backend/build/builder_test.go index c4f641b2..874afcd0 100644 --- a/pkg/backend/build/builder_test.go +++ b/pkg/backend/build/builder_test.go @@ -126,7 +126,7 @@ func (s *BuilderTestSuite) TestBuildLayer() { s.mockOutputStrategy.On("OutputLayer", mock.Anything, "test/media-type.tar", "test-file.txt", mock.AnythingOfType("string"), mock.AnythingOfType("int64"), mock.AnythingOfType("*io.PipeReader"), mock.Anything). Return(expectedDesc, nil) - desc, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, s.tempFile, hooks.NewHooks()) + desc, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, s.tempFile, nil, hooks.NewHooks()) s.NoError(err) s.Equal(expectedDesc.MediaType, desc.MediaType) s.Equal(expectedDesc.Digest, desc.Digest) @@ -134,12 +134,12 @@ func (s *BuilderTestSuite) TestBuildLayer() { }) s.Run("file not found", func() { - _, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, filepath.Join(s.tempDir, "non-existent.txt"), hooks.NewHooks()) + _, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, filepath.Join(s.tempDir, "non-existent.txt"), nil, hooks.NewHooks()) s.Error(err) }) s.Run("directory not supported", func() { - _, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, s.tempDir, hooks.NewHooks()) + _, err := s.builder.BuildLayer(context.Background(), "test/media-type.tar", s.tempDir, s.tempDir, nil, hooks.NewHooks()) s.Error(err) s.True(strings.Contains(err.Error(), "is a directory and not supported yet")) }) diff --git a/pkg/backend/build_test.go b/pkg/backend/build_test.go index d284e1f4..867d2102 100644 --- a/pkg/backend/build_test.go +++ b/pkg/backend/build_test.go @@ -29,7 +29,9 @@ func TestGetProcessors(t *testing.T) { modelfile := &modelfile.Modelfile{} modelfile.On("GetConfigs").Return([]string{"config1", "config2"}) modelfile.On("GetModels").Return([]string{"model1", "model2"}) + modelfile.On("GetModelFlags").Return(make(map[string]map[string]string)) modelfile.On("GetCodes").Return([]string{"1.py", "2.py"}) + modelfile.On("GetCodeFlags").Return(make(map[string]map[string]string)) modelfile.On("GetDocs").Return([]string{"doc1", "doc2"}) b := &backend{} diff --git a/pkg/backend/inspect.go b/pkg/backend/inspect.go index 86934876..108fc03d 100644 --- a/pkg/backend/inspect.go +++ b/pkg/backend/inspect.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" godigest "github.com/opencontainers/go-digest" @@ -62,6 +63,32 @@ type InspectedModelArtifactLayer struct { Size int64 `json:"Size"` // Filepath is the filepath of the model artifact layer. Filepath string `json:"Filepath"` + // Flags is the flags of the model artifact layer. + Flags string `json:"Flags,omitempty"` +} + +// collectFlags collects all annotations from the layer (excluding known metadata annotations) +// and formats them as key=value pairs +func collectFlags(annotations map[string]string) string { + if annotations == nil { + return "" + } + + var flags []string + // Skip the filepath annotation since it's already displayed separately + // Also skip the file metadata annotation since it's internal metadata + skipAnnotations := map[string]bool{ + modelspec.AnnotationFilepath: true, + modelspec.AnnotationFileMetadata: true, + } + + for key, value := range annotations { + if !skipAnnotations[key] { + flags = append(flags, fmt.Sprintf("%s=%s", key, value)) + } + } + + return strings.Join(flags, ",") } // Inspect inspects the target from the storage. @@ -107,6 +134,7 @@ func (b *backend) Inspect(ctx context.Context, target string, cfg *config.Inspec Digest: layer.Digest.String(), Size: layer.Size, Filepath: layer.Annotations[modelspec.AnnotationFilepath], + Flags: collectFlags(layer.Annotations), }) } diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index 2dcca78a..3e0a6cc2 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -46,8 +46,8 @@ type base struct { patterns []string } -// Process implements the Processor interface, which can be reused by other processors. -func (b *base) Process(ctx context.Context, builder build.Builder, workDir string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { +// Process implements the Processor interface with file flags support. +func (b *base) Process(ctx context.Context, builder build.Builder, workDir string, fileFlags map[string]map[string]string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { processOpts := &processOptions{} for _, opt := range opts { opt(processOpts) @@ -103,7 +103,20 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin eg.Go(func() error { return retry.Do(func() error { - desc, err := builder.BuildLayer(ctx, b.mediaType, workDir, path, hooks.NewHooks( + // Get relative path for looking up flags + relPath, relErr := filepath.Rel(absWorkDir, path) + if relErr != nil { + return relErr + } + + anno := make(map[string]string) + if fileFlags != nil { + if flags, exists := fileFlags[relPath]; exists { + anno = flags + } + } + + desc, buildErr := builder.BuildLayer(ctx, b.mediaType, workDir, path, anno, hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return tracker.Add(internalpb.NormalizePrompt("Building layer"), name, size, reader) }), @@ -114,9 +127,10 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin tracker.Complete(name, fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Built layer"), desc.Digest)) }), )) - if err != nil { + + if buildErr != nil { cancel() - return err + return buildErr } mu.Lock() diff --git a/pkg/backend/processor/code.go b/pkg/backend/processor/code.go index f14dbb3e..d827babe 100644 --- a/pkg/backend/processor/code.go +++ b/pkg/backend/processor/code.go @@ -30,7 +30,7 @@ const ( ) // NewCodeProcessor creates a new code processor. -func NewCodeProcessor(store storage.Storage, mediaType string, patterns []string) Processor { +func NewCodeProcessor(store storage.Storage, mediaType string, patterns []string, flags map[string]map[string]string) Processor { return &codeProcessor{ base: &base{ name: codeProcessorName, @@ -38,12 +38,14 @@ func NewCodeProcessor(store storage.Storage, mediaType string, patterns []string mediaType: mediaType, patterns: patterns, }, + flags: flags, } } // codeProcessor is the processor to process the code file. type codeProcessor struct { - base *base + base *base + flags map[string]map[string]string } func (p *codeProcessor) Name() string { @@ -51,5 +53,5 @@ func (p *codeProcessor) Name() string { } func (p *codeProcessor) Process(ctx context.Context, builder build.Builder, workDir string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { - return p.base.Process(ctx, builder, workDir, opts...) + return p.base.Process(ctx, builder, workDir, p.flags, opts...) } diff --git a/pkg/backend/processor/code_test.go b/pkg/backend/processor/code_test.go index 0c562e2d..3971bcfd 100644 --- a/pkg/backend/processor/code_test.go +++ b/pkg/backend/processor/code_test.go @@ -44,7 +44,7 @@ type codeProcessorSuite struct { func (s *codeProcessorSuite) SetupTest() { s.mockStore = &storage.Storage{} s.mockBuilder = &buildmock.Builder{} - s.processor = NewCodeProcessor(s.mockStore, modelspec.MediaTypeModelCode, []string{"*.py"}) + s.processor = NewCodeProcessor(s.mockStore, modelspec.MediaTypeModelCode, []string{"*.py"}, make(map[string]map[string]string)) // generate test files for prorcess. s.workDir = s.Suite.T().TempDir() if err := os.WriteFile(filepath.Join(s.workDir, "test.py"), []byte(""), 0644); err != nil { @@ -58,7 +58,7 @@ func (s *codeProcessorSuite) TestName() { func (s *codeProcessorSuite) TestProcess() { ctx := context.Background() - s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ + s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ Digest: godigest.Digest("sha256:1234567890abcdef"), Size: int64(1024), Annotations: map[string]string{ diff --git a/pkg/backend/processor/doc.go b/pkg/backend/processor/doc.go index a322d357..9f718072 100644 --- a/pkg/backend/processor/doc.go +++ b/pkg/backend/processor/doc.go @@ -51,5 +51,5 @@ func (p *docProcessor) Name() string { } func (p *docProcessor) Process(ctx context.Context, builder build.Builder, workDir string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { - return p.base.Process(ctx, builder, workDir, opts...) + return p.base.Process(ctx, builder, workDir, nil, opts...) } diff --git a/pkg/backend/processor/doc_test.go b/pkg/backend/processor/doc_test.go index f30d088a..7f2a6cbb 100644 --- a/pkg/backend/processor/doc_test.go +++ b/pkg/backend/processor/doc_test.go @@ -58,7 +58,7 @@ func (s *docProcessorSuite) TestName() { func (s *docProcessorSuite) TestProcess() { ctx := context.Background() - s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ + s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ Digest: godigest.Digest("sha256:1234567890abcdef"), Size: int64(1024), Annotations: map[string]string{ diff --git a/pkg/backend/processor/model.go b/pkg/backend/processor/model.go index 8169cc17..826af1d5 100644 --- a/pkg/backend/processor/model.go +++ b/pkg/backend/processor/model.go @@ -30,7 +30,7 @@ const ( ) // NewModelProcessor creates a new model processor. -func NewModelProcessor(store storage.Storage, mediaType string, patterns []string) Processor { +func NewModelProcessor(store storage.Storage, mediaType string, patterns []string, flags map[string]map[string]string) Processor { return &modelProcessor{ base: &base{ name: modelProcessorName, @@ -38,12 +38,14 @@ func NewModelProcessor(store storage.Storage, mediaType string, patterns []strin mediaType: mediaType, patterns: patterns, }, + flags: flags, } } // modelProcessor is the processor to process the model file. type modelProcessor struct { - base *base + base *base + flags map[string]map[string]string } func (p *modelProcessor) Name() string { @@ -51,5 +53,5 @@ func (p *modelProcessor) Name() string { } func (p *modelProcessor) Process(ctx context.Context, builder build.Builder, workDir string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { - return p.base.Process(ctx, builder, workDir, opts...) + return p.base.Process(ctx, builder, workDir, p.flags, opts...) } diff --git a/pkg/backend/processor/model_config.go b/pkg/backend/processor/model_config.go index b6c55ecf..2a2b1d60 100644 --- a/pkg/backend/processor/model_config.go +++ b/pkg/backend/processor/model_config.go @@ -51,5 +51,5 @@ func (p *modelConfigProcessor) Name() string { } func (p *modelConfigProcessor) Process(ctx context.Context, builder build.Builder, workDir string, opts ...ProcessOption) ([]ocispec.Descriptor, error) { - return p.base.Process(ctx, builder, workDir, opts...) + return p.base.Process(ctx, builder, workDir, nil, opts...) } diff --git a/pkg/backend/processor/model_config_test.go b/pkg/backend/processor/model_config_test.go index 1cd49e48..a898d1cf 100644 --- a/pkg/backend/processor/model_config_test.go +++ b/pkg/backend/processor/model_config_test.go @@ -58,7 +58,7 @@ func (s *modelConfigProcessorSuite) TestName() { func (s *modelConfigProcessorSuite) TestProcess() { ctx := context.Background() - s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ + s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ Digest: godigest.Digest("sha256:1234567890abcdef"), Size: int64(1024), Annotations: map[string]string{ diff --git a/pkg/backend/processor/model_test.go b/pkg/backend/processor/model_test.go index 1c29c8c9..4afff308 100644 --- a/pkg/backend/processor/model_test.go +++ b/pkg/backend/processor/model_test.go @@ -44,7 +44,7 @@ type modelProcessorSuite struct { func (s *modelProcessorSuite) SetupTest() { s.mockStore = &storage.Storage{} s.mockBuilder = &buildmock.Builder{} - s.processor = NewModelProcessor(s.mockStore, modelspec.MediaTypeModelWeight, []string{"model"}) + s.processor = NewModelProcessor(s.mockStore, modelspec.MediaTypeModelWeight, []string{"model"}, make(map[string]map[string]string)) // generate test files for prorcess. s.workDir = s.Suite.T().TempDir() if err := os.WriteFile(filepath.Join(s.workDir, "model"), []byte(""), 0644); err != nil { @@ -58,7 +58,7 @@ func (s *modelProcessorSuite) TestName() { func (s *modelProcessorSuite) TestProcess() { ctx := context.Background() - s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ + s.mockBuilder.On("BuildLayer", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(ocispec.Descriptor{ Digest: godigest.Digest("sha256:1234567890abcdef"), Size: int64(1024), Annotations: map[string]string{ diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index 2f1251a3..fe1c6f50 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -31,7 +31,7 @@ type GenerateConfig struct { Name string Version string Output string - IgnoreUnrecognizedFileTypes bool + IgnoreUnrecognizedFileTypes bool // [deprecated] will be removed in the next release Overwrite bool Arch string Family string @@ -83,7 +83,7 @@ func (g *GenerateConfig) Validate() error { // If the output path does not exist, we can create the modelfile. if _, err := os.Stat(g.Output); err == nil { if !g.Overwrite { - return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", g.Output) + return fmt.Errorf("modelfile already exists at %s - use --overwrite to overwrite", g.Output) } } diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index 92d07aca..6cf001db 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -17,6 +17,7 @@ package modelfile import ( + "fmt" "path/filepath" "strings" ) @@ -26,6 +27,8 @@ var ( ConfigFilePatterns = []string{ "*.json", // JSON configuration files "*.jsonl", // JSON Lines format + "*.json5", // JSON5 files + "*.jsonc", // JSON with comments "*.yaml", // YAML configuration files "*.yml", // YAML alternative extension "*.toml", // TOML configuration files @@ -45,6 +48,12 @@ var ( "*.meta", // Model metadata "*tokenizer.model*", // Tokenizer files (e.g., Mistral v3) "config.json.*", // Model configuration variants + "*.hparams", // Hyperparameter files + "*.params", // Parameter files + "*.hyperparams", // Hyperparameter configuration + "*.wandb", // Weights & Biases configuration + "*.mlflow", // MLflow configuration + "*.tensorboard", // TensorBoard configuration } // Model file patterns - supported model file extensions. @@ -56,29 +65,75 @@ var ( "*.bin", // General binary format "*.pt", // PyTorch model "*.pth", // PyTorch model (alternative extension) + "*.mar", // PyTorch Model Archive + "*.pte", // PyTorch ExecuTorch format + "*.pt2", // PyTorch 2.0 export format + "*.ptl", // PyTorch Mobile format // TensorFlow formats. "*.tflite", // TensorFlow Lite "*.h5", // Keras HDF5 format "*.hdf", // Hierarchical Data Format "*.hdf5", // HDF5 (alternative extension) + "*.pb", // TensorFlow SavedModel/Frozen Graph + "*.meta", // TensorFlow checkpoint metadata + "*.data-*", // TensorFlow checkpoint data files + "*.index", // TensorFlow checkpoint index + + // GGML formats. + "*.gguf", // GGML Universal Format + "*.ggml", // GGML format (legacy) + "*.ggmf", // GGMF format (deprecated) + "*.ggjt", // GGJT format (deprecated) + "*.q4_0", // GGML Q4_0 quantization + "*.q4_1", // GGML Q4_1 quantization + "*.q5_0", // GGML Q5_0 quantization + "*.q5_1", // GGML Q5_1 quantization + "*.q8_0", // GGML Q8_0 quantization + "*.f16", // GGML F16 format + "*.f32", // GGML F32 format + + // checkpoint formats. + "*.ckpt", // Checkpoint format + "*.checkpoint", // Checkpoint format (alternative extension) + "*.dist_ckpt", // Distributed checkpoint format + + // Semantics-specific formats + "*.tensor", // Generic tensor format + "*.weights", // Generic weights format + "*.state", // State files + "*.embedding", // Embedding files + "*.vocab", // Vocabulary files (when binary) // Other ML frameworks. "*.ot", // OpenVINO format "*.engine", // TensorRT format "*.trt", // TensorRT format (alternative extension) "*.onnx", // Open Neural Network Exchange format - "*.gguf", // GGML Universal Format "*.msgpack", // MessagePack serialization "*.model", // Some NLP frameworks "*.pkl", // Pickle format "*.pickle", // Pickle format (alternative extension) - "*.ckpt", // Checkpoint format - "*.checkpoint", // Checkpoint format (alternative extension) + "*.keras", // Keras native format + "*.joblib", // Joblib serialization (scikit-learn) + "*.npy", // NumPy array format + "*.npz", // NumPy compressed archive + "*.nc", // NetCDF format + "*.mlmodel", // Apple Core ML format + "*.coreml", // Apple Core ML format (alternative) + "*.mleap", // MLeap format (Spark ML) + "*.surml", // SurrealML format + "*.llamafile", // Llamafile format + "*.caffemodel", // Caffe model format + "*.prototxt", // Caffe model definition + "*.dlc", // Qualcomm Deep Learning Container + "*.circle", // Samsung Circle format + "*.nb", // Neural Network Binary format } // Code file patterns - supported script and notebook files. CodeFilePatterns = []string{ + // language source files "*.py", // Python source files "*.ipynb", // Jupyter notebooks "*.sh", // Shell scripts @@ -88,11 +143,18 @@ var ( "*.hxx", // C++ header files "*.cpp", // C++ source files "*.cc", // C++ source files + "*.cxx", // C++ source files (alternative) + "*.c++", // C++ source files (alternative) "*.hpp", // C++ header files "*.hh", // C++ header files + "*.h++", // C++ header files (alternative) "*.java", // Java source files "*.js", // JavaScript source files + "*.mjs", // JavaScript ES6 modules + "*.cjs", // CommonJS modules + "*.jsx", // React JSX files "*.ts", // TypeScript source files + "*.tsx", // TypeScript JSX files "*.go", // Go source files "*.rs", // Rust source files "*.swift", // Swift source files @@ -100,30 +162,118 @@ var ( "*.php", // PHP source files "*.scala", // Scala source files "*.kt", // Kotlin source files + "*.kts", // Kotlin script files "*.r", // R source files + "*.R", // R source files (alternative) "*.m", // MATLAB/Objective-C source files + "*.mm", // Objective-C++ source files "*.f", // Fortran source files "*.f90", // Fortran 90 source files + "*.f95", // Fortran 95 source files + "*.f03", // Fortran 2003 source files + "*.f08", // Fortran 2008 source files "*.jl", // Julia source files "*.lua", // Lua source files "*.pl", // Perl source files + "*.pm", // Perl modules "*.cs", // C# source files "*.vb", // Visual Basic source files "*.dart", // Dart source files "*.groovy", // Groovy source files "*.elm", // Elm source files "*.erl", // Erlang source files + "*.hrl", // Erlang header files "*.ex", // Elixir source files + "*.exs", // Elixir script files "*.hs", // Haskell source files + "*.lhs", // Literate Haskell source files "*.clj", // Clojure source files "*.cljs", // ClojureScript source files - "*.cljc", // Clojure Common Lisp source files + "*.cljc", // Clojure Common source files "*.cl", // Common Lisp source files "*.lisp", // Lisp source files + "*.lsp", // Lisp source files (alternative) "*.scm", // Scheme source files + "*.ss", // Scheme source files (alternative) + "*.rkt", // Racket source files + "*.sql", // SQL files + "*.psql", // PostgreSQL files + "*.mysql", // MySQL files + "*.sqlite", // SQLite files + "*.zig", // Zig source files "*.cu", // CUDA source files "*.cuh", // CUDA header files + // Scripting and automation + "*.bash", // Bash scripts + "*.zsh", // Zsh scripts + "*.fish", // Fish shell scripts + "*.csh", // C shell scripts + "*.tcsh", // TC shell scripts + "*.ksh", // Korn shell scripts + "*.ps1", // PowerShell scripts + "*.psm1", // PowerShell modules + "*.psd1", // PowerShell data files + "*.bat", // Windows batch files + "*.cmd", // Windows command files + "*.vbs", // VBScript files + "*.wsf", // Windows Script Files + "*.applescript", // AppleScript files + "*.scpt", // AppleScript compiled files + "*.awk", // AWK scripts + "*.sed", // sed scripts + "*.expect", // Expect scripts + + // Build and project files + "*.env", // Environment variable files + "*.env.*", // Environment files with suffixes + ".env*", // Environment files (hidden) + "Makefile*", // Makefile variants + "*.dockerfile", // Dockerfile configurations + "Dockerfile*", // Dockerfile variants + "*.mk", // Make include files + "*.cmake", // CMake files + "CMakeLists.txt", // CMake configuration + "*.gradle", // Gradle build files + "*.gradle.kts", // Kotlin DSL Gradle files + "build.gradle*", // Gradle build files + "settings.gradle*", // Gradle settings files + "*.sbt", // SBT build files + "*.mill", // Mill build files + "*.bazel", // Bazel build files + "*.bzl", // Bazel extension files + "BUILD*", // Bazel BUILD files + "WORKSPACE*", // Bazel WORKSPACE files + "*.buck", // Buck build files + "BUCK*", // Buck BUILD files + "*.ninja", // Ninja build files + "*.gyp", // GYP build files + "*.gypi", // GYP include files + "*.waf", // Waf build files + "wscript*", // Waf build scripts + "package.json", // Node.js package file + "package-lock.json", // Node.js lock file + "yarn.lock", // Yarn lock file + "pnpm-lock.yaml", // PNPM lock file + "requirements*.txt", // Python requirements + "Pipfile*", // Python Pipenv files + "pyproject.toml", // Python project configuration + "setup.cfg", // Python setup configuration + "tox.ini", // Python tox configuration + "poetry.lock", // Python Poetry lock file + "Cargo.toml", // Rust package configuration + "Cargo.lock", // Rust lock file + "go.mod", // Go module file + "go.sum", // Go checksum file + "composer.json", // PHP Composer file + "composer.lock", // PHP Composer lock file + "Gemfile*", // Ruby Gemfile + "*.gemspec", // Ruby gem specification + "mix.exs", // Elixir Mix file + "mix.lock", // Elixir Mix lock file + "rebar.config", // Erlang Rebar config + "rebar.lock", // Erlang Rebar lock file + // Library files. "*.so", // Shared object files "*.dll", // Dynamic Link Library @@ -144,6 +294,93 @@ var ( "*requirements*", // Dependency specifications "*.log", // Log files + // Office documents + "*.doc", // Microsoft Word 97-2003 Document + "*.docx", // Microsoft Word Document + "*.docm", // Word Macro-Enabled Document + "*.dot", // Word 97-2003 Template + "*.dotx", // Word Template + "*.dotm", // Word Macro-Enabled Template + "*.rtf", // Rich Text Format + "*.odt", // OpenDocument Text + "*.ott", // OpenDocument Text Template + "*.fodt", // Flat OpenDocument Text + "*.pages", // Apple Pages document + "*.wpd", // WordPerfect document + + // Spreadsheet documents + "*.xls", // Microsoft Excel 97-2003 Workbook + "*.xlsx", // Microsoft Excel Workbook + "*.xlsm", // Excel Macro-Enabled Workbook + "*.xlsb", // Excel Binary Workbook + "*.xlt", // Excel 97-2003 Template + "*.xltx", // Excel Template + "*.xltm", // Excel Macro-Enabled Template + "*.ods", // OpenDocument Spreadsheet + "*.ots", // OpenDocument Spreadsheet Template + "*.fods", // Flat OpenDocument Spreadsheet + "*.numbers", // Apple Numbers spreadsheet + "*.csv", // Comma-Separated Values + + // Presentation documents + "*.ppt", // Microsoft PowerPoint 97-2003 Presentation + "*.pptx", // Microsoft PowerPoint Presentation + "*.pptm", // PowerPoint Macro-Enabled Presentation + "*.pps", // PowerPoint 97-2003 Show + "*.ppsx", // PowerPoint Show + "*.ppsm", // PowerPoint Macro-Enabled Show + "*.pot", // PowerPoint 97-2003 Template + "*.potx", // PowerPoint Template + "*.potm", // PowerPoint Macro-Enabled Template + "*.odp", // OpenDocument Presentation + "*.otp", // OpenDocument Presentation Template + "*.fodp", // Flat OpenDocument Presentation + "*.key", // Apple Keynote presentation + + // eBook formats + "*.epub", // Electronic Publication + "*.mobi", // Mobipocket eBook + "*.azw", // Amazon Kindle eBook + "*.azw3", // Amazon Kindle eBook (KF8) + "*.fb2", // FictionBook 2.0 + "*.fb3", // FictionBook 3.0 + "*.lit", // Microsoft Literature + "*.pdb", // Palm Database/Document File + "*.djvu", // DjVu document + "*.djv", // DjVu document (alternative extension) + + // Web and markup documents + "*.html", // HyperText Markup Language + "*.htm", // HyperText Markup Language (alternative) + "*.xhtml", // Extensible HyperText Markup Language + "*.mhtml", // MIME HTML (Web Archive) + "*.mht", // MIME HTML (Web Archive, alternative) + "*.xml", // eXtensible Markup Language + "*.xsl", // eXtensible Stylesheet Language + "*.xslt", // XSL Transformations + + // Technical documentation formats + "*.tex", // LaTeX document + "*.latex", // LaTeX document (alternative) + "*.ltx", // LaTeX document (alternative) + "*.bib", // BibTeX bibliography + "*.rst", // reStructuredText + "*.asciidoc", // AsciiDoc + "*.adoc", // AsciiDoc (alternative) + "*.textile", // Textile markup + "*.wiki", // Wiki markup + "*.mediawiki", // MediaWiki markup + "*.org", // Org-mode document + "*.texi", // Texinfo document + "*.texinfo", // Texinfo document (alternative) + "*.info", // GNU Info document + "*.man", // Manual page + + // Archive and compressed documents + "*.chm", // Compiled HTML Help + "*.hlp", // Windows Help File + "*.xps", // XML Paper Specification + // Image assets. "*.jpg", // JPEG image format "*.jpeg", // JPEG alternative extension @@ -180,6 +417,14 @@ var ( "*.pyo", // Python optimized bytecode "*.pyd", // Python dynamic modules } + + // Large file size threshold + WeightFileSizeThreshold int64 = 128 * 1024 * 1024 + + // Workspace limits + MaxSingleFileSize int64 = 128 * 1024 * 1024 * 1024 // 128GB + MaxWorkspaceFileCount int = 1024 // 1024 files + MaxTotalWorkspaceSize int64 = 8 * 1024 * 1024 * 1024 * 1024 // 8TB ) // IsFileType checks if the filename matches any of the given patterns @@ -216,3 +461,23 @@ func isSkippable(filename string) bool { return false } + +// For large unknown file type, usually it is a weight file. +func SizeShouldBeWeightFile(size int64) bool { + return size > WeightFileSizeThreshold +} + +// formatBytes converts byte size to human-readable format +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"} + return fmt.Sprintf("%.1f%s", float64(bytes)/float64(div), units[exp+1]) +} diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index edccb59d..d2912625 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -28,6 +28,7 @@ import ( configmodelfile "github.com/CloudNativeAI/modctl/pkg/config/modelfile" modefilecommand "github.com/CloudNativeAI/modctl/pkg/modelfile/command" "github.com/CloudNativeAI/modctl/pkg/modelfile/parser" + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1" "github.com/emirpasic/gods/sets/hashset" ) @@ -81,6 +82,12 @@ type Modelfile interface { // GetQuantization returns the value of the quantization command in the modelfile. GetQuantization() string + // GetModelFlags returns the flags associated with a specific model file path + GetModelFlags() map[string]map[string]string + + // GetCodeFlags returns the flags associated with a specific code file path + GetCodeFlags() map[string]map[string]string + // Content returns the content of the modelfile. Content() []byte } @@ -100,17 +107,21 @@ type modelfile struct { paramsize string precision string quantization string + modelFlags map[string]map[string]string + codeFlags map[string]map[string]string } // NewModelfile creates a new modelfile by the path of the modelfile. // It parses the modelfile and returns the modelfile interface. func NewModelfile(path string) (Modelfile, error) { mf := &modelfile{ - config: hashset.New(), - model: hashset.New(), - code: hashset.New(), - dataset: hashset.New(), - doc: hashset.New(), + config: hashset.New(), + model: hashset.New(), + code: hashset.New(), + dataset: hashset.New(), + doc: hashset.New(), + modelFlags: make(map[string]map[string]string), + codeFlags: make(map[string]map[string]string), } if err := mf.parseFile(path); err != nil { @@ -138,9 +149,13 @@ func (mf *modelfile) parseFile(path string) error { case modefilecommand.CONFIG: mf.config.Add(child.GetNext().GetValue()) case modefilecommand.MODEL: - mf.model.Add(child.GetNext().GetValue()) + filePath := child.GetNext().GetValue() + mf.model.Add(filePath) + mf.modelFlags[filePath] = child.GetAttributes() case modefilecommand.CODE: - mf.code.Add(child.GetNext().GetValue()) + filePath := child.GetNext().GetValue() + mf.code.Add(filePath) + mf.codeFlags[filePath] = child.GetAttributes() case modefilecommand.DATASET: mf.dataset.Add(child.GetNext().GetValue()) case modefilecommand.DOC: @@ -197,15 +212,21 @@ func (mf *modelfile) parseFile(path string) error { // paramsize, precision, and quantization. func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateConfig) (Modelfile, error) { mf := &modelfile{ - workspace: workspace, - config: hashset.New(), - model: hashset.New(), - code: hashset.New(), - dataset: hashset.New(), - doc: hashset.New(), + workspace: workspace, + config: hashset.New(), + model: hashset.New(), + code: hashset.New(), + dataset: hashset.New(), + doc: hashset.New(), + modelFlags: make(map[string]map[string]string), + codeFlags: make(map[string]map[string]string), } - if err := mf.generateByWorkspace(config.IgnoreUnrecognizedFileTypes); err != nil { + if err := mf.validateWorkspace(); err != nil { + return nil, err + } + + if err := mf.generateByWorkspace(); err != nil { return nil, err } @@ -217,8 +238,42 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC return mf, nil } +// validateWorkspace validates the workspace directory +func (mf *modelfile) validateWorkspace() error { + // check if the workspace is a directory, symbolic link, or empty + info, err := os.Lstat(mf.workspace) + if err != nil { + return fmt.Errorf("access to workspace failed: %s", err) + } + + // check if the workspace is a symbolic link + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("for simplicity, the workspace should not be a symbolic link: %s", mf.workspace) + } + + // check if the workspace is a directory + if !info.IsDir() { + return fmt.Errorf("the workspace is not a directory: %s", mf.workspace) + } + + // check if the workspace is empty by reading directory contents + entries, err := os.ReadDir(mf.workspace) + if err != nil { + return fmt.Errorf("failed to read workspace directory: %s", err) + } + if len(entries) == 0 { + return fmt.Errorf("the workspace is empty: %s", mf.workspace) + } + + return nil +} + // generateByWorkspace generates the modelfile by the workspace's files. -func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error { +func (mf *modelfile) generateByWorkspace() error { + // Initialize counters for workspace limits validation + var fileCount int + var totalSize int64 + // Walk the path and get the files. if err := filepath.Walk(mf.workspace, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -240,6 +295,26 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return nil } + // Check workspace limits for regular files + fileCount++ + fileSize := info.Size() + totalSize += fileSize + + // Check single file size limit + if fileSize > MaxSingleFileSize { + return fmt.Errorf("file %s exceeds maximum single file size limit of %d bytes (%s)", path, MaxSingleFileSize, formatBytes(MaxSingleFileSize)) + } + + // Check file count limit + if fileCount > MaxWorkspaceFileCount { + return fmt.Errorf("workspace exceeds maximum file count limit of %d files", MaxWorkspaceFileCount) + } + + // Check total workspace size limit + if totalSize > MaxTotalWorkspaceSize { + return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize)) + } + // Get relative path from the base directory. relPath, err := filepath.Rel(mf.workspace, path) if err != nil { @@ -256,12 +331,24 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error case IsFileType(filename, DocFilePatterns): mf.doc.Add(relPath) default: - // Skip unrecognized files if IgnoreUnrecognizedFileTypes is true. - if ignoreUnrecognizedFileTypes { - return nil + // If the file is large, usually it is a weight file. + if SizeShouldBeWeightFile(info.Size()) { + mf.model.Add(relPath) + // Add untested flag for files detected by file size + if mf.modelFlags[relPath] == nil { + mf.modelFlags[relPath] = make(map[string]string) + } + mf.modelFlags[relPath][modelspec.AnnotationMediaTypeUntested] = "true" + } else { + mf.code.Add(relPath) + // Add untested flag for files detected by file size + if mf.codeFlags[relPath] == nil { + mf.codeFlags[relPath] = make(map[string]string) + } + mf.codeFlags[relPath][modelspec.AnnotationMediaTypeUntested] = "true" } - return fmt.Errorf("unknown file type: %s", filename) + return nil } return nil @@ -269,8 +356,8 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return err } - if mf.model.Size() == 0 { - return fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually") + if mf.model.Size() == 0 && mf.code.Size() == 0 && mf.dataset.Size() == 0 { + return fmt.Errorf("no model/code/dataset found - you have to create the Modelfile by yourself") } return nil @@ -470,6 +557,16 @@ func (mf *modelfile) GetQuantization() string { return mf.quantization } +// GetModelFlags returns the flags associated with a specific model file path +func (mf *modelfile) GetModelFlags() map[string]map[string]string { + return mf.modelFlags +} + +// GetCodeFlags returns the flags associated with a specific code file path +func (mf *modelfile) GetCodeFlags() map[string]map[string]string { + return mf.codeFlags +} + // Content returns the content of the modelfile. func (mf *modelfile) Content() []byte { content := "" @@ -486,8 +583,8 @@ func (mf *modelfile) Content() []byte { // Add multi-value commands. content += mf.writeMultiField("Config files (Generated from the files in the workspace directory)", modefilecommand.CONFIG, mf.GetConfigs(), ConfigFilePatterns) - content += mf.writeMultiField("Code files (Generated from the files in the workspace directory)", modefilecommand.CODE, mf.GetCodes(), CodeFilePatterns) - content += mf.writeMultiField("Model files (Generated from the files in the workspace directory)", modefilecommand.MODEL, mf.GetModels(), ModelFilePatterns) + content += mf.writeMultiFieldWithFlags("Code files (Generated from the files in the workspace directory)", modefilecommand.CODE, mf.GetCodes(), CodeFilePatterns, mf.codeFlags) + content += mf.writeMultiFieldWithFlags("Model files (Generated from the files in the workspace directory)", modefilecommand.MODEL, mf.GetModels(), ModelFilePatterns, mf.modelFlags) content += mf.writeMultiField("Documentation files (Generated from the files in the workspace directory)", modefilecommand.DOC, mf.GetDocs(), DocFilePatterns) return []byte(content) } @@ -515,3 +612,30 @@ func (mf *modelfile) writeMultiField(comment, cmd string, values []string, patte return content } + +func (mf *modelfile) writeMultiFieldWithFlags(comment, cmd string, values []string, patterns []string, flags map[string]map[string]string) string { + if len(values) == 0 { + return "" + } + + content := fmt.Sprintf("\n# %s\n", comment) + content += fmt.Sprintf("# Supported file types: %s\n", strings.Join(patterns, ", ")) + + sort.Strings(values) + for _, value := range values { + if fileFlags, hasFlags := flags[value]; hasFlags && len(fileFlags) > 0 { + // Build flags string + var flagParts []string + for key, val := range fileFlags { + flagParts = append(flagParts, fmt.Sprintf("%s=%s", key, val)) + } + sort.Strings(flagParts) // Sort for consistent output + flagsStr := strings.Join(flagParts, ",") + content += fmt.Sprintf("%s --label=%s %s\n", cmd, flagsStr, value) + } else { + content += fmt.Sprintf("%s %s\n", cmd, value) + } + } + + return content +} diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index a4c9a971..538268de 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -19,12 +19,14 @@ package modelfile import ( "encoding/json" "errors" + "fmt" "os" "path/filepath" "sort" "testing" configmodelfile "github.com/CloudNativeAI/modctl/pkg/config/modelfile" + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1" "github.com/emirpasic/gods/sets/hashset" "github.com/stretchr/testify/assert" ) @@ -269,25 +271,24 @@ NAME bar func TestNewModelfileByWorkspace(t *testing.T) { testcases := []struct { - name string - setupFiles map[string]string - setupDirs []string - configJson map[string]interface{} - genConfigJson map[string]interface{} - config *configmodelfile.GenerateConfig - ignoreUnrecognizedFileType bool - expectError bool - expectConfigs []string - expectModels []string - expectCodes []string - expectDocs []string - expectName string - expectArch string - expectFamily string - expectFormat string - expectParamsize string - expectPrecision string - expectQuantization string + name string + setupFiles map[string]string + setupDirs []string + configJson map[string]interface{} + genConfigJson map[string]interface{} + config *configmodelfile.GenerateConfig + expectError bool + expectConfigs []string + expectModels []string + expectCodes []string + expectDocs []string + expectName string + expectArch string + expectFamily string + expectFormat string + expectParamsize string + expectPrecision string + expectQuantization string }{ { name: "basic case", @@ -302,13 +303,12 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "test-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{"model.py", "tokenizer.py"}, - expectDocs: []string{"README.md", "LICENSE"}, - expectName: "test-model", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{"model.py", "tokenizer.py"}, + expectDocs: []string{"README.md", "LICENSE"}, + expectName: "test-model", }, { name: "empty workspace", @@ -316,12 +316,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "empty-model", }, - ignoreUnrecognizedFileType: false, - expectError: true, - expectConfigs: []string{}, - expectModels: []string{}, - expectCodes: []string{}, - expectName: "empty-model", + expectError: true, + expectConfigs: []string{}, + expectModels: []string{}, + expectCodes: []string{}, + expectName: "empty-model", }, { name: "with config.json values", @@ -337,15 +336,14 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "config-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "config-model", - expectArch: "transformer", - expectFamily: "llama", - expectPrecision: "float16", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "config-model", + expectArch: "transformer", + expectFamily: "llama", + expectPrecision: "float16", }, { name: "nested directory structure", @@ -370,8 +368,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "nested-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "docs/config/parameters.yaml", @@ -406,12 +403,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "deep-nested", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"level1/config.json"}, - expectModels: []string{"level1/level2/level3/model.bin"}, - expectCodes: []string{"level1/level2/level3/level4/code.py"}, - expectName: "deep-nested", + expectError: false, + expectConfigs: []string{"level1/config.json"}, + expectModels: []string{"level1/level2/level3/model.bin"}, + expectCodes: []string{"level1/level2/level3/level4/code.py"}, + expectName: "deep-nested", }, { name: "hidden files and directories", @@ -429,12 +425,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "hidden-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "hidden-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "hidden-test", }, { name: "multiple config files in directories", @@ -453,16 +448,15 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "multi-config", Format: "pytorch", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "multi-config", - expectArch: "transformer", - expectFamily: "gpt2", - expectFormat: "pytorch", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "multi-config", + expectArch: "transformer", + expectFamily: "gpt2", + expectFormat: "pytorch", + expectPrecision: "float32", }, { name: "special filename characters", @@ -482,8 +476,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "special-chars", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config with spaces.json", "dir-with-hyphens/config.json", @@ -520,8 +513,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "mixed-types", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "configs/main.json", "configs/params.yaml", @@ -554,8 +546,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "same-names", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "dir1/config.json", "dir2/config.json", @@ -604,8 +595,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "llama-7b", ParamSize: "7B", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "generation_config.json", @@ -650,14 +640,13 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "conflict-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "generation_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "conflict-test", - expectFamily: "llama", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "generation_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "conflict-test", + expectFamily: "llama", + expectPrecision: "float32", }, { name: "skipping internal directories", @@ -679,12 +668,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "skip-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"normal/model.bin"}, - expectCodes: []string{"valid_dir/model.py"}, - expectName: "skip-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"normal/model.bin"}, + expectCodes: []string{"valid_dir/model.py"}, + expectName: "skip-test", }, } @@ -734,7 +722,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { // Set workspace in config tc.config.Workspace = tempDir - tc.config.IgnoreUnrecognizedFileTypes = tc.ignoreUnrecognizedFileType + tc.config.IgnoreUnrecognizedFileTypes = false // Call the function being tested mf, err := NewModelfileByWorkspace(tempDir, tc.config) @@ -1202,6 +1190,379 @@ func createHashSet(items []string) *hashset.Set { for _, item := range items { set.Add(item) } - return set } + +func TestValidateWorkspace(t *testing.T) { + tests := []struct { + name string + setupFunc func() (string, func()) // returns workspace path and cleanup function + expectedError string + }{ + { + name: "valid_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a test file to make directory non-empty + testFile := filepath.Join(tmpDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + { + name: "non_existent_directory", + setupFunc: func() (string, func()) { + return "/non/existent/path", func() {} + }, + expectedError: "access to workspace failed:", + }, + { + name: "file_instead_of_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + testFile := filepath.Join(tmpDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + return testFile, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "the workspace is not a directory:", + }, + { + name: "empty_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "the workspace is empty:", + }, + { + name: "symbolic_link", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create target directory with content + targetDir := filepath.Join(tmpDir, "target") + err = os.Mkdir(targetDir, 0755) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create target dir: %v", err) + } + + testFile := filepath.Join(targetDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + // Create symbolic link + linkPath := filepath.Join(tmpDir, "link") + err = os.Symlink(targetDir, linkPath) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create symlink: %v", err) + } + + return linkPath, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "for simplicity, the workspace should not be a symbolic link:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workspace, cleanup := tt.setupFunc() + defer cleanup() + + mf := &modelfile{ + workspace: workspace, + config: hashset.New(), + model: hashset.New(), + code: hashset.New(), + dataset: hashset.New(), + doc: hashset.New(), + } + + err := mf.validateWorkspace() + + if tt.expectedError == "" { + assert.NoError(t, err, "Expected no error for test case: %s", tt.name) + } else { + assert.Error(t, err, "Expected error for test case: %s", tt.name) + assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text for test case: %s", tt.name) + } + }) + } +} + +func TestWorkspaceLimits(t *testing.T) { + tests := []struct { + name string + setupFunc func() (string, func()) // returns workspace path and cleanup function + expectedError string + }{ + { + name: "single_file_exceeds_128GB_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a test file that simulates exceeding 128GB + // We'll use a sparse file to avoid actually creating 128GB+ of data + testFile := filepath.Join(tmpDir, "large_model.bin") + file, err := os.Create(testFile) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + // Seek to position that would make file appear larger than 128GB + largeSize := MaxSingleFileSize + 1 + _, err = file.Seek(largeSize-1, 0) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to seek in file: %v", err) + } + + // Write one byte at the end to make the file that size + _, err = file.Write([]byte{0}) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to write to file: %v", err) + } + file.Close() + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum single file size limit", + }, + { + name: "file_count_exceeds_2048_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create more than 2048 files + for i := 0; i <= MaxWorkspaceFileCount; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("test"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum file count limit", + }, + { + name: "total_workspace_size_exceeds_8TB_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a few large files that together exceed 8TB + // Each file will be just under 128GB (single file limit) + // We'll create 70 files of ~120GB each to exceed 8TB total + fileSize := MaxSingleFileSize - (1024 * 1024 * 1024) // 127GB per file + numFiles := 70 // 70 * 127GB = ~8.9TB + + for i := 0; i < numFiles; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.bin", i)) + file, err := os.Create(testFile) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + + // Use sparse file technique + _, err = file.Seek(fileSize-1, 0) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to seek in file %d: %v", i, err) + } + + _, err = file.Write([]byte{0}) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to write to file %d: %v", i, err) + } + file.Close() + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum total size limit", + }, + { + name: "workspace_within_all_limits", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a reasonable number of small files + for i := 0; i < 10; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("small_file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("small content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + // Add a config file to make it a valid workspace + configFile := filepath.Join(tmpDir, "config.json") + err = os.WriteFile(configFile, []byte(`{"model_type": "test"}`), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create config file: %v", err) + } + + // Add a model file to make it a valid workspace + modelFile := filepath.Join(tmpDir, "model.safetensors") + err = os.WriteFile(modelFile, []byte("fake model data"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create model file: %v", err) + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + { + name: "exactly_at_file_count_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create exactly 2048 files (should be allowed) + // Include one model file to make it valid + modelFile := filepath.Join(tmpDir, "model.safetensors") + err = os.WriteFile(modelFile, []byte("fake model data"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create model file: %v", err) + } + + // Create the remaining files to reach exactly 2048 + for i := 1; i < MaxWorkspaceFileCount; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("test"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workspace, cleanup := tt.setupFunc() + defer cleanup() + + // Create a modelfile instance and try to generate by workspace + config := &configmodelfile.GenerateConfig{} + _, err := NewModelfileByWorkspace(workspace, config) + + if tt.expectedError == "" { + assert.NoError(t, err, "Expected no error for test case: %s", tt.name) + } else { + assert.Error(t, err, "Expected error for test case: %s", tt.name) + assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text for test case: %s", tt.name) + } + }) + } +} + +func TestDefaultBranchUntestedFlag(t *testing.T) { + // Create a temporary directory + tempDir := t.TempDir() + + // Create files that will fall into the default branch + // Use file extensions that don't match any known patterns + unknownLargeFile := filepath.Join(tempDir, "unknown_large.unknown") + unknownSmallFile := filepath.Join(tempDir, "unknown_small.unknown") + + // Create a large file (>128MB) that should go to model + largeContent := make([]byte, 129*1024*1024) // 129MB + err := os.WriteFile(unknownLargeFile, largeContent, 0644) + assert.NoError(t, err) + + // Create a small file that should go to code + smallContent := []byte("some unknown file content") + err = os.WriteFile(unknownSmallFile, smallContent, 0644) + assert.NoError(t, err) + + // Generate modelfile from workspace + config := &configmodelfile.GenerateConfig{ + Name: "test-untested-flags", + } + + mf, err := NewModelfileByWorkspace(tempDir, config) + assert.NoError(t, err) + + // Cast to concrete type to access flags + modelfile := mf.(*modelfile) + + // Check that the large file was added to models with the untested flag + assert.Contains(t, mf.GetModels(), "unknown_large.unknown") + modelFlags := modelfile.GetModelFlags() + assert.Contains(t, modelFlags, "unknown_large.unknown") + assert.Equal(t, "true", modelFlags["unknown_large.unknown"][modelspec.AnnotationMediaTypeUntested]) + + // Check that the small file was added to codes with the untested flag + assert.Contains(t, mf.GetCodes(), "unknown_small.unknown") + codeFlags := modelfile.GetCodeFlags() + assert.Contains(t, codeFlags, "unknown_small.unknown") + assert.Equal(t, "true", codeFlags["unknown_small.unknown"][modelspec.AnnotationMediaTypeUntested]) + + // Check that the generated content includes the flags + content := string(mf.Content()) + assert.Contains(t, content, fmt.Sprintf("MODEL --label=%s=true unknown_large.unknown", modelspec.AnnotationMediaTypeUntested)) + assert.Contains(t, content, fmt.Sprintf("CODE --label=%s=true unknown_small.unknown", modelspec.AnnotationMediaTypeUntested)) +} diff --git a/pkg/modelfile/parser/args_parser.go b/pkg/modelfile/parser/args_parser.go index 5f871f22..cd4cacf8 100644 --- a/pkg/modelfile/parser/args_parser.go +++ b/pkg/modelfile/parser/args_parser.go @@ -23,8 +23,12 @@ import ( // parseStringArgs parses the string type of args and returns a Node, for example: // "MODEL foo" args' value is "foo". func parseStringArgs(args []string, start, end int) (Node, error) { - if len(args) != 1 { - return nil, errors.New("invalid args") + if len(args) == 0 { + return nil, errors.New("no args provided") + } + + if len(args) > 1 { + return nil, errors.New("too many args provided") } if args[0] == "" { diff --git a/pkg/modelfile/parser/parser.go b/pkg/modelfile/parser/parser.go index 84ed5b1c..7ff21e43 100644 --- a/pkg/modelfile/parser/parser.go +++ b/pkg/modelfile/parser/parser.go @@ -18,9 +18,11 @@ package parser import ( "bufio" + "errors" "fmt" "io" "strings" + "unicode" "github.com/CloudNativeAI/modctl/pkg/modelfile/command" ) @@ -96,13 +98,26 @@ func isEmptyContinuationLine(line string) bool { // parseCommandLine parses the command line and returns the command node with the args node. // Need to walk the next node of the command node to get the args node. func parseCommandLine(line string, start, end int) (Node, error) { - cmd, args, err := splitCommand(line) + cmd, args, flags, err := splitCommand(line) if err != nil { return nil, err } switch cmd { - case command.CONFIG, command.MODEL, command.CODE, command.DATASET, command.DOC, command.NAME, command.ARCH, command.FAMILY, command.FORMAT, command.PARAMSIZE, command.PRECISION, command.QUANTIZATION: + case command.CONFIG, command.DOC, command.DATASET, command.NAME, command.ARCH, command.FAMILY, command.FORMAT, command.PARAMSIZE, command.PRECISION, command.QUANTIZATION: + if len(args) != 1 { + return nil, errors.New("command " + cmd + " requires exactly one argument") + } + + argsNode, err := parseStringArgs(args, start, end) + if err != nil { + return nil, err + } + cmdNode := NewNode(cmd, start, end) + cmdNode.AddNext(argsNode) + + return cmdNode, nil + case command.MODEL, command.CODE: argsNode, err := parseStringArgs(args, start, end) if err != nil { return nil, err @@ -110,20 +125,134 @@ func parseCommandLine(line string, start, end int) (Node, error) { cmdNode := NewNode(cmd, start, end) cmdNode.AddNext(argsNode) + + // Add flags as attributes if any exist + if len(flags) > 0 { + for _, flag := range flags { + // Parse the flag to get key and value + key, value := parseFlagKeyValue(flag) + if key != "" { + cmdNode.AddAttribute(key, value) + } + } + } return cmdNode, nil default: return nil, fmt.Errorf("invalid command: %s", cmd) } } -// splitCommand splits the command line into the command and the args. Returns the -// command and the args, and an error if the command line is invalid. -// Example: "MODEL foo" returns "MODEL", ["foo"] and nil. -func splitCommand(line string) (string, []string, error) { +// parseFlagKeyValue parses a flag string and returns the key and value +// Example: "key=value" returns "key", "value" +// Example: "org.cnai.model.file.mediatype.untested=true" returns "org.cnai.model.file.mediatype.untested", "true" +func parseFlagKeyValue(flag string) (string, string) { + // For flags that are just values (like "key=value" from "--label=key=value"), + // we need to determine if this is already a key=value pair or if we need to add a prefix + if idx := strings.Index(flag, "="); idx != -1 { + return flag[:idx], flag[idx+1:] + } + + // If no "=" found, treat the whole thing as a key with empty value + return flag, "" +} + +// splitCommand splits the command line into the command, args, and flags. Returns the +// command, the args, the flags, and an error if the command line is invalid. +// Example: "MODEL --label=key=value /home/user/model.safetensors" returns "MODEL", ["/home/user/model.safetensors"], ["key=value"], nil. +func splitCommand(line string) (string, []string, []string, error) { parts := strings.Fields(line) if len(parts) < 2 { - return "", nil, fmt.Errorf("invalid command line: %s", line) + return "", nil, nil, fmt.Errorf("invalid command line: %s", line) + } + + cmd := strings.ToUpper(parts[0]) + + // Extract flags and remaining args from the rest of the line + restOfLine := strings.TrimSpace(line[len(parts[0]):]) + remaining, flags, err := extractCommandFlags(restOfLine) + if err != nil { + return "", nil, nil, err + } + + // Parse remaining content as args + var args []string + if remaining != "" { + args = strings.Fields(remaining) + } + + return cmd, args, flags, nil +} + +// extractCommandFlags parses the command flags and returns the remaining part of the line +// and the command flags (with values only, without the flag names). +// Only accepts --label flags, other flags are ignored and treated as arguments. +func extractCommandFlags(line string) (string, []string, error) { + flags := []string{} + var i int + + // Skip leading spaces and process flags + for i < len(line) { + // Skip spaces + for i < len(line) && unicode.IsSpace(rune(line[i])) { + i++ + } + + // Check if we've reached the end or found a non-flag + if i >= len(line) || !isFlag(line, i) { + break + } + + // Extract the flag + start := i + for i < len(line) && !unicode.IsSpace(rune(line[i])) { + i++ + } + + flag := line[start:i] + if flag == "--" { + // Double dash terminator, return remaining content starting from the space after "--" + return line[start+2:], flags, nil + } + + if flag != "" { + // Only process --label flags, ignore all others + if strings.HasPrefix(flag, "--label") { + // Extract the value part from --label=value format + flagValue := extractFlagValue(flag) + if flagValue != "" { + flags = append(flags, flagValue) + } + } else { + // For non-label flags, treat them as part of the remaining arguments + // We need to backtrack to include this flag in the remaining content + return line[start:], flags, nil + } + } + } + + // Return remaining content after flags + return line[i:], flags, nil +} + +// isFlag checks if the content at position i starts with "--" +func isFlag(line string, i int) bool { + return i+1 < len(line) && line[i] == '-' && line[i+1] == '-' +} + +// extractFlagValue extracts the value from a --label flag string +// Example: "--label=key=value" returns "key=value" +// Example: "--label=" returns "" +func extractFlagValue(flag string) string { + // Remove the leading "--" + if strings.HasPrefix(flag, "--") { + flag = flag[2:] + } + + // Find the first "=" to get the value part + if idx := strings.Index(flag, "="); idx != -1 { + return flag[idx+1:] } - return strings.ToUpper(parts[0]), parts[1:], nil + // If no "=" found, return empty string + return "" } diff --git a/pkg/modelfile/parser/parser_test.go b/pkg/modelfile/parser/parser_test.go index 1447d4da..2572eddd 100644 --- a/pkg/modelfile/parser/parser_test.go +++ b/pkg/modelfile/parser/parser_test.go @@ -17,9 +17,11 @@ package parser import ( + "fmt" "strings" "testing" + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1" "github.com/stretchr/testify/assert" ) @@ -193,24 +195,69 @@ func TestSplitCommand(t *testing.T) { expectErr bool cmd string args []string + flags []string }{ - {"MODEL foo", false, "MODEL", []string{"foo"}}, - {"NAME bar", false, "NAME", []string{"bar"}}, - {"invalid", true, "", nil}, + {"MODEL foo", false, "MODEL", []string{"foo"}, []string{}}, + {"NAME bar", false, "NAME", []string{"bar"}, []string{}}, + {"MODEL --label=key=value /home/user/model.safetensors", false, "MODEL", []string{"/home/user/model.safetensors"}, []string{"key=value"}}, + {"MODEL --untested --experimental=test model.safetensors", false, "MODEL", []string{"--untested", "--experimental=test", "model.safetensors"}, []string{}}, + {"CONFIG --format=json config.yaml", false, "CONFIG", []string{"--format=json", "config.yaml"}, []string{}}, + {"MODEL --label=flag1=value1 --label=flag2=value2 model.bin", false, "MODEL", []string{"model.bin"}, []string{"flag1=value1", "flag2=value2"}}, + {"MODEL --untested model.safetensors", false, "MODEL", []string{"--untested", "model.safetensors"}, []string{}}, + {"invalid", true, "", nil, nil}, } assert := assert.New(t) for _, tc := range testCases { - cmd, args, err := splitCommand(tc.line) + cmd, args, flags, err := splitCommand(tc.line) if tc.expectErr { assert.Error(err) assert.Empty(cmd) assert.Nil(args) + assert.Nil(flags) continue } assert.NoError(err) assert.Equal(tc.cmd, cmd) assert.Equal(tc.args, args) + assert.Equal(tc.flags, flags) + } +} + +func TestExtractFlagValue(t *testing.T) { + testCases := []struct { + name string + flag string + expectedValue string + }{ + { + name: "label flag with key=value", + flag: "--label=key=value", + expectedValue: "key=value", + }, + { + name: "label flag with empty value", + flag: "--label=", + expectedValue: "", + }, + { + name: "complex label flag value", + flag: fmt.Sprintf("--label=%s=true", modelspec.AnnotationMediaTypeUntested), + expectedValue: fmt.Sprintf("%s=true", modelspec.AnnotationMediaTypeUntested), + }, + { + name: "label flag without value", + flag: "--label", + expectedValue: "", + }, + } + + assert := assert.New(t) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + value := extractFlagValue(tc.flag) + assert.Equal(tc.expectedValue, value, "Value mismatch for test case: %s", tc.name) + }) } } diff --git a/test/mocks/backend/build/builder.go b/test/mocks/backend/build/builder.go index b35539ab..e12578ba 100644 --- a/test/mocks/backend/build/builder.go +++ b/test/mocks/backend/build/builder.go @@ -102,9 +102,9 @@ func (_c *Builder_BuildConfig_Call) RunAndReturn(run func(context.Context, []v1. return _c } -// BuildLayer provides a mock function with given fields: ctx, mediaType, workDir, path, _a4 -func (_m *Builder) BuildLayer(ctx context.Context, mediaType string, workDir string, path string, _a4 hooks.Hooks) (v1.Descriptor, error) { - ret := _m.Called(ctx, mediaType, workDir, path, _a4) +// BuildLayer provides a mock function with given fields: ctx, mediaType, workDir, path, annotations, _a5 +func (_m *Builder) BuildLayer(ctx context.Context, mediaType string, workDir string, path string, annotations map[string]string, _a5 hooks.Hooks) (v1.Descriptor, error) { + ret := _m.Called(ctx, mediaType, workDir, path, annotations, _a5) if len(ret) == 0 { panic("no return value specified for BuildLayer") @@ -112,17 +112,17 @@ func (_m *Builder) BuildLayer(ctx context.Context, mediaType string, workDir str var r0 v1.Descriptor var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, hooks.Hooks) (v1.Descriptor, error)); ok { - return rf(ctx, mediaType, workDir, path, _a4) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, map[string]string, hooks.Hooks) (v1.Descriptor, error)); ok { + return rf(ctx, mediaType, workDir, path, annotations, _a5) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, hooks.Hooks) v1.Descriptor); ok { - r0 = rf(ctx, mediaType, workDir, path, _a4) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, map[string]string, hooks.Hooks) v1.Descriptor); ok { + r0 = rf(ctx, mediaType, workDir, path, annotations, _a5) } else { r0 = ret.Get(0).(v1.Descriptor) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, hooks.Hooks) error); ok { - r1 = rf(ctx, mediaType, workDir, path, _a4) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, map[string]string, hooks.Hooks) error); ok { + r1 = rf(ctx, mediaType, workDir, path, annotations, _a5) } else { r1 = ret.Error(1) } @@ -140,14 +140,15 @@ type Builder_BuildLayer_Call struct { // - mediaType string // - workDir string // - path string -// - _a4 hooks.Hooks -func (_e *Builder_Expecter) BuildLayer(ctx interface{}, mediaType interface{}, workDir interface{}, path interface{}, _a4 interface{}) *Builder_BuildLayer_Call { - return &Builder_BuildLayer_Call{Call: _e.mock.On("BuildLayer", ctx, mediaType, workDir, path, _a4)} +// - annotations map[string]string +// - _a5 hooks.Hooks +func (_e *Builder_Expecter) BuildLayer(ctx interface{}, mediaType interface{}, workDir interface{}, path interface{}, annotations interface{}, _a5 interface{}) *Builder_BuildLayer_Call { + return &Builder_BuildLayer_Call{Call: _e.mock.On("BuildLayer", ctx, mediaType, workDir, path, annotations, _a5)} } -func (_c *Builder_BuildLayer_Call) Run(run func(ctx context.Context, mediaType string, workDir string, path string, _a4 hooks.Hooks)) *Builder_BuildLayer_Call { +func (_c *Builder_BuildLayer_Call) Run(run func(ctx context.Context, mediaType string, workDir string, path string, annotations map[string]string, _a5 hooks.Hooks)) *Builder_BuildLayer_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(hooks.Hooks)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(map[string]string), args[5].(hooks.Hooks)) }) return _c } @@ -157,7 +158,7 @@ func (_c *Builder_BuildLayer_Call) Return(_a0 v1.Descriptor, _a1 error) *Builder return _c } -func (_c *Builder_BuildLayer_Call) RunAndReturn(run func(context.Context, string, string, string, hooks.Hooks) (v1.Descriptor, error)) *Builder_BuildLayer_Call { +func (_c *Builder_BuildLayer_Call) RunAndReturn(run func(context.Context, string, string, string, map[string]string, hooks.Hooks) (v1.Descriptor, error)) *Builder_BuildLayer_Call { _c.Call.Return(run) return _c } diff --git a/test/mocks/modelfile/modelfile.go b/test/mocks/modelfile/modelfile.go index cea3df3f..e99e2fb9 100644 --- a/test/mocks/modelfile/modelfile.go +++ b/test/mocks/modelfile/modelfile.go @@ -125,6 +125,53 @@ func (_c *Modelfile_GetArch_Call) RunAndReturn(run func() string) *Modelfile_Get return _c } +// GetCodeFlags provides a mock function with no fields +func (_m *Modelfile) GetCodeFlags() map[string]map[string]string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCodeFlags") + } + + var r0 map[string]map[string]string + if rf, ok := ret.Get(0).(func() map[string]map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]map[string]string) + } + } + + return r0 +} + +// Modelfile_GetCodeFlags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCodeFlags' +type Modelfile_GetCodeFlags_Call struct { + *mock.Call +} + +// GetCodeFlags is a helper method to define mock.On call +func (_e *Modelfile_Expecter) GetCodeFlags() *Modelfile_GetCodeFlags_Call { + return &Modelfile_GetCodeFlags_Call{Call: _e.mock.On("GetCodeFlags")} +} + +func (_c *Modelfile_GetCodeFlags_Call) Run(run func()) *Modelfile_GetCodeFlags_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Modelfile_GetCodeFlags_Call) Return(_a0 map[string]map[string]string) *Modelfile_GetCodeFlags_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Modelfile_GetCodeFlags_Call) RunAndReturn(run func() map[string]map[string]string) *Modelfile_GetCodeFlags_Call { + _c.Call.Return(run) + return _c +} + // GetCodes provides a mock function with no fields func (_m *Modelfile) GetCodes() []string { ret := _m.Called() @@ -403,6 +450,53 @@ func (_c *Modelfile_GetFormat_Call) RunAndReturn(run func() string) *Modelfile_G return _c } +// GetModelFlags provides a mock function with no fields +func (_m *Modelfile) GetModelFlags() map[string]map[string]string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetModelFlags") + } + + var r0 map[string]map[string]string + if rf, ok := ret.Get(0).(func() map[string]map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]map[string]string) + } + } + + return r0 +} + +// Modelfile_GetModelFlags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetModelFlags' +type Modelfile_GetModelFlags_Call struct { + *mock.Call +} + +// GetModelFlags is a helper method to define mock.On call +func (_e *Modelfile_Expecter) GetModelFlags() *Modelfile_GetModelFlags_Call { + return &Modelfile_GetModelFlags_Call{Call: _e.mock.On("GetModelFlags")} +} + +func (_c *Modelfile_GetModelFlags_Call) Run(run func()) *Modelfile_GetModelFlags_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Modelfile_GetModelFlags_Call) Return(_a0 map[string]map[string]string) *Modelfile_GetModelFlags_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Modelfile_GetModelFlags_Call) RunAndReturn(run func() map[string]map[string]string) *Modelfile_GetModelFlags_Call { + _c.Call.Return(run) + return _c +} + // GetModels provides a mock function with no fields func (_m *Modelfile) GetModels() []string { ret := _m.Called()