Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions pkg/modelagent/scout.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strconv"
"time"

"github.com/google/go-cmp/cmp/cmpopts"
"go.uber.org/zap"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -340,9 +341,10 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
return
}

if w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel) {
w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel)
}
policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnBM(oldBaseModel, newBaseModel)

// Exclude DownloadPolicy from Spec diff — policy changes are detected separately above.
ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy")

hasChanges := false
for _, diff := range []struct {
Expand All @@ -353,7 +355,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
{"Annotations", oldBaseModel.Annotations, newBaseModel.Annotations},
{"Spec", oldBaseModel.Spec, newBaseModel.Spec},
} {
result, err := kmp.SafeDiff(diff.old, diff.new)
result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy)
if err != nil {
w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s",
diff.name, newBaseModel.Name, newBaseModel.Namespace)
Expand All @@ -362,7 +364,7 @@ func (w *Scout) updateBaseModel(old, new interface{}) {
hasChanges = hasChanges || (result != "")
}

if hasChanges && w.shouldDownloadModelInUpdateEvent(newBaseModel.Spec.Storage) {
if (policyChanged || hasChanges) && w.shouldDownloadModel(newBaseModel.Spec.Storage) {
w.logger.Infof("BaseModel %s needs refresh in namespace %s", newBaseModel.GetName(), newBaseModel.GetNamespace())
w.generateDownloadOverrideTaskBasedOnBaseModel(newBaseModel)
}
Expand Down Expand Up @@ -395,9 +397,10 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
return
}

if w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel) {
w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel)
}
policyChanged := w.isToDownloadOverrideDueToDownloadPolicyBasedOnCBM(oldClusterBaseModel, newClusterBaseModel)

// Exclude DownloadPolicy from Spec diff — policy changes are detected separately above.
ignoreDownloadPolicy := cmpopts.IgnoreFields(v1beta1.StorageSpec{}, "DownloadPolicy")

hasChanges := false
for _, diff := range []struct {
Expand All @@ -408,7 +411,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
{"Annotations", oldClusterBaseModel.Annotations, newClusterBaseModel.Annotations},
{"Spec", oldClusterBaseModel.Spec, newClusterBaseModel.Spec},
} {
result, err := kmp.SafeDiff(diff.old, diff.new)
result, err := kmp.SafeDiff(diff.old, diff.new, ignoreDownloadPolicy)
if err != nil {
w.logger.Errorf("Failed to diff %s for BaseModel: %s in namespace %s",
diff.name, newClusterBaseModel.Name, newClusterBaseModel.Namespace)
Expand All @@ -417,7 +420,7 @@ func (w *Scout) updateClusterBaseModel(old, new interface{}) {
hasChanges = hasChanges || (result != "")
}

if hasChanges && w.shouldDownloadModelInUpdateEvent(newClusterBaseModel.Spec.Storage) {
if (policyChanged || hasChanges) && w.shouldDownloadModel(newClusterBaseModel.Spec.Storage) {
w.logger.Infof("ClusterBaseModel %s need refresh", newClusterBaseModel.GetName())
w.generateDownloadOverrideTaskBasedOnClusterBaseModel(newClusterBaseModel)
}
Expand Down Expand Up @@ -547,12 +550,6 @@ func (w *Scout) shouldDownloadModel(storageSpec *v1beta1.StorageSpec) bool {
return w.shouldDownloadModelCommon(storageSpec, true)
}

// shouldDownloadModelInUpdateEvent mirrors shouldDownloadModel logic but uses a default false decision,
// allowing callers to opt-in specific cases for updates if needed.
func (w *Scout) shouldDownloadModelInUpdateEvent(storageSpec *v1beta1.StorageSpec) bool {
return w.shouldDownloadModelCommon(storageSpec, false)
}

func (w *Scout) nodeMatchesSelectorTerm(term v1.NodeSelectorTerm) bool {
// Check match expressions
for _, expr := range term.MatchExpressions {
Expand Down
181 changes: 181 additions & 0 deletions pkg/modelagent/scout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,3 +1431,184 @@ func TestGenerateDownloadOverrideTaskBasedOnBaseModel_TensorRTLLMAndMetadataType
assert.Equal(t, "a10", filter.ShapeAlias)
assert.Equal(t, "CustomType", filter.ModelType)
}

func TestUpdateBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) {
logger, _ := zap.NewDevelopment()
sugaredLogger := logger.Sugar()
defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger)

ch := make(chan *GopherTask, 10)
scout := &Scout{
nodeShapeAlias: "a10",
gopherChan: ch,
logger: sugaredLogger,
nodeInfo: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
Labels: map[string]string{"gpu-model": "a10"},
},
},
}

tests := []struct {
name string
oldModel *v1beta1.BaseModel
newModel *v1beta1.BaseModel
expectedTasks int
description string
}{
{
name: "policy-only change on HuggingFace - exactly 1 task from dedicated handler",
oldModel: newBaseModel("model-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"),
newModel: newBaseModel("model-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"),
expectedTasks: 1,
description: "Only the dedicated policy-change handler should fire",
},
{
name: "policy-only change on non-HuggingFace - 0 tasks",
oldModel: newBaseModel("model-2", v1beta1.ReuseIfExists, "oci://bucket/model"),
newModel: newBaseModel("model-2", v1beta1.AlwaysDownload, "oci://bucket/model"),
expectedTasks: 0,
description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy",
},
{
name: "non-policy spec change - exactly 1 task from hasChanges block",
oldModel: &v1beta1.BaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "model-3"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
newModel: &v1beta1.BaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "model-3"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
expectedTasks: 1,
description: "Only the hasChanges block should fire",
},
{
name: "policy change + other spec change on HuggingFace - exactly 1 task",
oldModel: &v1beta1.BaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "model-4"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.ReuseIfExists),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
newModel: &v1beta1.BaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "model-4"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
expectedTasks: 1,
description: "Single task even when both policy and spec change",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
scout.updateBaseModel(tc.oldModel, tc.newModel)
assert.Equal(t, tc.expectedTasks, len(ch), tc.description)
for range tc.expectedTasks {
task := <-ch
assert.Equal(t, DownloadOverride, task.TaskType)
}
})
}
}

