diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 985c88c7..3cad9361 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -62,7 +62,7 @@ func init() { flags.StringVar(&generateConfig.Precision, "precision", "", "specify model precision, such as bf16, fp16, int8, etc") flags.StringVar(&generateConfig.Quantization, "quantization", "", "specify model quantization, such as awq, gptq, etc") flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") - flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") + flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "[deprecated] ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") if err := viper.BindPFlags(flags); err != nil { diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index 2f1251a3..fe1c6f50 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -31,7 +31,7 @@ type GenerateConfig struct { Name string Version string Output string - IgnoreUnrecognizedFileTypes bool + IgnoreUnrecognizedFileTypes bool // [deprecated] will be removed in the next release Overwrite bool Arch string Family string @@ -83,7 +83,7 @@ func (g *GenerateConfig) Validate() error { // If the output path does not exist, we can create the modelfile. if _, err := os.Stat(g.Output); err == nil { if !g.Overwrite { - return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", g.Output) + return fmt.Errorf("modelfile already exists at %s - use --overwrite to overwrite", g.Output) } } diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index 92d07aca..6cf001db 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -17,6 +17,7 @@ package modelfile import ( + "fmt" "path/filepath" "strings" ) @@ -26,6 +27,8 @@ var ( ConfigFilePatterns = []string{ "*.json", // JSON configuration files "*.jsonl", // JSON Lines format + "*.json5", // JSON5 files + "*.jsonc", // JSON with comments "*.yaml", // YAML configuration files "*.yml", // YAML alternative extension "*.toml", // TOML configuration files @@ -45,6 +48,12 @@ var ( "*.meta", // Model metadata "*tokenizer.model*", // Tokenizer files (e.g., Mistral v3) "config.json.*", // Model configuration variants + "*.hparams", // Hyperparameter files + "*.params", // Parameter files + "*.hyperparams", // Hyperparameter configuration + "*.wandb", // Weights & Biases configuration + "*.mlflow", // MLflow configuration + "*.tensorboard", // TensorBoard configuration } // Model file patterns - supported model file extensions. @@ -56,29 +65,75 @@ var ( "*.bin", // General binary format "*.pt", // PyTorch model "*.pth", // PyTorch model (alternative extension) + "*.mar", // PyTorch Model Archive + "*.pte", // PyTorch ExecuTorch format + "*.pt2", // PyTorch 2.0 export format + "*.ptl", // PyTorch Mobile format // TensorFlow formats. "*.tflite", // TensorFlow Lite "*.h5", // Keras HDF5 format "*.hdf", // Hierarchical Data Format "*.hdf5", // HDF5 (alternative extension) + "*.pb", // TensorFlow SavedModel/Frozen Graph + "*.meta", // TensorFlow checkpoint metadata + "*.data-*", // TensorFlow checkpoint data files + "*.index", // TensorFlow checkpoint index + + // GGML formats. + "*.gguf", // GGML Universal Format + "*.ggml", // GGML format (legacy) + "*.ggmf", // GGMF format (deprecated) + "*.ggjt", // GGJT format (deprecated) + "*.q4_0", // GGML Q4_0 quantization + "*.q4_1", // GGML Q4_1 quantization + "*.q5_0", // GGML Q5_0 quantization + "*.q5_1", // GGML Q5_1 quantization + "*.q8_0", // GGML Q8_0 quantization + "*.f16", // GGML F16 format + "*.f32", // GGML F32 format + + // checkpoint formats. + "*.ckpt", // Checkpoint format + "*.checkpoint", // Checkpoint format (alternative extension) + "*.dist_ckpt", // Distributed checkpoint format + + // Semantics-specific formats + "*.tensor", // Generic tensor format + "*.weights", // Generic weights format + "*.state", // State files + "*.embedding", // Embedding files + "*.vocab", // Vocabulary files (when binary) // Other ML frameworks. "*.ot", // OpenVINO format "*.engine", // TensorRT format "*.trt", // TensorRT format (alternative extension) "*.onnx", // Open Neural Network Exchange format - "*.gguf", // GGML Universal Format "*.msgpack", // MessagePack serialization "*.model", // Some NLP frameworks "*.pkl", // Pickle format "*.pickle", // Pickle format (alternative extension) - "*.ckpt", // Checkpoint format - "*.checkpoint", // Checkpoint format (alternative extension) + "*.keras", // Keras native format + "*.joblib", // Joblib serialization (scikit-learn) + "*.npy", // NumPy array format + "*.npz", // NumPy compressed archive + "*.nc", // NetCDF format + "*.mlmodel", // Apple Core ML format + "*.coreml", // Apple Core ML format (alternative) + "*.mleap", // MLeap format (Spark ML) + "*.surml", // SurrealML format + "*.llamafile", // Llamafile format + "*.caffemodel", // Caffe model format + "*.prototxt", // Caffe model definition + "*.dlc", // Qualcomm Deep Learning Container + "*.circle", // Samsung Circle format + "*.nb", // Neural Network Binary format } // Code file patterns - supported script and notebook files. CodeFilePatterns = []string{ + // language source files "*.py", // Python source files "*.ipynb", // Jupyter notebooks "*.sh", // Shell scripts @@ -88,11 +143,18 @@ var ( "*.hxx", // C++ header files "*.cpp", // C++ source files "*.cc", // C++ source files + "*.cxx", // C++ source files (alternative) + "*.c++", // C++ source files (alternative) "*.hpp", // C++ header files "*.hh", // C++ header files + "*.h++", // C++ header files (alternative) "*.java", // Java source files "*.js", // JavaScript source files + "*.mjs", // JavaScript ES6 modules + "*.cjs", // CommonJS modules + "*.jsx", // React JSX files "*.ts", // TypeScript source files + "*.tsx", // TypeScript JSX files "*.go", // Go source files "*.rs", // Rust source files "*.swift", // Swift source files @@ -100,30 +162,118 @@ var ( "*.php", // PHP source files "*.scala", // Scala source files "*.kt", // Kotlin source files + "*.kts", // Kotlin script files "*.r", // R source files + "*.R", // R source files (alternative) "*.m", // MATLAB/Objective-C source files + "*.mm", // Objective-C++ source files "*.f", // Fortran source files "*.f90", // Fortran 90 source files + "*.f95", // Fortran 95 source files + "*.f03", // Fortran 2003 source files + "*.f08", // Fortran 2008 source files "*.jl", // Julia source files "*.lua", // Lua source files "*.pl", // Perl source files + "*.pm", // Perl modules "*.cs", // C# source files "*.vb", // Visual Basic source files "*.dart", // Dart source files "*.groovy", // Groovy source files "*.elm", // Elm source files "*.erl", // Erlang source files + "*.hrl", // Erlang header files "*.ex", // Elixir source files + "*.exs", // Elixir script files "*.hs", // Haskell source files + "*.lhs", // Literate Haskell source files "*.clj", // Clojure source files "*.cljs", // ClojureScript source files - "*.cljc", // Clojure Common Lisp source files + "*.cljc", // Clojure Common source files "*.cl", // Common Lisp source files "*.lisp", // Lisp source files + "*.lsp", // Lisp source files (alternative) "*.scm", // Scheme source files + "*.ss", // Scheme source files (alternative) + "*.rkt", // Racket source files + "*.sql", // SQL files + "*.psql", // PostgreSQL files + "*.mysql", // MySQL files + "*.sqlite", // SQLite files + "*.zig", // Zig source files "*.cu", // CUDA source files "*.cuh", // CUDA header files + // Scripting and automation + "*.bash", // Bash scripts + "*.zsh", // Zsh scripts + "*.fish", // Fish shell scripts + "*.csh", // C shell scripts + "*.tcsh", // TC shell scripts + "*.ksh", // Korn shell scripts + "*.ps1", // PowerShell scripts + "*.psm1", // PowerShell modules + "*.psd1", // PowerShell data files + "*.bat", // Windows batch files + "*.cmd", // Windows command files + "*.vbs", // VBScript files + "*.wsf", // Windows Script Files + "*.applescript", // AppleScript files + "*.scpt", // AppleScript compiled files + "*.awk", // AWK scripts + "*.sed", // sed scripts + "*.expect", // Expect scripts + + // Build and project files + "*.env", // Environment variable files + "*.env.*", // Environment files with suffixes + ".env*", // Environment files (hidden) + "Makefile*", // Makefile variants + "*.dockerfile", // Dockerfile configurations + "Dockerfile*", // Dockerfile variants + "*.mk", // Make include files + "*.cmake", // CMake files + "CMakeLists.txt", // CMake configuration + "*.gradle", // Gradle build files + "*.gradle.kts", // Kotlin DSL Gradle files + "build.gradle*", // Gradle build files + "settings.gradle*", // Gradle settings files + "*.sbt", // SBT build files + "*.mill", // Mill build files + "*.bazel", // Bazel build files + "*.bzl", // Bazel extension files + "BUILD*", // Bazel BUILD files + "WORKSPACE*", // Bazel WORKSPACE files + "*.buck", // Buck build files + "BUCK*", // Buck BUILD files + "*.ninja", // Ninja build files + "*.gyp", // GYP build files + "*.gypi", // GYP include files + "*.waf", // Waf build files + "wscript*", // Waf build scripts + "package.json", // Node.js package file + "package-lock.json", // Node.js lock file + "yarn.lock", // Yarn lock file + "pnpm-lock.yaml", // PNPM lock file + "requirements*.txt", // Python requirements + "Pipfile*", // Python Pipenv files + "pyproject.toml", // Python project configuration + "setup.cfg", // Python setup configuration + "tox.ini", // Python tox configuration + "poetry.lock", // Python Poetry lock file + "Cargo.toml", // Rust package configuration + "Cargo.lock", // Rust lock file + "go.mod", // Go module file + "go.sum", // Go checksum file + "composer.json", // PHP Composer file + "composer.lock", // PHP Composer lock file + "Gemfile*", // Ruby Gemfile + "*.gemspec", // Ruby gem specification + "mix.exs", // Elixir Mix file + "mix.lock", // Elixir Mix lock file + "rebar.config", // Erlang Rebar config + "rebar.lock", // Erlang Rebar lock file + // Library files. "*.so", // Shared object files "*.dll", // Dynamic Link Library @@ -144,6 +294,93 @@ var ( "*requirements*", // Dependency specifications "*.log", // Log files + // Office documents + "*.doc", // Microsoft Word 97-2003 Document + "*.docx", // Microsoft Word Document + "*.docm", // Word Macro-Enabled Document + "*.dot", // Word 97-2003 Template + "*.dotx", // Word Template + "*.dotm", // Word Macro-Enabled Template + "*.rtf", // Rich Text Format + "*.odt", // OpenDocument Text + "*.ott", // OpenDocument Text Template + "*.fodt", // Flat OpenDocument Text + "*.pages", // Apple Pages document + "*.wpd", // WordPerfect document + + // Spreadsheet documents + "*.xls", // Microsoft Excel 97-2003 Workbook + "*.xlsx", // Microsoft Excel Workbook + "*.xlsm", // Excel Macro-Enabled Workbook + "*.xlsb", // Excel Binary Workbook + "*.xlt", // Excel 97-2003 Template + "*.xltx", // Excel Template + "*.xltm", // Excel Macro-Enabled Template + "*.ods", // OpenDocument Spreadsheet + "*.ots", // OpenDocument Spreadsheet Template + "*.fods", // Flat OpenDocument Spreadsheet + "*.numbers", // Apple Numbers spreadsheet + "*.csv", // Comma-Separated Values + + // Presentation documents + "*.ppt", // Microsoft PowerPoint 97-2003 Presentation + "*.pptx", // Microsoft PowerPoint Presentation + "*.pptm", // PowerPoint Macro-Enabled Presentation + "*.pps", // PowerPoint 97-2003 Show + "*.ppsx", // PowerPoint Show + "*.ppsm", // PowerPoint Macro-Enabled Show + "*.pot", // PowerPoint 97-2003 Template + "*.potx", // PowerPoint Template + "*.potm", // PowerPoint Macro-Enabled Template + "*.odp", // OpenDocument Presentation + "*.otp", // OpenDocument Presentation Template + "*.fodp", // Flat OpenDocument Presentation + "*.key", // Apple Keynote presentation + + // eBook formats + "*.epub", // Electronic Publication + "*.mobi", // Mobipocket eBook + "*.azw", // Amazon Kindle eBook + "*.azw3", // Amazon Kindle eBook (KF8) + "*.fb2", // FictionBook 2.0 + "*.fb3", // FictionBook 3.0 + "*.lit", // Microsoft Literature + "*.pdb", // Palm Database/Document File + "*.djvu", // DjVu document + "*.djv", // DjVu document (alternative extension) + + // Web and markup documents + "*.html", // HyperText Markup Language + "*.htm", // HyperText Markup Language (alternative) + "*.xhtml", // Extensible HyperText Markup Language + "*.mhtml", // MIME HTML (Web Archive) + "*.mht", // MIME HTML (Web Archive, alternative) + "*.xml", // eXtensible Markup Language + "*.xsl", // eXtensible Stylesheet Language + "*.xslt", // XSL Transformations + + // Technical documentation formats + "*.tex", // LaTeX document + "*.latex", // LaTeX document (alternative) + "*.ltx", // LaTeX document (alternative) + "*.bib", // BibTeX bibliography + "*.rst", // reStructuredText + "*.asciidoc", // AsciiDoc + "*.adoc", // AsciiDoc (alternative) + "*.textile", // Textile markup + "*.wiki", // Wiki markup + "*.mediawiki", // MediaWiki markup + "*.org", // Org-mode document + "*.texi", // Texinfo document + "*.texinfo", // Texinfo document (alternative) + "*.info", // GNU Info document + "*.man", // Manual page + + // Archive and compressed documents + "*.chm", // Compiled HTML Help + "*.hlp", // Windows Help File + "*.xps", // XML Paper Specification + // Image assets. "*.jpg", // JPEG image format "*.jpeg", // JPEG alternative extension @@ -180,6 +417,14 @@ var ( "*.pyo", // Python optimized bytecode "*.pyd", // Python dynamic modules } + + // Large file size threshold + WeightFileSizeThreshold int64 = 128 * 1024 * 1024 + + // Workspace limits + MaxSingleFileSize int64 = 128 * 1024 * 1024 * 1024 // 128GB + MaxWorkspaceFileCount int = 1024 // 1024 files + MaxTotalWorkspaceSize int64 = 8 * 1024 * 1024 * 1024 * 1024 // 8TB ) // IsFileType checks if the filename matches any of the given patterns @@ -216,3 +461,23 @@ func isSkippable(filename string) bool { return false } + +// For large unknown file type, usually it is a weight file. +func SizeShouldBeWeightFile(size int64) bool { + return size > WeightFileSizeThreshold +} + +// formatBytes converts byte size to human-readable format +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"} + return fmt.Sprintf("%.1f%s", float64(bytes)/float64(div), units[exp+1]) +} diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index edccb59d..e5b15290 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -205,7 +205,11 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC doc: hashset.New(), } - if err := mf.generateByWorkspace(config.IgnoreUnrecognizedFileTypes); err != nil { + if err := mf.validateWorkspace(); err != nil { + return nil, err + } + + if err := mf.generateByWorkspace(); err != nil { return nil, err } @@ -217,8 +221,42 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC return mf, nil } +// validateWorkspace validates the workspace directory +func (mf *modelfile) validateWorkspace() error { + // check if the workspace is a directory, symbolic link, or empty + info, err := os.Lstat(mf.workspace) + if err != nil { + return fmt.Errorf("access to workspace failed: %s", err) + } + + // check if the workspace is a symbolic link + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("for simplicity, the workspace should not be a symbolic link: %s", mf.workspace) + } + + // check if the workspace is a directory + if !info.IsDir() { + return fmt.Errorf("the workspace is not a directory: %s", mf.workspace) + } + + // check if the workspace is empty by reading directory contents + entries, err := os.ReadDir(mf.workspace) + if err != nil { + return fmt.Errorf("failed to read workspace directory: %s", err) + } + if len(entries) == 0 { + return fmt.Errorf("the workspace is empty: %s", mf.workspace) + } + + return nil +} + // generateByWorkspace generates the modelfile by the workspace's files. -func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error { +func (mf *modelfile) generateByWorkspace() error { + // Initialize counters for workspace limits validation + var fileCount int + var totalSize int64 + // Walk the path and get the files. if err := filepath.Walk(mf.workspace, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -240,6 +278,26 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return nil } + // Check workspace limits for regular files + fileCount++ + fileSize := info.Size() + totalSize += fileSize + + // Check single file size limit + if fileSize > MaxSingleFileSize { + return fmt.Errorf("file %s exceeds maximum single file size limit of %d bytes (%s)", path, MaxSingleFileSize, formatBytes(MaxSingleFileSize)) + } + + // Check file count limit + if fileCount > MaxWorkspaceFileCount { + return fmt.Errorf("workspace exceeds maximum file count limit of %d files", MaxWorkspaceFileCount) + } + + // Check total workspace size limit + if totalSize > MaxTotalWorkspaceSize { + return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize)) + } + // Get relative path from the base directory. relPath, err := filepath.Rel(mf.workspace, path) if err != nil { @@ -256,12 +314,14 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error case IsFileType(filename, DocFilePatterns): mf.doc.Add(relPath) default: - // Skip unrecognized files if IgnoreUnrecognizedFileTypes is true. - if ignoreUnrecognizedFileTypes { - return nil + // If the file is large, usually it is a weight file. + if SizeShouldBeWeightFile(info.Size()) { + mf.model.Add(relPath) + } else { + mf.code.Add(relPath) } - return fmt.Errorf("unknown file type: %s", filename) + return nil } return nil @@ -269,8 +329,8 @@ func (mf *modelfile) generateByWorkspace(ignoreUnrecognizedFileTypes bool) error return err } - if mf.model.Size() == 0 { - return fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually") + if mf.model.Size() == 0 && mf.code.Size() == 0 && mf.dataset.Size() == 0 { + return fmt.Errorf("no model/code/dataset found - you have to create the Modelfile by yourself") } return nil diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index a4c9a971..8b5eb12b 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -19,6 +19,7 @@ package modelfile import ( "encoding/json" "errors" + "fmt" "os" "path/filepath" "sort" @@ -269,25 +270,24 @@ NAME bar func TestNewModelfileByWorkspace(t *testing.T) { testcases := []struct { - name string - setupFiles map[string]string - setupDirs []string - configJson map[string]interface{} - genConfigJson map[string]interface{} - config *configmodelfile.GenerateConfig - ignoreUnrecognizedFileType bool - expectError bool - expectConfigs []string - expectModels []string - expectCodes []string - expectDocs []string - expectName string - expectArch string - expectFamily string - expectFormat string - expectParamsize string - expectPrecision string - expectQuantization string + name string + setupFiles map[string]string + setupDirs []string + configJson map[string]interface{} + genConfigJson map[string]interface{} + config *configmodelfile.GenerateConfig + expectError bool + expectConfigs []string + expectModels []string + expectCodes []string + expectDocs []string + expectName string + expectArch string + expectFamily string + expectFormat string + expectParamsize string + expectPrecision string + expectQuantization string }{ { name: "basic case", @@ -302,13 +302,12 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "test-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{"model.py", "tokenizer.py"}, - expectDocs: []string{"README.md", "LICENSE"}, - expectName: "test-model", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{"model.py", "tokenizer.py"}, + expectDocs: []string{"README.md", "LICENSE"}, + expectName: "test-model", }, { name: "empty workspace", @@ -316,12 +315,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "empty-model", }, - ignoreUnrecognizedFileType: false, - expectError: true, - expectConfigs: []string{}, - expectModels: []string{}, - expectCodes: []string{}, - expectName: "empty-model", + expectError: true, + expectConfigs: []string{}, + expectModels: []string{}, + expectCodes: []string{}, + expectName: "empty-model", }, { name: "with config.json values", @@ -337,15 +335,14 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "config-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "config-model", - expectArch: "transformer", - expectFamily: "llama", - expectPrecision: "float16", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "config-model", + expectArch: "transformer", + expectFamily: "llama", + expectPrecision: "float16", }, { name: "nested directory structure", @@ -370,8 +367,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "nested-model", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "docs/config/parameters.yaml", @@ -406,12 +402,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "deep-nested", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"level1/config.json"}, - expectModels: []string{"level1/level2/level3/model.bin"}, - expectCodes: []string{"level1/level2/level3/level4/code.py"}, - expectName: "deep-nested", + expectError: false, + expectConfigs: []string{"level1/config.json"}, + expectModels: []string{"level1/level2/level3/model.bin"}, + expectCodes: []string{"level1/level2/level3/level4/code.py"}, + expectName: "deep-nested", }, { name: "hidden files and directories", @@ -429,12 +424,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "hidden-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "hidden-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "hidden-test", }, { name: "multiple config files in directories", @@ -453,16 +447,15 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "multi-config", Format: "pytorch", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "multi-config", - expectArch: "transformer", - expectFamily: "gpt2", - expectFormat: "pytorch", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "models/config.json", "models/gen_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "multi-config", + expectArch: "transformer", + expectFamily: "gpt2", + expectFormat: "pytorch", + expectPrecision: "float32", }, { name: "special filename characters", @@ -482,8 +475,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "special-chars", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config with spaces.json", "dir-with-hyphens/config.json", @@ -520,8 +512,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "mixed-types", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "configs/main.json", "configs/params.yaml", @@ -554,8 +545,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "same-names", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "dir1/config.json", "dir2/config.json", @@ -604,8 +594,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { Name: "llama-7b", ParamSize: "7B", }, - ignoreUnrecognizedFileType: false, - expectError: false, + expectError: false, expectConfigs: []string{ "config.json", "generation_config.json", @@ -650,14 +639,13 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "conflict-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json", "generation_config.json"}, - expectModels: []string{"model.bin"}, - expectCodes: []string{}, - expectName: "conflict-test", - expectFamily: "llama", - expectPrecision: "float32", + expectError: false, + expectConfigs: []string{"config.json", "generation_config.json"}, + expectModels: []string{"model.bin"}, + expectCodes: []string{}, + expectName: "conflict-test", + expectFamily: "llama", + expectPrecision: "float32", }, { name: "skipping internal directories", @@ -679,12 +667,11 @@ func TestNewModelfileByWorkspace(t *testing.T) { config: &configmodelfile.GenerateConfig{ Name: "skip-test", }, - ignoreUnrecognizedFileType: false, - expectError: false, - expectConfigs: []string{"config.json"}, - expectModels: []string{"normal/model.bin"}, - expectCodes: []string{"valid_dir/model.py"}, - expectName: "skip-test", + expectError: false, + expectConfigs: []string{"config.json"}, + expectModels: []string{"normal/model.bin"}, + expectCodes: []string{"valid_dir/model.py"}, + expectName: "skip-test", }, } @@ -734,7 +721,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { // Set workspace in config tc.config.Workspace = tempDir - tc.config.IgnoreUnrecognizedFileTypes = tc.ignoreUnrecognizedFileType + tc.config.IgnoreUnrecognizedFileTypes = false // Call the function being tested mf, err := NewModelfileByWorkspace(tempDir, tc.config) @@ -1202,6 +1189,331 @@ func createHashSet(items []string) *hashset.Set { for _, item := range items { set.Add(item) } - return set } + +func TestValidateWorkspace(t *testing.T) { + tests := []struct { + name string + setupFunc func() (string, func()) // returns workspace path and cleanup function + expectedError string + }{ + { + name: "valid_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a test file to make directory non-empty + testFile := filepath.Join(tmpDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + { + name: "non_existent_directory", + setupFunc: func() (string, func()) { + return "/non/existent/path", func() {} + }, + expectedError: "access to workspace failed:", + }, + { + name: "file_instead_of_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + testFile := filepath.Join(tmpDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + return testFile, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "the workspace is not a directory:", + }, + { + name: "empty_directory", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "the workspace is empty:", + }, + { + name: "symbolic_link", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "validate_workspace_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create target directory with content + targetDir := filepath.Join(tmpDir, "target") + err = os.Mkdir(targetDir, 0755) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create target dir: %v", err) + } + + testFile := filepath.Join(targetDir, "test.txt") + err = os.WriteFile(testFile, []byte("test content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + // Create symbolic link + linkPath := filepath.Join(tmpDir, "link") + err = os.Symlink(targetDir, linkPath) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create symlink: %v", err) + } + + return linkPath, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "for simplicity, the workspace should not be a symbolic link:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workspace, cleanup := tt.setupFunc() + defer cleanup() + + mf := &modelfile{ + workspace: workspace, + config: hashset.New(), + model: hashset.New(), + code: hashset.New(), + dataset: hashset.New(), + doc: hashset.New(), + } + + err := mf.validateWorkspace() + + if tt.expectedError == "" { + assert.NoError(t, err, "Expected no error for test case: %s", tt.name) + } else { + assert.Error(t, err, "Expected error for test case: %s", tt.name) + assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text for test case: %s", tt.name) + } + }) + } +} + +func TestWorkspaceLimits(t *testing.T) { + tests := []struct { + name string + setupFunc func() (string, func()) // returns workspace path and cleanup function + expectedError string + }{ + { + name: "single_file_exceeds_128GB_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a test file that simulates exceeding 128GB + // We'll use a sparse file to avoid actually creating 128GB+ of data + testFile := filepath.Join(tmpDir, "large_model.bin") + file, err := os.Create(testFile) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file: %v", err) + } + + // Seek to position that would make file appear larger than 128GB + largeSize := MaxSingleFileSize + 1 + _, err = file.Seek(largeSize-1, 0) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to seek in file: %v", err) + } + + // Write one byte at the end to make the file that size + _, err = file.Write([]byte{0}) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to write to file: %v", err) + } + file.Close() + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum single file size limit", + }, + { + name: "file_count_exceeds_2048_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create more than 2048 files + for i := 0; i <= MaxWorkspaceFileCount; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("test"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum file count limit", + }, + { + name: "total_workspace_size_exceeds_8TB_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a few large files that together exceed 8TB + // Each file will be just under 128GB (single file limit) + // We'll create 70 files of ~120GB each to exceed 8TB total + fileSize := MaxSingleFileSize - (1024 * 1024 * 1024) // 127GB per file + numFiles := 70 // 70 * 127GB = ~8.9TB + + for i := 0; i < numFiles; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.bin", i)) + file, err := os.Create(testFile) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + + // Use sparse file technique + _, err = file.Seek(fileSize-1, 0) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to seek in file %d: %v", i, err) + } + + _, err = file.Write([]byte{0}) + if err != nil { + file.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to write to file %d: %v", i, err) + } + file.Close() + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "exceeds maximum total size limit", + }, + { + name: "workspace_within_all_limits", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create a reasonable number of small files + for i := 0; i < 10; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("small_file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("small content"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + // Add a config file to make it a valid workspace + configFile := filepath.Join(tmpDir, "config.json") + err = os.WriteFile(configFile, []byte(`{"model_type": "test"}`), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create config file: %v", err) + } + + // Add a model file to make it a valid workspace + modelFile := filepath.Join(tmpDir, "model.safetensors") + err = os.WriteFile(modelFile, []byte("fake model data"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create model file: %v", err) + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + { + name: "exactly_at_file_count_limit", + setupFunc: func() (string, func()) { + tmpDir, err := os.MkdirTemp("", "workspace_limits_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + // Create exactly 2048 files (should be allowed) + // Include one model file to make it valid + modelFile := filepath.Join(tmpDir, "model.safetensors") + err = os.WriteFile(modelFile, []byte("fake model data"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create model file: %v", err) + } + + // Create the remaining files to reach exactly 2048 + for i := 1; i < MaxWorkspaceFileCount; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("file_%d.txt", i)) + err = os.WriteFile(testFile, []byte("test"), 0644) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create test file %d: %v", i, err) + } + } + + return tmpDir, func() { os.RemoveAll(tmpDir) } + }, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workspace, cleanup := tt.setupFunc() + defer cleanup() + + // Create a modelfile instance and try to generate by workspace + config := &configmodelfile.GenerateConfig{} + _, err := NewModelfileByWorkspace(workspace, config) + + if tt.expectedError == "" { + assert.NoError(t, err, "Expected no error for test case: %s", tt.name) + } else { + assert.Error(t, err, "Expected error for test case: %s", tt.name) + assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text for test case: %s", tt.name) + } + }) + } +}