Skip to content

Commit df20380

Browse files
lukebaumanncopybara-github
authored andcommitted
Add base PathwaysJobSet builder
PiperOrigin-RevId: 919912906
1 parent 3e12d60 commit df20380

3 files changed

Lines changed: 236 additions & 0 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# https://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# https://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
"""Pathways JobSet generator and builder (Skeleton)."""
13+
from typing import Any, Mapping
14+
from kubernetes import client
15+
16+
# Core constants.
17+
PATHWAYS_HEAD_JOB_NAME = "pathways-head"
18+
PATHWAYS_WORKER_JOB_NAME = "pathways-worker"
19+
20+
MACHINE_TYPE_TO_TPU_VERSION_MAP = {
21+
"tpu7x-standard-4t": "tpu7x",
22+
"tpu7x": "tpu7x",
23+
"ct6e-standard-4t": "tpuv6e",
24+
"v6e": "tpuv6e",
25+
"ct6e-standard-8t": "tpuv6e1t",
26+
"ct5p-hightpu-4t": "tpuv5",
27+
"v5p": "tpuv5",
28+
"ct5lp-hightpu-4t": "tpuv5e",
29+
"v5e": "tpuv5e",
30+
"ct5lp-hightpu-8t": "tpuv5e1t",
31+
"ct4p-hightpu-4t": "tpuv4",
32+
"v4": "tpuv4",
33+
}
34+
35+
MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP = {
36+
"tpu7x-standard-4t": "tpu7x",
37+
"tpu7x": "tpu7x",
38+
"ct6e-standard-4t": "tpu-v6e-slice",
39+
"v6e": "tpu-v6e-slice",
40+
"ct6e-standard-8t": "tpu-v6e-slice",
41+
"ct5p-hightpu-4t": "tpu-v5p-slice",
42+
"v5p": "tpu-v5p-slice",
43+
"ct5lp-hightpu-4t": "tpu-v5-lite-podslice",
44+
"v5e": "tpu-v5-lite-podslice",
45+
"ct5lp-hightpu-8t": "tpu-v5-lite-podslice",
46+
"ct4p-hightpu-4t": "tpu-v4-podslice",
47+
"v4": "tpu-v4-podslice",
48+
}
49+
50+
51+
class PathwaysJobSet:
52+
"""Generates JobSet configuration for Pathways (Skeleton)."""
53+
54+
def __init__(
55+
self,
56+
name: str,
57+
namespace: str,
58+
tpu_type: str,
59+
num_slices: int,
60+
user_pod_template: Mapping[str, Any] | None = None,
61+
max_restarts: int = 0,
62+
jobset_api_version: str = "v1alpha2",
63+
labels: Mapping[str, str] | None = None,
64+
annotations: Mapping[str, str] | None = None,
65+
):
66+
"""Initializes the instance.
67+
68+
Args:
69+
name: Name of the JobSet.
70+
namespace: Namespace of the JobSet.
71+
tpu_type: TPU type (e.g., "v5e").
72+
num_slices: Number of slices.
73+
user_pod_template: Optional user pod template for the head job.
74+
max_restarts: Maximum number of restarts for the JobSet.
75+
jobset_api_version: API version of JobSet.
76+
labels: Optional labels for the JobSet.
77+
annotations: Optional annotations for the JobSet.
78+
"""
79+
self._name = name
80+
self._namespace = namespace
81+
self._jobset_api_version = jobset_api_version
82+
self._max_restarts = max_restarts
83+
self._worker_replicas = num_slices
84+
self._labels = dict(labels) if labels else {}
85+
self._annotations = dict(annotations) if annotations else {}
86+
87+
tpu_version = MACHINE_TYPE_TO_TPU_VERSION_MAP.get(tpu_type.lower())
88+
if not tpu_version:
89+
raise ValueError(f"Unsupported TPU type: {tpu_type}")
90+
91+
# Build minimal head template (placeholder)
92+
self._head_job_template = self._build_minimal_job_template("head")
93+
94+
# Build minimal worker template (placeholder)
95+
self._worker_job_template = self._build_minimal_job_template("worker")
96+
97+
self._success_policy = None
98+
if user_pod_template:
99+
self._success_policy = {
100+
"operator": "All",
101+
"targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME],
102+
}
103+
104+
def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
105+
"""Builds a minimal job template for a given role."""
106+
pod_spec = client.V1PodSpec(
107+
containers=[
108+
client.V1Container(name=f"placeholder-{role}", image="ubuntu")
109+
]
110+
)
111+
job_spec = client.V1JobSpec(
112+
template=client.V1PodTemplateSpec(
113+
metadata=client.V1ObjectMeta(labels={"role": role}), spec=pod_spec
114+
)
115+
)
116+
return client.V1JobTemplateSpec(spec=job_spec)
117+
118+
def _compile_config(self) -> dict[str, Any]:
119+
"""Compiles the JobSet configuration into a dictionary."""
120+
with client.ApiClient() as api_client:
121+
serialized_head = api_client.sanitize_for_serialization(
122+
self._head_job_template
123+
)
124+
serialized_worker = api_client.sanitize_for_serialization(
125+
self._worker_job_template
126+
)
127+
128+
replicated_jobs = [
129+
{
130+
"name": PATHWAYS_HEAD_JOB_NAME,
131+
"replicas": 1,
132+
"template": serialized_head,
133+
},
134+
{
135+
"name": PATHWAYS_WORKER_JOB_NAME,
136+
"replicas": self._worker_replicas,
137+
"template": serialized_worker,
138+
},
139+
]
140+
141+
jobset_config = {
142+
"apiVersion": f"jobset.sigs.k8s.io/{self._jobset_api_version}",
143+
"kind": "JobSet",
144+
"metadata": {
145+
"name": self._name,
146+
"namespace": self._namespace,
147+
},
148+
"spec": {
149+
"failurePolicy": {"maxRestarts": self._max_restarts},
150+
"replicatedJobs": replicated_jobs,
151+
},
152+
}
153+
if self._labels:
154+
jobset_config["metadata"]["labels"] = self._labels
155+
if self._annotations:
156+
jobset_config["metadata"]["annotations"] = self._annotations
157+
if self._success_policy:
158+
jobset_config["spec"]["successPolicy"] = self._success_policy
159+
160+
return jobset_config
161+
162+
def to_dict(self) -> dict[str, Any]:
163+
"""Returns the JobSet configuration as a dictionary."""
164+
return self._compile_config()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from absl.testing import absltest
2+
from absl.testing import parameterized
3+
from pathwaysutils.experimental.gke import jobset
4+
5+
6+
class PathwaysJobSetTest(parameterized.TestCase):
7+
8+
def test_invalid_tpu_type(self):
9+
with self.assertRaisesRegex(ValueError, "Unsupported TPU type"):
10+
jobset.PathwaysJobSet(
11+
name="test-jobset",
12+
namespace="default",
13+
tpu_type="invalid-tpu",
14+
num_slices=1,
15+
)
16+
17+
def test_basic_jobset_structure(self):
18+
js = jobset.PathwaysJobSet(
19+
name="test-jobset",
20+
namespace="default",
21+
tpu_type="v5e",
22+
num_slices=2,
23+
labels={"app": "pathways"},
24+
annotations={"example.com/annotation": "value"},
25+
)
26+
config = js.to_dict()
27+
28+
self.assertEqual(config["apiVersion"], "jobset.sigs.k8s.io/v1alpha2")
29+
self.assertEqual(config["kind"], "JobSet")
30+
self.assertEqual(config["metadata"]["name"], "test-jobset")
31+
self.assertEqual(config["metadata"]["namespace"], "default")
32+
self.assertEqual(config["metadata"]["labels"]["app"], "pathways")
33+
self.assertEqual(
34+
config["metadata"]["annotations"]["example.com/annotation"], "value"
35+
)
36+
37+
self.assertEqual(config["spec"]["failurePolicy"]["maxRestarts"], 0)
38+
39+
replicated_jobs = config["spec"]["replicatedJobs"]
40+
self.assertLen(replicated_jobs, 2)
41+
42+
head_job = replicated_jobs[0]
43+
self.assertEqual(head_job["name"], "pathways-head")
44+
self.assertEqual(head_job["replicas"], 1)
45+
46+
# In K8s API models, V1JobTemplateSpec -> V1JobSpec -> V1PodTemplateSpec
47+
# -> V1PodSpec. When serialized, they match this structure.
48+
head_pod_spec = head_job["template"]["spec"]["template"]["spec"]
49+
self.assertEqual(head_pod_spec["containers"][0]["name"], "placeholder-head")
50+
51+
worker_job = replicated_jobs[1]
52+
self.assertEqual(worker_job["name"], "pathways-worker")
53+
self.assertEqual(worker_job["replicas"], 2)
54+
worker_pod_spec = worker_job["template"]["spec"]["template"]["spec"]
55+
self.assertEqual(
56+
worker_pod_spec["containers"][0]["name"], "placeholder-worker"
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
absltest.main()

0 commit comments

Comments
 (0)