Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/default/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ metadata:
data:
config.data: |
scheduler-name: default-scheduler
# init-container-image: inftyai/model-loader:v0.0.10
init-container-image: inftyai/model-loader:v0.0.10
13 changes: 7 additions & 6 deletions pkg/controller/inference/service_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct
return ctrl.Result{}, err
}

workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models)
workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models, configs)
if err != nil {
return ctrl.Result{}, err
}

if err := setControllerReferenceForWorkload(service, workloadApplyConfiguration, r.Scheme); err != nil {
return ctrl.Result{}, err
}
Expand Down Expand Up @@ -162,7 +163,7 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error {
Complete(r)
}

func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel, configs *helper.GlobalConfigs) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
workload := applyconfigurationv1.LeaderWorkerSet(service.Name, service.Namespace)

leaderWorkerTemplate := applyconfigurationv1.LeaderWorkerTemplate()
Expand Down Expand Up @@ -193,7 +194,7 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
leaderWorkerTemplate.WithWorkerTemplate(&podTemplateSpecApplyConfiguration)

// The core logic to inject additional configurations.
injectModelProperties(leaderWorkerTemplate, models, service)
injectModelProperties(leaderWorkerTemplate, models, service, configs)

spec := applyconfigurationv1.LeaderWorkerSetSpec()
spec.WithLeaderWorkerTemplate(leaderWorkerTemplate)
Expand All @@ -215,17 +216,17 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
return workload, nil
}

func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service) {
func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service, configs *helper.GlobalConfigs) {
isMultiNodesInference := template.LeaderTemplate != nil

for i, model := range models {
source := modelSource.NewModelSourceProvider(model)
// Skip model-loader initContainer if llmaz.io/skip-model-loader annotation is set.
if !helper.SkipModelLoader(service) {
if isMultiNodesInference {
source.InjectModelLoader(template.LeaderTemplate, i)
source.InjectModelLoader(template.LeaderTemplate, i, configs.InitContainerImage)
}
source.InjectModelLoader(template.WorkerTemplate, i)
source.InjectModelLoader(template.WorkerTemplate, i, configs.InitContainerImage)
} else {
if isMultiNodesInference {
source.InjectModelEnvVars(template.LeaderTemplate)
Expand Down
11 changes: 11 additions & 0 deletions pkg/controller_helper/configmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ func ParseGlobalConfigmap(cm *corev1.ConfigMap) (*GlobalConfigs, error) {
return nil, fmt.Errorf("failed to unmarshal config.data: %v", err)
}

if err := configs.validate(); err != nil {
return nil, fmt.Errorf("invalid global config: %v", err)
}

return &configs, nil
}

func (c *GlobalConfigs) validate() error {
if c.InitContainerImage == "" {
return fmt.Errorf("init-container-image is required")
}
return nil
}
64 changes: 64 additions & 0 deletions pkg/controller_helper/configmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Copyright 2025 The InftyAI Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package helper

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGlobalConfigs_validate(t *testing.T) {
tests := []struct {
name string
config *GlobalConfigs
expectError bool
errorMsg string
}{
{
name: "valid config",
config: &GlobalConfigs{
SchedulerName: "custom-scheduler",
InitContainerImage: "inftyai/model-loader:v0.0.10",
},
expectError: false,
},
{
name: "empty init container image",
config: &GlobalConfigs{
SchedulerName: "custom-scheduler",
InitContainerImage: "",
},
expectError: true,
errorMsg: "init-container-image is required",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.validate()

if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
}
})
}
}
6 changes: 2 additions & 4 deletions pkg/controller_helper/modelsource/modelhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"strings"

coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"

"github.com/inftyai/llmaz/pkg"
)

var _ ModelSourceProvider = &ModelHubProvider{}
Expand Down Expand Up @@ -62,7 +60,7 @@ func (p *ModelHubProvider) ModelPath(skipModelLoader bool) string {
return CONTAINER_MODEL_PATH + "models--" + strings.ReplaceAll(p.modelID, "/", "--")
}

func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
initContainerName := MODEL_LOADER_CONTAINER_NAME
if index != 0 {
initContainerName += "-" + strconv.Itoa(index)
Expand All @@ -71,7 +69,7 @@ func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSp
// Handle initContainer.
initContainer := coreapplyv1.Container().
WithName(initContainerName).
WithImage(pkg.LOADER_IMAGE).
WithImage(initContainerImage).
WithVolumeMounts(coreapplyv1.VolumeMount().WithName(MODEL_VOLUME_NAME).WithMountPath(CONTAINER_MODEL_PATH))

// We have exactly one container in the template.Spec.Containers.
Expand Down
37 changes: 25 additions & 12 deletions pkg/controller_helper/modelsource/modelhub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import (
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"

"github.com/inftyai/llmaz/pkg"
)

func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
Expand All @@ -38,10 +36,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
ignorePatterns := []string{"*.tmp"}

tests := []struct {
name string
provider *ModelHubProvider
index int
expectMainModel bool
name string
provider *ModelHubProvider
index int
expectMainModel bool
initContainerImage string
}{
{
name: "inject full modelhub with fileName, revision, allow/ignore",
Expand All @@ -54,8 +53,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
modelAllowPatterns: allowPatterns,
modelIgnorePatterns: ignorePatterns,
},
index: 0,
expectMainModel: true,
index: 0,
expectMainModel: true,
initContainerImage: "model-loader:latest",
},
{
name: "inject with index > 0 skips volume/container mount",
Expand All @@ -64,8 +64,20 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
modelID: "some/model",
modelHub: "Huggingface",
},
index: 1,
expectMainModel: false,
index: 1,
expectMainModel: false,
initContainerImage: "model-loader:latest",
},
{
name: "inject with custom initContainerImage",
provider: &ModelHubProvider{
modelName: "llama3",
modelID: "meta/llama-3",
modelHub: "Huggingface",
},
index: 0,
expectMainModel: true,
initContainerImage: "custom-model-loader:latest",
},
}

