Skip to content

Commit 981e6f9

Browse files
ko3n1gclaude
andcommitted
feat(kubeflow): support pod-template annotations/labels (podTemplateOverrides metadata)
The executor's existing 'annotations' land on the TrainJob object. GKE multi-network attach (networking.gke.io/interfaces, for GPUDirect-RDMA/gIB) is read off the trainer POD, not the TrainJob — add pod_annotations (and pod_labels) that flow into podTemplateOverrides[].metadata, which the Kubeflow Trainer v2 CRD supports. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: oliver könig <okoenig@nvidia.com>
1 parent 213ba39 commit 981e6f9

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

nemo_run/core/execution/kubeflow.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class KubeflowExecutor(Executor):
9393
volumes: list[dict[str, Any]] = field(default_factory=list)
9494
labels: dict[str, Any] = field(default_factory=dict)
9595
annotations: dict[str, Any] = field(default_factory=dict)
96+
# pod_annotations land on the trainer POD template (podTemplateOverrides[].metadata),
97+
# not the TrainJob object — needed for e.g. GKE multi-network attach
98+
# (networking.gke.io/interfaces) which is read off the pod, not the TrainJob.
99+
pod_annotations: dict[str, Any] = field(default_factory=dict)
100+
pod_labels: dict[str, Any] = field(default_factory=dict)
96101
tolerations: list[dict[str, Any]] = field(default_factory=list)
97102
affinity: dict[str, Any] = field(default_factory=dict)
98103
# env_list accepts full env var dicts (e.g. valueFrom/secretKeyRef).
@@ -267,10 +272,18 @@ def get_job_body(self, name: str, command: list[str]) -> dict:
267272
"runtimeRef": {"name": self.runtime_ref},
268273
"trainer": trainer,
269274
}
270-
if pod_spec_override:
271-
spec["podTemplateOverrides"] = [
272-
{"targetJobs": [{"name": "node"}], "spec": pod_spec_override}
273-
]
275+
if pod_spec_override or self.pod_annotations or self.pod_labels:
276+
override_entry: dict[str, Any] = {"targetJobs": [{"name": "node"}]}
277+
if pod_spec_override:
278+
override_entry["spec"] = pod_spec_override
279+
pod_meta: dict[str, Any] = {}
280+
if self.pod_labels:
281+
pod_meta["labels"] = self.pod_labels
282+
if self.pod_annotations:
283+
pod_meta["annotations"] = self.pod_annotations
284+
if pod_meta:
285+
override_entry["metadata"] = pod_meta
286+
spec["podTemplateOverrides"] = [override_entry]
274287
spec.update(self.spec_kwargs)
275288

276289
metadata: dict[str, Any] = {"name": name, "namespace": self.namespace}

0 commit comments

Comments
 (0)