diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 985c88c7..4166879d 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -65,6 +65,10 @@ func init() { flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") + // Mark the ignore-unrecognized-file-types flag as deprecated and hidden + flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release") + flags.MarkHidden("ignore-unrecognized-file-types") + if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind cache list flags to viper: %w", err)) } diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index 2f1251a3..f2025052 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -31,7 +31,7 @@ type GenerateConfig struct { Name string Version string Output string - IgnoreUnrecognizedFileTypes bool + IgnoreUnrecognizedFileTypes bool // [deprecated] will be removed in the next release Overwrite bool Arch string Family string diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index 92d07aca..ef221d7c 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -19,6 +19,8 @@ package modelfile import ( "path/filepath" "strings" + + "github.com/dustin/go-humanize" ) var ( @@ -26,6 +28,8 @@ var ( ConfigFilePatterns = []string{ "*.json", // JSON configuration files "*.jsonl", // JSON Lines format + "*.json5", // JSON5 files + "*.jsonc", // JSON with comments "*.yaml", // YAML configuration files "*.yml", // YAML alternative extension "*.toml", // TOML configuration files @@ -45,6 +49,12 @@ var ( "*.meta", // Model metadata "*tokenizer.model*", // Tokenizer files (e.g., Mistral v3) "config.json.*", // Model configuration variants + "*.hparams", // Hyperparameter files + "*.params", // Parameter files + "*.hyperparams", // Hyperparameter configuration + "*.wandb", // Weights & Biases configuration + "*.mlflow", // MLflow configuration + "*.tensorboard", // TensorBoard configuration } // Model file patterns - supported model file extensions. @@ -56,29 +66,75 @@ var ( "*.bin", // General binary format "*.pt", // PyTorch model "*.pth", // PyTorch model (alternative extension) + "*.mar", // PyTorch Model Archive + "*.pte", // PyTorch ExecuTorch format + "*.pt2", // PyTorch 2.0 export format + "*.ptl", // PyTorch Mobile format // TensorFlow formats. "*.tflite", // TensorFlow Lite "*.h5", // Keras HDF5 format "*.hdf", // Hierarchical Data Format "*.hdf5", // HDF5 (alternative extension) + "*.pb", // TensorFlow SavedModel/Frozen Graph + "*.meta", // TensorFlow checkpoint metadata + "*.data-*", // TensorFlow checkpoint data files + "*.index", // TensorFlow checkpoint index + + // GGML formats. + "*.gguf", // GGML Universal Format + "*.ggml", // GGML format (legacy) + "*.ggmf", // GGMF format (deprecated) + "*.ggjt", // GGJT format (deprecated) + "*.q4_0", // GGML Q4_0 quantization + "*.q4_1", // GGML Q4_1 quantization + "*.q5_0", // GGML Q5_0 quantization + "*.q5_1", // GGML Q5_1 quantization + "*.q8_0", // GGML Q8_0 quantization + "*.f16", // GGML F16 format + "*.f32", // GGML F32 format + + // checkpoint formats. + "*.ckpt", // Checkpoint format + "*.checkpoint", // Checkpoint format (alternative extension) + "*.dist_ckpt", // Distributed checkpoint format + + // Semantics-specific formats + "*.tensor", // Generic tensor format + "*.weights", // Generic weights format + "*.state", // State files + "*.embedding", // Embedding files + "*.vocab", // Vocabulary files (when binary) // Other ML frameworks. "*.ot", // OpenVINO format "*.engine", // TensorRT format "*.trt", // TensorRT format (alternative extension) "*.onnx", // Open Neural Network Exchange format - "*.gguf", // GGML Universal Format "*.msgpack", // MessagePack serialization "*.model", // Some NLP frameworks "*.pkl", // Pickle format "*.pickle", // Pickle format (alternative extension) - "*.ckpt", // Checkpoint format - "*.checkpoint", // Checkpoint format (alternative extension) + "*.keras", // Keras native format + "*.joblib", // Joblib serialization (scikit-learn) + "*.npy", // NumPy array format + "*.npz", // NumPy compressed archive + "*.nc", // NetCDF format + "*.mlmodel", // Apple Core ML format + "*.coreml", // Apple Core ML format (alternative) + "*.mleap", // MLeap format (Spark ML) + "*.surml", // SurrealML format + "*.llamafile", // Llamafile format + "*.caffemodel", // Caffe model format + "*.prototxt", // Caffe model definition + "*.dlc", // Qualcomm Deep Learning Container + "*.circle", // Samsung Circle format + "*.nb", // Neural Network Binary format } // Code file patterns - supported script and notebook files. CodeFilePatterns = []string{ + // language source files "*.py", // Python source files "*.ipynb", // Jupyter notebooks "*.sh", // Shell scripts @@ -88,11 +144,18 @@ var ( "*.hxx", // C++ header files "*.cpp", // C++ source files "*.cc", // C++ source files + "*.cxx", // C++ source files (alternative) + "*.c++", // C++ source files (alternative) "*.hpp", // C++ header files "*.hh", // C++ header files + "*.h++", // C++ header files (alternative) "*.java", // Java source files "*.js", // JavaScript source files + "*.mjs", // JavaScript ES6 modules + "*.cjs", // CommonJS modules + "*.jsx", // React JSX files "*.ts", // TypeScript source files + "*.tsx", // TypeScript JSX files "*.go", // Go source files "*.rs", // Rust source files "*.swift", // Swift source files @@ -100,30 +163,118 @@ var ( "*.php", // PHP source files "*.scala", // Scala source files "*.kt", // Kotlin source files + "*.kts", // Kotlin script files "*.r", // R source files + "*.R", // R source files (alternative) "*.m", // MATLAB/Objective-C source files + "*.mm", // Objective-C++ source files "*.f", // Fortran source files "*.f90", // Fortran 90 source files + "*.f95", // Fortran 95 source files + "*.f03", // Fortran 2003 source files + "*.f08", // Fortran 2008 source files "*.jl", // Julia source files "*.lua", // Lua source files "*.pl", // Perl source files + "*.pm", // Perl modules "*.cs", // C# source files "*.vb", // Visual Basic source files "*.dart", // Dart source files "*.groovy", // Groovy source files "*.elm", // Elm source files "*.erl", // Erlang source files + "*.hrl", // Erlang header files "*.ex", // Elixir source files + "*.exs", // Elixir script files "*.hs", // Haskell source files + "*.lhs", // Literate Haskell source files "*.clj", // Clojure source files "*.cljs", // ClojureScript source files - "*.cljc", // Clojure Common Lisp source files + "*.cljc", // Clojure Common source files "*.cl", // Common Lisp source files "*.lisp", // Lisp source files + "*.lsp", // Lisp source files (alternative) "*.scm", // Scheme source files + "*.ss", // Scheme source files (alternative) + "*.rkt", // Racket source files + "*.sql", // SQL files + "*.psql", // PostgreSQL files + "*.mysql", // MySQL files + "*.sqlite", // SQLite files + "*.zig", // Zig source files "*.cu", // CUDA source files "*.cuh", // CUDA header files + // Scripting and automation + "*.bash", // Bash scripts + "*.zsh", // Zsh scripts + "*.fish", // Fish shell scripts + "*.csh", // C shell scripts + "*.tcsh", // TC shell scripts + "*.ksh", // Korn shell scripts + "*.ps1", // PowerShell scripts + "*.psm1", // PowerShell modules + "*.psd1", // PowerShell data files + "*.bat", // Windows batch files + "*.cmd", // Windows command files + "*.vbs", // VBScript files + "*.wsf", // Windows Script Files + "*.applescript", // AppleScript files + "*.scpt", // AppleScript compiled files + "*.awk", // AWK scripts + "*.sed", // sed scripts + "*.expect", // Expect scripts + + // Build and project files + "*.env", // Environment variable files + "*.env.*", // Environment files with suffixes + ".env*", // Environment files (hidden) + "Makefile*", // Makefile variants + "*.dockerfile", // Dockerfile configurations + "Dockerfile*", // Dockerfile variants + "*.mk", // Make include files + "*.cmake", // CMake files + "CMakeLists.txt", // CMake configuration + "*.gradle", // Gradle build files + "*.gradle.kts", // Kotlin DSL Gradle files + "build.gradle*", // Gradle build files + "settings.gradle*", // Gradle settings files + "*.sbt", // SBT build files + "*.mill", // Mill build files + "*.bazel", // Bazel build files + "*.bzl", // Bazel extension files + "BUILD*", // Bazel BUILD files + "WORKSPACE*", // Bazel WORKSPACE files + "*.buck", // Buck build files + "BUCK*", // Buck BUILD files + "*.ninja", // Ninja build files + "*.gyp", // GYP build files + "*.gypi", // GYP include files + "*.waf", // Waf build files + "wscript*", // Waf build scripts + "package.json", // Node.js package file + "package-lock.json", // Node.js lock file + "yarn.lock", // Yarn lock file + "pnpm-lock.yaml", // PNPM lock file + "requirements*.txt", // Python requirements + "Pipfile*", // Python Pipenv files + "pyproject.toml", // Python project configuration + "setup.cfg", // Python setup configuration + "tox.ini", // Python tox configuration + "poetry.lock", // Python Poetry lock file + "Cargo.toml", // Rust package configuration + "Cargo.lock", // Rust lock file + "go.mod", // Go module file + "go.sum", // Go checksum file + "composer.json", // PHP Composer file + "composer.lock", // PHP Composer lock file + "Gemfile*", // Ruby Gemfile + "*.gemspec", // Ruby gem specification + "mix.exs", // Elixir Mix file + "mix.lock", // Elixir Mix lock file + "rebar.config", // Erlang Rebar config + "rebar.lock", // Erlang Rebar lock file + // Library files. "*.so", // Shared object files "*.dll", // Dynamic Link Library @@ -144,6 +295,93 @@ var ( "*requirements*", // Dependency specifications "*.log", // Log files + // Office documents + "*.doc", // Microsoft Word 97-2003 Document + "*.docx", // Microsoft Word Document + "*.docm", // Word Macro-Enabled Document + "*.dot", // Word 97-2003 Template + "*.dotx", // Word Template + "*.dotm", // Word Macro-Enabled Template + "*.rtf", // Rich Text Format + "*.odt", // OpenDocument Text + "*.ott", // OpenDocument Text Template + "*.fodt", // Flat OpenDocument Text + "*.pages", // Apple Pages document + "*.wpd", // WordPerfect document + + // Spreadsheet documents + "*.xls", // Microsoft Excel 97-2003 Workbook + "*.xlsx", // Microsoft Excel Workbook + "*.xlsm", // Excel Macro-Enabled Workbook + "*.xlsb", // Excel Binary Workbook + "*.xlt", // Excel 97-2003 Template + "*.xltx", // Excel Template + "*.xltm", // Excel Macro-Enabled Template + "*.ods", // OpenDocument Spreadsheet + "*.ots", // OpenDocument Spreadsheet Template + "*.fods", // Flat OpenDocument Spreadsheet + "*.numbers", // Apple Numbers spreadsheet + "*.csv", // Comma-Separated Values + + // Presentation documents + "*.ppt", // Microsoft PowerPoint 97-2003 Presentation + "*.pptx", // Microsoft PowerPoint Presentation + "*.pptm", // PowerPoint Macro-Enabled Presentation + "*.pps", // PowerPoint 97-2003 Show + "*.ppsx", // PowerPoint Show + "*.ppsm", // PowerPoint Macro-Enabled Show + "*.pot", // PowerPoint 97-2003 Template + "*.potx", // PowerPoint Template + "*.potm", // PowerPoint Macro-Enabled Template + "*.odp", // OpenDocument Presentation + "*.otp", // OpenDocument Presentation Template + "*.fodp", // Flat OpenDocument Presentation + "*.key", // Apple Keynote presentation + + // eBook formats + "*.epub", // Electronic Publication + "*.mobi", // Mobipocket eBook + "*.azw", // Amazon Kindle eBook + "*.azw3", // Amazon Kindle eBook (KF8) + "*.fb2", // FictionBook 2.0 + "*.fb3", // FictionBook 3.0 + "*.lit", // Microsoft Literature + "*.pdb", // Palm Database/Document File + "*.djvu", // DjVu document + "*.djv", // DjVu document (alternative extension) + + // Web and markup documents + "*.html", // HyperText Markup Language + "*.htm", // HyperText Markup Language (alternative) + "*.xhtml", // Extensible HyperText Markup Language + "*.mhtml", // MIME HTML (Web Archive) + "*.mht", // MIME HTML (Web Archive, alternative) + "*.xml", // eXtensible Markup Language + "*.xsl", // eXtensible Stylesheet Language + "*.xslt", // XSL Transformations + + // Technical documentation formats + "*.tex", // LaTeX document + "*.latex", // LaTeX document (alternative) + "*.ltx", // LaTeX document (alternative) + "*.bib", // BibTeX bibliography + "*.rst", // reStructuredText + "*.asciidoc", // AsciiDoc + "*.adoc", // AsciiDoc (alternative) + "*.textile", // Textile markup + "*.wiki", // Wiki markup + "*.mediawiki", // MediaWiki markup + "*.org", // Org-mode document + "*.texi", // Texinfo document + "*.texinfo", // Texinfo document (alternative) + "*.info", // GNU Info document + "*.man", // Manual page + + // Archive and compressed documents + "*.chm", // Compiled HTML Help + "*.hlp", // Windows Help File + "*.xps", // XML Paper Specification + // Image assets. "*.jpg", // JPEG image format "*.jpeg", // JPEG alternative extension @@ -182,6 +420,14 @@ var ( } ) +const ( + // File size thresholds and workspace limits + WeightFileSizeThreshold int64 = 128 * humanize.MByte // 128MB - threshold for considering file as weight file + MaxSingleFileSize int64 = 128 * humanize.GByte // 128GB - maximum size for a single file + MaxWorkspaceFileCount int = 2048 // 2048 files - maximum number of files in workspace + MaxTotalWorkspaceSize int64 = 8 * humanize.TByte // 8TB - maximum total workspace size +) + // IsFileType checks if the filename matches any of the given patterns func IsFileType(filename string, patterns []string) bool { // Convert filename to lowercase for case-insensitive comparison @@ -216,3 +462,13 @@ func isSkippable(filename string) bool { return false } + +// For large unknown file type, usually it is a weight file. +func SizeShouldBeWeightFile(size int64) bool { + return size > WeightFileSizeThreshold +} + +// formatBytes converts byte size to human-readable format using go-humanize +func formatBytes(bytes int64) string { + return humanize.Bytes(uint64(bytes)) +} diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index edccb59d..e5b15290 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -205,7 +205,11 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC doc: hashset.New(), } - if err := mf.generateByWorkspace(config.IgnoreUnrecognizedFileTypes); err != nil { + if err := mf.validateWorkspace(); err != nil { + return nil, err + } + + if err := mf.generateByWorkspace(); err != nil { return nil, err } @@ -217,8 +221,42 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC return mf, nil } +// validateWorkspace validates the workspace directory +func (mf *modelfile) validateWorkspace() error { + // check if the workspace is a directory, symbolic link, or empty + info, err := os.Lstat(mf.workspace) + if err != nil { + return fmt.Errorf("access to workspace failed: %s", err) + } + + // check if the workspace is a symbolic link + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("for simplicity, the workspace should not be a symbolic link: %s", mf.workspace) + } + + // check if the workspace is a directory + if !info.IsDir() { + return fmt.Errorf("the workspace is not a directory: %s", mf.workspace) + } + + // check if the workspace is empty by reading directory contents + entries, err := os.ReadDir(mf.workspace) + if err != nil { + return fmt.Errorf("failed to read workspace directory: %s", err) + } + if len(entries) == 0 { + return fmt.Errorf("the workspace is empty: %s", mf.workspace) + } + + return nil +} + // generateByWorkspace generates the modelfile by the workspace's files. -func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error { +func (mf *modelfile) generateByWorkspace() error { + // Initialize counters for workspace limits validation + var fileCount int + var totalSize int64 + // Walk the path and get the files. if err := filepath.Walk(mf.workspace, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -240,6 +278,26 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return nil } + // Check workspace limits for regular files + fileCount++ + fileSize := info.Size() + totalSize += fileSize + + // Check single file size limit + if fileSize > MaxSingleFileSize { + return fmt.Errorf("file %s exceeds maximum single file size limit of %d bytes (%s)", path, MaxSingleFileSize, formatBytes(MaxSingleFileSize)) + } + + // Check file count limit + if fileCount > MaxWorkspaceFileCount { + return fmt.Errorf("workspace exceeds maximum file count limit of %d files", MaxWorkspaceFileCount) + } + + // Check total workspace size limit + if totalSize > MaxTotalWorkspaceSize { + return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize)) + } + // Get relative path from the base directory. relPath, err := filepath.Rel(mf.workspace, path) if err != nil { @@ -256,12 +314,14 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error case IsFileType(filename, DocFilePatterns): mf.doc.Add(relPath) default: - // Skip unrecognized files if IgnoreUnrecognizedFileTypes is true. - if ignoreUnrecognizedFileTypes { - return nil + // If the file is large, usually it is a weight file. + if SizeShouldBeWeightFile(info.Size()) { + mf.model.Add(relPath) + } else { + mf.code.Add(relPath) } - return fmt.Errorf("unknown file type: %s", filename) + return nil } return nil @@ -269,8 +329,8 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return err } - if mf.model.Size() == 0 { - return fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually") + if mf.model.Size() == 0 && mf.code.Size() == 0 && mf.dataset.Size() == 0 { + return fmt.Errorf("no model/code/dataset found - you have to create the Modelfile by yourself") } return nil diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index a4c9a971..d71ef187 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -19,14 +19,17 @@ package modelfile import ( "encoding/json" "errors" + "fmt" "os" "path/filepath" "sort" + "strings" "testing" configmodelfile "github.com/CloudNativeAI/modctl/pkg/config/modelfile" "github.com/emirpasic/gods/sets/hashset" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewModelfile(t *testing.T) { @@ -269,25 +272,24 @@ NAME bar func TestNewModelfileByWorkspace(t *testing.T) { testcases := []struct { - name string - setupFiles map[string]string - setupDirs []string - configJson map[string]interface{} - genConfigJson map[string]interface{} - config *configmodelfile.GenerateConfig - ignoreUnrecognizedFileType bool - expectError bool - expectConfigs []string - expectModels []string - expectCodes []string - expectDocs []string - expectName string - expectArch string - expectFamily string - expectFormat string - expectParamsize string - expectPrecision string - expectQuantization string + name string + setupFiles map[string]string + setupDirs []string + configJson map[string]interface{} + genConfigJson map[string]interface{} + config *configmodelfile.GenerateConfig + expectError bool + expectConfigs []string + expectModels []string + expectCodes []string + expectDocs []string + expectName string + expectArch string + expectFamily string + expectFormat string + expectParamsize string + expectPrecision string + expectQuantization string }{ { name: "basic case", @@ -302,13 +304,12 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "test-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{"model.py", "tokenizer.py"}, - expectDocs: []string{"README.md", "LICENSE"}, - expectName: "test-model", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{"model.py", "tokenizer.py"}, + expectDocs: []string{"README.md", "LICENSE"}, + expectName: "test-model", }, { name: "empty workspace", @@ -316,12 +317,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "empty-model", }, - ignoreUnrecognizedFileType: false, - expectError: true, - expectConfigs: []string{}, - expectModels: []string{}, - expectCodes: []string{}, - expectName: "empty-model", + expectError: true, + expectConfigs: []string{}, + expectModels: []string{}, + expectCodes: []string{}, + expectName: "empty-model", }, { name: "with config.json values", @@ -337,15 +337,14 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "config-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "config-model", - expectArch: "transformer", - expectFamily: "llama", - expectPrecision: "float16", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "config-model", + expectArch: "transformer", + expectFamily: "llama", + expectPrecision: "float16", }, { name: "nested directory structure", @@ -370,8 +369,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "nested-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "docs/config/parameters.yaml", @@ -406,12 +404,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "deep-nested", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"level1/config.json"}, - expectModels: []string{"level1/level2/level3/model.bin"}, - expectCodes: []string{"level1/level2/level3/level4/code.py"}, - expectName: "deep-nested", + expectError: false, + expectConfigs: []string{"level1/config.json"}, + expectModels: []string{"level1/level2/level3/model.bin"}, + expectCodes: []string{"level1/level2/level3/level4/code.py"}, + expectName: "deep-nested", }, { name: "hidden files and directories", @@ -429,12 +426,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "hidden-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "hidden-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "hidden-test", }, { name: "multiple config files in directories", @@ -453,16 +449,15 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "multi-config", Format: "pytorch", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "multi-config", - expectArch: "transformer", - expectFamily: "gpt2", - expectFormat: "pytorch", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "multi-config", + expectArch: "transformer", + expectFamily: "gpt2", + expectFormat: "pytorch", + expectPrecision: "float32", }, { name: "special filename characters", @@ -482,8 +477,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "special-chars", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config with spaces.json", "dir-with-hyphens/config.json", @@ -520,8 +514,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "mixed-types", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "configs/main.json", "configs/params.yaml", @@ -554,8 +547,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "same-names", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "dir1/config.json", "dir2/config.json", @@ -604,8 +596,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "llama-7b", ParamSize: "7B", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "generation_config.json", @@ -650,14 +641,13 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "conflict-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "generation_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "conflict-test", - expectFamily: "llama", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "generation_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "conflict-test", + expectFamily: "llama", + expectPrecision: "float32", }, { name: "skipping internal directories", @@ -679,12 +669,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "skip-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"normal/model.bin"}, - expectCodes: []string{"valid_dir/model.py"}, - expectName: "skip-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"normal/model.bin"}, + expectCodes: []string{"valid_dir/model.py"}, + expectName: "skip-test", }, } @@ -734,7 +723,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { // Set workspace in config tc.config.Workspace = tempDir - tc.config.IgnoreUnrecognizedFileTypes = tc.ignoreUnrecognizedFileType + tc.config.IgnoreUnrecognizedFileTypes = false // Call the function being tested mf, err := NewModelfileByWorkspace(tempDir, tc.config) @@ -1202,6 +1191,631 @@ func createHashSet(items []string) *hashset.Set { for _, item := range items { set.Add(item) } - return set } + +// TestGenerateByModelConfig tests the generateByModelConfig method +func TestGenerateByModelConfig(t *testing.T) { + testcases := []struct { + name string + configFiles map[string]map[string]interface{} + expectedArch string + expectedFamily string + expectedPrecision string + expectError bool + }{ + { + name: "config.json with all fields", + configFiles: map[string]map[string]interface{}{ + "config.json": { + "model_type": "llama", + "torch_dtype": "float16", + "transformers_version": "4.30.0", + }, + }, + expectedArch: "transformer", + expectedFamily: "llama", + expectedPrecision: "float16", + expectError: false, + }, + { + name: "generation_config.json overrides config.json", + configFiles: map[string]map[string]interface{}{ + "config.json": { + "model_type": "llama", + "torch_dtype": "float16", + }, + "generation_config.json": { + "model_type": "gpt2", + "torch_dtype": "float32", + }, + }, + expectedFamily: "gpt2", + expectedPrecision: "float32", + expectError: false, + }, + { + name: "invalid json file", + configFiles: map[string]map[string]interface{}{ + "config.json": nil, // This will create invalid JSON + }, + expectError: false, // Invalid JSON is silently ignored + }, + { + name: "no config files", + configFiles: map[string]map[string]interface{}{}, + expectError: false, + }, + { + name: "partial config", + configFiles: map[string]map[string]interface{}{ + "config.json": { + "model_type": "bert", + }, + }, + expectedFamily: "bert", + expectError: false, + }, + { + name: "transformers version only", + configFiles: map[string]map[string]interface{}{ + "config.json": { + "transformers_version": "4.25.1", + }, + }, + expectedArch: "transformer", + expectError: false, + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "model-config-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create config files + for filename, content := range tc.configFiles { + if content == nil { + // Create invalid JSON + err = os.WriteFile(filepath.Join(tempDir, filename), []byte("invalid json"), 0644) + } else { + data, err := json.Marshal(content) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tempDir, filename), data, 0644) + } + require.NoError(t, err) + } + + mf := &modelfile{workspace: tempDir} + err = mf.generateByModelConfig() + + if tc.expectError { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal(tc.expectedArch, mf.arch) + assert.Equal(tc.expectedFamily, mf.family) + assert.Equal(tc.expectedPrecision, mf.precision) + } + }) + } +} + +// TestGenerateByConfig tests the generateByConfig method +func TestGenerateByConfig(t *testing.T) { + testcases := []struct { + name string + workspace string + config *configmodelfile.GenerateConfig + expectedName string + expectedArch string + expectedFamily string + expectedFormat string + expectedParamsize string + expectedPrecision string + expectedQuantization string + }{ + { + name: "default name from workspace", + workspace: "/path/to/my-model", + config: &configmodelfile.GenerateConfig{}, + expectedName: "my-model", + }, + { + name: "custom name overrides workspace", + workspace: "/path/to/workspace", + config: &configmodelfile.GenerateConfig{ + Name: "custom-model", + }, + expectedName: "custom-model", + }, + { + name: "all fields provided", + workspace: "/path/to/model", + config: &configmodelfile.GenerateConfig{ + Name: "test-model", + Arch: "transformer", + Family: "llama", + Format: "safetensors", + ParamSize: "7B", + Precision: "float16", + Quantization: "int8", + }, + expectedName: "test-model", + expectedArch: "transformer", + expectedFamily: "llama", + expectedFormat: "safetensors", + expectedParamsize: "7B", + expectedPrecision: "float16", + expectedQuantization: "int8", + }, + { + name: "empty config uses workspace name", + workspace: "/tmp/test-workspace", + config: &configmodelfile.GenerateConfig{}, + expectedName: "test-workspace", + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + mf := &modelfile{workspace: tc.workspace} + mf.generateByConfig(tc.config) + + assert.Equal(tc.expectedName, mf.name) + assert.Equal(tc.expectedArch, mf.arch) + assert.Equal(tc.expectedFamily, mf.family) + assert.Equal(tc.expectedFormat, mf.format) + assert.Equal(tc.expectedParamsize, mf.paramsize) + assert.Equal(tc.expectedPrecision, mf.precision) + assert.Equal(tc.expectedQuantization, mf.quantization) + }) + } +} + +// TestValidateWorkspace tests the validateWorkspace method specifically +func TestValidateWorkspace(t *testing.T) { + testcases := []struct { + name string + setupFunc func() (string, func()) + expectError bool + errorMsg string + }{ + { + name: "valid directory workspace", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "valid-workspace-*") + require.NoError(t, err) + // Create a file to make it non-empty + err = os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("test"), 0644) + require.NoError(t, err) + return tempDir, func() { os.RemoveAll(tempDir) } + }, + expectError: false, + }, + { + name: "empty directory workspace", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "empty-workspace-*") + require.NoError(t, err) + return tempDir, func() { os.RemoveAll(tempDir) } + }, + expectError: true, + errorMsg: "the workspace is empty", + }, + { + name: "non-existent workspace", + setupFunc: func() (string, func()) { + return "/non/existent/path", func() {} + }, + expectError: true, + errorMsg: "access to workspace failed", + }, + { + name: "file instead of directory", + setupFunc: func() (string, func()) { + tempFile, err := os.CreateTemp("", "file-workspace-*") + require.NoError(t, err) + tempFile.Close() + return tempFile.Name(), func() { os.Remove(tempFile.Name()) } + }, + expectError: true, + errorMsg: "the workspace is not a directory", + }, + { + name: "symbolic link workspace", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "symlink-target-*") + require.NoError(t, err) + // Create a file in target + err = os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("test"), 0644) + require.NoError(t, err) + + symlinkPath := tempDir + "-symlink" + err = os.Symlink(tempDir, symlinkPath) + require.NoError(t, err) + + return symlinkPath, func() { + os.RemoveAll(tempDir) + os.Remove(symlinkPath) + } + }, + expectError: true, + errorMsg: "the workspace should not be a symbolic link", + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + workspace, cleanup := tc.setupFunc() + defer cleanup() + + mf := &modelfile{workspace: workspace} + err := mf.validateWorkspace() + + if tc.expectError { + assert.Error(err) + assert.Contains(err.Error(), tc.errorMsg) + } else { + assert.NoError(err) + } + }) + } +} + +func TestWorkspaceLimits(t *testing.T) { + testcases := []struct { + name string + setupFunc func() (string, func()) + expectError bool + errorMsg string + }{ + { + name: "exceeds file count limit", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "file-count-test-*") + require.NoError(t, err) + + // Create more files than the limit (1024) + for i := 0; i < MaxWorkspaceFileCount+10; i++ { + filename := fmt.Sprintf("file_%d.txt", i) + err = os.WriteFile(filepath.Join(tempDir, filename), []byte("test"), 0644) + require.NoError(t, err) + } + + return tempDir, func() { os.RemoveAll(tempDir) } + }, + expectError: true, + errorMsg: "exceeds maximum file count limit", + }, + { + name: "normal sized files should pass", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "normal-file-test-*") + require.NoError(t, err) + + // Create normal sized files including a model file + normalPath := filepath.Join(tempDir, "model.bin") + err = os.WriteFile(normalPath, []byte("test model content"), 0644) + require.NoError(t, err) + + // Add a config file too + configPath := filepath.Join(tempDir, "config.json") + err = os.WriteFile(configPath, []byte(`{"model_type": "test"}`), 0644) + require.NoError(t, err) + + return tempDir, func() { os.RemoveAll(tempDir) } + }, + expectError: false, + }, + { + name: "within limits", + setupFunc: func() (string, func()) { + tempDir, err := os.MkdirTemp("", "within-limits-test-*") + require.NoError(t, err) + + // Create a reasonable number of files including valid model/code files + for i := 0; i < 8; i++ { + filename := fmt.Sprintf("file_%d.txt", i) + err = os.WriteFile(filepath.Join(tempDir, filename), []byte("test content"), 0644) + require.NoError(t, err) + } + + // Add valid model and config files + err = os.WriteFile(filepath.Join(tempDir, "model.bin"), []byte("model content"), 0644) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(`{"model_type": "test"}`), 0644) + require.NoError(t, err) + + return tempDir, func() { os.RemoveAll(tempDir) } + }, + expectError: false, + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + workspace, cleanup := tc.setupFunc() + defer cleanup() + + config := &configmodelfile.GenerateConfig{ + Name: "test-model", + } + + _, err := NewModelfileByWorkspace(workspace, config) + + if tc.expectError { + assert.Error(err) + assert.Contains(err.Error(), tc.errorMsg) + } else { + assert.NoError(err) + } + }) + } +} + +func TestFileTypeClassification(t *testing.T) { + testcases := []struct { + name string + files map[string]int64 // filename -> size + expectedConfigs []string + expectedModels []string + expectedCodes []string + expectedDocs []string + }{ + { + name: "various file types", + files: map[string]int64{ + "config.json": 1024, + "model.bin": 1024 * 1024 * 1024, // 1GB - large file + "script.py": 2048, + "README.md": 512, + "tokenizer.json": 1024, + "weights.safetensors": 2 * 1024 * 1024 * 1024, // 2GB - large file + "inference.py": 3072, + "LICENSE": 256, + }, + expectedConfigs: []string{"config.json", "tokenizer.json"}, + expectedModels: []string{"model.bin", "weights.safetensors"}, + expectedCodes: []string{"script.py", "inference.py"}, + expectedDocs: []string{"README.md", "LICENSE"}, + }, + { + name: "small unknown files treated as code files", + files: map[string]int64{ + "unknown_small_file": 1024, // 1KB - below threshold + "another_unknown.xyz": 50 * 1024, // 50KB - below threshold + "config.json": 1024, // Add a config to make workspace valid + }, + expectedConfigs: []string{"config.json"}, + expectedModels: []string{}, + expectedCodes: []string{"unknown_small_file", "another_unknown.xyz"}, + }, + { + name: "case insensitive file extensions", + files: map[string]int64{ + "CONFIG.JSON": 1024, + "Model.BIN": 1024, + "Script.PY": 1024, + "README.MD": 1024, + }, + expectedConfigs: []string{"CONFIG.JSON"}, + expectedModels: []string{"Model.BIN"}, + expectedCodes: []string{"Script.PY"}, + expectedDocs: []string{"README.MD"}, + }, + { + name: "nested directory files", + files: map[string]int64{ + "configs/model.json": 1024, + "models/pytorch_model.bin": 1024, + "src/utils.py": 1024, + "docs/guide.md": 1024, + }, + expectedConfigs: []string{"configs/model.json"}, + expectedModels: []string{"models/pytorch_model.bin"}, + expectedCodes: []string{"src/utils.py"}, + expectedDocs: []string{"docs/guide.md"}, + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "file-type-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create files with specified sizes + for filename, size := range tc.files { + fullPath := filepath.Join(tempDir, filename) + + // Create directory if needed + dir := filepath.Dir(fullPath) + if dir != tempDir { + err = os.MkdirAll(dir, 0755) + require.NoError(t, err) + } + + // Create file with specified size + file, err := os.Create(fullPath) + require.NoError(t, err) + + if size > 0 { + // For large files, we'll write a smaller amount and then seek to create the size + // For testing purposes, just write some content + content := strings.Repeat("x", int(min(size, 1024))) + _, err = file.WriteString(content) + require.NoError(t, err) + } + file.Close() + } + + config := &configmodelfile.GenerateConfig{ + Name: "test-classification", + } + + mf, err := NewModelfileByWorkspace(tempDir, config) + require.NoError(t, err) + + assert.ElementsMatch(tc.expectedConfigs, mf.GetConfigs()) + assert.ElementsMatch(tc.expectedModels, mf.GetModels()) + assert.ElementsMatch(tc.expectedCodes, mf.GetCodes()) + assert.ElementsMatch(tc.expectedDocs, mf.GetDocs()) + }) + } +} + +func TestSkippedFiles(t *testing.T) { + tempDir, err := os.MkdirTemp("", "skip-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create various files and directories that should be skipped + filesToCreate := []string{ + ".hidden_file", + ".git/config", + "__pycache__/cache.pyc", + "model.pyo", + "script.pyd", + "modelfile", + "normal_file.txt", + "valid_model.bin", + } + + dirsToCreate := []string{ + ".git", + "__pycache__", + ".hidden_dir", + "normal_dir", + } + + for _, dir := range dirsToCreate { + err = os.MkdirAll(filepath.Join(tempDir, dir), 0755) + require.NoError(t, err) + } + + for _, file := range filesToCreate { + fullPath := filepath.Join(tempDir, file) + dir := filepath.Dir(fullPath) + if dir != tempDir { + err = os.MkdirAll(dir, 0755) + require.NoError(t, err) + } + err = os.WriteFile(fullPath, []byte("content"), 0644) + require.NoError(t, err) + } + + config := &configmodelfile.GenerateConfig{ + Name: "skip-test", + } + + mf, err := NewModelfileByWorkspace(tempDir, config) + require.NoError(t, err) + + // Only normal_file.txt and valid_model.bin should be included + allFiles := append(append(append(mf.GetConfigs(), mf.GetModels()...), mf.GetCodes()...), mf.GetDocs()...) + + assert := assert.New(t) + + // Check that skipped files are not included + for _, file := range allFiles { + assert.NotContains(file, ".hidden") + assert.NotContains(file, ".git") + assert.NotContains(file, "__pycache__") + assert.NotContains(file, ".pyc") + assert.NotContains(file, ".pyo") + assert.NotContains(file, ".pyd") + assert.NotEqual(file, "modelfile") + } + + // Check that normal files are included + expectedFiles := []string{"normal_file.txt", "valid_model.bin"} + for _, expectedFile := range expectedFiles { + found := false + for _, file := range allFiles { + if strings.Contains(file, expectedFile) { + found = true + break + } + } + assert.True(found, "Expected file %s should be included", expectedFile) + } +} + +func TestEmptyWorkspaceHandling(t *testing.T) { + testcases := []struct { + name string + files []string + expectError bool + errorMsg string + }{ + { + name: "only documentation files", + files: []string{"README.md", "LICENSE", "docs.txt"}, + expectError: true, + errorMsg: "no model/code/dataset found", + }, + { + name: "only configuration files", + files: []string{"config.json", "settings.yaml"}, + expectError: true, + errorMsg: "no model/code/dataset found", + }, + { + name: "has model files", + files: []string{"model.bin", "config.json"}, + expectError: false, + }, + { + name: "has code files", + files: []string{"script.py", "config.json"}, + expectError: false, + }, + { + name: "mixed valid files", + files: []string{"model.bin", "script.py", "README.md"}, + expectError: false, + }, + } + + assert := assert.New(t) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "empty-workspace-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + for _, filename := range tc.files { + err = os.WriteFile(filepath.Join(tempDir, filename), []byte("content"), 0644) + require.NoError(t, err) + } + + config := &configmodelfile.GenerateConfig{ + Name: "test-model", + } + + _, err = NewModelfileByWorkspace(tempDir, config) + + if tc.expectError { + assert.Error(err) + assert.Contains(err.Error(), tc.errorMsg) + } else { + assert.NoError(err) + } + }) + } +} + +// min returns the minimum of two integers +func min(a, b int64) int64 { + if a < b { + return a + } + return b +}