Skip to content

Commit d80d842

Browse files
committed
Pathways proxy and jax client shared memory
Also improves array_serialization to provide better prformance
1 parent c789a0d commit d80d842

3 files changed

Lines changed: 57 additions & 17 deletions

File tree

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.ht
9797
# Jax will fallback to CPU when run on a machine without TPU.
9898
RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean
9999
RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi
100+
RUN uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html
100101
COPY . .
101102

102103
################################################################################

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# There is no guarantee that this image will work with newer Jax releases.
4949
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
5050
# This image extends GRPC timeout for long context models.
51-
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
51+
_PATHWAYS_IMAGE_TAG = "shm_proxy"
5252
# The docker image used by pathways proxy container.
5353
_PATHWAYS_PROXY_IMAGE = (
5454
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
@@ -107,7 +107,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str:
107107

108108

109109
def get_megascale_options(
110-
xla_options: dict[str, Union[str, bool, int]]
110+
xla_options: dict[str, Union[str, bool, int]],
111111
) -> dict[str, Union[str, bool, int]]:
112112
"""Filters XLA options for those pertaining to Megascale.
113113
@@ -122,7 +122,7 @@ def get_megascale_options(
122122

123123

124124
def get_xla_options(
125-
xla_options: dict[str, Union[str, bool, int]]
125+
xla_options: dict[str, Union[str, bool, int]],
126126
) -> dict[str, Union[str, bool, int]]:
127127
"""Filters XLA options for those starting with 'xla_'.
128128
@@ -275,7 +275,11 @@ def _build_pathways_head_container(self) -> dict:
275275
# In Jax 0.6.2 and beyond this flag can be renamed to
276276
# IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS as well.
277277
self._update_env_list(env_list, "TEST_UNDECLARED_OUTPUTS_DIR", "true")
278-
278+
# 1 byte
279+
self._update_env_list(env_list, "IFRT_PROXY_LARGE_TRANSFER_THRESHOLD", "1")
280+
self._update_env_list(
281+
env_list, "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY", "/tmp/ifrt_proxy"
282+
)
279283
env_list.append(
280284
{
281285
"name": "HOST_ADDRESS",
@@ -315,10 +319,14 @@ def _build_pathways_head_container(self) -> dict:
315319
mem_req = f"{self.config.pathways_head_mem}Gi"
316320
resources = {
317321
"requests": {"cpu": cpu_req, "memory": mem_req},
318-
"limits": {"cpu": cpu_req, "memory": mem_req},
322+
# "limits": {"cpu": cpu_req, "memory": mem_req},
319323
}
320324
head_container["resources"] = resources
321325

326+
volume_mounts = head_container.get("volumeMounts", [])
327+
volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/ifrt_proxy"))
328+
head_container["volumeMounts"] = volume_mounts
329+
322330
return head_container
323331

324332
def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
@@ -354,6 +362,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
354362
dict(
355363
name=_PATHWAYS_PROXY_CONTAINER_NAME,
356364
image=_PATHWAYS_PROXY_IMAGE,
365+
# Enables gRPC zero copy for improved performance.
366+
securityContext={"privileged": True},
357367
# https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers
358368
# SideCar container is an init container with restartPolicy as "Always".
359369
restartPolicy="Always",
@@ -363,9 +373,16 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
363373
# TODO(samos123): Remove this once this becomes the default.
364374
{"name": "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS", "value": "true"},
365375
{"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"},
376+
{
377+
"name": "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY",
378+
"value": "/tmp/ifrt_proxy",
379+
},
366380
],
367381
ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)],
368-
volumeMounts=[dict(name="shared-output", mountPath="/output")],
382+
volumeMounts=[
383+
dict(name="shared-output", mountPath="/output"),
384+
dict(name="shared-memory", mountPath="/tmp/ifrt_proxy"),
385+
],
369386
),
370387
dict(
371388
name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME,
@@ -403,6 +420,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
403420
labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)})
404421

405422
volumes.append(dict(name="shared-output", emptyDir={}))
423+
volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory")))
406424

