Skip to content

Commit 8ea73a9

Browse files
committed
feat: support configuring init container image
Signed-off-by: rudeigerc <rudeigerc@gmail.com>
1 parent ea979b8 commit 8ea73a9

7 files changed

Lines changed: 74 additions & 39 deletions

File tree

pkg/controller/inference/service_controller.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct
116116
return ctrl.Result{}, err
117117
}
118118

119-
workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models)
119+
workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models, configs)
120120
if err != nil {
121121
return ctrl.Result{}, err
122122
}
123+
123124
if err := setControllerReferenceForWorkload(service, workloadApplyConfiguration, r.Scheme); err != nil {
124125
return ctrl.Result{}, err
125126
}
@@ -162,7 +163,7 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error {
162163
Complete(r)
163164
}
164165

165-
func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
166+
func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel, configs *helper.GlobalConfigs) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
166167
workload := applyconfigurationv1.LeaderWorkerSet(service.Name, service.Namespace)
167168

168169
leaderWorkerTemplate := applyconfigurationv1.LeaderWorkerTemplate()
@@ -193,7 +194,7 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
193194
leaderWorkerTemplate.WithWorkerTemplate(&podTemplateSpecApplyConfiguration)
194195

195196
// The core logic to inject additional configurations.
196-
injectModelProperties(leaderWorkerTemplate, models, service)
197+
injectModelProperties(leaderWorkerTemplate, models, service, configs)
197198

198199
spec := applyconfigurationv1.LeaderWorkerSetSpec()
199200
spec.WithLeaderWorkerTemplate(leaderWorkerTemplate)
@@ -215,17 +216,17 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
215216
return workload, nil
216217
}
217218