Expand All @@ -83,7 +95,7 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
),
)

tt.provider.InjectModelLoader(template, tt.index)
tt.provider.InjectModelLoader(template, tt.index, tt.initContainerImage)

assert.Len(t, template.Spec.InitContainers, 1)
initContainer := template.Spec.InitContainers[0]
Expand All @@ -92,8 +104,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
if tt.index != 0 {
expectedName += "-" + strconv.Itoa(tt.index)
}
expectedImage := tt.initContainerImage
assert.Equal(t, expectedName, *initContainer.Name)
assert.Equal(t, pkg.LOADER_IMAGE, *initContainer.Image)
assert.Equal(t, expectedImage, *initContainer.Image)

wantEnv := buildExpectedEnv(tt.provider)
if diff := cmp.Diff(wantEnv, initContainer.Env, envSortOpt); diff != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller_helper/modelsource/modelsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type ModelSourceProvider interface {
ModelPath(skipModelLoader bool) string
// InjectModelLoader will inject the model loader to the spec,
// index refers to the suffix of the initContainer name, like model-loader, model-loader-1.
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int)
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string)
// InjectModelEnvVars will inject the model credentials env to the model-runner container.
// This is used when the model-loader initContainer is not injected, and the model loading is handled by the model-runner container.
InjectModelEnvVars(spec *coreapplyv1.PodTemplateSpecApplyConfiguration)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller_helper/modelsource/modelsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestEnvInjectModelLoader(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.provider.InjectModelLoader(tt.template, 0)
tt.provider.InjectModelLoader(tt.template, 0, "model-loader:latest")
initContainer := tt.template.Spec.InitContainers[0]
assert.Subset(t, initContainer.Env, tt.template.Spec.Containers[0].Env)
})
Expand Down
7 changes: 3 additions & 4 deletions pkg/controller_helper/modelsource/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"strings"

coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"

"github.com/inftyai/llmaz/pkg"
)

var _ ModelSourceProvider = &URIProvider{}
Expand Down Expand Up @@ -80,7 +78,7 @@ func (p *URIProvider) ModelPath(skipModelLoader bool) string {
return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1]
}

func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
// We don't have additional operations for Ollama, just load in runtime.
if p.protocol == Ollama {
return
Expand Down Expand Up @@ -112,10 +110,11 @@ func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApp
if index != 0 {
initContainerName += "-" + strconv.Itoa(index)
}

// Handle initContainer.
initContainer := coreapplyv1.Container().
WithName(initContainerName).
WithImage(pkg.LOADER_IMAGE).
WithImage(initContainerImage).
WithVolumeMounts(
coreapplyv1.VolumeMount().
WithName(MODEL_VOLUME_NAME).
Expand Down
21 changes: 0 additions & 21 deletions pkg/defaults.go

This file was deleted.

27 changes: 22 additions & 5 deletions test/util/validation/validate_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import (

coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1"
"github.com/inftyai/llmaz/pkg"
helper "github.com/inftyai/llmaz/pkg/controller_helper"
modelSource "github.com/inftyai/llmaz/pkg/controller_helper/modelsource"
pkgUtil "github.com/inftyai/llmaz/pkg/util"
Expand Down Expand Up @@ -114,7 +113,7 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe
return err
}

if err := ValidateConfigmap(ctx, k8sClient, service); err != nil {
if err := ValidateConfigmap(ctx, k8sClient, service, &workload); err != nil {
return err
}

Expand Down Expand Up @@ -143,8 +142,8 @@ func ValidateModelLoader(model *coreapi.OpenModel, index int, template corev1.Po
if initContainer.Name != containerName {
return fmt.Errorf("unexpected initContainer name, want %s, got %s", modelSource.MODEL_LOADER_CONTAINER_NAME, initContainer.Name)
}
if initContainer.Image != pkg.LOADER_IMAGE {
return fmt.Errorf("unexpected initContainer image, want %s, got %s", pkg.LOADER_IMAGE, initContainer.Image)
if initContainer.Image == "" {
return fmt.Errorf("unexpected initContainer image, initContainer image should not be empty")
}

var envStrings []string
Expand Down Expand Up @@ -440,7 +439,7 @@ func CheckServiceAvaliable() error {
return nil
}

func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) error {
func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service, workload *lws.LeaderWorkerSet) error {
cm := corev1.ConfigMap{}
if err := k8sClient.Get(ctx, types.NamespacedName{Name: "llmaz-global-config", Namespace: "llmaz-system"}, &cm); err != nil {
return err
Expand All @@ -451,6 +450,7 @@ func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *in
return fmt.Errorf("failed to parse global configmap: %v", err)
}

// Validate scheduler name.
if service.Spec.WorkloadTemplate.LeaderTemplate != nil {
if service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName != data.SchedulerName {
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName, data.SchedulerName)
Expand All @@ -461,5 +461,22 @@ func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *in
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName, data.SchedulerName)
}

if !helper.SkipModelLoader(service) {
// Validate init container image.
if service.Spec.WorkloadTemplate.LeaderTemplate != nil {
for _, initContainer := range workload.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.InitContainers {
if initContainer.Image != data.InitContainerImage {
return fmt.Errorf("unexpected init container image %s in leader template, want %s", initContainer.Image, data.InitContainerImage)
}
}
}

for _, initContainer := range workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.InitContainers {
if initContainer.Image != data.InitContainerImage {
return fmt.Errorf("unexpected init container image %s in worker template, want %s", initContainer.Image, data.InitContainerImage)
}
}
}

return nil
}
Loading