Skip to content

Commit 3d33788

Browse files
Kangyan-Zhouclaude
andcommitted
feat: skip re-downloading models on shared storage
When multiple nodes mount the same filesystem (e.g., GPFS/NFS at /storage/models), the model-agent on each node would independently re-download from HuggingFace or OCI, causing rate-limiting and hours of unnecessary I/O. Add isModelAlreadyDownloaded() that checks: 1. config.json exists 2. If model.safetensors.index.json exists, ALL expected shards present 3. Otherwise, at least one weight file (.safetensors/.bin/.pt/.gguf) Only applies to fresh Download tasks (not DownloadOverride) so spec updates and failed retries still re-evaluate. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent aac400b commit 3d33788

2 files changed

Lines changed: 194 additions & 0 deletions

File tree

pkg/modelagent/gopher.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package modelagent
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"os"
78
"path/filepath"
@@ -324,6 +325,25 @@ func (s *Gopher) processTask(task *GopherTask) error {
324325
s.logger.Errorf("Failed to get target directory path for model %s: %v", modelInfo, err)
325326
return err
326327
}
328+
329+
// Check if the model is already present on shared storage.
330+
// Only for fresh Download tasks — DownloadOverride indicates a spec change
331+
// or failed retry that must re-evaluate the model files.
332+
if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) {
333+
s.logger.Infof("Model %s already exists at %s (shared storage), skipping OCI download", modelInfo, destPath)
334+
var baseModel *v1beta1.BaseModel
335+
var clusterBaseModel *v1beta1.ClusterBaseModel
336+
if task.BaseModel != nil {
337+
baseModel = task.BaseModel
338+
} else if task.ClusterBaseModel != nil {
339+
clusterBaseModel = task.ClusterBaseModel
340+
}
341+
if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil {
342+
s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err)
343+
}
344+
break
345+
}
346+
327347
err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error {
328348
downloadErr := s.downloadModel(ctx, osUri, destPath, task)
329349
if downloadErr != nil {
@@ -971,6 +991,29 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask,
971991
// Create destination path
972992
destPath := getDestPath(&baseModelSpec, s.modelRootDir)
973993

994+
// Check if the model is already present on shared storage (e.g., another node
995+
// already downloaded it to the same NFS/shared filesystem path). When storage is
996+
// shared across nodes, each model-agent would otherwise independently re-download
997+
// from HuggingFace, causing rate-limiting and hours of unnecessary I/O.
998+
// Only for fresh Download tasks — DownloadOverride indicates a spec change
999+
// or failed retry that must re-evaluate the model files.
1000+
if task.TaskType == Download && s.isModelAlreadyDownloaded(destPath) {
1001+
s.logger.Infof("Model %s already exists at %s (shared storage), skipping HuggingFace download", modelInfo, destPath)
1002+
1003+
var baseModel *v1beta1.BaseModel
1004+
var clusterBaseModel *v1beta1.ClusterBaseModel
1005+
if task.BaseModel != nil {
1006+
baseModel = task.BaseModel
1007+
} else if task.ClusterBaseModel != nil {
1008+
clusterBaseModel = task.ClusterBaseModel
1009+
}
1010+
1011+
if err := s.safeParseAndUpdateModelConfig(destPath, baseModel, clusterBaseModel, nil); err != nil {
1012+
s.logger.Errorf("Failed to parse and update model config for pre-existing model: %v", err)
1013+
}
1014+
return nil
1015+
}
1016+
9741017
// fetch sha value based on model ID from Huggingface model API
9751018
shaStr, isShaAvailable := s.fetchSha(ctx, hfComponents.ModelID, name)
9761019
isReuseEligible, matchedModelTypeAndModeName, parentPath := s.isEligibleForOptimization(ctx, task, baseModelSpec, modelType, namespace, isShaAvailable, shaStr, name)
@@ -1440,3 +1483,79 @@ func (s *Gopher) isRemoveParentArtifactDirectory(ctx context.Context, hasChildre
14401483
s.logger.Infof("parent entry %s:%s exists on node configmap: %v", parentName, parentDir, exists)
14411484
return !exists
14421485
}
1486+
1487+
// isModelAlreadyDownloaded checks whether the model files are already present at
1488+
// destPath. This handles the shared-storage case: when multiple nodes mount the
1489+
// same filesystem (e.g., NFS at /storage/models), the first node that finishes an
1490+
// HF download writes the files once. Subsequent nodes should detect the existing
1491+
// files and skip re-downloading.
1492+
//
1493+
// The check requires model.safetensors.index.json to be present so that ALL
1494+
// expected shards can be verified. Without the index file, completeness cannot
1495+
// be determined and the method returns false (letting the normal download proceed).
1496+
func (s *Gopher) isModelAlreadyDownloaded(destPath string) bool {
1497+
// Check if directory exists
1498+
info, err := os.Stat(destPath)
1499+
if err != nil || !info.IsDir() {
1500+
return false
1501+
}
1502+
1503+
// Check for config.json (primary indicator of a complete HF download).
1504+
// Use err != nil (not os.IsNotExist) so that NFS I/O errors and permission
1505+
// errors are treated conservatively as "not present" rather than silently
1506+
// falling through as "exists".
1507+
configPath := filepath.Join(destPath, "config.json")
1508+
if _, err := os.Stat(configPath); err != nil {
1509+
return false
1510+
}
1511+
1512+
// Require model.safetensors.index.json for shard completeness verification.
1513+
// Without it we cannot determine if the download is complete, so fall through
1514+
// to let the normal download path handle it.
1515+
indexPath := filepath.Join(destPath, "model.safetensors.index.json")
1516+
indexData, err := os.ReadFile(indexPath)
1517+
if err != nil {
1518+
return false
1519+
}
1520+
1521+
var index struct {
1522+
WeightMap map[string]string `json:"weight_map"`
1523+
}
1524+
if err := json.Unmarshal(indexData, &index); err != nil {
1525+
s.logger.Warnf("Failed to parse model index file %s: %v", indexPath, err)
1526+
return false
1527+
}
1528+
1529+
if len(index.WeightMap) == 0 {
1530+
return false
1531+
}
1532+
1533+
// Build a set of filenames for fast lookup
1534+
entries, err := os.ReadDir(destPath)
1535+
if err != nil {
1536+
return false
1537+
}
1538+
fileSet := make(map[string]bool, len(entries))
1539+
for _, entry := range entries {
1540+
if !entry.IsDir() {
1541+
fileSet[entry.Name()] = true
1542+
}
1543+
}
1544+
1545+
// Collect unique shard filenames and verify every one exists on disk
1546+
expectedShards := make(map[string]bool)
1547+
for _, shard := range index.WeightMap {
1548+
expectedShards[shard] = true
1549+
}
1550+
1551+
for shard := range expectedShards {
1552+
if !fileSet[shard] {
1553+
s.logger.Infof("Model at %s is missing shard %s (expected %d shards), not treating as complete",
1554+
destPath, shard, len(expectedShards))
1555+
return false
1556+
}
1557+
}
1558+
1559+
s.logger.Infof("Model at %s has all %d expected shards from index", destPath, len(expectedShards))
1560+
return true
1561+
}

pkg/modelagent/gopher_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"os"
9+
"path/filepath"
810
"testing"
911

1012
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -961,6 +963,79 @@ func TestIsEligibleForOptimization_NoMatch(t *testing.T) {
961963
assert.Empty(t, parent)
962964
}
963965

966+
func TestIsModelAlreadyDownloaded(t *testing.T) {
967+
logger, _ := zap.NewDevelopment()
968+
sugaredLogger := logger.Sugar()
969+
defer func() { _ = sugaredLogger.Sync() }()
970+
971+
gopher := &Gopher{logger: sugaredLogger}
972+
973+
t.Run("nonexistent directory returns false", func(t *testing.T) {
974+
assert.False(t, gopher.isModelAlreadyDownloaded("/nonexistent/path/that/does/not/exist"))
975+
})
976+
977+
t.Run("empty directory returns false", func(t *testing.T) {
978+
dir := t.TempDir()
979+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
980+
})
981+
982+
t.Run("directory with only config.json returns false (no weights)", func(t *testing.T) {
983+
dir := t.TempDir()
984+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
985+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
986+
})
987+
988+
t.Run("directory with only weights returns false (no config.json)", func(t *testing.T) {
989+
dir := t.TempDir()
990+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644))
991+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
992+
})
993+
994+
t.Run("config.json and weights but no index file returns false", func(t *testing.T) {
995+
dir := t.TempDir()
996+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
997+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("weight data"), 0644))
998+
assert.False(t, gopher.isModelAlreadyDownloaded(dir), "without index file, cannot verify completeness")
999+
})
1000+
1001+
t.Run("file path instead of directory returns false", func(t *testing.T) {
1002+
dir := t.TempDir()
1003+
filePath := filepath.Join(dir, "somefile")
1004+
assert.NoError(t, os.WriteFile(filePath, []byte("data"), 0644))
1005+
assert.False(t, gopher.isModelAlreadyDownloaded(filePath))
1006+
})
1007+
1008+
// Shard completeness tests using model.safetensors.index.json
1009+
1010+
t.Run("index with all shards present returns true", func(t *testing.T) {
1011+
dir := t.TempDir()
1012+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1013+
index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}`
1014+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644))
1015+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644))
1016+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("data"), 0644))
1017+
assert.True(t, gopher.isModelAlreadyDownloaded(dir))
1018+
})
1019+
1020+
t.Run("index with missing shard returns false", func(t *testing.T) {
1021+
dir := t.TempDir()
1022+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1023+
index := `{"metadata":{"total_size":100},"weight_map":{"w1":"model-00001-of-00002.safetensors","w2":"model-00002-of-00002.safetensors"}}`
1024+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(index), 0644))
1025+
// Only write shard 1, shard 2 is missing
1026+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("data"), 0644))
1027+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
1028+
})
1029+
1030+
t.Run("malformed index file returns false", func(t *testing.T) {
1031+
dir := t.TempDir()
1032+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"llama"}`), 0644))
1033+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(`{invalid json`), 0644))
1034+
assert.NoError(t, os.WriteFile(filepath.Join(dir, "model-00001-of-00001.safetensors"), []byte("data"), 0644))
1035+
assert.False(t, gopher.isModelAlreadyDownloaded(dir))
1036+
})
1037+
}
1038+
9641039
func TestIsEligibleForOptimization_AlwaysDownloadNotEligible(t *testing.T) {
9651040
nodeName := "node-1"
9661041
sha := "123abc"

0 commit comments

Comments
 (0)