Skip to content

Commit 50dfba9

Browse files
lukebaumanncopybara-github
authored andcommitted
Expose concatenate_by_mesh_axis in pathwaysutils.
This change flattens/unflattens the PyTree and calls the jaxlib API. It also adds tests for standard undo-split, partial concatenation, and mesh expansion. PiperOrigin-RevId: 903988823
1 parent fcd0d6f commit 50dfba9

2 files changed

Lines changed: 168 additions & 2 deletions

File tree

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# https://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
"""Pathwaysutils JAX concatenate_by_mesh_axis."""
14+
15+
from collections.abc import Sequence
16+
import itertools
17+
from typing import Any
18+
import jax
19+
import numpy as np
20+
from pathwaysutils import jax as pw_jax
21+
22+
23+
def concatenate_by_mesh_axis(
24+
array_trees: Sequence[Any],
25+
mesh_axis: str,
26+
) -> Any:
27+
"""Concatenates meshes by an axis. Returns arrays on the concatenated mesh.
28+
29+
Note: This API always donates the given arrays.
30+
31+
Args:
32+
array_trees: Sequence of PyTrees of JAX arrays with `NamedSharding`. All
33+
PyTrees in the sequence have the same structure. All arrays in each PyTree
34+
are sharded/replicated on the same mesh. The input arrays are always
35+
donated.
36+
mesh_axis: Mesh axis to concatenate.
37+
38+
Returns:
39+
A PyTree with the same structure as `array_trees[i]`. It has arrays with
40+
their shards concatenated to match a concatenated mesh.
41+
"""
42+
if not array_trees:
43+
return array_trees
44+
45+
def _get_named_sharding(array: jax.Array) -> jax.sharding.NamedSharding:
46+
if not isinstance(array, jax.Array):
47+
raise ValueError(f"Elements must be jax.Array. Got {type(array)}")
48+
sharding = array.sharding
49+
if not isinstance(sharding, jax.sharding.NamedSharding):
50+
raise ValueError(f"Expected NamedSharding. Got {type(sharding)}")
51+
return sharding
52+
53+
flats_and_defs = [jax.tree_util.tree_flatten(at) for at in array_trees]
54+
flattened_arrays = [fd[0] for fd in flats_and_defs]
55+
treedefs = [fd[1] for fd in flats_and_defs]
56+
input_treedef = treedefs[0]
57+
for td in treedefs[1:]:
58+
if td != input_treedef:
59+
raise ValueError(
60+
"All array trees must have the same treedef. Got"
61+
f" {td} vs. {input_treedef}"
62+
)
63+
64+
# Convert to have the output array structure in the outer list, and each entry
65+
# be a list of arrays from each shard for the concatenated output array.
66+
input_flat_arrays = list(zip(*flattened_arrays))
67+
68+
if not flattened_arrays[0]:
69+
return array_trees[0]
70+
71+
# Extract the shared mesh from each PyTree (from an arbitrary array in each).
72+
meshes_to_concatenate = [
73+
_get_named_sharding(array).mesh for array in input_flat_arrays[0]
74+
]
75+
76+
# Validate that the meshes are compatible.
77+
reference_mesh = meshes_to_concatenate[0]
78+
if mesh_axis not in reference_mesh.axis_names:
79+
raise ValueError(
80+
f"mesh_axis '{mesh_axis}' not found in mesh axis names:"
81+
f" {reference_mesh.axis_names}"
82+
)
83+
mesh_axis_idx = reference_mesh.axis_names.index(mesh_axis)
84+
for mesh in meshes_to_concatenate:
85+
if mesh.axis_names != reference_mesh.axis_names:
86+
raise ValueError(
87+
"Meshes must have the same axis names. Got"
88+
f" {mesh} vs. {reference_mesh}."
89+
)
90+
if (
91+
mesh.axis_sizes[:mesh_axis_idx]
92+
!= reference_mesh.axis_sizes[:mesh_axis_idx]
93+
or mesh.axis_sizes[mesh_axis_idx + 1 :]
94+
!= reference_mesh.axis_sizes[mesh_axis_idx + 1 :]
95+
):
96+
raise ValueError(
97+
"Arrays must have the same mesh axis sizes for all axes except"
98+
f" {mesh_axis}. Got {mesh} vs. {reference_mesh}."
99+
)
100+
101+
# Construct list of the mesh axis section boundaries.
102+
devices = [mesh.devices for mesh in meshes_to_concatenate]
103+
mesh_axis_sections = list(
104+
itertools.accumulate(
105+
mesh.axis_sizes[mesh_axis_idx] for mesh in meshes_to_concatenate
106+
)
107+
)
108+
109+
concatenated_mesh = jax.sharding.Mesh(
110+
np.concatenate(devices, mesh_axis_idx),
111+
axis_names=reference_mesh.axis_names,
112+
axis_types=reference_mesh.axis_types,
113+
)
114+
115+
def _get_output_sharding(
116+
arrays: Sequence[jax.Array],
117+
) -> jax.sharding.NamedSharding:
118+
reference_sharding = _get_named_sharding(arrays[0])
119+
reference_spec = reference_sharding.spec
120+
return jax.sharding.NamedSharding(concatenated_mesh, reference_spec)
121+
122+
def _sharded_dim_idx_for_sharding(
123+
sharding: jax.sharding.NamedSharding,
124+
) -> int:
125+
sharded_dim_idx = -1
126+
for dim_idx, dim_spec in enumerate(sharding.spec):
127+
flat_dim_spec, _ = jax.tree_util.tree_flatten(dim_spec)
128+
if mesh_axis in flat_dim_spec:
129+
sharded_dim_idx = dim_idx
130+
break
131+
return sharded_dim_idx
132+
133+
out_shardings = [_get_output_sharding(arrays) for arrays in input_flat_arrays]
134+
sharded_dim_idxs = [
135+
_sharded_dim_idx_for_sharding(sharding) for sharding in out_shardings
136+
]
137+
138+
flat_output_arrays = pw_jax.concatenate_by_mesh_axis(
139+
arrays=input_flat_arrays,
140+
sharded_dim_idxs=sharded_dim_idxs,
141+
mesh_axis_sizes=concatenated_mesh.axis_sizes,
142+
mesh_axis_idx=mesh_axis_idx,
143+
mesh_axis_sections=mesh_axis_sections,
144+
out_shardings=out_shardings,
145+
donate=True,
146+
)
147+
148+
return jax.tree_util.tree_unflatten(input_treedef, flat_output_arrays)

