@@ -18,10 +18,15 @@ limitations under the License.
1818package modelSource
1919
2020import (
21+ "strconv"
22+ "strings"
2123 "testing"
2224
25+ "github.com/google/go-cmp/cmp"
26+ "github.com/google/go-cmp/cmp/cmpopts"
2327 "github.com/stretchr/testify/assert"
2428 corev1 "k8s.io/api/core/v1"
29+ "k8s.io/utils/ptr"
2530
2631 "github.com/inftyai/llmaz/pkg"
2732)
@@ -32,12 +37,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
3237 allowPatterns := []string {"*.gguf" , "*.json" }
3338 ignorePatterns := []string {"*.tmp" }
3439
35- testCases := []struct {
36- name string
37- provider * ModelHubProvider
38- index int
39- expectMainModel bool
40- expectEnvContains []string
40+ tests := []struct {
41+ name string
42+ provider * ModelHubProvider
43+ index int
44+ expectMainModel bool
4145 }{
4246 {
4347 name : "inject full modelhub with fileName, revision, allow/ignore" ,
@@ -52,11 +56,6 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
5256 },
5357 index : 0 ,
5458 expectMainModel : true ,
55- expectEnvContains : []string {
56- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" , "MODEL_FILENAME" ,
57- "REVISION" , "MODEL_ALLOW_PATTERNS" , "MODEL_IGNORE_PATTERNS" ,
58- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
59- },
6059 },
6160 {
6261 name : "inject with index > 0 skips volume/container mount" ,
@@ -67,15 +66,15 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
6766 },
6867 index : 1 ,
6968 expectMainModel : false ,
70- expectEnvContains : []string {
71- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" ,
72- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
73- },
7469 },
7570 }
7671
77- for _ , tc := range testCases {
78- t .Run (tc .name , func (t * testing.T ) {
72+ envSortOpt := cmpopts .SortSlices (func (a , b corev1.EnvVar ) bool {
73+ return a .Name < b .Name
74+ })
75+
76+ for _ , tt := range tests {
77+ t .Run (tt .name , func (t * testing.T ) {
7978 template := & corev1.PodTemplateSpec {
8079 Spec : corev1.PodSpec {
8180 Containers : []corev1.Container {
@@ -89,57 +88,94 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
8988 },
9089 }
9190
92- tc .provider .InjectModelLoader (template , tc .index )
91+ tt .provider .InjectModelLoader (template , tt .index )
9392
9493 assert .Len (t , template .Spec .InitContainers , 1 )
9594 initContainer := template .Spec .InitContainers [0 ]
95+
9696 expectedName := MODEL_LOADER_CONTAINER_NAME
97- if tc .index != 0 {
98- expectedName += "-" + string ( rune ( '0' + tc .index ) )
97+ if tt .index != 0 {
98+ expectedName += "-" + strconv . Itoa ( tt .index )
9999 }
100100 assert .Equal (t , expectedName , initContainer .Name )
101101 assert .Equal (t , pkg .LOADER_IMAGE , initContainer .Image )
102102
103- // Check env vars exist
104- for _ , key := range tc .expectEnvContains {
105- found := false
106- for _ , env := range initContainer .Env {
107- if env .Name == key {
108- found = true
109- break
110- }
111- }
112- assert .True (t , found , "expected env %s not found" , key )
103+ wantEnv := buildExpectedEnv (tt .provider )
104+ if diff := cmp .Diff (wantEnv , initContainer .Env , envSortOpt ); diff != "" {
105+ t .Errorf ("InitContainer.Env mismatch (-want +got):\n %s" , diff )
113106 }
114107
115- // Main model should inject volume & container mount
116- if tc .expectMainModel {
117- // Volume should be present
118- foundVol := false
119- for _ , v := range template .Spec .Volumes {
120- if v .Name == MODEL_VOLUME_NAME {
121- foundVol = true
122- break
123- }
124- }
125- assert .True (t , foundVol , "volume not injected" )
126-
127- // Runner container mount should exist
128- foundMount := false
129- for _ , m := range template .Spec .Containers [0 ].VolumeMounts {
130- if m .Name == MODEL_VOLUME_NAME && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
131- foundMount = true
132- }
133- }
134- assert .True (t , foundMount , "volume mount not injected to runner" )
108+ if tt .expectMainModel {
109+ assert .True (t , hasVolume (template .Spec .Volumes , MODEL_VOLUME_NAME ), "model volume missing" )
110+ assert .True (t , hasMount (template .Spec .Containers [0 ].VolumeMounts , MODEL_VOLUME_NAME ), "runner volumeMount missing" )
135111 } else {
136- // No volumes or mounts should be injected
137- assert .Empty (t , template .Spec .Volumes )
138- assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts )
112+ assert .Empty (t , template .Spec .Volumes , "unexpected volumes for sub-model" )
113+ assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts , "unexpected mounts for sub-model" )
139114 }
115+ })
116+ }
117+ }
118+
119+ func buildExpectedEnv (p * ModelHubProvider ) []corev1.EnvVar {
120+ envs := make ([]corev1.EnvVar , 0 , 10 )
121+
122+ envs = append (envs , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
123+
124+ envs = append (envs ,
125+ corev1.EnvVar {Name : "MODEL_SOURCE_TYPE" , Value : MODEL_SOURCE_MODELHUB },
126+ corev1.EnvVar {Name : "MODEL_ID" , Value : p .modelID },
127+ corev1.EnvVar {Name : "MODEL_HUB_NAME" , Value : p .modelHub },
128+ )
129+
130+ if p .fileName != nil {
131+ envs = append (envs , corev1.EnvVar {Name : "MODEL_FILENAME" , Value : * p .fileName })
132+ }
133+ if p .modelRevision != nil {
134+ envs = append (envs , corev1.EnvVar {Name : "REVISION" , Value : * p .modelRevision })
135+ }
136+ if p .modelAllowPatterns != nil {
137+ envs = append (envs , corev1.EnvVar {
138+ Name : "MODEL_ALLOW_PATTERNS" ,
139+ Value : strings .Join (p .modelAllowPatterns , "," ),
140+ })
141+ }
142+ if p .modelIgnorePatterns != nil {
143+ envs = append (envs , corev1.EnvVar {
144+ Name : "MODEL_IGNORE_PATTERNS" ,
145+ Value : strings .Join (p .modelIgnorePatterns , "," ),
146+ })
147+ }
140148
141- // Should always carry over container env
142- assert .Contains (t , initContainer .Env , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
149+ for _ , tokenName := range []string {"HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" } {
150+ envs = append (envs , corev1.EnvVar {
151+ Name : tokenName ,
152+ ValueFrom : & corev1.EnvVarSource {
153+ SecretKeyRef : & corev1.SecretKeySelector {
154+ LocalObjectReference : corev1.LocalObjectReference {Name : MODELHUB_SECRET_NAME },
155+ Key : HUGGINGFACE_TOKEN_KEY ,
156+ Optional : ptr .To (true ),
157+ },
158+ },
143159 })
144160 }
161+
162+ return envs
163+ }
164+
165+ func hasVolume (vols []corev1.Volume , name string ) bool {
166+ for _ , v := range vols {
167+ if v .Name == name {
168+ return true
169+ }
170+ }
171+ return false
172+ }
173+
174+ func hasMount (mounts []corev1.VolumeMount , name string ) bool {
175+ for _ , m := range mounts {
176+ if m .Name == name && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
177+ return true
178+ }
179+ }
180+ return false
145181}
0 commit comments