func TestUpdateClusterBaseModel_NoDuplicateTaskOnPolicyOnlyChange(t *testing.T) {
logger, _ := zap.NewDevelopment()
sugaredLogger := logger.Sugar()
defer func(s *zap.SugaredLogger) { _ = s.Sync() }(sugaredLogger)

ch := make(chan *GopherTask, 10)
scout := &Scout{
nodeShapeAlias: "a10",
gopherChan: ch,
logger: sugaredLogger,
nodeInfo: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
Labels: map[string]string{"gpu-model": "a10"},
},
},
}

tests := []struct {
name string
oldModel *v1beta1.ClusterBaseModel
newModel *v1beta1.ClusterBaseModel
expectedTasks int
description string
}{
{
name: "policy-only change on HuggingFace - exactly 1 task",
oldModel: newClusterBaseModel("cbm-1", v1beta1.ReuseIfExists, "hf://meta-llama/llama-3"),
newModel: newClusterBaseModel("cbm-1", v1beta1.AlwaysDownload, "hf://meta-llama/llama-3"),
expectedTasks: 1,
description: "Only the dedicated policy-change handler should fire",
},
{
name: "policy-only change on non-HuggingFace - 0 tasks",
oldModel: newClusterBaseModel("cbm-2", v1beta1.ReuseIfExists, "oci://bucket/model"),
newModel: newClusterBaseModel("cbm-2", v1beta1.AlwaysDownload, "oci://bucket/model"),
expectedTasks: 0,
description: "Dedicated handler skips non-HF, spec diff excludes DownloadPolicy",
},
{
name: "non-policy spec change - exactly 1 task",
oldModel: &v1beta1.ClusterBaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("Old Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
newModel: &v1beta1.ClusterBaseModel{
ObjectMeta: metav1.ObjectMeta{Name: "cbm-3"},
Spec: v1beta1.BaseModelSpec{
ModelExtensionSpec: v1beta1.ModelExtensionSpec{DisplayName: ptr("New Name")},
Storage: &v1beta1.StorageSpec{
DownloadPolicy: ptr(v1beta1.AlwaysDownload),
StorageUri: ptr("hf://meta-llama/llama-3"),
},
},
},
expectedTasks: 1,
description: "Only the hasChanges block should fire",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
scout.updateClusterBaseModel(tc.oldModel, tc.newModel)
assert.Equal(t, tc.expectedTasks, len(ch), tc.description)
for range tc.expectedTasks {
task := <-ch
assert.Equal(t, DownloadOverride, task.TaskType)
}
})
}
}
Loading