Skip to content

Commit 87b82b6

Browse files
lukebaumanncopybara-github
authored andcommitted
Add Head Job configuration to PathwaysJobSet
PiperOrigin-RevId: 926777681
1 parent ba762b6 commit 87b82b6

2 files changed

Lines changed: 329 additions & 29 deletions

File tree

pathwaysutils/experimental/gke/jobset.py

Lines changed: 254 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,31 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
"""Pathways JobSet generator and builder (Skeleton)."""
12+
"""Pathways JobSet generator and builder (Head Job Config)."""
13+
14+
import json
15+
import logging
1316
from typing import Any, Mapping
1417
from kubernetes import client
1518

19+
# GKE sidecar containers restartPolicy compatibility placeholder.
20+
21+
_logger = logging.getLogger(__name__)
22+
1623
# Core constants.
1724
PATHWAYS_HEAD_JOB_NAME = "pathways-head"
1825
PATHWAYS_WORKER_JOB_NAME = "pathways-worker"
1926

27+
DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE = (
28+
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
29+
)
30+
DEFAULT_PATHWAYS_PROXY_IMAGE = (
31+
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
32+
)
33+
34+
PATHWAYS_PROXY_PORT = 29000
35+
PATHWAYS_RM_PORT = 29001
36+
2037
MACHINE_TYPE_TO_TPU_VERSION_MAP = {
2138
"tpu7x-standard-4t": "tpu7x",
2239
"tpu7x": "tpu7x",
@@ -48,18 +65,34 @@
4865
}
4966

5067

68+
def _deserialize_dict(
69+
api_client: client.ApiClient, data_dict: Mapping[str, Any], klass: Any
70+
) -> Any:
71+
class FakeResponse:
72+
73+
def __init__(self, data):
74+
self.data = data
75+
76+
return api_client.deserialize(FakeResponse(json.dumps(data_dict)), klass)
77+
78+
5179
class PathwaysJobSet:
52-
"""Generates JobSet configuration for Pathways (Skeleton)."""
80+
"""Generates JobSet configuration for Pathways (with Head Job Config)."""
5381

