Skip to content

Commit 8e1db1d

Browse files
[Bugfix] fix spec/label/annotation changes can't trigger redownload
1 parent 957f541 commit 8e1db1d

2 files changed

Lines changed: 194 additions & 16 deletions

File tree

pkg/modelagent/scout.go

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strconv"
77
"time"
88

9+
"github.com/google/go-cmp/cmp/cmpopts"
910
"go.uber.org/zap"
1011
v1 "k8s.io/api/core/v1"
1112
"k8s.io/apimachinery/pkg/api/errors"
@@ -340,9 +341,10 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
340341
return
341342
}
342343

343-
if w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel) {
344-
w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel)
345-
}
344+
policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel)
345+
346+
// Exclude DownloadPolicy from Spec diff — policy changes are detected separately above.
347+
ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy")
346348

347349
hasChanges := false
348350
for _, diff := range []struct {
@@ -353,7 +355,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
353355
{"Annotations", oldBaseModel.Annotations, newBaseModel.Annotations},
354356
{"Spec", oldBaseModel.Spec, newBaseModel.Spec},
355357
} {
356-
result, err := kmp.SafeDiff(diff.old, diff.new)
358+
result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy)
357359
if err != nil {
358360
w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s",
359361
diff.name, newBaseModel.Name, newBaseModel.Namespace)
@@ -362,7 +364,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
362364
hasChanges = hasChanges || (result != "")
363365
}
364366

365-
if hasChanges && w.shouldDownloadModelInUpdateEvent(newBaseModel.Spec.Storage) {
367+
if (policyChanged || hasChanges) && w.shouldDownloadModel(newBaseModel.Spec.Storage) {
366368
w.logger.Infof("BaseModel %s needs refresh in namespace %s", newBaseModel.GetName(), newBaseModel.GetNamespace())
367369
w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel)
368370
}
@@ -395,9 +397,10 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
395397
return
396398
}
397399

398-
if w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel) {
399-
w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel)
400-
}
400+
policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel)
401+
402+
// Exclude DownloadPolicy from Spec diff — policy changes are detected separately above.
403+
ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy")
401404

402405
hasChanges := false
403406
for _, diff := range []struct {
@@ -408,7 +411,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
408411
{"Annotations", oldClusterBaseModel.Annotations, newClusterBaseModel.Annotations},
409412
{"Spec", oldClusterBaseModel.Spec, newClusterBaseModel.Spec},
410413
} {
411-
result, err := kmp.SafeDiff(diff.old, diff.new)
414+
result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy)
412415
if err != nil {
413416
w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s",
414417
diff.name, newClusterBaseModel.Name, newClusterBaseModel.Namespace)
@@ -417,7 +420,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
417420
hasChanges = hasChanges || (result != "")
418421
}
419422

420-
if hasChanges && w.shouldDownloadModelInUpdateEvent(newClusterBaseModel.Spec.Storage) {
423+
if (policyChanged || hasChanges) && w.shouldDownloadModel(newClusterBaseModel.Spec.Storage) {
421424
w.logger.Infof("ClusterBaseModel %s need refresh", newClusterBaseModel.GetName())
422425
w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel)
423426
}
@@ -547,12 +550,6 @@ func (w *Scout) shouldDownloadModel(storageSpec *v1beta1.StorageSpec) bool {
547550
return w.shouldDownloadModelCommon(storageSpec, true)
548551
}
549552