218-
func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service) {
219+
func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service, configs *helper.GlobalConfigs) {
219220
isMultiNodesInference := template.LeaderTemplate != nil
220221

221222
for i, model := range models {
222223
source := modelSource.NewModelSourceProvider(model)
223224
// Skip model-loader initContainer if llmaz.io/skip-model-loader annotation is set.
224225
if !helper.SkipModelLoader(service) {
225226
if isMultiNodesInference {
226-
source.InjectModelLoader(template.LeaderTemplate, i)
227+
source.InjectModelLoader(template.LeaderTemplate, i, configs.InitContainerImage)
227228
}
228-
source.InjectModelLoader(template.WorkerTemplate, i)
229+
source.InjectModelLoader(template.WorkerTemplate, i, configs.InitContainerImage)
229230
} else {
230231
if isMultiNodesInference {
231232
source.InjectModelEnvVars(template.LeaderTemplate)

pkg/controller_helper/modelsource/modelhub.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,20 @@ func (p *ModelHubProvider) ModelPath(skipModelLoader bool) string {
6262
return CONTAINER_MODEL_PATH + "models--" + strings.ReplaceAll(p.modelID, "/", "--")
6363
}
6464

65-
func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
65+
func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
6666
initContainerName := MODEL_LOADER_CONTAINER_NAME
6767
if index != 0 {
6868
initContainerName += "-" + strconv.Itoa(index)
6969
}
7070

71+
if initContainerImage == "" {
72+
initContainerImage = pkg.LOADER_IMAGE
73+
}
74+
7175
// Handle initContainer.
7276
initContainer := coreapplyv1.Container().
7377
WithName(initContainerName).
74-
WithImage(pkg.LOADER_IMAGE).
78+
WithImage(initContainerImage).
7579
WithVolumeMounts(coreapplyv1.VolumeMount().WithName(MODEL_VOLUME_NAME).WithMountPath(CONTAINER_MODEL_PATH))
7680

7781
// We have exactly one container in the template.Spec.Containers.

pkg/controller_helper/modelsource/modelhub_test.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
3838
ignorePatterns := []string{"*.tmp"}
3939

4040
tests := []struct {
41-
name string
42-
provider *ModelHubProvider
43-
index int
44-
expectMainModel bool
41+
name string
42+
provider *ModelHubProvider
43+
index int
44+
expectMainModel bool
45+
initContainerImage string
4546
}{
4647
{
4748
name: "inject full modelhub with fileName, revision, allow/ignore",
@@ -67,6 +68,17 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
6768
index: 1,
6869
expectMainModel: false,
6970
},
71+
{
72+
name: "inject with custom initContainerImage",
73+
provider: &ModelHubProvider{
74+
modelName: "llama3",
75+
modelID: "meta/llama-3",
76+
modelHub: "Huggingface",
77+
},
78+
index: 0,
79+
expectMainModel: true,
80+
initContainerImage: "custom-loader-image:latest",
81+
},
7082
}
7183

7284
envSortOpt := cmpopts.SortSlices(func(a, b corev1.EnvVar) bool {
@@ -83,7 +95,7 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
8395
),
8496
)
8597

86-
tt.provider.InjectModelLoader(template, tt.index)
98+
tt.provider.InjectModelLoader(template, tt.index, tt.initContainerImage)
8799

88100
assert.Len(t, template.Spec.InitContainers, 1)
89101
initContainer := template.Spec.InitContainers[0]
@@ -92,8 +104,12 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
92104
if tt.index != 0 {
93105
expectedName += "-" + strconv.Itoa(tt.index)
94106
}
107+
expectedImage := tt.initContainerImage
108+
if expectedImage == "" {
109+
expectedImage = pkg.LOADER_IMAGE
110+
}
95111
assert.Equal(t, expectedName, *initContainer.Name)
96-
assert.Equal(t, pkg.LOADER_IMAGE, *initContainer.Image)
112+
assert.Equal(t, expectedImage, *initContainer.Image)
97113

98114
wantEnv := buildExpectedEnv(tt.provider)
99115
if diff := cmp.Diff(wantEnv, initContainer.Env, envSortOpt); diff != "" {

pkg/controller_helper/modelsource/modelsource.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type ModelSourceProvider interface {
5959
ModelPath(skipModelLoader bool) string
6060
// InjectModelLoader will inject the model loader to the spec,
6161
// index refers to the suffix of the initContainer name, like model-loader, model-loader-1.
62-
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int)
62+
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string)
6363
// InjectModelEnvVars will inject the model credentials env to the model-runner container.
6464
// This is used when the model-loader initContainer is not injected, and the model loading is handled by the model-runner container.
6565
InjectModelEnvVars(spec *coreapplyv1.PodTemplateSpecApplyConfiguration)

pkg/controller_helper/modelsource/modelsource_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"
2424

2525
coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
26+
"github.com/inftyai/llmaz/pkg"
2627
"github.com/inftyai/llmaz/test/util"
2728
"github.com/inftyai/llmaz/test/util/wrapper"
2829
)
@@ -130,7 +131,7 @@ func TestEnvInjectModelLoader(t *testing.T) {
130131

131132
for _, tt := range tests {
132133
t.Run(tt.name, func(t *testing.T) {
133-
tt.provider.InjectModelLoader(tt.template, 0)
134+
tt.provider.InjectModelLoader(tt.template, 0, pkg.LOADER_IMAGE)
134135
initContainer := tt.template.Spec.InitContainers[0]
135136
assert.Subset(t, initContainer.Env, tt.template.Spec.Containers[0].Env)
136137
})

pkg/controller_helper/modelsource/uri.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (p *URIProvider) ModelPath(skipModelLoader bool) string {
8080
return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1]
8181
}
8282

83-
func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
83+
func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
8484
// We don't have additional operations for Ollama, just load in runtime.
8585
if p.protocol == Ollama {
8686
return
@@ -112,10 +112,16 @@ func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApp
112112
if index != 0 {
113113
initContainerName += "-" + strconv.Itoa(index)
114114
}
115+
116+
// Handle the image of initContainer.
117+
if initContainerImage == "" {
118+
initContainerImage = pkg.LOADER_IMAGE
119+
}
120+
115121
// Handle initContainer.
116122
initContainer := coreapplyv1.Container().
117123
WithName(initContainerName).
118-
WithImage(pkg.LOADER_IMAGE).
124+
WithImage(initContainerImage).
119125
WithVolumeMounts(
120126
coreapplyv1.VolumeMount().
121127
WithName(MODEL_VOLUME_NAME).

test/util/validation/validate_service.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe
7272
models = append(models, model)
7373
}
7474

75+
// Fetch global configuration.
76+
cm := corev1.ConfigMap{}
77+
if err := k8sClient.Get(ctx, types.NamespacedName{Name: "llmaz-global-config", Namespace: "llmaz-system"}, &cm); err != nil {
78+
return err
79+
}
80+
81+
data, err := helper.ParseGlobalConfigmap(&cm)
82+
if err != nil {
83+
return fmt.Errorf("failed to parse global configmap: %v", err)
84+
}
85+
86+
initContainerImage := data.InitContainerImage
87+
if initContainerImage == "" {
88+
initContainerImage = pkg.LOADER_IMAGE
89+
}
90+
7591
for index, model := range models {
7692
if helper.SkipModelLoader(service) {
7793
if service.Spec.WorkloadTemplate.LeaderTemplate != nil {
@@ -85,11 +101,11 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe
85101
} else {
86102
// Validate injecting modelLoaders
87103
if service.Spec.WorkloadTemplate.LeaderTemplate != nil {
88-
if err := ValidateModelLoader(model, index, *workload.Spec.LeaderWorkerTemplate.LeaderTemplate, service); err != nil {
104+
if err := ValidateModelLoader(model, index, *workload.Spec.LeaderWorkerTemplate.LeaderTemplate, service, initContainerImage); err != nil {
89105
return err
90106
}
91107
}
92-
if err := ValidateModelLoader(model, index, workload.Spec.LeaderWorkerTemplate.WorkerTemplate, service); err != nil {
108+
if err := ValidateModelLoader(model, index, workload.Spec.LeaderWorkerTemplate.WorkerTemplate, service, initContainerImage); err != nil {
93109
return err
94110
}
95111
}
@@ -114,15 +130,15 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe
114130
return err
115131
}
116132

117-
if err := ValidateConfigmap(ctx, k8sClient, service); err != nil {
133+
if err := ValidateSchedulerName(data.SchedulerName, service); err != nil {
118134
return err
119135
}
120136

121137
return nil
122138
}, util.IntegrationTimeout, util.Interval).Should(gomega.Succeed())
123139
}
124140

125-
func ValidateModelLoader(model *coreapi.OpenModel, index int, template corev1.PodTemplateSpec, service *inferenceapi.Service) error {
141+
func ValidateModelLoader(model *coreapi.OpenModel, index int, template corev1.PodTemplateSpec, service *inferenceapi.Service, initContainerImage string) error {
126142
if model.Spec.Source.URI != nil {
127143
protocol, _, _ := pkgUtil.ParseURI(string(*model.Spec.Source.URI))
128144
if protocol == modelSource.Ollama {
@@ -143,8 +159,9 @@ func ValidateModelLoader(model *coreapi.OpenModel, index int, template corev1.Po
143159
if initContainer.Name != containerName {
144160
return fmt.Errorf("unexpected initContainer name, want %s, got %s", modelSource.MODEL_LOADER_CONTAINER_NAME, initContainer.Name)
145161
}
146-
if initContainer.Image != pkg.LOADER_IMAGE {
147-
return fmt.Errorf("unexpected initContainer image, want %s, got %s", pkg.LOADER_IMAGE, initContainer.Image)
162+
163+
if initContainer.Image != initContainerImage {
164+
return fmt.Errorf("unexpected initContainer image, want %s, got %s", initContainerImage, initContainer.Image)
148165
}
149166

150167
var envStrings []string
@@ -437,25 +454,15 @@ func CheckServiceAvaliable() error {
437454
return nil
438455
}
439456

440-
func ValidateConfigmap(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) error {
441-
cm := corev1.ConfigMap{}
442-
if err := k8sClient.Get(ctx, types.NamespacedName{Name: "llmaz-global-config", Namespace: "llmaz-system"}, &cm); err != nil {
443-
return err
444-
}
445-
446-
data, err := helper.ParseGlobalConfigmap(&cm)
447-
if err != nil {
448-
return fmt.Errorf("failed to parse global configmap: %v", err)
449-
}
450-
457+
func ValidateSchedulerName(schedulerName string, service *inferenceapi.Service) error {
451458
if service.Spec.WorkloadTemplate.LeaderTemplate != nil {
452-
if service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName != data.SchedulerName {
453-
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName, data.SchedulerName)
459+
if service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName != schedulerName {
460+
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.LeaderTemplate.Spec.SchedulerName, schedulerName)
454461
}
455462
}
456463

457-
if service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName != data.SchedulerName {
458-
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName, data.SchedulerName)
464+
if service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName != schedulerName {
465+
return fmt.Errorf("unexpected scheduler name %s, want %s", service.Spec.WorkloadTemplate.WorkerTemplate.Spec.SchedulerName, schedulerName)
459466
}
460467

461468
return nil

0 commit comments

Comments
 (0)