Skip to content

Commit 6e0973c

Browse files
committed
Introduce unittests for llama_stack_config.go
1 parent 48b5142 commit 6e0973c

1 file changed

Lines changed: 151 additions & 0 deletions

File tree

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package controller
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
apiv1beta1 "github.com/openstack-lightspeed/operator/api/v1beta1"
8+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
9+
10+
. "github.com/onsi/ginkgo/v2"
11+
. "github.com/onsi/gomega"
12+
)
13+
14+
func expectSentenceTransformersProvider(providers []interface{}) {
15+
sentenceTransformers := providers[0].(map[string]interface{})
16+
Expect(sentenceTransformers["provider_id"]).To(Equal("sentence-transformers"))
17+
Expect(sentenceTransformers["provider_type"]).To(Equal("inline::sentence-transformers"))
18+
}
19+
20+
func getOpenStackLightspeedProvidersInstance(provider string) *apiv1beta1.OpenStackLightspeed {
21+
instance := &apiv1beta1.OpenStackLightspeed{
22+
ObjectMeta: metav1.ObjectMeta{
23+
Name: "openstack-lightspeed",
24+
Namespace: "openstack-lightspeed",
25+
},
26+
}
27+
28+
switch provider {
29+
case "openai":
30+
instance.Spec.LLMEndpointType = "openai"
31+
instance.Spec.LLMEndpoint = "https://api.openai.com/v1"
32+
instance.Spec.ModelName = "gpt-4o"
33+
return instance
34+
case "gemini":
35+
instance.Spec.LLMEndpointType = "gemini"
36+
instance.Spec.ModelName = "gemini-2.0-flash"
37+
return instance
38+
case "rhoai_vllm":
39+
instance.Spec.LLMEndpointType = "rhoai_vllm"
40+
instance.Spec.LLMEndpoint = "https://vllm.example.com/v1"
41+
instance.Spec.ModelName = "meta-llama/Llama-3.1-70B-Instruct"
42+
return instance
43+
case "rhelai_vllm":
44+
instance.Spec.LLMEndpointType = "rhelai_vllm"
45+
instance.Spec.LLMEndpoint = "https://rhelai-vllm.example.com/v1"
46+
instance.Spec.ModelName = "meta-llama/Llama-3.1-70B-Instruct"
47+
return instance
48+
case "azure_openai":
49+
instance.Spec.LLMEndpointType = "azure_openai"
50+
instance.Spec.LLMEndpoint = "https://my-resource.openai.azure.com"
51+
instance.Spec.LLMDeploymentName = "gpt-4o-deployment"
52+
instance.Spec.LLMAPIVersion = "2024-02-01"
53+
instance.Spec.ModelName = "gpt-4o"
54+
return instance
55+
case "watsonx":
56+
instance.Spec.LLMEndpointType = "watsonx"
57+
instance.Spec.LLMEndpoint = "https://watsonx.example.com"
58+
instance.Spec.LLMProjectID = "test-project-id"
59+
instance.Spec.ModelName = "ibm/granite-13b-chat-v2"
60+
return instance
61+
default:
62+
Fail(fmt.Sprintf("Unknown provider %s", provider))
63+
}
64+
65+
return nil
66+
}
67+
68+
func checkModelCommonConfig(modelConfig map[string]interface{}, instance *apiv1beta1.OpenStackLightspeed) {
69+
Expect(modelConfig["model_id"]).To(Equal(instance.Spec.ModelName))
70+
Expect(modelConfig["model_type"]).To(Equal("llm"))
71+
Expect(modelConfig["provider_id"]).To(Equal(OpenStackLightspeedDefaultProvider))
72+
Expect(modelConfig["provider_model_id"]).To(Equal(instance.Spec.ModelName))
73+
Expect(modelConfig).NotTo(HaveKey("metadata"))
74+
}
75+
76+
var _ = Describe("Llama Stack config", func() {
77+
Describe("buildLlamaStackInferenceProviders", func() {
78+
DescribeTable("should return correct inference providers config",
79+
func(provider, providerType string, checkConfig func(map[string]interface{}, *apiv1beta1.OpenStackLightspeed)) {
80+
instance := getOpenStackLightspeedProvidersInstance(provider)
81+
inferenceProvidersConfig, err := buildLlamaStackInferenceProviders(nil, context.Background(), instance)
82+
83+
Expect(err).NotTo(HaveOccurred())
84+
Expect(inferenceProvidersConfig).To(HaveLen(2))
85+
86+
expectSentenceTransformersProvider(inferenceProvidersConfig)
87+
88+
inferenceProvider := inferenceProvidersConfig[1].(map[string]interface{})
89+
Expect(inferenceProvider["provider_id"]).To(Equal(OpenStackLightspeedDefaultProvider))
90+
Expect(inferenceProvider["provider_type"]).To(Equal(providerType))
91+
92+
checkConfig(inferenceProvider["config"].(map[string]interface{}), instance)
93+
},
94+
Entry("for openai", "openai", "remote::openai",
95+
func(config map[string]interface{}, _ *apiv1beta1.OpenStackLightspeed) {
96+
Expect(config["api_key"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY}"))
97+
}),
98+
Entry("for gemini", "gemini", "remote::gemini",
99+
func(config map[string]interface{}, _ *apiv1beta1.OpenStackLightspeed) {
100+
Expect(config["api_key"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY}"))
101+
Expect(config).NotTo(HaveKey("base_url"))
102+
}),
103+
Entry("for rhoai_vllm", "rhoai_vllm", "remote::vllm",
104+
func(config map[string]interface{}, instance *apiv1beta1.OpenStackLightspeed) {
105+
Expect(config["api_token"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY}"))
106+
Expect(config["base_url"]).To(Equal(instance.Spec.LLMEndpoint))
107+
}),
108+
Entry("for rhelai_vllm", "rhelai_vllm", "remote::vllm",
109+
func(config map[string]interface{}, instance *apiv1beta1.OpenStackLightspeed) {
110+
Expect(config["api_token"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY}"))
111+
Expect(config["base_url"]).To(Equal(instance.Spec.LLMEndpoint))
112+
}),
113+
Entry("for azure_openai", "azure_openai", "remote::azure",
114+
func(config map[string]interface{}, instance *apiv1beta1.OpenStackLightspeed) {
115+
Expect(config["api_key"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY}"))
116+
Expect(config["client_id"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_CLIENT_ID:=}"))
117+
Expect(config["tenant_id"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_TENANT_ID:=}"))
118+
Expect(config["client_secret"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_CLIENT_SECRET:=}"))
119+
Expect(config["api_base"]).To(Equal(instance.Spec.LLMEndpoint))
120+
Expect(config["deployment_name"]).To(Equal(instance.Spec.LLMDeploymentName))
121+
Expect(config["api_version"]).To(Equal(instance.Spec.LLMAPIVersion))
122+
}),
123+
Entry("for watsonx", "watsonx", "remote::watsonx",
124+
func(config map[string]interface{}, instance *apiv1beta1.OpenStackLightspeed) {
125+
Expect(config["base_url"]).To(Equal(instance.Spec.LLMEndpoint))
126+
Expect(config["project_id"]).To(Equal(instance.Spec.LLMProjectID))
127+
Expect(config["api_key"]).To(Equal("${env.OPENSTACK_LIGHTSPEED_PROVIDER_API_KEY:=}"))
128+
}),
129+
)
130+
})
131+
132+
Describe("buildLlamaStackModels", func() {
133+
DescribeTable("should return correct models config",
134+
func(provider string) {
135+
instance := getOpenStackLightspeedProvidersInstance(provider)
136+
modelsConfig := buildLlamaStackModels(nil, instance)
137+
138+
Expect(modelsConfig).To(HaveLen(1))
139+
140+
modelConfig := modelsConfig[0].(map[string]interface{})
141+
checkModelCommonConfig(modelConfig, instance)
142+
},
143+
Entry("for openai", "openai"),
144+
Entry("for gemini", "gemini"),
145+
Entry("for rhoai_vllm", "rhoai_vllm"),
146+
Entry("for rhelai_vllm", "rhelai_vllm"),
147+
Entry("for azure_openai", "azure_openai"),
148+
Entry("for watsonx", "watsonx"),
149+
)
150+
})
151+
})

0 commit comments

Comments
 (0)