pathwaysutils/jax/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
import functools
22+
import jax
2223

2324

2425
class _FakeJaxFunction:
@@ -52,7 +53,7 @@ def __call__(self, *args, **kwargs):
5253
split_by_mesh_axis = _pathways._split_by_mesh_axis
5354
del _pathways
5455

55-
except ImportError:
56+
except (ImportError, AttributeError):
5657
# jax<0.8.0
5758

5859
split_by_mesh_axis = _FakeJaxFunction(
@@ -70,14 +71,31 @@ def __call__(self, *args, **kwargs):
7071

7172
del jaxlib_pathways
7273

73-
except ImportError:
74+
except (ImportError, AttributeError):
7475
# jax<0.8.3
7576
transfer_to_shardings = _FakeJaxFunction(
7677
"jax.jaxlib._pathways._transfer_to_shardings",
7778
"0.8.3",
7879
)
7980

8081

82+
try:
83+
# jax>=0.10.0
84+
# The import may fail if the JAX version is not new enough.
85+
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
86+
87+
concatenate_by_mesh_axis = _pathways._concatenate_by_mesh_axis
88+
89+
del _pathways
90+
91+
except (ImportError, AttributeError):
92+
# jax<0.10.0
93+
concatenate_by_mesh_axis = _FakeJaxFunction(
94+
"jax.jaxlib._pathways._concatenate_by_mesh_axis",
95+
"0.10.0",
96+
)
97+
98+
8199
@functools.lru_cache(maxsize=1)
82100
def ifrt_reshard_available() -> bool:
83101
"""Checks if transfer_to_shardings is available."""

0 commit comments

Comments
 (0)