Skip to content

Commit d22e4a1

Browse files
authored
feat(inference): add a proxy to easily switch models (#99)
1 parent b61d9bb commit d22e4a1

7 files changed

Lines changed: 260 additions & 5 deletions

File tree

modules/inference/example/dev/example_workspace.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
modules:
33
inference:
44
path: oci://ghcr.io/kusionstack/inference
5-
version: 0.1.0-beta.5
5+
version: 0.1.0
66
configs:
77
default: {}
88
network:

modules/inference/example/dev/kcl.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "example"
33

44
[dependencies]
5-
inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0-beta.5" }
5+
inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0" }
66
service = {oci = "oci://ghcr.io/kusionstack/service", tag = "0.1.0" }
77
kam = { git = "https://github.com/KusionStack/kam.git", tag = "0.2.0" }
88
network = { oci = "oci://ghcr.io/kusionstack/network", tag = "0.2.0" }

modules/inference/kcl.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[package]
22
name = "inference"
3-
version = "0.1.0-beta.5"
3+
version = "0.1.0"

modules/inference/src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ TEST?=$$(go list ./... | grep -v 'vendor')
22
###### chang variables below according to your own modules ###
33
NAMESPACE=kusionstack
44
NAME=inference
5-
VERSION=0.1.0-beta.5
5+
VERSION=0.1.0
66
BINARY=../bin/kusion-module-${NAME}_${VERSION}
77

88
LOCAL_ARCH := $(shell uname -m)

modules/inference/src/inference.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ var (
5757
OllamaImage = "ollama/ollama"
5858
)
5959

60+
// proxy
61+
var (
62+
ProxyName = "proxy"
63+
ProxyPort = 5000
64+
ProxyImage = "kangy126/proxy"
65+
)
66+
6067
func main() {
6168
server.Start(&Inference{})
6269
}

modules/inference/src/ollama_frame.go

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,21 @@ func (infer *Inference) GenerateOllamaResource(request *module.GeneratorRequest)
3030
}
3131
resources = append(resources, *svc)
3232

33-
patcher, err := infer.GenerateEnv(svcName)
33+
// Build Kubernetes Deployment for proxy.
34+
deploymentProxy, err := infer.generateProxyDeployment(request, svcName)
35+
if err != nil {
36+
return nil, nil, err
37+
}
38+
resources = append(resources, *deploymentProxy)
39+
40+
// Build Kubernetes Service for proxy.
41+
svcProxy, svcNameProxy, err := infer.generateProxyService(request)
42+
if err != nil {
43+
return nil, nil, err
44+
}
45+
resources = append(resources, *svcProxy)
46+
47+
patcher, err := infer.GenerateEnv(svcNameProxy)
3448
if err != nil {
3549
return nil, nil, err
3650
}
@@ -200,3 +214,126 @@ func (infer *Inference) generateMatchLabels() map[string]string {
200214
"accessory": strings.ToLower(infer.Framework),
201215
}
202216
}
217+
218+
// generateMatchLabels generates the match labels for the Kubernetes resources of proxy.
219+
func (infer *Inference) generateMatchLabelsForProxy() map[string]string {
220+
return map[string]string{
221+
"accessory": strings.ToLower(ProxyName),
222+
}
223+
}
224+
225+
// generatePodSpec generates the Kubernetes PodSpec for proxy.
226+
func (infer *Inference) generateProxyPodSpec(_ *module.GeneratorRequest, svcName string) (v1.PodSpec, error) {
227+
portName := strings.ToLower(ProxyName) + inferContainerPortSuffix
228+
if len(portName) > 15 {
229+
portName = portName[:15]
230+
}
231+
containerPort := int32(ProxyPort)
232+
ports := []v1.ContainerPort{
233+
{
234+
Name: portName,
235+
ContainerPort: containerPort,
236+
},
237+
}
238+
239+
envVars := []v1.EnvVar{
240+
{
241+
Name: "MODEL",
242+
Value: infer.Model,
243+
},
244+
{
245+
Name: "FRAMEWORK_URL",
246+
Value: svcName,
247+
},
248+
}
249+
250+
image := ProxyImage
251+
podSpec := v1.PodSpec{
252+
Containers: []v1.Container{
253+
{
254+
Name: strings.ToLower(ProxyName) + inferContainerSuffix,
255+
Image: image,
256+
Ports: ports,
257+
Env: envVars,
258+
},
259+
},
260+
}
261+
return podSpec, nil
262+
}
263+
264+
// generateDeployment generates the Kubernetes Deployment resource for proxy.
265+
func (infer *Inference) generateProxyDeployment(request *module.GeneratorRequest, svcName string) (*apiv1.Resource, error) {
266+
podSpec, err := infer.generateProxyPodSpec(request, svcName)
267+
if err != nil {
268+
return nil, nil
269+
}
270+
271+
deployment := &appsv1.Deployment{
272+
TypeMeta: metav1.TypeMeta{
273+
Kind: "Deployment",
274+
APIVersion: appsv1.SchemeGroupVersion.String(),
275+
},
276+
ObjectMeta: metav1.ObjectMeta{
277+
Name: strings.ToLower(ProxyName) + inferDeploymentSuffix,
278+
Namespace: request.Project,
279+
},
280+
Spec: appsv1.DeploymentSpec{
281+
Selector: &metav1.LabelSelector{
282+
MatchLabels: infer.generateMatchLabelsForProxy(),
283+
},
284+
Template: v1.PodTemplateSpec{
285+
ObjectMeta: metav1.ObjectMeta{
286+
Labels: infer.generateMatchLabelsForProxy(),
287+
},
288+
Spec: podSpec,
289+
},
290+
},
291+
}
292+
293+
resourceID := module.KubernetesResourceID(deployment.TypeMeta, deployment.ObjectMeta)
294+
resource, err := module.WrapK8sResourceToKusionResource(resourceID, deployment)
295+
if err != nil {
296+
return nil, err
297+
}
298+
299+
return resource, nil
300+
}
301+
302+
// generateService generates the Kubernetes Service resource for proxy.
303+
func (infer *Inference) generateProxyService(request *module.GeneratorRequest) (*apiv1.Resource, string, error) {
304+
svcName := strings.ToLower(ProxyName) + inferServiceSuffix
305+
svcPort := []v1.ServicePort{
306+
{
307+
Port: int32(CalledPort),
308+
TargetPort: intstr.IntOrString{
309+
Type: intstr.Int,
310+
IntVal: int32(ProxyPort),
311+
},
312+
},
313+
}
314+
315+
service := &v1.Service{
316+
TypeMeta: metav1.TypeMeta{
317+
Kind: "Service",
318+
APIVersion: v1.SchemeGroupVersion.String(),
319+
},
320+
ObjectMeta: metav1.ObjectMeta{
321+
Name: svcName,
322+
Namespace: request.Project,
323+
Labels: infer.generateMatchLabelsForProxy(),
324+
},
325+
Spec: v1.ServiceSpec{
326+
Type: v1.ServiceTypeClusterIP,
327+
Ports: svcPort,
328+
Selector: infer.generateMatchLabelsForProxy(),
329+
},
330+
}
331+
332+
resourceID := module.KubernetesResourceID(service.TypeMeta, service.ObjectMeta)
333+
resource, err := module.WrapK8sResourceToKusionResource(resourceID, service)
334+
if err != nil {
335+
return nil, svcName, err
336+
}
337+
338+
return resource, svcName, nil
339+
}

