4848# There is no guarantee that this image will work with newer Jax releases.
4949# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
5050# This image extends GRPC timeout for long context models.
51- _PATHWAYS_IMAGE_TAG = "disable_settings_20250701 "
51+ _PATHWAYS_IMAGE_TAG = "shm_proxy "
5252# The docker image used by pathways proxy container.
5353_PATHWAYS_PROXY_IMAGE = (
5454 f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{ _PATHWAYS_IMAGE_TAG } "
@@ -107,7 +107,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str:
107107
108108
109109def get_megascale_options (
110- xla_options : dict [str , Union [str , bool , int ]]
110+ xla_options : dict [str , Union [str , bool , int ]],
111111) -> dict [str , Union [str , bool , int ]]:
112112 """Filters XLA options for those pertaining to Megascale.
113113
@@ -122,7 +122,7 @@ def get_megascale_options(
122122
123123
124124def get_xla_options (
125- xla_options : dict [str , Union [str , bool , int ]]
125+ xla_options : dict [str , Union [str , bool , int ]],
126126) -> dict [str , Union [str , bool , int ]]:
127127 """Filters XLA options for those starting with 'xla_'.
128128
@@ -275,7 +275,11 @@ def _build_pathways_head_container(self) -> dict:
275275 # In Jax 0.6.2 and beyond this flag can be renamed to
276276 # IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS as well.
277277 self ._update_env_list (env_list , "TEST_UNDECLARED_OUTPUTS_DIR" , "true" )
278-
278+ # 1 byte
279+ self ._update_env_list (env_list , "IFRT_PROXY_LARGE_TRANSFER_THRESHOLD" , "1" )
280+ self ._update_env_list (
281+ env_list , "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY" , "/tmp/ifrt_proxy"
282+ )
279283 env_list .append (
280284 {
281285 "name" : "HOST_ADDRESS" ,
@@ -315,10 +319,14 @@ def _build_pathways_head_container(self) -> dict:
315319 mem_req = f"{ self .config .pathways_head_mem } Gi"
316320 resources = {
317321 "requests" : {"cpu" : cpu_req , "memory" : mem_req },
318- "limits" : {"cpu" : cpu_req , "memory" : mem_req },
322+ # "limits": {"cpu": cpu_req, "memory": mem_req},
319323 }
320324 head_container ["resources" ] = resources
321325
326+ volume_mounts = head_container .get ("volumeMounts" , [])
327+ volume_mounts .append (dict (name = "shared-memory" , mountPath = "/tmp/ifrt_proxy" ))
328+ head_container ["volumeMounts" ] = volume_mounts
329+
322330 return head_container
323331
324332 def _build_pathways_head_sidecar_containers (self ) -> list [Nested [Any ]]:
@@ -354,6 +362,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
354362 dict (
355363 name = _PATHWAYS_PROXY_CONTAINER_NAME ,
356364 image = _PATHWAYS_PROXY_IMAGE ,
365+ # Enables gRPC zero copy for improved performance.
366+ securityContext = {"privileged" : True },
357367 # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers
358368 # SideCar container is an init container with restartPolicy as "Always".
359369 restartPolicy = "Always" ,
@@ -363,9 +373,16 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
363373 # TODO(samos123): Remove this once this becomes the default.
364374 {"name" : "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS" , "value" : "true" },
365375 {"name" : "XLA_FLAGS" , "value" : f"--xla_dump_to=/output/{ cfg .name } /xla" },
376+ {
377+ "name" : "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY" ,
378+ "value" : "/tmp/ifrt_proxy" ,
379+ },
366380 ],
367381 ports = [dict (containerPort = _PATHWAYS_PROXY_PORT )],
368- volumeMounts = [dict (name = "shared-output" , mountPath = "/output" )],
382+ volumeMounts = [
383+ dict (name = "shared-output" , mountPath = "/output" ),
384+ dict (name = "shared-memory" , mountPath = "/tmp/ifrt_proxy" ),
385+ ],
369386 ),
370387 dict (
371388 name = _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME ,
@@ -403,6 +420,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
403420 labels .update ({BASTION_JOB_VERSION_LABEL : os .environ .get (BASTION_JOB_VERSION_ENV_VAR )})
404421
405422 volumes .append (dict (name = "shared-output" , emptyDir = {}))
423+ volumes .append (dict (name = "shared-memory" , emptyDir = dict (medium = "Memory" )))
406424
407425 if cfg .gcsfuse_mount :
408426 annotations .update (
@@ -537,6 +555,10 @@ def _build_pathways_worker_container(
537555 f"--resource_manager_address={ pathways_head_address } :"
538556 + f"{ _PATHWAYS_RESOURCE_MANAGER_PORT } " ,
539557 f"--gcs_scratch_location={ cfg .output_dir } /pathways-staging" ,
558+ # Recycling host memory gives a slight increase in performance.
559+ "--tpu_pinned_host_allocation_recycle=true" ,
560+ # The flag below is needed for better H2D performance.
561+ "--tpu_premapped_buffer_size=17179869184" ,
540562 ]
541563 mega_scale_args = xla_flags_from_options (self ._mxla_options ).split ()
542564 worker_container ["args" ].extend (mega_scale_args )
@@ -910,7 +932,8 @@ def _build_head_container(self) -> dict:
910932 mem_req = f"{ self .config .pathways_head_mem } Gi"
911933 resources = {
912934 "requests" : {"cpu" : cpu_req , "memory" : mem_req },
913- "limits" : {"cpu" : cpu_req , "memory" : mem_req },
935+ # Do not set a limit so full host memory can be used
936+ # "limits": {"cpu": cpu_req, "memory": mem_req},
914937 }
915938 return dict (
916939 name = cfg .name ,
@@ -936,9 +959,9 @@ def _build_head_container(self) -> dict:
936959 ],
937960 imagePullPolicy = "Always" ,
938961 resources = resources ,
939- ports = [ dict ( containerPort = self . config . target_port )]
940- if self .config .enable_service
941- else [] ,
962+ ports = (
963+ [ dict ( containerPort = self . config . target_port )] if self .config .enable_service else []
964+ ) ,
942965 )
943966
944967 def build_leader_pod (self ) -> Nested [Any ]:
0 commit comments