@@ -18,10 +18,15 @@ limitations under the License.
1818package modelSource
1919
2020import (
21+ "strconv"
22+ "strings"
2123 "testing"
2224
25+ "github.com/google/go-cmp/cmp"
26+ "github.com/google/go-cmp/cmp/cmpopts"
2327 "github.com/stretchr/testify/assert"
2428 corev1 "k8s.io/api/core/v1"
29+ "k8s.io/utils/ptr"
2530
2631 "github.com/inftyai/llmaz/pkg"
2732)
@@ -32,12 +37,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
3237 allowPatterns := []string {"*.gguf" , "*.json" }
3338 ignorePatterns := []string {"*.tmp" }
3439
35- testCases := []struct {
36- name string
37- provider * ModelHubProvider
38- index int
39- expectMainModel bool
40- expectEnvContains []string
40+ tests := []struct {
41+ name string
42+ provider * ModelHubProvider
43+ index int
44+ expectMainModel bool
4145 }{
4246 {
4347 name : "inject full modelhub with fileName, revision, allow/ignore" ,
@@ -52,11 +56,6 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
5256 },
5357 index : 0 ,
5458 expectMainModel : true ,
55- expectEnvContains : []string {
56- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" , "MODEL_FILENAME" ,
57- "REVISION" , "MODEL_ALLOW_PATTERNS" , "MODEL_IGNORE_PATTERNS" ,
58- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
59- },
6059 },
6160 {
6261 name : "inject with index > 0 skips volume/container mount" ,
@@ -67,15 +66,16 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
6766 },
6867 index : 1 ,
6968 expectMainModel : false ,
70- expectEnvContains : []string {
71- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" ,
72- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
73- },
7469 },
7570 }
7671
77- for _ , tc := range testCases {
78- t .Run (tc .name , func (t * testing.T ) {
72+ envSortOpt := cmpopts .SortSlices (func (a , b corev1.EnvVar ) bool {
73+ return a .Name < b .Name
74+ })
75+
76+ for _ , tt := range tests {
77+ // tt := tt
78+ t .Run (tt .name , func (t * testing.T ) {
7979 template := & corev1.PodTemplateSpec {
8080 Spec : corev1.PodSpec {
8181 Containers : []corev1.Container {
@@ -89,57 +89,94 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
8989 },
9090 }
9191
92- tc .provider .InjectModelLoader (template , tc .index )
92+ tt .provider .InjectModelLoader (template , tt .index )
9393
9494 assert .Len (t , template .Spec .InitContainers , 1 )
9595 initContainer := template .Spec .InitContainers [0 ]
96+
9697 expectedName := MODEL_LOADER_CONTAINER_NAME
97- if tc .index != 0 {
98- expectedName += "-" + string ( rune ( '0' + tc .index ) )
98+ if tt .index != 0 {
99+ expectedName += "-" + strconv . Itoa ( tt .index )
99100 }
100101 assert .Equal (t , expectedName , initContainer .Name )
101102 assert .Equal (t , pkg .LOADER_IMAGE , initContainer .Image )
102103
103- // Check env vars exist
104- for _ , key := range tc .expectEnvContains {
105- found := false
106- for _ , env := range initContainer .Env {
107- if env .Name == key {
108- found = true
109- break
110- }
111- }
112- assert .True (t , found , "expected env %s not found" , key )
104+ wantEnv := buildExpectedEnv (tt .provider )
105+ if diff := cmp .Diff (wantEnv , initContainer .Env , envSortOpt ); diff != "" {
106+ t .Errorf ("InitContainer.Env mismatch (-want +got):\n %s" , diff )
113107 }
114108
115- // Main model should inject volume & container mount
116- if tc .expectMainModel {
117- // Volume should be present
118- foundVol := false
119- for _ , v := range template .Spec .Volumes {
120- if v .Name == MODEL_VOLUME_NAME {
121- foundVol = true
122- break
123- }
124- }
125- assert .True (t , foundVol , "volume not injected" )
126-
127- // Runner container mount should exist
128- foundMount := false
129- for _ , m := range template .Spec .Containers [0 ].VolumeMounts {
130- if m .Name == MODEL_VOLUME_NAME && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
131- foundMount = true
132- }
133- }
134- assert .True (t , foundMount , "volume mount not injected to runner" )
109+ if tt .expectMainModel {
110+ assert .True (t , hasVolume (template .Spec .Volumes , MODEL_VOLUME_NAME ), "model volume missing" )
111+ assert .True (t , hasMount (template .Spec .Containers [0 ].VolumeMounts , MODEL_VOLUME_NAME ), "runner volumeMount missing" )
135112 } else {
136- // No volumes or mounts should be injected
137- assert .Empty (t , template .Spec .Volumes )
138- assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts )
113+ assert .Empty (t , template .Spec .Volumes , "unexpected volumes for sub-model" )
114+ assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts , "unexpected mounts for sub-model" )
139115 }
116+ })
117+ }
118+ }
119+
120+ func buildExpectedEnv (p * ModelHubProvider ) []corev1.EnvVar {
121+ var envs []corev1.EnvVar
122+
123+ envs = append (envs , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
124+
125+ envs = append (envs ,
126+ corev1.EnvVar {Name : "MODEL_SOURCE_TYPE" , Value : MODEL_SOURCE_MODELHUB },
127+ corev1.EnvVar {Name : "MODEL_ID" , Value : p .modelID },
128+ corev1.EnvVar {Name : "MODEL_HUB_NAME" , Value : p .modelHub },
129+ )
130+
131+ if p .fileName != nil {
132+ envs = append (envs , corev1.EnvVar {Name : "MODEL_FILENAME" , Value : * p .fileName })
133+ }
134+ if p .modelRevision != nil {
135+ envs = append (envs , corev1.EnvVar {Name : "REVISION" , Value : * p .modelRevision })
136+ }
137+ if p .modelAllowPatterns != nil {
138+ envs = append (envs , corev1.EnvVar {
139+ Name : "MODEL_ALLOW_PATTERNS" ,
140+ Value : strings .Join (p .modelAllowPatterns , "," ),
141+ })
142+ }
143+ if p .modelIgnorePatterns != nil {
144+ envs = append (envs , corev1.EnvVar {
145+ Name : "MODEL_IGNORE_PATTERNS" ,
146+ Value : strings .Join (p .modelIgnorePatterns , "," ),
147+ })
148+ }
140149
141- // Should always carry over container env
142- assert .Contains (t , initContainer .Env , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
150+ for _ , tokenName := range []string {"HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" } {
151+ envs = append (envs , corev1.EnvVar {
152+ Name : tokenName ,
153+ ValueFrom : & corev1.EnvVarSource {
154+ SecretKeyRef : & corev1.SecretKeySelector {
155+ LocalObjectReference : corev1.LocalObjectReference {Name : MODELHUB_SECRET_NAME },
156+ Key : HUGGINGFACE_TOKEN_KEY ,
157+ Optional : ptr .To (true ),
158+ },
159+ },
143160 })
144161 }
162+
163+ return envs
164+ }
165+
166+ func hasVolume (vols []corev1.Volume , name string ) bool {
167+ for _ , v := range vols {
168+ if v .Name == name {
169+ return true
170+ }
171+ }
172+ return false
173+ }
174+
175+ func hasMount (mounts []corev1.VolumeMount , name string ) bool {
176+ for _ , m := range mounts {
177+ if m .Name == name && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
178+ return true
179+ }
180+ }
181+ return false
145182}
0 commit comments