99# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010# See the License for the specific language governing permissions and
1111# limitations under the License.
12- """Pathways JobSet generator and builder (Skeleton)."""
12+ """Pathways JobSet generator and builder (Head Job Config)."""
13+
14+ import json
15+ import logging
1316from typing import Any , Mapping
1417from kubernetes import client
1518
19+ # GKE sidecar containers restartPolicy compatibility placeholder.
20+
21+ _logger = logging .getLogger (__name__ )
22+
1623# Core constants.
1724PATHWAYS_HEAD_JOB_NAME = "pathways-head"
1825PATHWAYS_WORKER_JOB_NAME = "pathways-worker"
1926
27+ DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE = (
28+ "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
29+ )
30+ DEFAULT_PATHWAYS_PROXY_IMAGE = (
31+ "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
32+ )
33+
34+ PATHWAYS_PROXY_PORT = 29000
35+ PATHWAYS_RM_PORT = 29001
36+
2037MACHINE_TYPE_TO_TPU_VERSION_MAP = {
2138 "tpu7x-standard-4t" : "tpu7x" ,
2239 "tpu7x" : "tpu7x" ,
4865}
4966
5067
68+ def _deserialize_dict (
69+ api_client : client .ApiClient , data_dict : Mapping [str , Any ], klass : Any
70+ ) -> Any :
71+ class FakeResponse :
72+
73+ def __init__ (self , data ):
74+ self .data = data
75+
76+ return api_client .deserialize (FakeResponse (json .dumps (data_dict )), klass )
77+
78+
5179class PathwaysJobSet :
52- """Generates JobSet configuration for Pathways (Skeleton )."""
80+ """Generates JobSet configuration for Pathways (with Head Job Config )."""
5381
5482 def __init__ (
5583 self ,
5684 name : str ,
5785 namespace : str ,
86+ pathways_dir : str ,
5887 tpu_type : str ,
88+ topology : str ,
5989 num_slices : int ,
6090 user_pod_template : Mapping [str , Any ] | None = None ,
91+ main_container_name : str = "main" ,
6192 max_restarts : int = 0 ,
93+ pathways_version : str = "latest" ,
6294 jobset_api_version : str = "v1alpha2" ,
95+ elastic_slices : int = 0 ,
6396 labels : Mapping [str , str ] | None = None ,
6497 annotations : Mapping [str , str ] | None = None ,
6598 ):
@@ -68,11 +101,16 @@ def __init__(
68101 Args:
69102 name: Name of the JobSet.
70103 namespace: Namespace of the JobSet.
104+ pathways_dir: GCS path for Pathways scratch space.
71105 tpu_type: TPU type (e.g., "v5e").
106+ topology: TPU topology (e.g., "2x2").
72107 num_slices: Number of slices.
73108 user_pod_template: Optional user pod template for the head job.
109+ main_container_name: Name of the main container in user_pod_template.
74110 max_restarts: Maximum number of restarts for the JobSet.
111+ pathways_version: Version tag for Pathways images.
75112 jobset_api_version: API version of JobSet.
113+ elastic_slices: Number of elastic slices.
76114 labels: Optional labels for the JobSet.
77115 annotations: Optional annotations for the JobSet.
78116 """
@@ -88,8 +126,19 @@ def __init__(
88126 if not tpu_version :
89127 raise ValueError (f"Unsupported TPU type: { tpu_type } " )
90128
91- # Build minimal head template (placeholder)
92- self ._head_job_template = self ._build_minimal_job_template ("head" )
129+ instance_type = f"{ tpu_version } :{ topology } "
130+ image_tag = pathways_version
131+
132+ # Build head template.
133+ self ._head_job_template = self ._build_head_job_template (
134+ pathways_dir = pathways_dir ,
135+ num_slices = num_slices ,
136+ instance_type = instance_type ,
137+ image_tag = image_tag ,
138+ user_pod_template = user_pod_template ,
139+ main_container_name = main_container_name ,
140+ elastic_slices = elastic_slices ,
141+ )
93142
94143 # Build minimal worker template (placeholder)
95144 self ._worker_job_template = self ._build_minimal_job_template ("worker" )
@@ -115,6 +164,207 @@ def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
115164 )
116165 return client .V1JobTemplateSpec (spec = job_spec )
117166
167+ def _build_head_job_template (
168+ self ,
169+ pathways_dir : str ,
170+ num_slices : int ,
171+ instance_type : str ,
172+ image_tag : str ,
173+ user_pod_template : Mapping [str , Any ] | None ,
174+ main_container_name : str ,
175+ elastic_slices : int ,
176+ ) -> client .V1JobTemplateSpec :
177+ """Builds the head job template for the JobSet.
178+
179+ Args:
180+ pathways_dir: GCS path for Pathways scratch space.
181+ num_slices: Number of slices.
182+ instance_type: TPU instance type (e.g., "tpuv5:2x2").
183+ image_tag: Version tag for Pathways images.
184+ user_pod_template: Optional user pod template for the head job.
185+ main_container_name: Name of the main container in user_pod_template.
186+ elastic_slices: Number of elastic slices.
187+
188+ Returns:
189+ The head job template.
190+ """
191+ rm_image = f"{ DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE } :{ image_tag } "
192+ proxy_image = f"{ DEFAULT_PATHWAYS_PROXY_IMAGE } :{ image_tag } "
193+
194+ rm_args = [
195+ f"--server_port={ PATHWAYS_RM_PORT } " ,
196+ f"--gcs_scratch_location={ pathways_dir } " ,
197+ "--node_type=resource_manager" ,
198+ f"--instance_count={ num_slices } " ,
199+ f"--instance_type={ instance_type } " ,
200+ ]
201+ rm_env = [
202+ client .V1EnvVar (
203+ name = "REPLICATED_JOB_NAME" ,
204+ value_from = client .V1EnvVarSource (
205+ field_ref = client .V1ObjectFieldSelector (
206+ field_path = "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']"
207+ )
208+ ),
209+ ),
210+ client .V1EnvVar (
211+ name = "JOBSET_NAME" ,
212+ value_from = client .V1EnvVarSource (
213+ field_ref = client .V1ObjectFieldSelector (
214+ field_path = (
215+ "metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
216+ )
217+ )
218+ ),
219+ ),
220+ client .V1EnvVar (
221+ name = "HOST_ADDRESS" ,
222+ value_from = client .V1EnvVarSource (
223+ field_ref = client .V1ObjectFieldSelector (
224+ field_path = (
225+ "metadata.labels['jobset.sigs.k8s.io/coordinator']"
226+ )
227+ )
228+ ),
229+ ),
230+ client .V1EnvVar (name = "TPU_SKIP_MDS_QUERY" , value = "true" ),
231+ ]
232+ rm_container = client .V1Container (
233+ name = "pathways-rm" ,
234+ image = rm_image ,
235+ image_pull_policy = "Always" ,
236+ args = rm_args ,
237+ env = rm_env ,
238+ ports = [
239+ client .V1ContainerPort (
240+ container_port = PATHWAYS_RM_PORT , protocol = "TCP"
241+ ),
242+ client .V1ContainerPort (container_port = 29002 , protocol = "TCP" ),
243+ ],
244+ resources = client .V1ResourceRequirements (
245+ limits = {"cpu" : "8" , "memory" : "32G" }
246+ ),
247+ )
248+
249+ proxy_args = [
250+ f"--server_port={ PATHWAYS_PROXY_PORT } " ,
251+ f"--resource_manager_address=$(PATHWAYS_HEAD):{ PATHWAYS_RM_PORT } " ,
252+ f"--gcs_scratch_location={ pathways_dir } " ,
253+ ]
254+ if elastic_slices > 0 :
255+ proxy_args .append (f"--num_elastic_slices={ elastic_slices } " )
256+
257+ proxy_env = [
258+ client .V1EnvVar (
259+ name = "PATHWAYS_HEAD" ,
260+ value_from = client .V1EnvVarSource (
261+ field_ref = client .V1ObjectFieldSelector (
262+ field_path = (
263+ "metadata.labels['jobset.sigs.k8s.io/coordinator']"
264+ )
265+ )
266+ ),
267+ )
268+ ]
269+ proxy_container = client .V1Container (
270+ name = "pathways-proxy" ,
271+ image = proxy_image ,
272+ image_pull_policy = "Always" ,
273+ args = proxy_args ,
274+ env = proxy_env ,
275+ ports = [
276+ client .V1ContainerPort (
277+ container_port = PATHWAYS_PROXY_PORT , protocol = "TCP"
278+ )
279+ ],
280+ resources = client .V1ResourceRequirements (
281+ limits = {"cpu" : "16" , "memory" : "100G" }
282+ ),
283+ )
284+
285+ api_client = client .ApiClient ()
286+
287+ if user_pod_template :
288+ user_template_obj = _deserialize_dict (
289+ api_client , user_pod_template , client .V1PodTemplateSpec
290+ )
291+ head_pod_spec = user_template_obj .spec
292+ head_pod_spec .host_network = True
293+ head_pod_spec .dns_policy = "ClusterFirstWithHostNet"
294+
295+ rm_container .restart_policy = "Always"
296+ proxy_container .restart_policy = "Always"
297+
298+ init_containers = head_pod_spec .init_containers or []
299+ init_containers .extend ([rm_container , proxy_container ])
300+ head_pod_spec .init_containers = init_containers
301+
302+ # Inject JAX env vars into main container.
303+ jax_env = [
304+ client .V1EnvVar (
305+ name = "PATHWAYS_HEAD" ,
306+ value_from = client .V1EnvVarSource (
307+ field_ref = client .V1ObjectFieldSelector (
308+ field_path = (
309+ "metadata.labels['jobset.sigs.k8s.io/coordinator']"
310+ )
311+ )
312+ ),
313+ ),
314+ client .V1EnvVar (name = "JAX_PLATFORMS" , value = "proxy" ),
315+ client .V1EnvVar (name = "XCLOUD_ENVIRONMENT" , value = "GCP" ),
316+ client .V1EnvVar (
317+ name = "JAX_BACKEND_TARGET" ,
318+ value = f"grpc://$(PATHWAYS_HEAD):{ PATHWAYS_PROXY_PORT } " ,
319+ ),
320+ ]
321+ containers = head_pod_spec .containers or []
322+ for c in containers :
323+ if c .name == main_container_name :
324+ env = c .env or []
325+ env .extend (jax_env )
326+ c .env = env
327+ break
328+ head_pod_spec .containers = containers
329+
330+ annotations = user_pod_template .get ("metadata" , {}).get ("annotations" , {})
331+ labels = user_pod_template .get ("metadata" , {}).get ("labels" , {})
332+ else :
333+ # Headless mode.
334+ head_pod_spec = client .V1PodSpec (
335+ host_network = True ,
336+ dns_policy = "ClusterFirstWithHostNet" ,
337+ containers = [rm_container , proxy_container ],
338+ )
339+ annotations = {}
340+ labels = {}
341+
342+ if not head_pod_spec .restart_policy :
343+ head_pod_spec .restart_policy = "Never"
344+
345+ # Default annotations
346+ job_annotations = {
347+ "alpha.jobset.sigs.k8s.io/exclusive-topology" : "kubernetes.io/hostname"
348+ }
349+ job_annotations .update (annotations )
350+
351+ head_job_template = client .V1JobTemplateSpec (
352+ metadata = client .V1ObjectMeta (annotations = job_annotations ),
353+ spec = client .V1JobSpec (
354+ backoff_limit = 0 ,
355+ completion_mode = "Indexed" ,
356+ completions = 1 ,
357+ parallelism = 1 ,
358+ template = client .V1PodTemplateSpec (
359+ metadata = client .V1ObjectMeta (
360+ annotations = job_annotations , labels = labels
361+ ),
362+ spec = head_pod_spec ,
363+ ),
364+ ),
365+ )
366+ return head_job_template
367+
118368 def _compile_config (self ) -> dict [str , Any ]:
119369 """Compiles the JobSet configuration into a dictionary."""
120370 with client .ApiClient () as api_client :
0 commit comments