550-
// shouldDownloadModelInUpdateEvent mirrors shouldDownloadModel logic but uses a default false decision,
551-
// allowing callers to opt-in specific cases for updates if needed.
552-
func (w *Scout) shouldDownloadModelInUpdateEvent(storageSpec *v1beta1.StorageSpec) bool {
553-
return w.shouldDownloadModelCommon(storageSpec, false)
554-
}
555-
556553
func (w *Scout) nodeMatchesSelectorTerm(term v1.NodeSelectorTerm) bool {
557554
// Check match expressions
558555
for _, expr := range term.MatchExpressions {

pkg/modelagent/scout_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,3 +1431,184 @@ func TestGenerateDownloadOverrideTaskBasedOnBaseModel_TensorRTLLMAndMetadataType
14311431
assert.Equal(t, "a10", filter.ShapeAlias)
14321432
assert.Equal(t, "CustomType", filter.ModelType)
14331433
}
1434+
1435+
func TestUpdateBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) {
1436+
logger, _ := zap.NewDevelopment()
1437+
sugaredLogger := logger.Sugar()
1438+
defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger)
1439+
1440+
ch := make(chan *GopherTask, 10)
1441+
scout := &Scout{
1442+
nodeShapeAlias: "a10",
1443+
gopherChan: ch,
1444+
logger: sugaredLogger,
1445+
nodeInfo: &corev1.Node{
1446+
ObjectMeta: metav1.ObjectMeta{
1447+
Name: "test-node",
1448+
Labels: map[string]string{"gpu-model": "a10"},
1449+
},
1450+
},
1451+
}
1452+
1453+
tests := []struct {
1454+
name string
1455+
oldModel *v1beta1.BaseModel
1456+
newModel *v1beta1.BaseModel
1457+
expectedTasks int
1458+
description string
1459+
}{
1460+
{
1461+
name: "policy-only change on HuggingFace - exactly 1 task from dedicated handler",
1462+
oldModel: newBaseModel("model-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"),
1463+
newModel: newBaseModel("model-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"),
1464+
expectedTasks: 1,
1465+
description: "Only the dedicated policy-change handler should fire",
1466+
},
1467+
{
1468+
name: "policy-only change on non-HuggingFace - 0 tasks",
1469+
oldModel: newBaseModel("model-2", v1beta1.ReuseIfExists, "oci://bucket/model"),
1470+
newModel: newBaseModel("model-2", v1beta1.AlwaysDownload, "oci://bucket/model"),
1471+
expectedTasks: 0,
1472+
description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy",
1473+
},
1474+
{
1475+
name: "non-policy spec change - exactly 1 task from hasChanges block",
1476+
oldModel: &v1beta1.BaseModel{
1477+
ObjectMeta: metav1.ObjectMeta{Name: "model-3"},
1478+
Spec: v1beta1.BaseModelSpec{
1479+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
1480+
Storage: &v1beta1.StorageSpec{
1481+
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
1482+
StorageUri: ptr("hf://meta-llama/llama-3"),
1483+
},
1484+
},
1485+
},
1486+
newModel: &v1beta1.BaseModel{
1487+
ObjectMeta: metav1.ObjectMeta{Name: "model-3"},
1488+
Spec: v1beta1.BaseModelSpec{
1489+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
1490+
Storage: &v1beta1.StorageSpec{
1491+
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
1492+
StorageUri: ptr("hf://meta-llama/llama-3"),
1493+
},
1494+
},
1495+
},
1496+
expectedTasks: 1,
1497+
description: "Only the hasChanges block should fire",
1498+
},
1499+
{
1500+
name: "policy change + other spec change on HuggingFace - exactly 1 task",
1501+
oldModel: &v1beta1.BaseModel{
1502+
ObjectMeta: metav1.ObjectMeta{Name: "model-4"},
1503+
Spec: v1beta1.BaseModelSpec{
1504+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
1505+
Storage: &v1beta1.StorageSpec{
1506+
DownloadPolicy: ptr(v1beta1.ReuseIfExists),
1507+
StorageUri: ptr("hf://meta-llama/llama-3"),
1508+
},
1509+
},
1510+
},
1511+
newModel: &v1beta1.BaseModel{
1512+
ObjectMeta: metav1.ObjectMeta{Name: "model-4"},
1513+
Spec: v1beta1.BaseModelSpec{
1514+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
1515+
Storage: &v1beta1.StorageSpec{
1516+
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
1517+
StorageUri: ptr("hf://meta-llama/llama-3"),
1518+
},
1519+
},
1520+
},
1521+
expectedTasks: 1,
1522+
description: "Single task even when both policy and spec change",
1523+
},
1524+
}
1525+
1526+
for _, tc := range tests {
1527+
t.Run(tc.name, func(t *testing.T) {
1528+
scout.updateBaseModel(tc.oldModel, tc.newModel)
1529+
assert.Equal(t, tc.expectedTasks, len(ch), tc.description)
1530+
for range tc.expectedTasks {
1531+
task := <-ch
1532+
assert.Equal(t, DownloadOverride, task.TaskType)
1533+
}
1534+
})
1535+
}
1536+
}
1537+
1538+
func TestUpdateClusterBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) {
1539+
logger, _ := zap.NewDevelopment()
1540+
sugaredLogger := logger.Sugar()
1541+
defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger)
1542+
1543+
ch := make(chan *GopherTask, 10)
1544+
scout := &Scout{
1545+
nodeShapeAlias: "a10",
1546+
gopherChan: ch,
1547+
logger: sugaredLogger,
1548+
nodeInfo: &corev1.Node{
1549+
ObjectMeta: metav1.ObjectMeta{
1550+
Name: "test-node",
1551+
Labels: map[string]string{"gpu-model": "a10"},
1552+
},
1553+
},
1554+
}
1555+
1556+
tests := []struct {
1557+
name string
1558+
oldModel *v1beta1.ClusterBaseModel
1559+
newModel *v1beta1.ClusterBaseModel
1560+
expectedTasks int
1561+
description string
1562+
}{
1563+
{
1564+
name: "policy-only change on HuggingFace - exactly 1 task",
1565+
oldModel: newClusterBaseModel("cbm-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"),
1566+
newModel: newClusterBaseModel("cbm-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"),
1567+
expectedTasks: 1,
1568+
description: "Only the dedicated policy-change handler should fire",
1569+
},
1570+
{
1571+
name: "policy-only change on non-HuggingFace - 0 tasks",
1572+
oldModel: newClusterBaseModel("cbm-2", v1beta1.ReuseIfExists, "oci://bucket/model"),
1573+
newModel: newClusterBaseModel("cbm-2", v1beta1.AlwaysDownload, "oci://bucket/model"),
1574+
expectedTasks: 0,
1575+
description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy",
1576+
},
1577+
{
1578+
name: "non-policy spec change - exactly 1 task",
1579+
oldModel: &v1beta1.ClusterBaseModel{
1580+
ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"},
1581+
Spec: v1beta1.BaseModelSpec{
1582+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
1583+
Storage: &v1beta1.StorageSpec{
1584+
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
1585+
StorageUri: ptr("hf://meta-llama/llama-3"),
1586+
},
1587+
},
1588+
},
1589+
newModel: &v1beta1.ClusterBaseModel{
1590+
ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"},
1591+
Spec: v1beta1.BaseModelSpec{
1592+
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
1593+
Storage: &v1beta1.StorageSpec{
1594+
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
1595+
StorageUri: ptr("hf://meta-llama/llama-3"),
1596+
},
1597+
},
1598+
},
1599+
expectedTasks: 1,
1600+
description: "Only the hasChanges block should fire",
1601+
},
1602+
}
1603+
1604+
for _, tc := range tests {
1605+
t.Run(tc.name, func(t *testing.T) {
1606+
scout.updateClusterBaseModel(tc.oldModel, tc.newModel)
1607+
assert.Equal(t, tc.expectedTasks, len(ch), tc.description)
1608+
for range tc.expectedTasks {
1609+
task := <-ch
1610+
assert.Equal(t, DownloadOverride, task.TaskType)
1611+
}
1612+
})
1613+
}
1614+
}

0 commit comments

Comments
 (0)