Skip to content

Commit 2b69fdb

Browse files
lukebaumanncopybara-github
authored andcommitted
Expose concatenate_by_mesh_axis in pathwaysutils.
This change flattens/unflattens the PyTree and calls the jaxlib API. It also adds tests for standard undo-split, partial concatenation, and mesh expansion. PiperOrigin-RevId: 893777987
1 parent 66e9754 commit 2b69fdb

2 files changed

Lines changed: 163 additions & 2 deletions

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Pathwaysutils JAX concatenate_by_mesh_axis."""
15+
16+
from typing import Any, Sequence
17+
18+
import jax
19+
import numpy as np
20+
21+
from pathwaysutils import jax as pw_jax
22+
23+
24+
def concatenate_by_mesh_axis(
25+
array_trees: Sequence[Any],
26+
mesh_axis: str,
27+
) -> Any:
28+
"""Concatenates meshes by an axis. Returns arrays on the concatenated mesh.
29+
30+
Note: This API donates the given arrays.
31+
32+
Args:
33+
array_trees: Sequence of PyTrees of JAX arrays with `NamedSharding`. All
34+
PyTrees in the sequence have the same structure. All arrays in each PyTree
35+
are sharded/replicated on the same mesh.
36+
mesh_axis: Mesh axis to concatenate.
37+
38+
Returns:
39+
A PyTree with the same structure as `array_trees[i]`. It has arrays with
40+
their shards concatenated to match a concatenated mesh.
41+
"""
42+
if not array_trees:
43+
return array_trees
44+
45+
def _get_named_sharding(array: jax.Array) -> jax.sharding.NamedSharding:
46+
if not isinstance(array, jax.Array):
47+
raise ValueError(f"Elements must be jax.Array. Got {type(array)}")
48+
sharding = array.sharding
49+
if not isinstance(sharding, jax.sharding.NamedSharding):
50+
raise ValueError(f"Expected NamedSharding. Got {type(sharding)}")
51+
return sharding
52+
53+
flattened_arrays = []
54+
input_treedefs = None
55+
for array_tree in array_trees:
56+
flat, treedef = jax.tree_util.tree_flatten(array_tree)
57+
flattened_arrays.append(flat)
58+
if input_treedefs is not None:
59+
if treedef != input_treedefs:
60+
raise ValueError(
61+
"All array trees must have the same treedef. Got"
62+
f" {treedef} vs. {input_treedefs}"
63+
)
64+
else:
65+
input_treedefs = treedef
66+
67+
# Convert to have the output array structure in the outer list, and each entry
68+
# be a list of arrays from each shard for the concatenated output array.
69+
input_flat_arrays = list(zip(*flattened_arrays))
70+
71+
# Extract the shared mesh from each PyTree (from an arbitrary array in each).
72+
meshes_to_concatenate = [
73+
_get_named_sharding(array).mesh for array in input_flat_arrays[0]
74+
]
75+
76+
# Validate that the meshes are compatible.
77+
reference_mesh = meshes_to_concatenate[0]
78+
mesh_axis_idx = reference_mesh.axis_names.index(mesh_axis)
79+
for mesh in meshes_to_concatenate:
80+
if mesh.axis_names != reference_mesh.axis_names:
81+
raise ValueError(
82+
"Meshes must have the same axis names. Got"
83+
f" {mesh} vs. {reference_mesh}."
84+
)
85+
if (
86+
mesh.axis_sizes[:mesh_axis_idx]
87+
!= reference_mesh.axis_sizes[:mesh_axis_idx]
88+
or mesh.axis_sizes[mesh_axis_idx + 1 :]
89+
!= reference_mesh.axis_sizes[mesh_axis_idx + 1 :]
90+
):
91+
raise ValueError(
92+
"Arrays must have the same mesh axis sizes for all axes except"
93+
f" {mesh_axis}. Got {mesh} vs. {reference_mesh}."
94+
)
95+
96+
# Construct list of the mesh axis section boundaries.
97+
devices = []
98+
mesh_axis_sections = []
99+
cur_mesh_axis_size = 0
100+
for mesh in meshes_to_concatenate:
101+
devices.append(mesh.devices)
102+
cur_mesh_axis_size += mesh.axis_sizes[mesh_axis_idx]
103+
mesh_axis_sections.append(cur_mesh_axis_size)
104+
105+
concatenated_mesh = jax.sharding.Mesh(
106+
np.concatenate(devices, mesh_axis_idx),
107+
axis_names=reference_mesh.axis_names,
108+
axis_types=reference_mesh.axis_types,
109+
)
110+
111+
def _get_output_sharding(
112+
arrays: Sequence[jax.Array],
113+
) -> jax.sharding.NamedSharding:
114+
reference_sharding = _get_named_sharding(arrays[0])
115+
reference_spec = reference_sharding.spec
116+
return jax.sharding.NamedSharding(concatenated_mesh, reference_spec)
117+
118+
def _sharded_dim_idx_for_sharding(
119+
sharding: jax.sharding.NamedSharding,
120+
) -> int:
121+
sharded_dim_idx = -1
122+
for dim_idx, dim_spec in enumerate(sharding.spec):
123+
flat_dim_spec, _ = jax.tree_util.tree_flatten(dim_spec)
124+
if mesh_axis in flat_dim_spec:
125+
sharded_dim_idx = dim_idx
126+
break
127+
return sharded_dim_idx
128+
129+
out_shardings = [_get_output_sharding(arrays) for arrays in input_flat_arrays]
130+
sharded_dim_idxs = [
131+
_sharded_dim_idx_for_sharding(sharding) for sharding in out_shardings
132+
]
133+
134+
flat_output_arrays = pw_jax.concatenate_by_mesh_axis(
135+
arrays=input_flat_arrays,
136+
sharded_dim_idxs=sharded_dim_idxs,
137+
mesh_axis_sizes=concatenated_mesh.axis_sizes,
138+
mesh_axis_idx=mesh_axis_idx,
139+
mesh_axis_sections=mesh_axis_sections,
140+
out_shardings=out_shardings,
141+
donate=True,
142+
)
143+
144+
return jax.tree_util.tree_unflatten(input_treedefs, flat_output_arrays)

pathwaysutils/jax/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs):
5252
split_by_mesh_axis = _pathways._split_by_mesh_axis
5353
del _pathways
5454

55-
except ImportError:
55+
except (ImportError, AttributeError):
5656
# jax<0.8.0
5757

5858
split_by_mesh_axis = _FakeJaxFunction(
@@ -70,14 +70,31 @@ def __call__(self, *args, **kwargs):
7070

7171
del jaxlib_pathways
7272

73-
except ImportError:
73+
except (ImportError, AttributeError):
7474
# jax<0.8.3
7575
transfer_to_shardings = _FakeJaxFunction(
7676
"jax.jaxlib._pathways._transfer_to_shardings",
7777
"0.8.3",
7878
)
7979

8080

81+
try:
82+
# jax>=0.9.3
83+
# The import may fail if the JAX version is not new enough.
84+
from jaxlib import _pathways as jaxlib_pathways_concat # pylint: disable=g-import-not-at-top
85+
86+
concatenate_by_mesh_axis = jaxlib_pathways_concat._concatenate_by_mesh_axis
87+
88+
del jaxlib_pathways_concat
89+
90+
except (ImportError, AttributeError):
91+
# jax<0.9.3
92+
concatenate_by_mesh_axis = _FakeJaxFunction(
93+
"jax.jaxlib._pathways._concatenate_by_mesh_axis",
94+
"0.9.3",
95+
)
96+
97+
8198
@functools.lru_cache(maxsize=1)
8299
def ifrt_reshard_available() -> bool:
83100
"""Checks if transfer_to_shardings is available."""

0 commit comments

Comments
 (0)