5482
def __init__(
5583
self,
5684
name: str,
5785
namespace: str,
86+
pathways_dir: str,
5887
tpu_type: str,
88+
topology: str,
5989
num_slices: int,
6090
user_pod_template: Mapping[str, Any] | None = None,
91+
main_container_name: str = "main",
6192
max_restarts: int = 0,
93+
pathways_version: str = "latest",
6294
jobset_api_version: str = "v1alpha2",
95+
elastic_slices: int = 0,
6396
labels: Mapping[str, str] | None = None,
6497
annotations: Mapping[str, str] | None = None,
6598
):
@@ -68,11 +101,16 @@ def __init__(
68101
Args:
69102
name: Name of the JobSet.
70103
namespace: Namespace of the JobSet.
104+
pathways_dir: GCS path for Pathways scratch space.
71105
tpu_type: TPU type (e.g., "v5e").
106+
topology: TPU topology (e.g., "2x2").
72107
num_slices: Number of slices.
73108
user_pod_template: Optional user pod template for the head job.
109+
main_container_name: Name of the main container in user_pod_template.
74110
max_restarts: Maximum number of restarts for the JobSet.
111+
pathways_version: Version tag for Pathways images.
75112
jobset_api_version: API version of JobSet.
113+
elastic_slices: Number of elastic slices.
76114
labels: Optional labels for the JobSet.
77115
annotations: Optional annotations for the JobSet.
78116
"""
@@ -88,8 +126,19 @@ def __init__(
88126
if not tpu_version:
89127
raise ValueError(f"Unsupported TPU type: {tpu_type}")
90128

91-
# Build minimal head template (placeholder)
92-
self._head_job_template = self._build_minimal_job_template("head")
129+
instance_type = f"{tpu_version}:{topology}"
130+
image_tag = pathways_version
131+
132+
# Build head template.
133+
self._head_job_template = self._build_head_job_template(
134+
pathways_dir=pathways_dir,
135+
num_slices=num_slices,
136+
instance_type=instance_type,
137+
image_tag=image_tag,
138+
user_pod_template=user_pod_template,
139+
main_container_name=main_container_name,
140+
elastic_slices=elastic_slices,
141+
)
93142

94143
# Build minimal worker template (placeholder)
95144
self._worker_job_template = self._build_minimal_job_template("worker")
@@ -115,6 +164,207 @@ def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
115164
)
116165
return client.V1JobTemplateSpec(spec=job_spec)
117166

167+
def _build_head_job_template(
168+
self,
169+
pathways_dir: str,
170+
num_slices: int,
171+
instance_type: str,
172+
image_tag: str,
173+
user_pod_template: Mapping[str, Any] | None,
174+
main_container_name: str,
175+
elastic_slices: int,
176+
) -> client.V1JobTemplateSpec:
177+
"""Builds the head job template for the JobSet.
178+
179+
Args:
180+
pathways_dir: GCS path for Pathways scratch space.
181+
num_slices: Number of slices.
182+
instance_type: TPU instance type (e.g., "tpuv5:2x2").
183+
image_tag: Version tag for Pathways images.
184+
user_pod_template: Optional user pod template for the head job.
185+
main_container_name: Name of the main container in user_pod_template.
186+
elastic_slices: Number of elastic slices.
187+
188+
Returns:
189+
The head job template.
190+
"""
191+
rm_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}"
192+
proxy_image = f"{DEFAULT_PATHWAYS_PROXY_IMAGE}:{image_tag}"
193+
194+
rm_args = [
195+
f"--server_port={PATHWAYS_RM_PORT}",
196+
f"--gcs_scratch_location={pathways_dir}",
197+
"--node_type=resource_manager",
198+
f"--instance_count={num_slices}",
199+
f"--instance_type={instance_type}",
200+
]
201+
rm_env = [
202+
client.V1EnvVar(
203+
name="REPLICATED_JOB_NAME",
204+
value_from=client.V1EnvVarSource(
205+
field_ref=client.V1ObjectFieldSelector(
206+
field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']"
207+
)
208+
),
209+
),
210+
client.V1EnvVar(
211+
name="JOBSET_NAME",
212+
value_from=client.V1EnvVarSource(
213+
field_ref=client.V1ObjectFieldSelector(
214+
field_path=(
215+
"metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
216+
)
217+
)
218+
),
219+
),
220+
client.V1EnvVar(
221+
name="HOST_ADDRESS",
222+
value_from=client.V1EnvVarSource(
223+
field_ref=client.V1ObjectFieldSelector(
224+
field_path=(
225+
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
226+
)
227+
)
228+
),
229+
),
230+
client.V1EnvVar(name="TPU_SKIP_MDS_QUERY", value="true"),
231+
]
232+
rm_container = client.V1Container(
233+
name="pathways-rm",
234+
image=rm_image,
235+
image_pull_policy="Always",
236+
args=rm_args,
237+
env=rm_env,
238+
ports=[
239+
client.V1ContainerPort(
240+
container_port=PATHWAYS_RM_PORT, protocol="TCP"
241+
),
242+
client.V1ContainerPort(container_port=29002, protocol="TCP"),
243+
],
244+
resources=client.V1ResourceRequirements(
245+
limits={"cpu": "8", "memory": "32G"}
246+
),
247+
)
248+
249+
proxy_args = [
250+
f"--server_port={PATHWAYS_PROXY_PORT}",
251+
f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}",
252+
f"--gcs_scratch_location={pathways_dir}",
253+
]
254+
if elastic_slices > 0:
255+
proxy_args.append(f"--num_elastic_slices={elastic_slices}")
256+
257+
proxy_env = [
258+
client.V1EnvVar(
259+
name="PATHWAYS_HEAD",
260+
value_from=client.V1EnvVarSource(
261+
field_ref=client.V1ObjectFieldSelector(
262+
field_path=(
263+
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
264+
)
265+
)
266+
),
267+
)
268+
]
269+
proxy_container = client.V1Container(
270+
name="pathways-proxy",
271+
image=proxy_image,
272+
image_pull_policy="Always",
273+
args=proxy_args,
274+
env=proxy_env,
275+
ports=[
276+
client.V1ContainerPort(
277+
container_port=PATHWAYS_PROXY_PORT, protocol="TCP"
278+
)
279+
],
280+
resources=client.V1ResourceRequirements(
281+
limits={"cpu": "16", "memory": "100G"}
282+
),
283+
)
284+
285+
api_client = client.ApiClient()
286+
287+
if user_pod_template:
288+
user_template_obj = _deserialize_dict(
289+
api_client, user_pod_template, client.V1PodTemplateSpec
290+
)
291+
head_pod_spec = user_template_obj.spec
292+
head_pod_spec.host_network = True
293+
head_pod_spec.dns_policy = "ClusterFirstWithHostNet"
294+
295+
rm_container.restart_policy = "Always"
296+
proxy_container.restart_policy = "Always"
297+
298+
init_containers = head_pod_spec.init_containers or []
299+
init_containers.extend([rm_container, proxy_container])
300+
head_pod_spec.init_containers = init_containers
301+
302+
# Inject JAX env vars into main container.
303+
jax_env = [
304+
client.V1EnvVar(
305+
name="PATHWAYS_HEAD",
306+
value_from=client.V1EnvVarSource(
307+
field_ref=client.V1ObjectFieldSelector(
308+
field_path=(
309+
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
310+
)
311+
)
312+
),
313+
),
314+
client.V1EnvVar(name="JAX_PLATFORMS", value="proxy"),
315+
client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"),
316+
client.V1EnvVar(
317+
name="JAX_BACKEND_TARGET",
318+
value=f"grpc://$(PATHWAYS_HEAD):{PATHWAYS_PROXY_PORT}",
319+
),
320+
]
321+
containers = head_pod_spec.containers or []
322+
for c in containers:
323+
if c.name == main_container_name:
324+
env = c.env or []
325+
env.extend(jax_env)
326+
c.env = env
327+
break
328+
head_pod_spec.containers = containers
329+
330+
annotations = user_pod_template.get("metadata", {}).get("annotations", {})
331+
labels = user_pod_template.get("metadata", {}).get("labels", {})
332+
else:
333+
# Headless mode.
334+
head_pod_spec = client.V1PodSpec(
335+
host_network=True,
336+
dns_policy="ClusterFirstWithHostNet",
337+
containers=[rm_container, proxy_container],
338+
)
339+
annotations = {}
340+
labels = {}
341+
342+
if not head_pod_spec.restart_policy:
343+
head_pod_spec.restart_policy = "Never"
344+
345+
# Default annotations
346+
job_annotations = {
347+
"alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname"
348+
}
349+
job_annotations.update(annotations)
350+
351+
head_job_template = client.V1JobTemplateSpec(
352+
metadata=client.V1ObjectMeta(annotations=job_annotations),
353+
spec=client.V1JobSpec(
354+
backoff_limit=0,
355+
completion_mode="Indexed",
356+
completions=1,
357+
parallelism=1,
358+
template=client.V1PodTemplateSpec(
359+
metadata=client.V1ObjectMeta(
360+
annotations=job_annotations, labels=labels
361+
),
362+
spec=head_pod_spec,
363+
),
364+
),
365+
)
366+
return head_job_template
367+
118368
def _compile_config(self) -> dict[str, Any]:
119369
"""Compiles the JobSet configuration into a dictionary."""
120370
with client.ApiClient() as api_client:

0 commit comments

Comments
 (0)