modules/inference/src/ollama_frame_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,114 @@ func TestInferenceModule_GenerateOllamaService(t *testing.T) {
135135
assert.Equal(t, strings.ToLower(infer.Framework)+inferServiceSuffix, svcName)
136136
assert.NoError(t, err)
137137
}
138+
139+
func TestInferenceModule_GenerateProxyPodSpec(t *testing.T) {
140+
r := &module.GeneratorRequest{
141+
Project: "test-project",
142+
Stack: "test-stack",
143+
App: "test-app",
144+
Workload: &v1.Workload{
145+
Header: v1.Header{
146+
Type: "Service",
147+
},
148+
Service: &v1.Service{},
149+
},
150+
}
151+
152+
infer := &Inference{
153+
Model: "qwen",
154+
Framework: "Ollama",
155+
System: "",
156+
Template: "",
157+
TopK: 40,
158+
TopP: 0.9,
159+
Temperature: 0.8,
160+
NumPredict: 128,
161+
NumCtx: 2048,
162+
}
163+
164+
res, err := infer.generateProxyPodSpec(r, "ollama-svc")
165+
166+
assert.NotNil(t, res)
167+
assert.NoError(t, err)
168+
}
169+
170+
func TestInferenceModule_GenerateProxyDeployment(t *testing.T) {
171+
r := &module.GeneratorRequest{
172+
Project: "test-project",
173+
Stack: "test-stack",
174+
App: "test-app",
175+
Workload: &v1.Workload{
176+
Header: v1.Header{
177+
Type: "Service",
178+
},
179+
Service: &v1.Service{},
180+
},
181+
}
182+
183+
infer := &Inference{
184+
Model: "qwen",
185+
Framework: "Ollama",
186+
System: "",
187+
Template: "",
188+
TopK: 40,
189+
TopP: 0.9,
190+
Temperature: 0.8,
191+
NumPredict: 128,
192+
NumCtx: 2048,
193+
}
194+
195+
res, err := infer.generateProxyDeployment(r, "ollama-svc")
196+
197+
assert.NotNil(t, res)
198+
assert.NoError(t, err)
199+
}
200+
201+
func TestInferenceModule_GenerateProxyService(t *testing.T) {
202+
r := &module.GeneratorRequest{
203+
Project: "test-project",
204+
Stack: "test-stack",
205+
App: "test-app",
206+
Workload: &v1.Workload{
207+
Header: v1.Header{
208+
Type: "Service",
209+
},
210+
Service: &v1.Service{},
211+
},
212+
}
213+
214+
infer := &Inference{
215+
Model: "qwen",
216+
Framework: "Ollama",
217+
System: "",
218+
Template: "",
219+
TopK: 40,
220+
TopP: 0.9,
221+
Temperature: 0.8,
222+
NumPredict: 128,
223+
NumCtx: 2048,
224+
}
225+
226+
res, svcName, err := infer.generateProxyService(r)
227+
228+
assert.NotNil(t, res)
229+
assert.NotNil(t, svcName)
230+
assert.Equal(t, strings.ToLower(ProxyName)+inferServiceSuffix, svcName)
231+
assert.NoError(t, err)
232+
}
233+
234+
func TestInferenceModule_GenerateMatchLabels(t *testing.T) {
235+
t.Run("generate matchLabels", func(t *testing.T) {
236+
infer := &Inference{Framework: "Ollama"}
237+
labels := infer.generateMatchLabels()
238+
assert.Equal(t, strings.ToLower(infer.Framework), labels["accessory"])
239+
})
240+
}
241+
242+
func TestInferenceModule_GenerateMatchLabelsForProxy(t *testing.T) {
243+
t.Run("generate matchLabels for proxy", func(t *testing.T) {
244+
infer := &Inference{}
245+
labels := infer.generateMatchLabelsForProxy()
246+
assert.Equal(t, strings.ToLower(ProxyName), labels["accessory"])
247+
})
248+
}

0 commit comments

Comments
 (0)