Skip to content

Commit 11b4e2b

Browse files
committed
Topollogy NodeSelector in STS webhook
1 parent 310b90c commit 11b4e2b

6 files changed

Lines changed: 350 additions & 20 deletions

File tree

slice/cmd/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTime
298298
setupLog.Error(err, "Unable to create webhook", "webhook", "LeaderWorkerSet")
299299
os.Exit(1)
300300
}
301+
if err := webhooks.SetupStatefulSetWebhookWithManager(mgr); err != nil {
302+
setupLog.Error(err, "Unable to create webhook", "webhook", "StatefulSet")
303+
os.Exit(1)
304+
}
301305

302306
if failedCtrl, err := controller.SetupControllers(mgr, controller.Options{
303307
ActivationTimeout: activationTimeout, RetryDelayOnSliceFailure: retryDelay}); err != nil {

slice/config/webhook/manifests.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,22 @@ webhooks:
6161
resources:
6262
- leaderworkersets
6363
sideEffects: None
64+
- admissionReviewVersions:
65+
- v1
66+
clientConfig:
67+
service:
68+
name: webhook-service
69+
namespace: system
70+
path: /mutate-apps-v1-statefulset
71+
failurePolicy: Fail
72+
name: mstatefulset.slice-controller.kb.io
73+
rules:
74+
- apiGroups:
75+
- apps
76+
apiVersions:
77+
- v1
78+
operations:
79+
- CREATE
80+
resources:
81+
- statefulsets
82+
sideEffects: None

slice/internal/util/testingjobs/leaderworkerset/wrappers.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,72 @@ func (w *Wrapper) LeaderNodeAffinity(key string, values []string) *Wrapper {
182182
)
183183
return w
184184
}
185+
186+
func (w *Wrapper) StartupPolicy(policy leaderworkersetv1.StartupPolicyType) *Wrapper {
187+
w.Spec.StartupPolicy = policy
188+
return w
189+
}
190+
191+
func (w *Wrapper) WorkerName(name string) *Wrapper {
192+
if len(w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers) == 0 {
193+
w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers = []corev1.Container{{}}
194+
}
195+
w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name = name
196+
return w
197+
}
198+
199+
func (w *Wrapper) LeaderName(name string) *Wrapper {
200+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil {
201+
w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{}
202+
}
203+
if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 {
204+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}}
205+
}
206+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Name = name
207+
return w
208+
}
209+
210+
func (w *Wrapper) LeaderImage(img string) *Wrapper {
211+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil {
212+
w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{}
213+
}
214+
if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 {
215+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}}
216+
}
217+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Image = img
218+
return w
219+
}
220+
221+
func (w *Wrapper) LeaderArgs(args ...string) *Wrapper {
222+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil {
223+
w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{}
224+
}
225+
if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 {
226+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}}
227+
}
228+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Args = args
229+
return w
230+
}
231+
232+
func (w *Wrapper) LeaderLimit(resourceName corev1.ResourceName, quantity string) *Wrapper {
233+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil {
234+
w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{}
235+
}
236+
if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 {
237+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}}
238+
}
239+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits == nil {
240+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits = make(corev1.ResourceList)
241+
}
242+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits[resourceName] = resource.MustParse(quantity)
243+
return w
244+
}
245+
246+
func (w *Wrapper) LeaderRequestAndLimit(resourceName corev1.ResourceName, quantity string) *Wrapper {
247+
w.LeaderLimit(resourceName, quantity)
248+
if w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests == nil {
249+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests = make(corev1.ResourceList)
250+
}
251+
w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests[resourceName] = resource.MustParse(quantity)
252+
return w
253+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
Copyright The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package webhooks
18+
19+
import (
20+
"context"
21+
22+
appsv1 "k8s.io/api/apps/v1"
23+
ctrl "sigs.k8s.io/controller-runtime"
24+
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
25+
26+
"tpu-slice-controller/internal/core"
27+
)
28+
29+
type StatefulSetWebhook struct{}
30+
31+
func SetupStatefulSetWebhookWithManager(mgr ctrl.Manager) error {
32+
return ctrl.NewWebhookManagedBy(mgr, &appsv1.StatefulSet{}).
33+
WithDefaulter(&StatefulSetWebhook{}).
34+
Complete()
35+
}
36+
37+
// +kubebuilder:webhook:path=/mutate-apps-v1-statefulset,mutating=true,failurePolicy=fail,sideEffects=None,groups=apps,resources=statefulsets,verbs=create,versions=v1,name=mstatefulset.slice-controller.kb.io,admissionReviewVersions=v1
38+
39+
var _ admission.Defaulter[*appsv1.StatefulSet] = &StatefulSetWebhook{}
40+
41+
func (r *StatefulSetWebhook) Default(ctx context.Context, sts *appsv1.StatefulSet) error {
42+
log := ctrl.LoggerFrom(ctx).WithName("statefulset-webhook")
43+
log.V(5).Info("Defaulting StatefulSet")
44+
45+
if !core.IsRelevantPodTemplateSpec(sts.Spec.Template) {
46+
log.V(5).Info("Skipping non-TPUv7 StatefulSet")
47+
return nil
48+
}
49+
50+
tpuTopology := core.GetTPUTopology(sts.Spec.Template)
51+
if sts.Spec.Template.Spec.NodeSelector == nil {
52+
sts.Spec.Template.Spec.NodeSelector = make(map[string]string)
53+
}
54+
if _, ok := sts.Spec.Template.Spec.NodeSelector[core.TPUTopologyAnnotation]; !ok {
55+
sts.Spec.Template.Spec.NodeSelector[core.TPUTopologyAnnotation] = tpuTopology
56+
}
57+
58+
return nil
59+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
Copyright The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package webhooks
18+
19+
import (
20+
"testing"
21+
22+
"github.com/google/go-cmp/cmp"
23+
appsv1 "k8s.io/api/apps/v1"
24+
corev1 "k8s.io/api/core/v1"
25+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
26+
27+
slice "tpu-slice-controller/api/v1beta1"
28+
"tpu-slice-controller/internal/core"
29+
utiltesting "tpu-slice-controller/internal/util/testing"
30+
)
31+
32+
func TestStatefulSetDefault(t *testing.T) {
33+
const (
34+
baseName = "sts"
35+
baseNamespace = "default"
36+
)
37+
38+
testCases := map[string]struct {
39+
sts *appsv1.StatefulSet
40+
wantSts *appsv1.StatefulSet
41+
wantErr error
42+
}{
43+
"not a relevant statefulset (missing annotation and node selector)": {
44+
sts: &appsv1.StatefulSet{
45+
ObjectMeta: metav1.ObjectMeta{Name: baseName, Namespace: baseNamespace},
46+
Spec: appsv1.StatefulSetSpec{},
47+
},
48+
wantSts: &appsv1.StatefulSet{
49+
ObjectMeta: metav1.ObjectMeta{Name: baseName, Namespace: baseNamespace},
50+
Spec: appsv1.StatefulSetSpec{},
51+
},
52+
},
53+
"not a relevant statefulset (missing node selector)": {
54+
sts: &appsv1.StatefulSet{
55+
ObjectMeta: metav1.ObjectMeta{
56+
Name: baseName,
57+
Namespace: baseNamespace,
58+
},
59+
Spec: appsv1.StatefulSetSpec{
60+
Template: corev1.PodTemplateSpec{
61+
ObjectMeta: metav1.ObjectMeta{
62+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
63+
},
64+
Spec: corev1.PodSpec{},
65+
},
66+
},
67+
},
68+
wantSts: &appsv1.StatefulSet{
69+
ObjectMeta: metav1.ObjectMeta{
70+
Name: baseName,
71+
Namespace: baseNamespace,
72+
},
73+
Spec: appsv1.StatefulSetSpec{
74+
Template: corev1.PodTemplateSpec{
75+
ObjectMeta: metav1.ObjectMeta{
76+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
77+
},
78+
Spec: corev1.PodSpec{},
79+
},
80+
},
81+
},
82+
},
83+
"relevant statefulset, missing tpu topology node selector": {
84+
sts: &appsv1.StatefulSet{
85+
ObjectMeta: metav1.ObjectMeta{
86+
Name: baseName,
87+
Namespace: baseNamespace,
88+
},
89+
Spec: appsv1.StatefulSetSpec{
90+
Template: corev1.PodTemplateSpec{
91+
ObjectMeta: metav1.ObjectMeta{
92+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
93+
},
94+
Spec: corev1.PodSpec{
95+
NodeSelector: map[string]string{
96+
core.TPUAcceleratorLabel: string(slice.TypeTpu7x),
97+
},
98+
},
99+
},
100+
},
101+
},
102+
wantSts: &appsv1.StatefulSet{
103+
ObjectMeta: metav1.ObjectMeta{
104+
Name: baseName,
105+
Namespace: baseNamespace,
106+
},
107+
Spec: appsv1.StatefulSetSpec{
108+
Template: corev1.PodTemplateSpec{
109+
ObjectMeta: metav1.ObjectMeta{
110+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
111+
},
112+
Spec: corev1.PodSpec{
113+
NodeSelector: map[string]string{
114+
core.TPUAcceleratorLabel: string(slice.TypeTpu7x),
115+
core.TPUTopologyAnnotation: "4x4x12",
116+
},
117+
},
118+
},
119+
},
120+
},
121+
},
122+
"relevant statefulset, with tpu topology node selector already present": {
123+
sts: &appsv1.StatefulSet{
124+
ObjectMeta: metav1.ObjectMeta{
125+
Name: baseName,
126+
Namespace: baseNamespace,
127+
},
128+
Spec: appsv1.StatefulSetSpec{
129+
Template: corev1.PodTemplateSpec{
130+
ObjectMeta: metav1.ObjectMeta{
131+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
132+
},
133+
Spec: corev1.PodSpec{
134+
NodeSelector: map[string]string{
135+
core.TPUAcceleratorLabel: string(slice.TypeTpu7x),
136+
core.TPUTopologyAnnotation: "4x4x12",
137+
},
138+
},
139+
},
140+
},
141+
},
142+
wantSts: &appsv1.StatefulSet{
143+
ObjectMeta: metav1.ObjectMeta{
144+
Name: baseName,
145+
Namespace: baseNamespace,
146+
},
147+
Spec: appsv1.StatefulSetSpec{
148+
Template: corev1.PodTemplateSpec{
149+
ObjectMeta: metav1.ObjectMeta{
150+
Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"},
151+
},
152+
Spec: corev1.PodSpec{
153+
NodeSelector: map[string]string{
154+
core.TPUAcceleratorLabel: string(slice.TypeTpu7x),
155+
core.TPUTopologyAnnotation: "4x4x12",
156+
},
157+
},
158+
},
159+
},
160+
},
161+
},
162+
}
163+
164+
for name, tc := range testCases {
165+
t.Run(name, func(t *testing.T) {
166+
ctx := t.Context()
167+
webhook := &StatefulSetWebhook{}
168+
169+
gotErr := webhook.Default(ctx, tc.sts)
170+
if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" {
171+
t.Errorf("Default() error mismatch (-want +got):\n%s", diff)
172+
}
173+
if tc.wantSts != nil {
174+
if diff := cmp.Diff(tc.wantSts, tc.sts); diff != "" {
175+
t.Errorf("Default() mismatch (-want,+got):\n%s", diff)
176+
}
177+
}
178+
})
179+
}
180+
}

0 commit comments

Comments
 (0)