Skip to content

Commit 1b134f7

Browse files
lukebaumanncopybara-github
authored andcommitted
Add base PathwaysJobSet builder
PiperOrigin-RevId: 919912906
1 parent 9e2bd20 commit 1b134f7

3 files changed

Lines changed: 204 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty __init__.py
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Pathways JobSet generator and builder (Skeleton)."""
2+
3+
import math
4+
from typing import Any, Optional
5+
from kubernetes import client
6+
import yaml
7+
8+
# Core constants
9+
PATHWAYS_HEAD_JOB_NAME = "pathways-head"
10+
PATHWAYS_WORKER_JOB_NAME = "pathways-worker"
11+
12+
MACHINE_TYPE_TO_TPU_VERSION_MAP = {
13+
"tpu7x-standard-4t": "tpu7x",
14+
"tpu7x": "tpu7x",
15+
"ct6e-standard-4t": "tpuv6e",
16+
"v6e": "tpuv6e",
17+
"ct6e-standard-8t": "tpuv6e1t",
18+
"ct5p-hightpu-4t": "tpuv5",
19+
"v5p": "tpuv5",
20+
"ct5lp-hightpu-4t": "tpuv5e",
21+
"v5e": "tpuv5e",
22+
"ct5lp-hightpu-8t": "tpuv5e1t",
23+
"ct4p-hightpu-4t": "tpuv4",
24+
"v4": "tpuv4",
25+
}
26+
27+
MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP = {
28+
"tpu7x-standard-4t": "tpu7x",
29+
"tpu7x": "tpu7x",
30+
"ct6e-standard-4t": "tpu-v6e-slice",
31+
"v6e": "tpu-v6e-slice",
32+
"ct6e-standard-8t": "tpu-v6e-slice",
33+
"ct5p-hightpu-4t": "tpu-v5p-slice",
34+
"v5p": "tpu-v5p-slice",
35+
"ct5lp-hightpu-4t": "tpu-v5-lite-podslice",
36+
"v5e": "tpu-v5-lite-podslice",
37+
"ct5lp-hightpu-8t": "tpu-v5-lite-podslice",
38+
"ct4p-hightpu-4t": "tpu-v4-podslice",
39+
"v4": "tpu-v4-podslice",
40+
}
41+
42+
43+
class PathwaysJobSet:
44+
"""Generates JobSet configuration for Pathways (Skeleton)."""
45+
46+
def __init__(
47+
self,
48+
name: str,
49+
namespace: str,
50+
pathways_dir: str,
51+
tpu_type: str,
52+
topology: str,
53+
num_slices: int,
54+
user_pod_template: Optional[dict[str, Any]] = None,
55+
max_restarts: int = 0,
56+
jobset_api_version: str = "v1alpha2",
57+
labels: Optional[dict[str, str]] = None,
58+
annotations: Optional[dict[str, str]] = None,
59+
):
60+
"""Initializes the JobSet configuration using K8s API models."""
61+
self._name = name
62+
self._namespace = namespace
63+
self._jobset_api_version = jobset_api_version
64+
self._max_restarts = max_restarts
65+
self._worker_replicas = num_slices
66+
self._labels = labels or {}
67+
self._annotations = annotations or {}
68+
69+
tpu_version = MACHINE_TYPE_TO_TPU_VERSION_MAP.get(tpu_type.lower())
70+
if not tpu_version:
71+
raise ValueError(f"Unsupported TPU type: {tpu_type}")
72+
73+
# Build minimal head template (placeholder)
74+
self._head_job_template = self._build_minimal_job_template("head")
75+
76+
# Build minimal worker template (placeholder)
77+
self._worker_job_template = self._build_minimal_job_template("worker")
78+
79+
self._success_policy = None
80+
if user_pod_template:
81+
self._success_policy = {
82+
"operator": "All",
83+
"targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME],
84+
}
85+
86+
def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
87+
pod_spec = client.V1PodSpec(
88+
containers=[client.V1Container(name=f"placeholder-{role}", image="ubuntu")]
89+
)
90+
job_spec = client.V1JobSpec(
91+
template=client.V1PodTemplateSpec(
92+
metadata=client.V1ObjectMeta(labels={"role": role}),
93+
spec=pod_spec
94+
)
95+
)
96+
return client.V1JobTemplateSpec(spec=job_spec)
97+
98+
def _compile_config(self) -> dict[str, Any]:
99+
api_client = client.ApiClient()
100+
serialized_head = api_client.sanitize_for_serialization(
101+
self._head_job_template
102+
)
103+
serialized_worker = api_client.sanitize_for_serialization(
104+
self._worker_job_template
105+
)
106+
107+
replicated_jobs = [
108+
{
109+
"name": PATHWAYS_HEAD_JOB_NAME,
110+
"replicas": 1,
111+
"template": serialized_head,
112+
},
113+
{
114+
"name": PATHWAYS_WORKER_JOB_NAME,
115+
"replicas": self._worker_replicas,
116+
"template": serialized_worker,
117+
},
118+
]
119+
120+
jobset_config = {
121+
"apiVersion": f"jobset.sigs.k8s.io/{self._jobset_api_version}",
122+
"kind": "JobSet",
123+
"metadata": {
124+
"name": self._name,
125+
"namespace": self._namespace,
126+
},
127+
"spec": {
128+
"failurePolicy": {"maxRestarts": self._max_restarts},
129+
"replicatedJobs": replicated_jobs,
130+
},
131+
}
132+
if self._labels:
133+
jobset_config["metadata"]["labels"] = self._labels
134+
if self._annotations:
135+
jobset_config["metadata"]["annotations"] = self._annotations
136+
if self._success_policy:
137+
jobset_config["spec"]["successPolicy"] = self._success_policy
138+
139+
return jobset_config
140+
141+
def to_dict(self) -> dict[str, Any]:
142+
return self._compile_config()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from absl.testing import absltest
2+
from absl.testing import parameterized
3+
from pathwaysutils.experimental.gke import jobset
4+
5+
class PathwaysJobSetTest(parameterized.TestCase):
6+
7+
def test_invalid_tpu_type(self):
8+
with self.assertRaisesRegex(ValueError, "Unsupported TPU type"):
9+
jobset.PathwaysJobSet(
10+
name="test-jobset",
11+
namespace="default",
12+
pathways_dir="gs://test-bucket",
13+
tpu_type="invalid-tpu",
14+
topology="4x4",
15+
num_slices=1,
16+
)
17+
18+
def test_basic_jobset_structure(self):
19+
js = jobset.PathwaysJobSet(
20+
name="test-jobset",
21+
namespace="default",
22+
pathways_dir="gs://test-bucket",
23+
tpu_type="v5e",
24+
topology="4x8",
25+
num_slices=2,
26+
labels={"app": "pathways"},
27+
annotations={"example.com/annotation": "value"},
28+
)
29+
config = js.to_dict()
30+
31+
self.assertEqual(config["apiVersion"], "jobset.sigs.k8s.io/v1alpha2")
32+
self.assertEqual(config["kind"], "JobSet")
33+
self.assertEqual(config["metadata"]["name"], "test-jobset")
34+
self.assertEqual(config["metadata"]["namespace"], "default")
35+
self.assertEqual(config["metadata"]["labels"]["app"], "pathways")
36+
self.assertEqual(
37+
config["metadata"]["annotations"]["example.com/annotation"], "value"
38+
)
39+
40+
self.assertEqual(config["spec"]["failurePolicy"]["maxRestarts"], 0)
41+
42+
replicated_jobs = config["spec"]["replicatedJobs"]
43+
self.assertEqual(len(replicated_jobs), 2)
44+
45+
head_job = replicated_jobs[0]
46+
self.assertEqual(head_job["name"], "pathways-head")
47+
self.assertEqual(head_job["replicas"], 1)
48+
49+
# In K8s API models, V1JobTemplateSpec -> V1JobSpec -> V1PodTemplateSpec -> V1PodSpec
50+
# When serialized, they match this structure.
51+
head_pod_spec = head_job["template"]["spec"]["template"]["spec"]
52+
self.assertEqual(head_pod_spec["containers"][0]["name"], "placeholder-head")
53+
54+
worker_job = replicated_jobs[1]
55+
self.assertEqual(worker_job["name"], "pathways-worker")
56+
self.assertEqual(worker_job["replicas"], 2)
57+
worker_pod_spec = worker_job["template"]["spec"]["template"]["spec"]
58+
self.assertEqual(worker_pod_spec["containers"][0]["name"], "placeholder-worker")
59+
60+
if __name__ == "__main__":
61+
absltest.main()

0 commit comments

Comments
 (0)