407425
if cfg.gcsfuse_mount:
408426
annotations.update(
@@ -537,6 +555,10 @@ def _build_pathways_worker_container(
537555
f"--resource_manager_address={pathways_head_address}:"
538556
+ f"{_PATHWAYS_RESOURCE_MANAGER_PORT}",
539557
f"--gcs_scratch_location={cfg.output_dir}/pathways-staging",
558+
# Recycling host memory gives a slight increase in performance.
559+
"--tpu_pinned_host_allocation_recycle=true",
560+
# The flag below is needed for better H2D performance.
561+
"--tpu_premapped_buffer_size=17179869184",
540562
]
541563
mega_scale_args = xla_flags_from_options(self._mxla_options).split()
542564
worker_container["args"].extend(mega_scale_args)
@@ -910,7 +932,8 @@ def _build_head_container(self) -> dict:
910932
mem_req = f"{self.config.pathways_head_mem}Gi"
911933
resources = {
912934
"requests": {"cpu": cpu_req, "memory": mem_req},
913-
"limits": {"cpu": cpu_req, "memory": mem_req},
935+
# Do not set a limit so full host memory can be used
936+
# "limits": {"cpu": cpu_req, "memory": mem_req},
914937
}
915938
return dict(
916939
name=cfg.name,
@@ -936,9 +959,9 @@ def _build_head_container(self) -> dict:
936959
],
937960
imagePullPolicy="Always",
938961
resources=resources,
939-
ports=[dict(containerPort=self.config.target_port)]
940-
if self.config.enable_service
941-
else [],
962+
ports=(
963+
[dict(containerPort=self.config.target_port)] if self.config.enable_service else []
964+
),
942965
)
943966

944967
def build_leader_pod(self) -> Nested[Any]:

axlearn/common/array_serialization.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]):
205205

206206

207207
class TensorstoreSpecModifier:
208-
def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]):
209-
...
208+
def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): ...
210209

211210

212211
async def _async_serialize(
@@ -404,11 +403,22 @@ async def _async_deserialize(
404403
f" an instance of `jax.sharding.Sharding`. Got {in_sharding}"
405404
)
406405
dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None
406+
407+
# gcs_grpc improves performance.
408+
if tensorstore_spec.get("kvstore", {}).get("driver") == "gcs":
409+
tensorstore_spec["kvstore"]["driver"] = "gcs_grpc"
407410
t = await ts.open(
408411
tensorstore_spec,
409412
open=True,
410413
assume_metadata=False,
411-
context=serialization.TS_CONTEXT,
414+
# Improve GCS performance
415+
context=ts.Context(
416+
{
417+
"cache_pool": {"total_bytes_limit": 0},
418+
"data_copy_concurrency": {"limit": "shared"},
419+
"gcs_request_concurrency": {"limit": 480},
420+
}
421+
),
412422
)
413423
shape = tuple(t.shape if global_shape is None else global_shape)
414424
new_shard_shape = in_sharding.shard_shape(shape)
@@ -454,9 +464,12 @@ async def cb(index: array.Index, device: jax.Device):
454464
dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind)
455465
)
456466
try:
457-
await h2d_limiter.wait_for_bytes(out_size)
458-
result = await loop.run_in_executor(None, _blocking_device_put, out, layout)
459-
await h2d_limiter.release_bytes(out_size)
467+
if os.getenv("JAX_PLATFORMS") == "proxy":
468+
result = await loop.run_in_executor(None, jax.device_put, out, layout)
469+
else:
470+
await h2d_limiter.wait_for_bytes(out_size)
471+
result = await loop.run_in_executor(None, _blocking_device_put, out, layout)
472+
await h2d_limiter.release_bytes(out_size)
460473
except ValueError as e:
461474
if "Requested more bytes than we reserved" not in str(e):
462475
raise e # Raise if it's not the type of error we expect.
@@ -589,6 +602,7 @@ def deserialize(
589602
concurrent_gb: int = 32,
590603
):
591604
self.wait_until_finished()
605+
start_time = time.time()
592606

593607
concurrent_bytes = concurrent_gb * 10**9
594608

@@ -613,7 +627,9 @@ async def _run_deserializer():
613627
return await asyncio.gather(*future_arrays)
614628

615629
fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop)
616-
return fut.result()
630+
result = fut.result()
631+
logging.info("deserialize took %.4f seconds.", time.time() - start_time)
632+
return result
617633

618634

619635
class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager):

0 commit comments

Comments
 (0)