diff --git a/slice/cmd/main.go b/slice/cmd/main.go index cdfe5de11..69290567b 100644 --- a/slice/cmd/main.go +++ b/slice/cmd/main.go @@ -298,6 +298,10 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}, activationTime setupLog.Error(err, "Unable to create webhook", "webhook", "LeaderWorkerSet") os.Exit(1) } + if err := webhooks.SetupStatefulSetWebhookWithManager(mgr); err != nil { + setupLog.Error(err, "Unable to create webhook", "webhook", "StatefulSet") + os.Exit(1) + } if failedCtrl, err := controller.SetupControllers(mgr, controller.Options{ ActivationTimeout: activationTimeout, RetryDelayOnSliceFailure: retryDelay}); err != nil { diff --git a/slice/config/webhook/manifests.yaml b/slice/config/webhook/manifests.yaml index 5b71139bb..38b175fb5 100644 --- a/slice/config/webhook/manifests.yaml +++ b/slice/config/webhook/manifests.yaml @@ -61,3 +61,22 @@ webhooks: resources: - leaderworkersets sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /mutate-apps-v1-statefulset + failurePolicy: Fail + name: mstatefulset.slice-controller.kb.io + rules: + - apiGroups: + - apps + apiVersions: + - v1 + operations: + - CREATE + resources: + - statefulsets + sideEffects: None diff --git a/slice/internal/util/testingjobs/leaderworkerset/wrappers.go b/slice/internal/util/testingjobs/leaderworkerset/wrappers.go index e96770de0..6fe566313 100644 --- a/slice/internal/util/testingjobs/leaderworkerset/wrappers.go +++ b/slice/internal/util/testingjobs/leaderworkerset/wrappers.go @@ -182,3 +182,72 @@ func (w *Wrapper) LeaderNodeAffinity(key string, values []string) *Wrapper { ) return w } + +func (w *Wrapper) StartupPolicy(policy leaderworkersetv1.StartupPolicyType) *Wrapper { + w.Spec.StartupPolicy = policy + return w +} + +func (w *Wrapper) WorkerName(name string) *Wrapper { + if len(w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers) == 0 { + w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers = []corev1.Container{{}} + } + w.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name = name + return w +} + +func (w *Wrapper) LeaderName(name string) *Wrapper { + if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{} + } + if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}} + } + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Name = name + return w +} + +func (w *Wrapper) LeaderImage(img string) *Wrapper { + if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{} + } + if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}} + } + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Image = img + return w +} + +func (w *Wrapper) LeaderArgs(args ...string) *Wrapper { + if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{} + } + if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}} + } + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Args = args + return w +} + +func (w *Wrapper) LeaderLimit(resourceName corev1.ResourceName, quantity string) *Wrapper { + if w.Spec.LeaderWorkerTemplate.LeaderTemplate == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate = &corev1.PodTemplateSpec{} + } + if len(w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers) == 0 { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{{}} + } + if w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits = make(corev1.ResourceList) + } + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Limits[resourceName] = resource.MustParse(quantity) + return w +} + +func (w *Wrapper) LeaderRequestAndLimit(resourceName corev1.ResourceName, quantity string) *Wrapper { + w.LeaderLimit(resourceName, quantity) + if w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests == nil { + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests = make(corev1.ResourceList) + } + w.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers[0].Resources.Requests[resourceName] = resource.MustParse(quantity) + return w +} diff --git a/slice/internal/webhooks/statefulset_webhook.go b/slice/internal/webhooks/statefulset_webhook.go new file mode 100644 index 000000000..de4cb6734 --- /dev/null +++ b/slice/internal/webhooks/statefulset_webhook.go @@ -0,0 +1,59 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhooks + +import ( + "context" + + appsv1 "k8s.io/api/apps/v1" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + "tpu-slice-controller/internal/core" +) + +type StatefulSetWebhook struct{} + +func SetupStatefulSetWebhookWithManager(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr, &appsv1.StatefulSet{}). + WithDefaulter(&StatefulSetWebhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/mutate-apps-v1-statefulset,mutating=true,failurePolicy=fail,sideEffects=None,groups=apps,resources=statefulsets,verbs=create,versions=v1,name=mstatefulset.slice-controller.kb.io,admissionReviewVersions=v1 + +var _ admission.Defaulter[*appsv1.StatefulSet] = &StatefulSetWebhook{} + +func (r *StatefulSetWebhook) Default(ctx context.Context, sts *appsv1.StatefulSet) error { + log := ctrl.LoggerFrom(ctx).WithName("statefulset-webhook") + + if !core.IsRelevantPodTemplateSpec(sts.Spec.Template) { + log.V(5).Info("Skipping non-relevant StatefulSet") + return nil + } + + log.V(5).Info("Defaulting StatefulSet") + tpuTopology := core.GetTPUTopology(sts.Spec.Template) + if sts.Spec.Template.Spec.NodeSelector == nil { + sts.Spec.Template.Spec.NodeSelector = make(map[string]string) + } + if _, ok := sts.Spec.Template.Spec.NodeSelector[core.TPUTopologyAnnotation]; !ok { + sts.Spec.Template.Spec.NodeSelector[core.TPUTopologyAnnotation] = tpuTopology + } + + return nil +} diff --git a/slice/internal/webhooks/statefulset_webhook_test.go b/slice/internal/webhooks/statefulset_webhook_test.go new file mode 100644 index 000000000..b7506f984 --- /dev/null +++ b/slice/internal/webhooks/statefulset_webhook_test.go @@ -0,0 +1,263 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhooks + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + slice "tpu-slice-controller/api/v1beta1" + "tpu-slice-controller/internal/core" + utiltesting "tpu-slice-controller/internal/util/testing" +) + +func TestStatefulSetDefault(t *testing.T) { + const ( + baseName = "sts" + baseNamespace = "default" + ) + + testCases := map[string]struct { + sts *appsv1.StatefulSet + wantSts *appsv1.StatefulSet + wantErr error + }{ + "not a relevant statefulset (missing annotation and node selector)": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{Name: baseName, Namespace: baseNamespace}, + Spec: appsv1.StatefulSetSpec{}, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{Name: baseName, Namespace: baseNamespace}, + Spec: appsv1.StatefulSetSpec{}, + }, + }, + "not a relevant statefulset (missing node selector)": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{}, + }, + }, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{}, + }, + }, + }, + }, + "relevant statefulset, missing tpu topology node selector": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + }, + }, + }, + }, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + core.TPUTopologyAnnotation: "4x4x12", + }, + }, + }, + }, + }, + }, + "relevant statefulset, with tpu topology node selector already present": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + core.TPUTopologyAnnotation: "4x4x12", + }, + }, + }, + }, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + core.TPUTopologyAnnotation: "4x4x12", + }, + }, + }, + }, + }, + }, + "relevant statefulset, missing tpu topology node selector and has other node selector": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + "kubernetes.io/os": "linux", + }, + }, + }, + }, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + "kubernetes.io/os": "linux", + core.TPUTopologyAnnotation: "4x4x12", + }, + }, + }, + }, + }, + }, + "relevant statefulset, with tpu topology node selector and other node selector": { + sts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + core.TPUTopologyAnnotation: "4x4x12", + "kubernetes.io/os": "linux", + }, + }, + }, + }, + }, + wantSts: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseName, + Namespace: baseNamespace, + }, + Spec: appsv1.StatefulSetSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{core.TPUSliceTopologyAnnotation: "4x4x12"}, + }, + Spec: corev1.PodSpec{ + NodeSelector: map[string]string{ + core.TPUAcceleratorLabel: string(slice.TypeTpu7x), + core.TPUTopologyAnnotation: "4x4x12", + "kubernetes.io/os": "linux", + }, + }, + }, + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ctx := t.Context() + webhook := &StatefulSetWebhook{} + + gotErr := webhook.Default(ctx, tc.sts) + if diff := cmp.Diff(tc.wantErr, gotErr, utiltesting.EquateErrors); diff != "" { + t.Errorf("Default() error mismatch (-want +got):\n%s", diff) + } + if tc.wantSts != nil { + if diff := cmp.Diff(tc.wantSts, tc.sts); diff != "" { + t.Errorf("Default() mismatch (-want,+got):\n%s", diff) + } + } + }) + } +} diff --git a/slice/test/e2e/subslice/lws_test.go b/slice/test/e2e/subslice/lws_test.go index e6b82e29b..d6ee29e64 100644 --- a/slice/test/e2e/subslice/lws_test.go +++ b/slice/test/e2e/subslice/lws_test.go @@ -22,7 +22,6 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" "sigs.k8s.io/controller-runtime/pkg/client" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta2" "sigs.k8s.io/kueue/pkg/util/tas" @@ -109,6 +108,8 @@ var _ = ginkgo.Describe("LWS Subslicing", func() { wrapper := testingjobslws.MakeLeaderWorkerSet(name, ns.Name). Queue(lq.Name). Size(tc.size). + StartupPolicy(leaderworkersetv1.LeaderCreatedStartupPolicy). + WorkerName("worker"). WorkerImage(utils.E2eTestAgnHostImage). WorkerArgs(utils.BehaviorWaitForDeletion...). WorkerAnnotation(core.TPUSliceTopologyAnnotation, tc.topology). @@ -117,31 +118,18 @@ var _ = ginkgo.Describe("LWS Subslicing", func() { if tc.withLeader { wrapper = wrapper. + LeaderName("leader"). + LeaderImage(utils.E2eTestAgnHostImage). + LeaderArgs(utils.BehaviorWaitForDeletion...). LeaderAnnotation(core.TPUSliceTopologyAnnotation, tc.topology). LeaderNodeSelector("cloud.google.com/gke-tpu-accelerator", string(slice.TypeTpu7x)) - } - - lws := wrapper.Obj() - lws.Spec.StartupPolicy = leaderworkersetv1.LeaderCreatedStartupPolicy - if tc.withLeader { - container := corev1.Container{ - Name: "leader", - Image: utils.E2eTestAgnHostImage, - Args: utils.BehaviorWaitForDeletion, - } if tc.leaderRequiresTPUs { - container.Resources = corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - core.TPUResourceName: resource.MustParse("4"), - }, - Requests: corev1.ResourceList{ - core.TPUResourceName: resource.MustParse("4"), - }, - } + wrapper = wrapper.LeaderRequestAndLimit(core.TPUResourceName, "4") } - lws.Spec.LeaderWorkerTemplate.LeaderTemplate.Spec.Containers = []corev1.Container{container} } + lws := wrapper.Obj() + ginkgo.By("Creating a LeaderWorkerSet", func() { utils.MustCreate(ctx, k8sClient, lws) }) @@ -184,6 +172,17 @@ var _ = ginkgo.Describe("LWS Subslicing", func() { }, utils.Timeout, utils.Interval).Should(gomega.Succeed()) }) + ginkgo.By("Checking that all pods are created with the topology node selector", func() { + pods := &corev1.PodList{} + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.List(ctx, pods, client.InNamespace(ns.Name))).To(gomega.Succeed()) + g.Expect(pods.Items).Should(gomega.HaveLen(int(tc.size))) + for _, pod := range pods.Items { + g.Expect(pod.Spec.NodeSelector).To(gomega.HaveKeyWithValue(core.TPUTopologyAnnotation, tc.topology)) + } + }, utils.LongTimeout, utils.Interval).Should(gomega.Succeed()) + }) + createdWorkload := &kueue.Workload{} ginkgo.By("Waiting for Admission of the Workload", func() { gomega.Eventually(func(g gomega.Gomega) {