From 8e1db1d2b859fe027feef1c6df57e3fc6a709f9b Mon Sep 17 00:00:00 2001 From: yifeng Date: Mon, 30 Mar 2026 15:21:49 -0700 Subject: [PATCH] [Bugfix] fix spec/label/annotation changes can't trigger redownload --- pkg/modelagent/scout.go | 29 +++--- pkg/modelagent/scout_test.go | 181 +++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 16 deletions(-) diff --git a/pkg/modelagent/scout.go b/pkg/modelagent/scout.go index 85b1ec5b..6c5f2d84 100644 --- a/pkg/modelagent/scout.go +++ b/pkg/modelagent/scout.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/google/go-cmp/cmp/cmpopts" "go.uber.org/zap" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" @@ -340,9 +341,10 @@ func (w *Scout) updateBaseModel(old, new interface{}) { return } - if w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel) { - w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel) - } + policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel) + + // Exclude DownloadPolicy from Spec diff — policy changes are detected separately above. + ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy") hasChanges := false for _, diff := range []struct { @@ -353,7 +355,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) { {"Annotations", oldBaseModel.Annotations, newBaseModel.Annotations}, {"Spec", oldBaseModel.Spec, newBaseModel.Spec}, } { - result, err := kmp.SafeDiff(diff.old, diff.new) + result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy) if err != nil { w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s", diff.name, newBaseModel.Name, newBaseModel.Namespace) @@ -362,7 +364,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) { hasChanges = hasChanges || (result != "") } - if hasChanges && w.shouldDownloadModelInUpdateEvent(newBaseModel.Spec.Storage) { + if (policyChanged || hasChanges) && w.shouldDownloadModel(newBaseModel.Spec.Storage) { w.logger.Infof("BaseModel %s needs refresh in namespace %s", newBaseModel.GetName(), newBaseModel.GetNamespace()) w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel) } @@ -395,9 +397,10 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) { return } - if w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel) { - w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel) - } + policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel) + + // Exclude DownloadPolicy from Spec diff — policy changes are detected separately above. + ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy") hasChanges := false for _, diff := range []struct { @@ -408,7 +411,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) { {"Annotations", oldClusterBaseModel.Annotations, newClusterBaseModel.Annotations}, {"Spec", oldClusterBaseModel.Spec, newClusterBaseModel.Spec}, } { - result, err := kmp.SafeDiff(diff.old, diff.new) + result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy) if err != nil { w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s", diff.name, newClusterBaseModel.Name, newClusterBaseModel.Namespace) @@ -417,7 +420,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) { hasChanges = hasChanges || (result != "") } - if hasChanges && w.shouldDownloadModelInUpdateEvent(newClusterBaseModel.Spec.Storage) { + if (policyChanged || hasChanges) && w.shouldDownloadModel(newClusterBaseModel.Spec.Storage) { w.logger.Infof("ClusterBaseModel %s need refresh", newClusterBaseModel.GetName()) w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel) } @@ -547,12 +550,6 @@ func (w *Scout) shouldDownloadModel(storageSpec *v1beta1.StorageSpec) bool { return w.shouldDownloadModelCommon(storageSpec, true) } -// shouldDownloadModelInUpdateEvent mirrors shouldDownloadModel logic but uses a default false decision, -// allowing callers to opt-in specific cases for updates if needed. -func (w *Scout) shouldDownloadModelInUpdateEvent(storageSpec *v1beta1.StorageSpec) bool { - return w.shouldDownloadModelCommon(storageSpec, false) -} - func (w *Scout) nodeMatchesSelectorTerm(term v1.NodeSelectorTerm) bool { // Check match expressions for _, expr := range term.MatchExpressions { diff --git a/pkg/modelagent/scout_test.go b/pkg/modelagent/scout_test.go index 7de9cb1c..832bac42 100644 --- a/pkg/modelagent/scout_test.go +++ b/pkg/modelagent/scout_test.go @@ -1431,3 +1431,184 @@ func TestGenerateDownloadOverrideTaskBasedOnBaseModel_TensorRTLLMAndMetadataType assert.Equal(t, "a10", filter.ShapeAlias) assert.Equal(t, "CustomType", filter.ModelType) } + +func TestUpdateBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) { + logger, _ := zap.NewDevelopment() + sugaredLogger := logger.Sugar() + defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger) + + ch := make(chan *GopherTask, 10) + scout := &Scout{ + nodeShapeAlias: "a10", + gopherChan: ch, + logger: sugaredLogger, + nodeInfo: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + Labels: map[string]string{"gpu-model": "a10"}, + }, + }, + } + + tests := []struct { + name string + oldModel *v1beta1.BaseModel + newModel *v1beta1.BaseModel + expectedTasks int + description string + }{ + { + name: "policy-only change on HuggingFace - exactly 1 task from dedicated handler", + oldModel: newBaseModel("model-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"), + newModel: newBaseModel("model-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"), + expectedTasks: 1, + description: "Only the dedicated policy-change handler should fire", + }, + { + name: "policy-only change on non-HuggingFace - 0 tasks", + oldModel: newBaseModel("model-2", v1beta1.ReuseIfExists, "oci://bucket/model"), + newModel: newBaseModel("model-2", v1beta1.AlwaysDownload, "oci://bucket/model"), + expectedTasks: 0, + description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy", + }, + { + name: "non-policy spec change - exactly 1 task from hasChanges block", + oldModel: &v1beta1.BaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "model-3"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.AlwaysDownload), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + newModel: &v1beta1.BaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "model-3"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.AlwaysDownload), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + expectedTasks: 1, + description: "Only the hasChanges block should fire", + }, + { + name: "policy change + other spec change on HuggingFace - exactly 1 task", + oldModel: &v1beta1.BaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "model-4"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.ReuseIfExists), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + newModel: &v1beta1.BaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "model-4"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.AlwaysDownload), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + expectedTasks: 1, + description: "Single task even when both policy and spec change", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + scout.updateBaseModel(tc.oldModel, tc.newModel) + assert.Equal(t, tc.expectedTasks, len(ch), tc.description) + for range tc.expectedTasks { + task := <-ch + assert.Equal(t, DownloadOverride, task.TaskType) + } + }) + } +} + +func TestUpdateClusterBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) { + logger, _ := zap.NewDevelopment() + sugaredLogger := logger.Sugar() + defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger) + + ch := make(chan *GopherTask, 10) + scout := &Scout{ + nodeShapeAlias: "a10", + gopherChan: ch, + logger: sugaredLogger, + nodeInfo: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + Labels: map[string]string{"gpu-model": "a10"}, + }, + }, + } + + tests := []struct { + name string + oldModel *v1beta1.ClusterBaseModel + newModel *v1beta1.ClusterBaseModel + expectedTasks int + description string + }{ + { + name: "policy-only change on HuggingFace - exactly 1 task", + oldModel: newClusterBaseModel("cbm-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"), + newModel: newClusterBaseModel("cbm-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"), + expectedTasks: 1, + description: "Only the dedicated policy-change handler should fire", + }, + { + name: "policy-only change on non-HuggingFace - 0 tasks", + oldModel: newClusterBaseModel("cbm-2", v1beta1.ReuseIfExists, "oci://bucket/model"), + newModel: newClusterBaseModel("cbm-2", v1beta1.AlwaysDownload, "oci://bucket/model"), + expectedTasks: 0, + description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy", + }, + { + name: "non-policy spec change - exactly 1 task", + oldModel: &v1beta1.ClusterBaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.AlwaysDownload), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + newModel: &v1beta1.ClusterBaseModel{ + ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"}, + Spec: v1beta1.BaseModelSpec{ + ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")}, + Storage: &v1beta1.StorageSpec{ + DownloadPolicy: ptr(v1beta1.AlwaysDownload), + StorageUri: ptr("hf://meta-llama/llama-3"), + }, + }, + }, + expectedTasks: 1, + description: "Only the hasChanges block should fire", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + scout.updateClusterBaseModel(tc.oldModel, tc.newModel) + assert.Equal(t, tc.expectedTasks, len(ch), tc.description) + for range tc.expectedTasks { + task := <-ch + assert.Equal(t, DownloadOverride, task.TaskType) + } + }) + } +}