From 8e81b1643d380afc5eaef11485ae1522e8379dc6 Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Sun, 15 Mar 2026 20:33:24 -0700 Subject: [PATCH 1/6] feat: skip re-downloading models on shared storage When multiple nodes mount the same filesystem (e.g., GPFS/NFS at /storage/models), the model-agent on each node would independently re-download from HuggingFace or OCI, causing rate-limiting and hours of unnecessary I/O. Add isModelAlreadyDownloaded() that checks: 1. config.json exists 2. If model.safetensors.index.json exists, ALL expected shards present 3. Otherwise, at least one weight file (.safetensors/.bin/.pt/.gguf) Only applies to fresh Download tasks (not DownloadOverride) so spec updates and failed retries still re-evaluate. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/modelagent/gopher.go | 119 ++++++++++++++++++++++++++++++++++ pkg/modelagent/gopher_test.go | 82 +++++++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index ff48c041..73a965b6 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -2,6 +2,7 @@ package modelagent import ( "context" + "encoding/json" "fmt" "os" "path/filepath" @@ -324,6 +325,25 @@ func (s *Gopher) processTask(task *GopherTask) error { s.logger.Errorf("Failed to get target directory path for model %s: %v", modelInfo, err) return err } + + // Check if the model is already present on shared storage. + // Only for fresh Download tasks — DownloadOverride indicates a spec change + // or failed retry that must re-evaluate the model files. + if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Model %s already exists at %s (shared storage), skipping OCI download", modelInfo, destPath) + var baseModel *v1beta1.BaseModel + var clusterBaseModel *v1beta1.ClusterBaseModel + if task.BaseModel != nil { + baseModel = task.BaseModel + } else if task.ClusterBaseModel != nil { + clusterBaseModel = task.ClusterBaseModel + } + if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { + s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err) + } + break + } + err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error { downloadErr := s.downloadModel(ctx, osUri, destPath, task) if downloadErr != nil { @@ -971,6 +991,29 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // Create destination path destPath := getDestPath(&baseModelSpec, s.modelRootDir) + // Check if the model is already present on shared storage (e.g., another node + // already downloaded it to the same NFS/shared filesystem path). When storage is + // shared across nodes, each model-agent would otherwise independently re-download + // from HuggingFace, causing rate-limiting and hours of unnecessary I/O. + // Only for fresh Download tasks — DownloadOverride indicates a spec change + // or failed retry that must re-evaluate the model files. + if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath) + + var baseModel *v1beta1.BaseModel + var clusterBaseModel *v1beta1.ClusterBaseModel + if task.BaseModel != nil { + baseModel = task.BaseModel + } else if task.ClusterBaseModel != nil { + clusterBaseModel = task.ClusterBaseModel + } + + if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { + s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err) + } + return nil + } + // fetch sha value based on model ID from Huggingface model API shaStr, isShaAvailable := s.fetchSha(ctx, hfComponents.ModelID, name) isReuseEligible, matchedModelTypeAndModeName, parentPath := s.isEligibleForOptimization(ctx, task, baseModelSpec, modelType, namespace, isShaAvailable, shaStr, name) @@ -1440,3 +1483,79 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre s.logger.Infof("parent entry %s:%s exists on node configmap: %v", parentName, parentDir, exists) return !exists } + +// isModelAlreadyDownloaded checks whether the model files are already present at +// destPath. This handles the shared-storage case: when multiple nodes mount the +// same filesystem (e.g., NFS at /storage/models), the first node that finishes an +// HF download writes the files once. Subsequent nodes should detect the existing +// files and skip re-downloading. +// +// The check requires model.safetensors.index.json to be present so that ALL +// expected shards can be verified. Without the index file, completeness cannot +// be determined and the method returns false (letting the normal download proceed). +func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { + // Check if directory exists + info, err := os.Stat(destPath) + if err != nil || !info.IsDir() { + return false + } + + // Check for config.json (primary indicator of a complete HF download). + // Use err != nil (not os.IsNotExist) so that NFS I/O errors and permission + // errors are treated conservatively as "not present" rather than silently + // falling through as "exists". + configPath := filepath.Join(destPath, "config.json") + if _, err := os.Stat(configPath); err != nil { + return false + } + + // Require model.safetensors.index.json for shard completeness verification. + // Without it we cannot determine if the download is complete, so fall through + // to let the normal download path handle it. + indexPath := filepath.Join(destPath, "model.safetensors.index.json") + indexData, err := os.ReadFile(indexPath) + if err != nil { + return false + } + + var index struct { + WeightMap map[string]string `json:"weight_map"` + } + if err := json.Unmarshal(indexData, &index); err != nil { + s.logger.Warnf("Failed to parse model index file %s: %v", indexPath, err) + return false + } + + if len(index.WeightMap) == 0 { + return false + } + + // Build a set of filenames for fast lookup + entries, err := os.ReadDir(destPath) + if err != nil { + return false + } + fileSet := make(map[string]bool, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + fileSet[entry.Name()] = true + } + } + + // Collect unique shard filenames and verify every one exists on disk + expectedShards := make(map[string]bool) + for _, shard := range index.WeightMap { + expectedShards[shard] = true + } + + for shard := range expectedShards { + if !fileSet[shard] { + s.logger.Infof("Model at %s is missing shard %s (expected %d shards), not treating as complete", + destPath, shard, len(expectedShards)) + return false + } + } + + s.logger.Infof("Model at %s has all %d expected shards from index", destPath, len(expectedShards)) + return true +} diff --git a/pkg/modelagent/gopher_test.go b/pkg/modelagent/gopher_test.go index b6d56f36..ad074a47 100644 --- a/pkg/modelagent/gopher_test.go +++ b/pkg/modelagent/gopher_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "os" + "path/filepath" "testing" "k8s.io/apimachinery/pkg/runtime/schema" @@ -961,6 +963,86 @@ func TestIsEligibleForOptimization_NoMatch(t *testing.T) { assert.Empty(t, parent) } +func TestIsModelAlreadyDownloaded(t *testing.T) { + logger, _ := zap.NewDevelopment() + sugaredLogger := logger.Sugar() + defer func() { _ = sugaredLogger.Sync() }() + + gopher := &Gopher{logger: sugaredLogger} + + t.Run("nonexistent directory returns false", func(t *testing.T) { + assert.False(t, gopher.isModelAlreadyDownloaded("/nonexistent/path/that/does/not/exist")) + }) + + t.Run("empty directory returns false", func(t *testing.T) { + dir := t.TempDir() + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("directory with only config.json returns false (no weights)", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("directory with only weights returns false (no config.json)", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and weights but no index file returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir), "without index file, cannot verify completeness") + }) + + t.Run("file path instead of directory returns false", func(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "somefile") + assert.NoError(t, os.WriteFile(filePath, []byte("data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(filePath)) + }) + + // Shard completeness tests using model.safetensors.index.json + + t.Run("index with all shards present returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("index with missing shard returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + // Only write shard 1, shard 2 is missing + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("malformed index file returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{invalid json`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("index with empty weight_map returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{"weight_map":{}}`), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) +} + func TestIsEligibleForOptimization_AlwaysDownloadNotEligible(t *testing.T) { nodeName := "node-1" sha := "123abc" From 3b1140ec075bc279d0334ad34948366de9747d04 Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Thu, 2 Apr 2026 23:00:04 -0700 Subject: [PATCH 2/6] feat: extend shared-storage skip to diffusion pipelines and single-file models Extend isModelAlreadyDownloaded() to handle three model layouts: 1. Sharded safetensors (existing): verify all shards via index 2. Diffusion pipelines (new): verify component dirs via model_index.json 3. Single-file fallback (new): config.json + weight file heuristic Also: - Propagate safeParseAndUpdateModelConfig errors instead of swallowing - Add path traversal guard for untrusted JSON keys - Add detailed logging at every decision point for debugging - Differentiate os.Stat permission errors from "not exist" Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/modelagent/gopher.go | 184 +++++++++++++++++++++++++++++----- pkg/modelagent/gopher_test.go | 169 +++++++++++++++++++++++++++---- 2 files changed, 311 insertions(+), 42 deletions(-) diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index 73a965b6..707bd26b 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -339,7 +339,7 @@ func (s *Gopher) processTask(task *GopherTask) error { clusterBaseModel = task.ClusterBaseModel } if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { - s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err) + return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) } break } @@ -1009,7 +1009,7 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, } if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { - s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err) + return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) } return nil } @@ -1490,31 +1490,65 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre // HF download writes the files once. Subsequent nodes should detect the existing // files and skip re-downloading. // -// The check requires model.safetensors.index.json to be present so that ALL -// expected shards can be verified. Without the index file, completeness cannot -// be determined and the method returns false (letting the normal download proceed). +// Supports three model layouts: +// 1. Sharded safetensors: model.safetensors.index.json lists all expected shards. +// 2. Diffusion pipelines: model_index.json lists component subdirectories, each +// containing its own config and weight files. +// 3. Single-file models: no index file, but config.json + at least one weight +// file (.safetensors, .bin, .pt, .gguf) present. Note: this fallback cannot +// verify shard completeness for multi-shard models without an index file. +// +// All filesystem checks treat errors conservatively as "not present" so that +// NFS I/O or permission errors fall through to the normal download path rather +// than silently skipping the download. func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { // Check if directory exists info, err := os.Stat(destPath) - if err != nil || !info.IsDir() { + if err != nil { + if os.IsNotExist(err) { + s.logger.Infof("isModelAlreadyDownloaded(%s): directory does not exist", destPath) + } else { + s.logger.Warnf("isModelAlreadyDownloaded(%s): failed to stat directory: %v", destPath, err) + } return false } - - // Check for config.json (primary indicator of a complete HF download). - // Use err != nil (not os.IsNotExist) so that NFS I/O errors and permission - // errors are treated conservatively as "not present" rather than silently - // falling through as "exists". - configPath := filepath.Join(destPath, "config.json") - if _, err := os.Stat(configPath); err != nil { + if !info.IsDir() { + s.logger.Warnf("isModelAlreadyDownloaded(%s): path exists but is not a directory", destPath) return false } - // Require model.safetensors.index.json for shard completeness verification. - // Without it we cannot determine if the download is complete, so fall through - // to let the normal download path handle it. + // Try each layout in order of specificity. + // When an index file exists, its verdict is authoritative — we don't fall + // through to a weaker check that might accept an incomplete download. + + // 1. Sharded safetensors with model.safetensors.index.json + indexPath := filepath.Join(destPath, "model.safetensors.index.json") + if _, err := os.Stat(indexPath); err == nil { + return s.checkSafetensorsIndex(destPath) + } + + // 2. Diffusion pipeline with model_index.json + diffIndexPath := filepath.Join(destPath, "model_index.json") + if _, err := os.Stat(diffIndexPath); err == nil { + return s.checkDiffusionIndex(destPath) + } + + // 3. Fallback: config.json + at least one weight file + if s.checkConfigAndWeights(destPath) { + return true + } + + s.logger.Infof("isModelAlreadyDownloaded(%s): no known layout matched, will proceed with download", destPath) + return false +} + +// checkSafetensorsIndex verifies a sharded safetensors model by reading +// model.safetensors.index.json and ensuring every listed shard file exists. +func (s *Gopher) checkSafetensorsIndex(destPath string) bool { indexPath := filepath.Join(destPath, "model.safetensors.index.json") indexData, err := os.ReadFile(indexPath) if err != nil { + s.logger.Warnf("checkSafetensorsIndex(%s): failed to read index file: %v", destPath, err) return false } @@ -1522,17 +1556,17 @@ func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { WeightMap map[string]string `json:"weight_map"` } if err := json.Unmarshal(indexData, &index); err != nil { - s.logger.Warnf("Failed to parse model index file %s: %v", indexPath, err) + s.logger.Warnf("checkSafetensorsIndex(%s): failed to parse index file (may be mid-write by another node on shared storage): %v", destPath, err) return false } - if len(index.WeightMap) == 0 { + s.logger.Infof("checkSafetensorsIndex(%s): index has empty weight_map", destPath) return false } - // Build a set of filenames for fast lookup entries, err := os.ReadDir(destPath) if err != nil { + s.logger.Warnf("checkSafetensorsIndex(%s): failed to read directory: %v", destPath, err) return false } fileSet := make(map[string]bool, len(entries)) @@ -1542,20 +1576,124 @@ func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { } } - // Collect unique shard filenames and verify every one exists on disk + // Verify every expected shard exists expectedShards := make(map[string]bool) for _, shard := range index.WeightMap { expectedShards[shard] = true } - for shard := range expectedShards { if !fileSet[shard] { - s.logger.Infof("Model at %s is missing shard %s (expected %d shards), not treating as complete", + s.logger.Infof("checkSafetensorsIndex(%s): missing shard %s (expected %d shards), not treating as complete", destPath, shard, len(expectedShards)) return false } } - s.logger.Infof("Model at %s has all %d expected shards from index", destPath, len(expectedShards)) + s.logger.Infof("checkSafetensorsIndex(%s): verified, all %d shards present", destPath, len(expectedShards)) return true } + +// checkDiffusionIndex verifies a diffusion pipeline model by reading +// model_index.json and ensuring every listed component subdirectory exists +// and is non-empty. +func (s *Gopher) checkDiffusionIndex(destPath string) bool { + indexPath := filepath.Join(destPath, "model_index.json") + indexData, err := os.ReadFile(indexPath) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to read index file: %v", destPath, err) + return false + } + + var index map[string]interface{} + if err := json.Unmarshal(indexData, &index); err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to parse index file (may be mid-write by another node on shared storage): %v", destPath, err) + return false + } + + // Components are top-level keys that don't start with "_" (e.g. "transformer", + // "vae", "text_encoder"). Keys like "_class_name" and "_diffusers_version" are + // metadata. + componentCount := 0 + for key, val := range index { + if len(key) > 0 && key[0] == '_' { + continue + } + // Component values are arrays like ["diffusers", "ClassName"] or null + if val == nil { + continue + } + // Guard against path traversal from untrusted JSON keys + if filepath.Base(key) != key { + s.logger.Warnf("checkDiffusionIndex(%s): skipping suspicious component key %q", destPath, key) + return false + } + componentCount++ + compDir := filepath.Join(destPath, key) + dirInfo, err := os.Stat(compDir) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to stat component directory %s: %v", destPath, key, err) + return false + } + if !dirInfo.IsDir() { + s.logger.Warnf("checkDiffusionIndex(%s): component %s exists but is not a directory", destPath, key) + return false + } + // Check that the component directory is not empty + compEntries, err := os.ReadDir(compDir) + if err != nil { + s.logger.Warnf("checkDiffusionIndex(%s): failed to read component directory %s: %v", destPath, key, err) + return false + } + if len(compEntries) == 0 { + s.logger.Infof("checkDiffusionIndex(%s): component directory %s is empty", destPath, key) + return false + } + } + + if componentCount == 0 { + s.logger.Infof("checkDiffusionIndex(%s): model_index.json has no components", destPath) + return false + } + + s.logger.Infof("checkDiffusionIndex(%s): verified, all %d components present", destPath, componentCount) + return true +} + +// checkConfigAndWeights is a fallback check for single-file models or models +// without an index file. It verifies config.json exists and at least one +// weight file (.safetensors, .bin, .pt, .gguf) is present. +func (s *Gopher) checkConfigAndWeights(destPath string) bool { + configPath := filepath.Join(destPath, "config.json") + if _, err := os.Stat(configPath); err != nil { + s.logger.Infof("checkConfigAndWeights(%s): no config.json found", destPath) + return false + } + + entries, err := os.ReadDir(destPath) + if err != nil { + s.logger.Warnf("checkConfigAndWeights(%s): failed to read directory: %v", destPath, err) + return false + } + + weightExtensions := map[string]bool{ + ".safetensors": true, + ".bin": true, + ".pt": true, + ".gguf": true, + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + ext := filepath.Ext(entry.Name()) + if weightExtensions[ext] { + s.logger.Warnf("checkConfigAndWeights(%s): using fallback heuristic (no index file found). "+ + "config.json + weight file %s found, but shard completeness cannot be fully verified. "+ + "If the model fails to load, re-trigger a DownloadOverride to force re-download.", destPath, entry.Name()) + return true + } + } + + s.logger.Infof("checkConfigAndWeights(%s): config.json exists but no weight files found", destPath) + return false +} diff --git a/pkg/modelagent/gopher_test.go b/pkg/modelagent/gopher_test.go index ad074a47..2ae97813 100644 --- a/pkg/modelagent/gopher_test.go +++ b/pkg/modelagent/gopher_test.go @@ -985,19 +985,6 @@ func TestIsModelAlreadyDownloaded(t *testing.T) { assert.False(t, gopher.isModelAlreadyDownloaded(dir)) }) - t.Run("directory with only weights returns false (no config.json)", func(t *testing.T) { - dir := t.TempDir() - assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644)) - assert.False(t, gopher.isModelAlreadyDownloaded(dir)) - }) - - t.Run("config.json and weights but no index file returns false", func(t *testing.T) { - dir := t.TempDir() - assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) - assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644)) - assert.False(t, gopher.isModelAlreadyDownloaded(dir), "without index file, cannot verify completeness") - }) - t.Run("file path instead of directory returns false", func(t *testing.T) { dir := t.TempDir() filePath := filepath.Join(dir, "somefile") @@ -1005,9 +992,9 @@ func TestIsModelAlreadyDownloaded(t *testing.T) { assert.False(t, gopher.isModelAlreadyDownloaded(filePath)) }) - // Shard completeness tests using model.safetensors.index.json + // --- Sharded safetensors tests --- - t.Run("index with all shards present returns true", func(t *testing.T) { + t.Run("safetensors index with all shards present returns true", func(t *testing.T) { dir := t.TempDir() assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` @@ -1017,28 +1004,172 @@ func TestIsModelAlreadyDownloaded(t *testing.T) { assert.True(t, gopher.isModelAlreadyDownloaded(dir)) }) - t.Run("index with missing shard returns false", func(t *testing.T) { + t.Run("safetensors index with missing shard returns false", func(t *testing.T) { dir := t.TempDir() assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) - // Only write shard 1, shard 2 is missing assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) assert.False(t, gopher.isModelAlreadyDownloaded(dir)) }) - t.Run("malformed index file returns false", func(t *testing.T) { + t.Run("malformed safetensors index returns false (no fallback)", func(t *testing.T) { dir := t.TempDir() assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{invalid json`), 0644)) assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644)) + // Index exists but is malformed → authoritative false, no fallback assert.False(t, gopher.isModelAlreadyDownloaded(dir)) }) - t.Run("index with empty weight_map returns false", func(t *testing.T) { + t.Run("safetensors index with empty weight_map returns false (no fallback)", func(t *testing.T) { dir := t.TempDir() assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{"weight_map":{}}`), 0644)) + // Index exists but empty → authoritative false, no fallback + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Diffusion pipeline tests --- + + t.Run("diffusion model with all components returns true", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"QwenImagePipeline","_diffusers_version":"0.36.0","scheduler":["diffusers","FlowMatchEulerDiscreteScheduler"],"transformer":["diffusers","QwenTransformer2DModel"],"vae":["diffusers","AutoencoderKL"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // Create component subdirectories with at least one file each + for _, comp := range []string{"scheduler", "transformer", "vae"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with missing component returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"QwenImagePipeline","scheduler":["diffusers","Scheduler"],"transformer":["diffusers","Transformer"],"vae":["diffusers","VAE"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // Only create scheduler and transformer, vae is missing + for _, comp := range []string{"scheduler", "transformer"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with empty component directory returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + assert.NoError(t, os.MkdirAll(filepath.Join(dir, "transformer"), 0755)) + // transformer dir exists but is empty + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion model with null component is skipped", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"],"safety_checker":null}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + compDir := filepath.Join(dir, "transformer") + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "model.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Fallback: config.json + weight file --- + + t.Run("config.json and single safetensors file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .bin weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"bert"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "pytorch_model.bin"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .gguf weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-q4.gguf"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("config.json and .pt weight file returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.pt"), []byte("weight data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("only weights without config.json returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("weight data"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Strategy ordering tests --- + + t.Run("safetensors index takes priority over diffusion index", func(t *testing.T) { + dir := t.TempDir() + // Failing safetensors index (missing shard) + stIndex := `{"weight_map":{"w1":"shard-00001.safetensors","w2":"shard-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(stIndex), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "shard-00001.safetensors"), []byte("data"), 0644)) + // shard-00002 is missing + + // Passing diffusion index + diffIndex := `{"_class_name":"Pipeline","encoder":["diffusers","Encoder"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(diffIndex), 0644)) + assert.NoError(t, os.MkdirAll(filepath.Join(dir, "encoder"), 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "encoder", "config.json"), []byte(`{}`), 0644)) + + // Safetensors check is authoritative — must return false despite valid diffusion layout + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("safetensors index with many weights mapping to same shard returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)) + index := `{"weight_map":{"layer1.weight":"model-00001-of-00001.safetensors","layer1.bias":"model-00001-of-00001.safetensors","layer2.weight":"model-00001-of-00001.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644)) + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + // --- Additional diffusion edge cases --- + + t.Run("malformed diffusion index returns false", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{not valid json`), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion index with only metadata keys returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","_diffusers_version":"0.36.0"}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion component exists as file not directory returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","transformer":["diffusers","Model"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + // "transformer" is a file, not a directory + assert.NoError(t, os.WriteFile(filepath.Join(dir, "transformer"), []byte("not a directory"), 0644)) + assert.False(t, gopher.isModelAlreadyDownloaded(dir)) + }) + + t.Run("diffusion index with path traversal component key returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"_class_name":"Pipeline","../etc":["diffusers","Exploit"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) assert.False(t, gopher.isModelAlreadyDownloaded(dir)) }) } From 8d0baf69396a453e6edc1dd7dba014a9f3464f51 Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Thu, 2 Apr 2026 23:51:04 -0700 Subject: [PATCH 3/6] feat: add K8s Lease-based leader election for shared storage downloads On shared filesystems (NFS/GPFS/CephFS/Lustre), only one agent should download model files per model. Others wait with jitter and recheck. - Detect shared storage via syscall.Statfs filesystem magic numbers - Per-model K8s Leases (model-download-) for parallel downloads of different models while preventing duplicate downloads of the same model - Non-leaders wait up to 5.5min with 15s jitter between rechecks - Handle expired leases, API errors (IsNotFound vs transient), context cancellation - Guard against nil HolderIdentity, lease renewal conflicts - Use time.NewTimer with explicit Stop() to avoid timer leaks - Fall back to downloading if leader times out Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/model-agent/main.go | 2 + pkg/modelagent/gopher.go | 266 +++++++++++++++++++++++++++++++++++---- 2 files changed, 246 insertions(+), 22 deletions(-) diff --git a/cmd/model-agent/main.go b/cmd/model-agent/main.go index 08909beb..0e005a6b 100644 --- a/cmd/model-agent/main.go +++ b/cmd/model-agent/main.go @@ -268,6 +268,8 @@ func initializeComponents( logger, baseModelInformer.Lister(), clusterBaseModelInformer.Lister(), + cfg.nodeName, + cfg.namespace, ) if err != nil { return nil, nil, fmt.Errorf("failed to create gopher: %w", err) diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index 707bd26b..f6edf947 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -4,17 +4,21 @@ import ( "context" "encoding/json" "fmt" + "math/rand" "os" "path/filepath" "strings" "sync" "sync/atomic" + "syscall" "time" "k8s.io/apimachinery/pkg/labels" "github.com/oracle/oci-go-sdk/v65/objectstorage" "go.uber.org/zap" + coordinationv1 "k8s.io/api/coordination/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -64,6 +68,13 @@ type Gopher struct { // Track active downloads for cancellation activeDownloads map[string]context.CancelFunc // key: model UID activeDownloadsMutex sync.RWMutex + + // Shared storage coordination: when modelRootDir is on a shared filesystem + // (NFS, GPFS, CephFS, Lustre), only the download leader should download. + // Other agents wait with jitter and recheck for files on disk. + isSharedStorage bool + nodeName string + namespace string } const ( @@ -84,12 +95,19 @@ func NewGopher( metrics *Metrics, logger *zap.SugaredLogger, baseModelLister omev1beta1lister.BaseModelLister, - clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister) (*Gopher, error) { + clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister, + nodeName string, + namespace string) (*Gopher, error) { if xetConfig == nil { return nil, fmt.Errorf("xet hugging face config cannot be nil") } + shared := isSharedFilesystem(modelRootDir, logger) + if shared { + logger.Infof("Detected shared filesystem at %s — download leader election enabled", modelRootDir) + } + return &Gopher{ modelConfigParser: modelConfigParser, configMapReconciler: configMapReconciler, @@ -106,6 +124,9 @@ func NewGopher( activeDownloads: make(map[string]context.CancelFunc), baseModelLister: baseModelLister, clusterBaseModelLister: clusterBaseModelLister, + isSharedStorage: shared, + nodeName: nodeName, + namespace: namespace, }, nil } @@ -331,15 +352,8 @@ func (s *Gopher) processTask(task *GopherTask) error { // or failed retry that must re-evaluate the model files. if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) { s.logger.Infof("Model %s already exists at %s (shared storage), skipping OCI download", modelInfo, destPath) - var baseModel *v1beta1.BaseModel - var clusterBaseModel *v1beta1.ClusterBaseModel - if task.BaseModel != nil { - baseModel = task.BaseModel - } else if task.ClusterBaseModel != nil { - clusterBaseModel = task.ClusterBaseModel - } - if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { - return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) + if err := s.skipDownloadAndUpdateConfig(destPath, task); err != nil { + return err } break } @@ -997,21 +1011,25 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // from HuggingFace, causing rate-limiting and hours of unnecessary I/O. // Only for fresh Download tasks — DownloadOverride indicates a spec change // or failed retry that must re-evaluate the model files. - if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) { - s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath) - - var baseModel *v1beta1.BaseModel - var clusterBaseModel *v1beta1.ClusterBaseModel - if task.BaseModel != nil { - baseModel = task.BaseModel - } else if task.ClusterBaseModel != nil { - clusterBaseModel = task.ClusterBaseModel + if task.TaskType == Download { + if s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath) + return s.skipDownloadAndUpdateConfig(destPath, task) } - if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { - return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) + // On shared storage, only the download leader should proceed. Non-leaders + // wait for the leader to finish and then recheck for files on disk. + if s.isSharedStorage && !s.isDownloadLeader(ctx, modelInfo) { + if s.waitForSharedStorageModel(ctx, destPath, modelInfo) { + s.logger.Infof("Model %s appeared on shared storage at %s after waiting for leader", modelInfo, destPath) + return s.skipDownloadAndUpdateConfig(destPath, task) + } + if ctx.Err() != nil { + return fmt.Errorf("download cancelled while waiting for shared storage leader: %w", ctx.Err()) + } + // Timed out waiting — fall through to download as a fallback + s.logger.Warnf("Model %s not found after waiting for leader, proceeding with own download", modelInfo) } - return nil } // fetch sha value based on model ID from Huggingface model API @@ -1484,6 +1502,24 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre return !exists } +// skipDownloadAndUpdateConfig handles the case where model files already exist +// at the destination path (e.g., downloaded by another node on shared storage, +// or left from a previous run). It parses the model config and updates the +// ConfigMap, bypassing the download step. +func (s *Gopher) skipDownloadAndUpdateConfig(destPath string, task *GopherTask) error { + var baseModel *v1beta1.BaseModel + var clusterBaseModel *v1beta1.ClusterBaseModel + if task.BaseModel != nil { + baseModel = task.BaseModel + } else if task.ClusterBaseModel != nil { + clusterBaseModel = task.ClusterBaseModel + } + if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil { + return fmt.Errorf("model files exist at %s but config update failed: %w", destPath, err) + } + return nil +} + // isModelAlreadyDownloaded checks whether the model files are already present at // destPath. This handles the shared-storage case: when multiple nodes mount the // same filesystem (e.g., NFS at /storage/models), the first node that finishes an @@ -1697,3 +1733,189 @@ func (s *Gopher) checkConfigAndWeights(destPath string) bool { s.logger.Infof("checkConfigAndWeights(%s): config.json exists but no weight files found", destPath) return false } + +// isSharedFilesystem detects whether the given path is on a shared/network +// filesystem by checking the filesystem type via syscall.Statfs. +// Known shared filesystem types: NFS, GPFS, CephFS, Lustre, GlusterFS, FUSE. +// Note: filesystem type detection via magic numbers only works on Linux. +// On macOS/Darwin, Statfs_t has a different layout and this will return false. +func isSharedFilesystem(path string, logger *zap.SugaredLogger) bool { + var stat syscall.Statfs_t + if err := syscall.Statfs(path, &stat); err != nil { + logger.Warnf("isSharedFilesystem(%s): syscall.Statfs failed: %v — shared storage detection disabled", path, err) + return false + } + // Filesystem magic numbers (from linux/magic.h and kernel sources) + switch stat.Type { + case 0x6969: // NFS_SUPER_MAGIC + return true + case 0x47504653: // GPFS (IBM Spectrum Scale) + return true + case 0x00C36400: // CEPH_SUPER_MAGIC + return true + case 0x0BD00BD0: // LUSTRE_SUPER_MAGIC + return true + case 0x65735546: // FUSE_SUPER_MAGIC (commonly used for network mounts) + return true + case 0x6A656A62: // GlusterFS + return true + default: + return false + } +} + +const ( + // downloadLeaderLeasePrefix is the prefix for per-model K8s Leases used for + // leader election on shared storage. Each model gets its own lease so different + // models can be downloaded in parallel by different nodes. + downloadLeaderLeasePrefix = "model-download-" + + // downloadLeaderLeaseDuration is how long a leader holds a per-model lease. + downloadLeaderLeaseDuration = 5 * time.Minute + + // sharedStorageRecheckInterval is how often non-leaders recheck for files. + sharedStorageRecheckInterval = 30 * time.Second + + // sharedStorageMaxJitter is the max random jitter added before rechecks. + sharedStorageMaxJitter = 15 * time.Second +) + +// sanitizeLeaseeName converts a model identifier (e.g., "google/gemma-4-31B-it") +// into a valid K8s resource name (lowercase, no slashes, max 253 chars). +func sanitizeLeaseName(modelInfo string) string { + name := strings.ToLower(modelInfo) + name = strings.ReplaceAll(name, "/", "-") + name = strings.ReplaceAll(name, "_", "-") + name = strings.ReplaceAll(name, " ", "-") + name = strings.ReplaceAll(name, ".", "-") + // Strip any remaining non-alphanumeric/dash characters + filtered := make([]byte, 0, len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' { + filtered = append(filtered, c) + } + } + name = string(filtered) + // Trim leading/trailing dashes + name = strings.Trim(name, "-") + // K8s names must be <= 253 chars + if len(name) > 200 { + name = name[:200] + } + return downloadLeaderLeasePrefix + name +} + +// isDownloadLeader checks if this node currently holds the per-model download +// Lease. Each model gets its own lease so different models can be downloaded in +// parallel by different nodes. If the lease doesn't exist, it tries to create it +// (becoming leader). If held by this node, it renews and returns true. If held by +// another node, it checks for expiry and attempts to take over; otherwise returns false. +func (s *Gopher) isDownloadLeader(ctx context.Context, modelInfo string) bool { + leaseName := sanitizeLeaseName(modelInfo) + leasesClient := s.kubeClient.CoordinationV1().Leases(s.namespace) + now := metav1.NewMicroTime(time.Now()) + + lease, err := leasesClient.Get(ctx, leaseName, metav1.GetOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + // API server error — coordination unavailable, download independently + s.logger.Warnf("Failed to check download leader lease (API error): %v — this node will download independently", err) + return true + } + // Lease doesn't exist — try to create it and become leader + leaseDuration := int32(downloadLeaderLeaseDuration.Seconds()) + newLease := &coordinationv1.Lease{ + ObjectMeta: metav1.ObjectMeta{ + Name: leaseName, + Namespace: s.namespace, + }, + Spec: coordinationv1.LeaseSpec{ + HolderIdentity: &s.nodeName, + LeaseDurationSeconds: &leaseDuration, + AcquireTime: &now, + RenewTime: &now, + }, + } + _, createErr := leasesClient.Create(ctx, newLease, metav1.CreateOptions{}) + if createErr != nil { + s.logger.Infof("Failed to acquire download leader lease for %s (another node won): %v", modelInfo, createErr) + return false + } + s.logger.Infof("Acquired download leader lease for %s — this node (%s) will download", modelInfo, s.nodeName) + return true + } + + // Lease exists — check if we hold it or if it's expired + if lease.Spec.HolderIdentity != nil && *lease.Spec.HolderIdentity == s.nodeName { + // We hold it — renew + lease.Spec.RenewTime = &now + _, err = leasesClient.Update(ctx, lease, metav1.UpdateOptions{}) + if err != nil { + if apierrors.IsConflict(err) || apierrors.IsNotFound(err) { + s.logger.Warnf("Lost download leader lease during renewal: %v — yielding leadership", err) + return false + } + s.logger.Warnf("Failed to renew download leader lease (transient error): %v — proceeding as leader", err) + } + return true + } + + // Another node holds it — check if expired + if lease.Spec.RenewTime != nil && lease.Spec.LeaseDurationSeconds != nil { + expiry := lease.Spec.RenewTime.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second) + if time.Now().After(expiry) { + // Expired — take over + lease.Spec.HolderIdentity = &s.nodeName + lease.Spec.AcquireTime = &now + lease.Spec.RenewTime = &now + _, err = leasesClient.Update(ctx, lease, metav1.UpdateOptions{}) + if err != nil { + s.logger.Infof("Failed to take over expired download leader lease: %v", err) + return false + } + s.logger.Infof("Took over expired download leader lease for %s — this node (%s) will download", modelInfo, s.nodeName) + return true + } + } + + holderID := "" + if lease.Spec.HolderIdentity != nil { + holderID = *lease.Spec.HolderIdentity + } + s.logger.Infof("Download leader lease for %s held by %s — this node (%s) will wait for shared storage files", + modelInfo, holderID, s.nodeName) + return false +} + +// waitForSharedStorageModel waits for a model to appear on shared storage, +// with jitter and periodic rechecks. Returns true if the model appeared (another +// node downloaded it), false if the context was cancelled or max wait exceeded. +// Maximum wait time is downloadLeaderLeaseDuration + 30s (currently 5m30s). +func (s *Gopher) waitForSharedStorageModel(ctx context.Context, destPath string, modelInfo string) bool { + maxWait := downloadLeaderLeaseDuration + 30*time.Second + deadline := time.Now().Add(maxWait) + + for time.Now().Before(deadline) { + // Add random jitter to avoid thundering herd on recheck + jitter := time.Duration(rand.Int63n(int64(sharedStorageMaxJitter))) + s.logger.Infof("Shared storage: waiting %v before rechecking %s for model %s", sharedStorageRecheckInterval+jitter, destPath, modelInfo) + + timer := time.NewTimer(sharedStorageRecheckInterval + jitter) + select { + case <-ctx.Done(): + timer.Stop() + s.logger.Infof("Shared storage: wait cancelled for model %s at %s: %v", modelInfo, destPath, ctx.Err()) + return false + case <-timer.C: + } + + if s.isModelAlreadyDownloaded(destPath) { + s.logger.Infof("Shared storage: model %s appeared at %s (downloaded by leader)", modelInfo, destPath) + return true + } + } + + s.logger.Warnf("Shared storage: timed out waiting for model %s at %s — will attempt own download", modelInfo, destPath) + return false +} From 256c2e7af9be79e99be9d7d88cce5282b6697613 Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Fri, 3 Apr 2026 00:35:08 -0700 Subject: [PATCH 4/6] fix: handle non-component entries in diffusion model_index.json Real-world model_index.json files contain non-component entries like: - "boundary_ratio": 0.9 (float metadata) - "image_encoder": [null, null] (disabled component) Only treat entries as components if they are arrays with at least 2 elements where the first is a non-null string (library name). Also: add lease cleanup after download, fix lease name sanitization (spaces, dots), and improve sanitizeLeaseName for RFC 1123 compliance. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/modelagent/gopher.go | 33 +++++++++++++++++++++++++++++++-- pkg/modelagent/gopher_test.go | 14 ++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index f6edf947..62892711 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -1207,6 +1207,14 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, s.logger.Infof("Successfully downloaded HuggingFace model %s to %s", modelInfo, downloadPath) + + // Clean up the per-model download leader lease so other agents detect + // "not found" → check files on disk → skip. Without cleanup, non-leaders + // must wait for the lease to expire (5 min) before they can take it over. + if s.isSharedStorage { + s.releaseDownloadLease(ctx, modelInfo) + } + artifact = s.modelConfigParser.buildArtifactAttribute(shaStr, s.configMapReconciler.getModelConfigMapKey(task.BaseModel, task.ClusterBaseModel), destPath, childrenPaths) } @@ -1654,8 +1662,16 @@ func (s *Gopher) checkDiffusionIndex(destPath string) bool { if len(key) > 0 && key[0] == '_' { continue } - // Component values are arrays like ["diffusers", "ClassName"] or null - if val == nil { + // Components are arrays like ["diffusers", "ClassName"]. Skip: + // - null values (disabled components) + // - non-array values like floats (e.g. "boundary_ratio": 0.9) + // - arrays of nulls (e.g. "image_encoder": [null, null]) + arr, isArray := val.([]interface{}) + if !isArray || len(arr) < 2 { + continue + } + // Check that the first element is a non-null string (library name) + if _, isStr := arr[0].(string); !isStr { continue } // Guard against path traversal from untrusted JSON keys @@ -1888,6 +1904,19 @@ func (s *Gopher) isDownloadLeader(ctx context.Context, modelInfo string) bool { return false } +// releaseDownloadLease deletes the per-model download leader lease after a +// successful download. This allows waiting agents to immediately detect "no lease" +// → check files on disk → skip, instead of waiting for the lease to expire. +func (s *Gopher) releaseDownloadLease(ctx context.Context, modelInfo string) { + leaseName := sanitizeLeaseName(modelInfo) + err := s.kubeClient.CoordinationV1().Leases(s.namespace).Delete(ctx, leaseName, metav1.DeleteOptions{}) + if err != nil { + s.logger.Warnf("Failed to release download leader lease %s: %v (non-critical, lease will expire)", leaseName, err) + } else { + s.logger.Infof("Released download leader lease %s after successful download", leaseName) + } +} + // waitForSharedStorageModel waits for a model to appear on shared storage, // with jitter and periodic rechecks. Returns true if the model appeared (another // node downloaded it), false if the context was cancelled or max wait exceeded. diff --git a/pkg/modelagent/gopher_test.go b/pkg/modelagent/gopher_test.go index 2ae97813..c81ba48a 100644 --- a/pkg/modelagent/gopher_test.go +++ b/pkg/modelagent/gopher_test.go @@ -1077,6 +1077,20 @@ func TestIsModelAlreadyDownloaded(t *testing.T) { assert.True(t, gopher.isModelAlreadyDownloaded(dir)) }) + t.Run("diffusion model with float and null-array metadata is skipped", func(t *testing.T) { + dir := t.TempDir() + // Real-world pattern: boundary_ratio is a float, image_encoder is [null, null] + index := `{"_class_name":"Pipeline","boundary_ratio":0.9,"image_encoder":[null,null],"transformer":["diffusers","Model"],"vae":["diffusers","VAE"]}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(index), 0644)) + for _, comp := range []string{"transformer", "vae"} { + compDir := filepath.Join(dir, comp) + assert.NoError(t, os.MkdirAll(compDir, 0755)) + assert.NoError(t, os.WriteFile(filepath.Join(compDir, "config.json"), []byte(`{}`), 0644)) + } + // boundary_ratio and image_encoder should NOT require directories + assert.True(t, gopher.isModelAlreadyDownloaded(dir)) + }) + // --- Fallback: config.json + weight file --- t.Run("config.json and single safetensors file returns true", func(t *testing.T) { From 2f482869e9e9093898adeb7d9f8cf31c96ed0f1d Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Mon, 6 Apr 2026 15:44:00 -0700 Subject: [PATCH 5/6] feat: add download timeout to prevent stuck workers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Downloads had no timeout — xet C library retried 403 errors forever, blocking all 5 gopher workers for days. Add configurable --download-timeout (default 6h) using context.WithTimeout so stuck downloads are cancelled and workers freed for new tasks. Also: - Release download leader lease on both success and failure paths - Classify context.Canceled (from Delete) separately from download errors - Use fresh context for lease release to avoid near-deadline failures - Short-circuit OCI retry loop when context is already expired Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/model-agent/main.go | 3 ++ cmd/model-agent/main_test.go | 3 ++ pkg/modelagent/gopher.go | 75 +++++++++++++++++++++++++----------- 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/cmd/model-agent/main.go b/cmd/model-agent/main.go index 0e005a6b..6e1b03da 100644 --- a/cmd/model-agent/main.go +++ b/cmd/model-agent/main.go @@ -43,6 +43,7 @@ type config struct { numDownloadWorker int namespace string logLevel string + downloadTimeout time.Duration } // Logger type alias for zap.SugaredLogger @@ -73,6 +74,7 @@ func init() { rootCmd.PersistentFlags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 5, "Number of download workers") rootCmd.PersistentFlags().StringVar(&cfg.namespace, "namespace", "ome", "Kubernetes namespace to use") rootCmd.PersistentFlags().StringVar(&cfg.logLevel, "log-level", "info", "Log level (debug, info, warn, error)") + rootCmd.PersistentFlags().DurationVar(&cfg.downloadTimeout, "download-timeout", 6*time.Hour, "Maximum time allowed for a single model download before it is cancelled") _ = v.BindPFlags(rootCmd.PersistentFlags()) v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) @@ -270,6 +272,7 @@ func initializeComponents( clusterBaseModelInformer.Lister(), cfg.nodeName, cfg.namespace, + cfg.downloadTimeout, ) if err != nil { return nil, nil, fmt.Errorf("failed to create gopher: %w", err) diff --git a/cmd/model-agent/main_test.go b/cmd/model-agent/main_test.go index db035df7..6b5b0507 100644 --- a/cmd/model-agent/main_test.go +++ b/cmd/model-agent/main_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "os" "testing" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" @@ -113,6 +114,7 @@ func TestDefaultConfig(t *testing.T) { testCmd.Flags().StringVar(&cfg.downloadAuthType, "download-auth-type", "instance-principal", "authentication method for model download") testCmd.Flags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 3, "number of download workers") testCmd.Flags().StringVar(&cfg.namespace, "namespace", "ome", "the namespace of the ome model agents daemon set") + testCmd.Flags().DurationVar(&cfg.downloadTimeout, "download-timeout", 6*time.Hour, "maximum time for a single model download") // Call initConfig to set cfg.nodeName initConfig(nil, nil) @@ -127,6 +129,7 @@ func TestDefaultConfig(t *testing.T) { assert.Equal(t, "instance-principal", cfg.downloadAuthType) assert.Equal(t, 3, cfg.numDownloadWorker) assert.Equal(t, "ome", cfg.namespace) + assert.Equal(t, 6*time.Hour, cfg.downloadTimeout) } func TestInitializeLogger(t *testing.T) { diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index 62892711..05fe2e66 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -75,6 +75,11 @@ type Gopher struct { isSharedStorage bool nodeName string namespace string + + // downloadTimeout is the maximum time allowed for a single model download. + // Prevents stuck downloads (e.g., xet retrying 403 errors forever) from + // blocking workers indefinitely. + downloadTimeout time.Duration } const ( @@ -97,11 +102,15 @@ func NewGopher( baseModelLister omev1beta1lister.BaseModelLister, clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister, nodeName string, - namespace string) (*Gopher, error) { + namespace string, + downloadTimeout time.Duration) (*Gopher, error) { if xetConfig == nil { return nil, fmt.Errorf("xet hugging face config cannot be nil") } + if downloadTimeout <= 0 { + return nil, fmt.Errorf("downloadTimeout must be positive, got %v", downloadTimeout) + } shared := isSharedFilesystem(modelRootDir, logger) if shared { @@ -127,6 +136,7 @@ func NewGopher( isSharedStorage: shared, nodeName: nodeName, namespace: namespace, + downloadTimeout: downloadTimeout, }, nil } @@ -297,8 +307,9 @@ func (s *Gopher) processTask(task *GopherTask) error { // Continue with download anyway } - // Create a cancellable context for this download - ctx, cancel = context.WithCancel(context.Background()) + // Create a context with timeout for this download to prevent stuck + // downloads (e.g., xet retrying 403 errors) from blocking workers forever. + ctx, cancel = context.WithTimeout(context.Background(), s.downloadTimeout) // Register the cancel function s.activeDownloadsMutex.Lock() @@ -359,9 +370,13 @@ func (s *Gopher) processTask(task *GopherTask) error { } err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error { + // Short-circuit if context is already done (timeout or cancel) + if ctx.Err() != nil { + return ctx.Err() + } downloadErr := s.downloadModel(ctx, osUri, destPath, task) if downloadErr != nil { - // Check if context was cancelled + // Check if context was cancelled during download if ctx.Err() != nil { s.logger.Infof("Download cancelled for model %s: %v", modelInfo, ctx.Err()) return ctx.Err() @@ -372,12 +387,19 @@ func (s *Gopher) processTask(task *GopherTask) error { return downloadErr }) if err != nil { - s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) - - // Record download failure in metrics + // Record download failure in metrics with specific error classification errorType := "download_error" - if strings.Contains(err.Error(), "MD5") { + if ctx.Err() == context.Canceled { + errorType = "download_cancelled" + s.logger.Infof("Download cancelled for OCI model %s: %v", modelInfo, err) + } else if ctx.Err() == context.DeadlineExceeded { + errorType = "download_timeout" + s.logger.Errorf("Download timed out for OCI model %s after %v: %v", modelInfo, s.downloadTimeout, err) + } else if strings.Contains(err.Error(), "MD5") { errorType = "md5_verification_error" + s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) + } else { + s.logger.Errorf("All download attempts failed for model %s: %v", modelInfo, err) } s.metrics.RecordFailedDownload(modelType, namespace, name, errorType) @@ -1190,9 +1212,25 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // when status becomes Ready/Failed, ensuring the controller sees the final progress downloadPath, err := xet.SnapshotDownloadWithProgress(ctx, config, progressHandler, progressThrottle) + // Always release the download leader lease after a download attempt + // (success or failure). On failure, this lets non-leaders detect "no lease" + // and try their own download instead of waiting the full 5-minute expiry. + // Use a fresh context since the download context may be near its deadline. + if s.isSharedStorage { + leaseCtx, leaseCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer leaseCancel() + s.releaseDownloadLease(leaseCtx, modelInfo) + } + if err != nil { // Check error type for better handling - if strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "rate limit") { + if ctx.Err() == context.Canceled { + s.logger.Infof("Download cancelled for HuggingFace model %s: %v", modelInfo, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "download_cancelled") + } else if ctx.Err() == context.DeadlineExceeded { + s.logger.Errorf("Download timed out for HuggingFace model %s after %v: %v", modelInfo, s.downloadTimeout, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "download_timeout") + } else if strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "rate limit") { s.logger.Warnf("Rate limited while downloading HuggingFace model %s: %v", modelInfo, err) s.metrics.RecordRateLimit(modelType, namespace, name, 30*time.Second) // Estimate s.metrics.RecordFailedDownload(modelType, namespace, name, "rate_limit_error") @@ -1208,13 +1246,6 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, s.logger.Infof("Successfully downloaded HuggingFace model %s to %s", modelInfo, downloadPath) - // Clean up the per-model download leader lease so other agents detect - // "not found" → check files on disk → skip. Without cleanup, non-leaders - // must wait for the lease to expire (5 min) before they can take it over. - if s.isSharedStorage { - s.releaseDownloadLease(ctx, modelInfo) - } - artifact = s.modelConfigParser.buildArtifactAttribute(shaStr, s.configMapReconciler.getModelConfigMapKey(task.BaseModel, task.ClusterBaseModel), destPath, childrenPaths) } @@ -1763,17 +1794,17 @@ func isSharedFilesystem(path string, logger *zap.SugaredLogger) bool { } // Filesystem magic numbers (from linux/magic.h and kernel sources) switch stat.Type { - case 0x6969: // NFS_SUPER_MAGIC + case 0x6969: // NFS_SUPER_MAGIC return true - case 0x47504653: // GPFS (IBM Spectrum Scale) + case 0x47504653: // GPFS (IBM Spectrum Scale) return true - case 0x00C36400: // CEPH_SUPER_MAGIC + case 0x00C36400: // CEPH_SUPER_MAGIC return true - case 0x0BD00BD0: // LUSTRE_SUPER_MAGIC + case 0x0BD00BD0: // LUSTRE_SUPER_MAGIC return true - case 0x65735546: // FUSE_SUPER_MAGIC (commonly used for network mounts) + case 0x65735546: // FUSE_SUPER_MAGIC (commonly used for network mounts) return true - case 0x6A656A62: // GlusterFS + case 0x6A656A62: // GlusterFS return true default: return false From 51b3554c79c6cbd2fce7d2651cf3d74c248b1db5 Mon Sep 17 00:00:00 2001 From: Kangyan Zhou Date: Tue, 14 Apr 2026 20:20:00 -0700 Subject: [PATCH 6/6] fix --- pkg/modelagent/gopher.go | 190 +++++++++++++++++++++++++++++++--- pkg/modelagent/gopher_test.go | 71 +++++++++++++ 2 files changed, 248 insertions(+), 13 deletions(-) diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index 05fe2e66..6bd3e8ee 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -48,6 +48,24 @@ type GopherTask struct { TensorRTLLMShapeFilter *TensorRTLLMShapeFilter } +// ModelLayout describes the expected file layout of a HuggingFace model repo, +// determined by querying the HF API. Used to prevent false-positive readiness +// checks when the fallback heuristic would otherwise accept an incomplete +// download on shared storage. +type ModelLayout int + +const ( + // ModelLayoutUnknown means the layout could not be determined (API failure). + // Falls back to the existing heuristic for backward compatibility. + ModelLayoutUnknown ModelLayout = iota + // ModelLayoutSafetensorsSharded means the repo contains model.safetensors.index.json. + ModelLayoutSafetensorsSharded + // ModelLayoutDiffusion means the repo contains model_index.json (diffusion pipeline). + ModelLayoutDiffusion + // ModelLayoutSingleFile means the repo has no index file; the fallback heuristic is safe. + ModelLayoutSingleFile +) + type Gopher struct { modelConfigParser *ModelConfigParser configMapReconciler *ConfigMapReconciler @@ -80,6 +98,13 @@ type Gopher struct { // Prevents stuck downloads (e.g., xet retrying 403 errors forever) from // blocking workers indefinitely. downloadTimeout time.Duration + + // repoLayoutCache caches the expected file layout (ModelLayout) for HuggingFace + // repos, keyed by "modelID@revision". Prevents redundant HF API calls and + // ensures the layout-aware readiness check does not use the weak fallback + // heuristic for sharded or diffusion models whose index file hasn't been + // written to disk yet. + repoLayoutCache sync.Map } const ( @@ -1034,7 +1059,18 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // Only for fresh Download tasks — DownloadOverride indicates a spec change // or failed retry that must re-evaluate the model files. if task.TaskType == Download { - if s.isModelAlreadyDownloaded(destPath) { + // On shared storage, query the HF API for the expected model layout so + // isModelAlreadyDownloaded can avoid the false-positive fallback heuristic + // when the index file hasn't been written to disk yet by a concurrent + // downloader on another node. On non-shared storage there's no concurrent + // writer, so the extra API call is unnecessary. + var expectedLayout ModelLayout + if s.isSharedStorage { + hfToken := s.getHuggingFaceToken(task, baseModelSpec, modelInfo) + expectedLayout = s.fetchModelLayout(ctx, hfComponents.ModelID, hfComponents.Branch, hfToken) + } + + if s.isModelAlreadyDownloadedWithLayout(destPath, expectedLayout) { s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath) return s.skipDownloadAndUpdateConfig(destPath, task) } @@ -1042,7 +1078,7 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, // On shared storage, only the download leader should proceed. Non-leaders // wait for the leader to finish and then recheck for files on disk. if s.isSharedStorage && !s.isDownloadLeader(ctx, modelInfo) { - if s.waitForSharedStorageModel(ctx, destPath, modelInfo) { + if s.waitForSharedStorageModel(ctx, destPath, modelInfo, expectedLayout) { s.logger.Infof("Model %s appeared on shared storage at %s after waiting for leader", modelInfo, destPath) return s.skipDownloadAndUpdateConfig(destPath, task) } @@ -1394,6 +1430,103 @@ func (s *Gopher) fetchSha(ctx context.Context, modelId string, modelName string) return shaStr, isShaAvailable } +// fetchModelLayout queries the HuggingFace API to determine the expected file layout +// of a model repo. The result is cached in repoLayoutCache so subsequent calls +// for the same model (e.g., when multiple download tasks target the same model) +// don't hit the API again. +// On any error (network, auth, rate limit, timeout), returns ModelLayoutUnknown +// so the caller falls back to the existing heuristic. +func (s *Gopher) fetchModelLayout(ctx context.Context, modelID, revision, hfToken string) ModelLayout { + cacheKey := modelID + "@" + revision + if cached, ok := s.repoLayoutCache.Load(cacheKey); ok { + if layout, isLayout := cached.(ModelLayout); isLayout { + return layout + } + s.logger.Warnf("fetchModelLayout(%s): unexpected type in cache, refetching", cacheKey) + } + + // Use a short timeout — this is a lightweight metadata query, not a download. + // Prevents a hanging HF API call from blocking a worker goroutine indefinitely. + fetchCtx, fetchCancel := context.WithTimeout(ctx, 30*time.Second) + defer fetchCancel() + + type fetchResult struct { + layout ModelLayout + } + ch := make(chan fetchResult, 1) + go func() { + ch <- fetchResult{layout: s.fetchModelLayoutFromAPI(modelID, revision, hfToken)} + }() + + select { + case result := <-ch: + return result.layout + case <-fetchCtx.Done(): + s.logger.Warnf("fetchModelLayout(%s): timed out after 30s — layout unknown", cacheKey) + return ModelLayoutUnknown + } +} + +// fetchModelLayoutFromAPI does the actual HF API call. Called by fetchModelLayout +// inside a timeout-guarded goroutine. +func (s *Gopher) fetchModelLayoutFromAPI(modelID, revision, hfToken string) ModelLayout { + cacheKey := modelID + "@" + revision + + xetConfig := &xet.Config{ + Endpoint: s.xetConfig.Endpoint, + Token: hfToken, + CacheDir: s.xetConfig.CacheDir, + MaxConcurrentDownloads: s.xetConfig.MaxConcurrentDownloads, + EnableDedup: true, + } + if xetConfig.Endpoint == "" { + xetConfig.Endpoint = "https://huggingface.co" + } + if xetConfig.MaxConcurrentDownloads == 0 { + xetConfig.MaxConcurrentDownloads = 4 + } + + client, err := xet.NewClient(xetConfig) + if err != nil { + s.logger.Warnf("fetchModelLayout(%s): failed to create xet client: %v — layout unknown", cacheKey, err) + return ModelLayoutUnknown + } + defer client.Close() + + rev := revision + if rev == "" { + rev = "main" + } + files, err := client.ListFiles(modelID, rev) + if err != nil { + s.logger.Warnf("fetchModelLayout(%s): ListFiles failed: %v — layout unknown", cacheKey, err) + return ModelLayoutUnknown + } + + layout := ModelLayoutSingleFile + for _, f := range files { + switch f.Path { + case "model.safetensors.index.json": + layout = ModelLayoutSafetensorsSharded + // Safetensors sharded takes highest priority — stop scanning. + s.logger.Infof("fetchModelLayout(%s): detected sharded safetensors layout", cacheKey) + s.repoLayoutCache.Store(cacheKey, layout) + return layout + case "model_index.json": + layout = ModelLayoutDiffusion + // Don't return yet — a model.safetensors.index.json might appear later in the list. + } + } + + if layout == ModelLayoutDiffusion { + s.logger.Infof("fetchModelLayout(%s): detected diffusion pipeline layout", cacheKey) + } else { + s.logger.Infof("fetchModelLayout(%s): no index file found — single-file layout", cacheKey) + } + s.repoLayoutCache.Store(cacheKey, layout) + return layout +} + /* isEligibleForOptimization determines whether a Hugging Face model can reuse an existing artifact. @@ -1559,24 +1692,35 @@ func (s *Gopher) skipDownloadAndUpdateConfig(destPath string, task *GopherTask) return nil } -// isModelAlreadyDownloaded checks whether the model files are already present at -// destPath. This handles the shared-storage case: when multiple nodes mount the -// same filesystem (e.g., NFS at /storage/models), the first node that finishes an -// HF download writes the files once. Subsequent nodes should detect the existing -// files and skip re-downloading. +// isModelAlreadyDownloaded checks whether model files are present at destPath. +// This is the backward-compatible entry point used by the OCI download path, +// where no HF repo metadata is available. +func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { + return s.isModelAlreadyDownloadedWithLayout(destPath, ModelLayoutUnknown) +} + +// isModelAlreadyDownloadedWithLayout checks whether the model files are already +// present at destPath, using expectedLayout (from the HF API) to avoid false +// positives from the fallback heuristic. +// +// On shared storage, the xet downloader writes files concurrently with no ordering +// guarantee. When config.json + one shard exist but model.safetensors.index.json +// hasn't been written yet, the fallback heuristic would incorrectly report the +// model as complete. The expectedLayout hint prevents this: if HF says the model +// should have an index file but it's not on disk yet, we return false instead of +// falling through to the weak fallback. // // Supports three model layouts: // 1. Sharded safetensors: model.safetensors.index.json lists all expected shards. // 2. Diffusion pipelines: model_index.json lists component subdirectories, each // containing its own config and weight files. // 3. Single-file models: no index file, but config.json + at least one weight -// file (.safetensors, .bin, .pt, .gguf) present. Note: this fallback cannot -// verify shard completeness for multi-shard models without an index file. +// file (.safetensors, .bin, .pt, .gguf) present. // // All filesystem checks treat errors conservatively as "not present" so that // NFS I/O or permission errors fall through to the normal download path rather // than silently skipping the download. -func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { +func (s *Gopher) isModelAlreadyDownloadedWithLayout(destPath string, expectedLayout ModelLayout) bool { // Check if directory exists info, err := os.Stat(destPath) if err != nil { @@ -1608,7 +1752,27 @@ func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool { return s.checkDiffusionIndex(destPath) } - // 3. Fallback: config.json + at least one weight file + // If the HF API told us this model should have an index file but it's not on + // disk yet, the download is still in progress. Don't fall through to the weak + // fallback heuristic which would accept config.json + one shard as "complete". + switch expectedLayout { + case ModelLayoutSafetensorsSharded: + s.logger.Infof("isModelAlreadyDownloaded(%s): HF API indicates sharded safetensors "+ + "but model.safetensors.index.json not on disk yet — not complete", destPath) + return false + case ModelLayoutDiffusion: + s.logger.Infof("isModelAlreadyDownloaded(%s): HF API indicates diffusion pipeline "+ + "but model_index.json not on disk yet — not complete", destPath) + return false + case ModelLayoutUnknown, ModelLayoutSingleFile: + // Fall through to the fallback heuristic. + // Unknown: API failed, preserve backward-compatible behavior. + // SingleFile: model genuinely has no index file, fallback is safe. + default: + s.logger.Warnf("isModelAlreadyDownloaded(%s): unhandled layout %d, falling through to heuristic", destPath, expectedLayout) + } + + // 3. Fallback: config.json + at least one weight file. if s.checkConfigAndWeights(destPath) { return true } @@ -1952,7 +2116,7 @@ func (s *Gopher) releaseDownloadLease(ctx context.Context, modelInfo string) { // with jitter and periodic rechecks. Returns true if the model appeared (another // node downloaded it), false if the context was cancelled or max wait exceeded. // Maximum wait time is downloadLeaderLeaseDuration + 30s (currently 5m30s). -func (s *Gopher) waitForSharedStorageModel(ctx context.Context, destPath string, modelInfo string) bool { +func (s *Gopher) waitForSharedStorageModel(ctx context.Context, destPath string, modelInfo string, expectedLayout ModelLayout) bool { maxWait := downloadLeaderLeaseDuration + 30*time.Second deadline := time.Now().Add(maxWait) @@ -1970,7 +2134,7 @@ func (s *Gopher) waitForSharedStorageModel(ctx context.Context, destPath string, case <-timer.C: } - if s.isModelAlreadyDownloaded(destPath) { + if s.isModelAlreadyDownloadedWithLayout(destPath, expectedLayout) { s.logger.Infof("Shared storage: model %s appeared at %s (downloaded by leader)", modelInfo, destPath) return true } diff --git a/pkg/modelagent/gopher_test.go b/pkg/modelagent/gopher_test.go index c81ba48a..99abd470 100644 --- a/pkg/modelagent/gopher_test.go +++ b/pkg/modelagent/gopher_test.go @@ -1188,6 +1188,77 @@ func TestIsModelAlreadyDownloaded(t *testing.T) { }) } +func TestIsModelAlreadyDownloadedWithLayout(t *testing.T) { + logger, _ := zap.NewDevelopment() + sugaredLogger := logger.Sugar() + defer func() { _ = sugaredLogger.Sync() }() + + gopher := &Gopher{logger: sugaredLogger} + + t.Run("sharded layout hint blocks fallback when index not on disk", func(t *testing.T) { + // Simulates the race condition: config.json + one shard exist but + // model.safetensors.index.json hasn't been written yet by the concurrent downloader. + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"glm"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00010.safetensors"), []byte("data"), 0644)) + // Without layout hint, the fallback would return true (false positive) + assert.True(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutUnknown)) + // With sharded layout hint, we know the index file should exist — return false + assert.False(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutSafetensorsSharded)) + }) + + t.Run("diffusion layout hint blocks fallback when index not on disk", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"flux"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "unet.safetensors"), []byte("data"), 0644)) + // Without layout hint, the fallback would return true + assert.True(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutUnknown)) + // With diffusion layout hint, we know model_index.json should exist — return false + assert.False(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutDiffusion)) + }) + + t.Run("single-file layout hint allows fallback", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"bert"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("data"), 0644)) + // SingleFile layout: fallback is legitimate, model genuinely has no index + assert.True(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutSingleFile)) + }) + + t.Run("unknown layout preserves backward-compat behavior", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("data"), 0644)) + // Unknown layout (API failed): same as before the fix — fallback returns true + assert.True(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutUnknown)) + }) + + t.Run("sharded layout with index present and all shards returns true", func(t *testing.T) { + dir := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"glm"}`), 0644)) + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("data"), 0644)) + // Layout hint doesn't matter when index file is on disk — authoritative check succeeds + assert.True(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutSafetensorsSharded)) + }) + + t.Run("sharded layout with index present but missing shard returns false", func(t *testing.T) { + dir := t.TempDir() + index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}` + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644)) + assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644)) + // Index is on disk but shard 2 is missing — authoritative check correctly returns false + assert.False(t, gopher.isModelAlreadyDownloadedWithLayout(dir, ModelLayoutSafetensorsSharded)) + }) + + t.Run("nonexistent directory returns false regardless of layout", func(t *testing.T) { + assert.False(t, gopher.isModelAlreadyDownloadedWithLayout("/nonexistent/path", ModelLayoutSafetensorsSharded)) + assert.False(t, gopher.isModelAlreadyDownloadedWithLayout("/nonexistent/path", ModelLayoutUnknown)) + }) +} + func TestIsEligibleForOptimization_AlwaysDownloadNotEligible(t *testing.T) { nodeName := "node-1" sha := "123abc"