diff --git a/config/default/configmap.yaml b/config/default/configmap.yaml index 26a5c3ee..442b6293 100644 --- a/config/default/configmap.yaml +++ b/config/default/configmap.yaml @@ -5,4 +5,4 @@ metadata: data: config.data: | scheduler-name: default-scheduler - # init-container-image: inftyai/model-loader:v0.0.10 + init-container-image: inftyai/model-loader:v0.0.10 diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go index 1d2e9f3f..efd6b842 100644 --- a/pkg/controller/inference/service_controller.go +++ b/pkg/controller/inference/service_controller.go @@ -116,10 +116,11 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct return ctrl.Result{}, err } - workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models) + workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models, configs) if err != nil { return ctrl.Result{}, err } + if err := setControllerReferenceForWorkload(service, workloadApplyConfiguration, r.Scheme); err != nil { return ctrl.Result{}, err } @@ -162,7 +163,7 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } -func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) { +func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel, configs *helper.GlobalConfigs) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) { workload := applyconfigurationv1.LeaderWorkerSet(service.Name, service.Namespace) leaderWorkerTemplate := applyconfigurationv1.LeaderWorkerTemplate() @@ -193,7 +194,7 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co leaderWorkerTemplate.WithWorkerTemplate(&podTemplateSpecApplyConfiguration) // The core logic to inject additional configurations. - injectModelProperties(leaderWorkerTemplate, models, service) + injectModelProperties(leaderWorkerTemplate, models, service, configs) spec := applyconfigurationv1.LeaderWorkerSetSpec() spec.WithLeaderWorkerTemplate(leaderWorkerTemplate) @@ -215,7 +216,7 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co return workload, nil } -func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service) { +func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service, configs *helper.GlobalConfigs) { isMultiNodesInference := template.LeaderTemplate != nil for i, model := range models { @@ -223,9 +224,9 @@ func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateAp // Skip model-loader initContainer if llmaz.io/skip-model-loader annotation is set. if !helper.SkipModelLoader(service) { if isMultiNodesInference { - source.InjectModelLoader(template.LeaderTemplate, i) + source.InjectModelLoader(template.LeaderTemplate, i, configs.InitContainerImage) } - source.InjectModelLoader(template.WorkerTemplate, i) + source.InjectModelLoader(template.WorkerTemplate, i, configs.InitContainerImage) } else { if isMultiNodesInference { source.InjectModelEnvVars(template.LeaderTemplate) diff --git a/pkg/controller_helper/configmap.go b/pkg/controller_helper/configmap.go index bd4d9f45..89f05685 100644 --- a/pkg/controller_helper/configmap.go +++ b/pkg/controller_helper/configmap.go @@ -40,5 +40,16 @@ func ParseGlobalConfigmap(cm *corev1.ConfigMap) (*GlobalConfigs, error) { return nil, fmt.Errorf("failed to unmarshal config.data: %v", err) } + if err := configs.validate(); err != nil { + return nil, fmt.Errorf("invalid global config: %v", err) + } + return &configs, nil } + +func (c *GlobalConfigs) validate() error { + if c.InitContainerImage == "" { + return fmt.Errorf("init-container-image is required") + } + return nil +} diff --git a/pkg/controller_helper/configmap_test.go b/pkg/controller_helper/configmap_test.go new file mode 100644 index 00000000..f9520396 --- /dev/null +++ b/pkg/controller_helper/configmap_test.go @@ -0,0 +1,64 @@ +/* +Copyright 2025 The InftyAI Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helper + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGlobalConfigs_validate(t *testing.T) { + tests := []struct { + name string + config *GlobalConfigs + expectError bool + errorMsg string + }{ + { + name: "valid config", + config: &GlobalConfigs{ + SchedulerName: "custom-scheduler", + InitContainerImage: "inftyai/model-loader:v0.0.10", + }, + expectError: false, + }, + { + name: "empty init container image", + config: &GlobalConfigs{ + SchedulerName: "custom-scheduler", + InitContainerImage: "", + }, + expectError: true, + errorMsg: "init-container-image is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.validate() + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/controller_helper/modelsource/modelhub.go b/pkg/controller_helper/modelsource/modelhub.go index 72df6297..36e990d4 100644 --- a/pkg/controller_helper/modelsource/modelhub.go +++ b/pkg/controller_helper/modelsource/modelhub.go @@ -21,8 +21,6 @@ import ( "strings" coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1" - - "github.com/inftyai/llmaz/pkg" ) var _ ModelSourceProvider = &ModelHubProvider{} @@ -62,7 +60,7 @@ func (p *ModelHubProvider) ModelPath(skipModelLoader bool) string { return CONTAINER_MODEL_PATH + "models--" + strings.ReplaceAll(p.modelID, "/", "--") } -func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) { +func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) { initContainerName := MODEL_LOADER_CONTAINER_NAME if index != 0 { initContainerName += "-" + strconv.Itoa(index) @@ -71,7 +69,7 @@ func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSp // Handle initContainer. initContainer := coreapplyv1.Container(). WithName(initContainerName). - WithImage(pkg.LOADER_IMAGE). + WithImage(initContainerImage). WithVolumeMounts(coreapplyv1.VolumeMount().WithName(MODEL_VOLUME_NAME).WithMountPath(CONTAINER_MODEL_PATH)) // We have exactly one container in the template.Spec.Containers. diff --git a/pkg/controller_helper/modelsource/modelhub_test.go b/pkg/controller_helper/modelsource/modelhub_test.go index 697e2cf6..9165d803 100644 --- a/pkg/controller_helper/modelsource/modelhub_test.go +++ b/pkg/controller_helper/modelsource/modelhub_test.go @@ -27,8 +27,6 @@ import ( "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1" - - "github.com/inftyai/llmaz/pkg" ) func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { @@ -38,10 +36,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { ignorePatterns := []string{"*.tmp"} tests := []struct { - name string - provider *ModelHubProvider - index int - expectMainModel bool + name string + provider *ModelHubProvider + index int + expectMainModel bool + initContainerImage string }{ { name: "inject full modelhub with fileName, revision, allow/ignore", @@ -54,8 +53,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { modelAllowPatterns: allowPatterns, modelIgnorePatterns: ignorePatterns, }, - index: 0, - expectMainModel: true, + index: 0, + expectMainModel: true, + initContainerImage: "model-loader:latest", }, { name: "inject with index > 0 skips volume/container mount", @@ -64,8 +64,20 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { modelID: "some/model", modelHub: "Huggingface", }, - index: 1, - expectMainModel: false, + index: 1, + expectMainModel: false, + initContainerImage: "model-loader:latest", + }, + { + name: "inject with custom initContainerImage", + provider: &ModelHubProvider{ + modelName: "llama3", + modelID: "meta/llama-3", + modelHub: "Huggingface", + }, + index: 0, + expectMainModel: true, + initContainerImage: "custom-model-loader:latest", }, } @@ -83,7 +95,7 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { ), ) - tt.provider.InjectModelLoader(template, tt.index) + tt.provider.InjectModelLoader(template, tt.index, tt.initContainerImage) assert.Len(t, template.Spec.InitContainers, 1) initContainer := template.Spec.InitContainers[0] @@ -92,8 +104,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) { if tt.index != 0 { expectedName += "-" + strconv.Itoa(tt.index) } + expectedImage := tt.initContainerImage assert.Equal(t, expectedName, *initContainer.Name) - assert.Equal(t, pkg.LOADER_IMAGE, *initContainer.Image) + assert.Equal(t, expectedImage, *initContainer.Image) wantEnv := buildExpectedEnv(tt.provider) if diff := cmp.Diff(wantEnv, initContainer.Env, envSortOpt); diff != "" { diff --git a/pkg/controller_helper/modelsource/modelsource.go b/pkg/controller_helper/modelsource/modelsource.go index 6de94911..7225972e 100644 --- a/pkg/controller_helper/modelsource/modelsource.go +++ b/pkg/controller_helper/modelsource/modelsource.go @@ -59,7 +59,7 @@ type ModelSourceProvider interface { ModelPath(skipModelLoader bool) string // InjectModelLoader will inject the model loader to the spec, // index refers to the suffix of the initContainer name, like model-loader, model-loader-1. - InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) + InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) // InjectModelEnvVars will inject the model credentials env to the model-runner container. // This is used when the model-loader initContainer is not injected, and the model loading is handled by the model-runner container. InjectModelEnvVars(spec *coreapplyv1.PodTemplateSpecApplyConfiguration) diff --git a/pkg/controller_helper/modelsource/modelsource_test.go b/pkg/controller_helper/modelsource/modelsource_test.go index 87ed2cb9..1a02b6fd 100644 --- a/pkg/controller_helper/modelsource/modelsource_test.go +++ b/pkg/controller_helper/modelsource/modelsource_test.go @@ -130,7 +130,7 @@ func TestEnvInjectModelLoader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.provider.InjectModelLoader(tt.template, 0) + tt.provider.InjectModelLoader(tt.template, 0, "model-loader:latest") initContainer := tt.template.Spec.InitContainers[0] assert.Subset(t, initContainer.Env, tt.template.Spec.Containers[0].Env) }) diff --git a/pkg/controller_helper/modelsource/uri.go b/pkg/controller_helper/modelsource/uri.go index 870cef36..c60812ea 100644 --- a/pkg/controller_helper/modelsource/uri.go +++ b/pkg/controller_helper/modelsource/uri.go @@ -21,8 +21,6 @@ import ( "strings" coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1" - - "github.com/inftyai/llmaz/pkg" ) var _ ModelSourceProvider = &URIProvider{} @@ -80,7 +78,7 @@ func (p *URIProvider) ModelPath(skipModelLoader bool) string { return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1] } -func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) { +func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) { // We don't have additional operations for Ollama, just load in runtime. if p.protocol == Ollama { return @@ -112,10 +110,11 @@ func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApp if index != 0 { initContainerName += "-" + strconv.Itoa(index) } + // Handle initContainer. initContainer := coreapplyv1.Container(). WithName(initContainerName). - WithImage(pkg.LOADER_IMAGE). + WithImage(initContainerImage). WithVolumeMounts( coreapplyv1.VolumeMount(). WithName(MODEL_VOLUME_NAME). diff --git a/pkg/defaults.go b/pkg/defaults.go deleted file mode 100644 index 82d54f5a..00000000 --- a/pkg/defaults.go +++ /dev/null @@ -1,21 +0,0 @@ -/* -Copyright 2024 The InftyAI Team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package pkg - -const ( - LOADER_IMAGE = "inftyai/model-loader:v0.0.10" -) diff --git a/test/util/validation/validate_service.go b/test/util/validation/validate_service.go index e3a1efaf..58a750ef 100644 --- a/test/util/validation/validate_service.go +++ b/test/util/validation/validate_service.go @@ -42,7 +42,6 @@ import ( coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" - "github.com/inftyai/llmaz/pkg" helper "github.com/inftyai/llmaz/pkg/controller_helper" modelSource "github.com/inftyai/llmaz/pkg/controller_helper/modelsource" pkgUtil "github.com/inftyai/llmaz/pkg/util" @@ -114,7 +113,7 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe return err } - if err := ValidateConfigmap(ctx, k8sClient, service); err != nil { + if err := ValidateConfigmap(ctx, k8sClient, service, &workload); err != nil { return err } @@ -143,8 +142,8 @@ func ValidateModelLoader(model *coreapi.OpenModel, index int, template corev1.Po if initContainer.Name != containerName { return fmt.Errorf("unexpected initContainer name, want %s, got %s", modelSource.MODEL_LOADER_CONTAINER_NAME, initContainer.Name) } - if initContainer.Image != pkg.LOADER_IMAGE { - return fmt.Errorf("unexpected initContainer image, want %s, got %s", pkg.LOADER_IMAGE, initContainer.Image) + if initContainer.Image == "" { + return fmt.Errorf("unexpected initContainer image, initContainer image should not be empty") } var envStrings []string @@ -440,7 +439,7 @@ func CheckServiceAvaliable() error { return nil } -func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) error { +func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service, workload *lws.LeaderWorkerSet) error { cm := corev1.ConfigMap{} if err := k8sClient.Get(ctx, types.NamespacedName{Name: "llmaz-global-config", Namespace: "llmaz-system"}, &cm); err != nil { return err @@ -451,6 +450,7 @@ func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *in return fmt.Errorf("failed to parse global configmap: %v", err) } + // Validate scheduler name. if service.Spec.WorkloadTemplate.LeaderTemplate != nil { if service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName != data.SchedulerName { return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName, data.SchedulerName) @@ -461,5 +461,22 @@ func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *in return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName, data.SchedulerName) } + if !helper.SkipModelLoader(service) { + // Validate init container image. + if service.Spec.WorkloadTemplate.LeaderTemplate != nil { + for _, initContainer := range workload.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.InitContainers { + if initContainer.Image != data.InitContainerImage { + return fmt.Errorf("unexpected init container image %s in leader template, want %s", initContainer.Image, data.InitContainerImage) + } + } + } + + for _, initContainer := range workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.InitContainers { + if initContainer.Image != data.InitContainerImage { + return fmt.Errorf("unexpected init container image %s in worker template, want %s", initContainer.Image, data.InitContainerImage) + } + } + } + return nil }