|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Pathwaysutils JAX concatenate_by_mesh_axis.""" |
| 15 | + |
| 16 | +from typing import Any, Sequence |
| 17 | + |
| 18 | +import jax |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +from pathwaysutils import jax as pw_jax |
| 22 | + |
| 23 | + |
| 24 | +def concatenate_by_mesh_axis( |
| 25 | + array_trees: Sequence[Any], |
| 26 | + mesh_axis: str, |
| 27 | +) -> Any: |
| 28 | + """Concatenates meshes by an axis. Returns arrays on the concatenated mesh. |
| 29 | +
|
| 30 | + Note: This API donates the given arrays. |
| 31 | +
|
| 32 | + Args: |
| 33 | + array_trees: Sequence of PyTrees of JAX arrays with `NamedSharding`. All |
| 34 | + PyTrees in the sequence have the same structure. All arrays in each PyTree |
| 35 | + are sharded/replicated on the same mesh. |
| 36 | + mesh_axis: Mesh axis to concatenate. |
| 37 | +
|
| 38 | + Returns: |
| 39 | + A PyTree with the same structure as `array_trees[i]`. It has arrays with |
| 40 | + their shards concatenated to match a concatenated mesh. |
| 41 | + """ |
| 42 | + if not array_trees: |
| 43 | + return array_trees |
| 44 | + |
| 45 | + def _get_named_sharding(array: jax.Array) -> jax.sharding.NamedSharding: |
| 46 | + if not isinstance(array, jax.Array): |
| 47 | + raise ValueError(f"Elements must be jax.Array. Got {type(array)}") |
| 48 | + sharding = array.sharding |
| 49 | + if not isinstance(sharding, jax.sharding.NamedSharding): |
| 50 | + raise ValueError(f"Expected NamedSharding. Got {type(sharding)}") |
| 51 | + return sharding |
| 52 | + |
| 53 | + flattened_arrays = [] |
| 54 | + input_treedefs = None |
| 55 | + for array_tree in array_trees: |
| 56 | + flat, treedef = jax.tree_util.tree_flatten(array_tree) |
| 57 | + flattened_arrays.append(flat) |
| 58 | + if input_treedefs is not None: |
| 59 | + if treedef != input_treedefs: |
| 60 | + raise ValueError( |
| 61 | + "All array trees must have the same treedef. Got" |
| 62 | + f" {treedef} vs. {input_treedefs}" |
| 63 | + ) |
| 64 | + else: |
| 65 | + input_treedefs = treedef |
| 66 | + |
| 67 | + # Convert to have the output array structure in the outer list, and each entry |
| 68 | + # be a list of arrays from each shard for the concatenated output array. |
| 69 | + input_flat_arrays = list(zip(*flattened_arrays)) |
| 70 | + |
| 71 | + # Extract the shared mesh from each PyTree (from an arbitrary array in each). |
| 72 | + meshes_to_concatenate = [ |
| 73 | + _get_named_sharding(array).mesh for array in input_flat_arrays[0] |
| 74 | + ] |
| 75 | + |
| 76 | + # Validate that the meshes are compatible. |
| 77 | + reference_mesh = meshes_to_concatenate[0] |
| 78 | + mesh_axis_idx = reference_mesh.axis_names.index(mesh_axis) |
| 79 | + for mesh in meshes_to_concatenate: |
| 80 | + if mesh.axis_names != reference_mesh.axis_names: |
| 81 | + raise ValueError( |
| 82 | + "Meshes must have the same axis names. Got" |
| 83 | + f" {mesh} vs. {reference_mesh}." |
| 84 | + ) |
| 85 | + if ( |
| 86 | + mesh.axis_sizes[:mesh_axis_idx] |
| 87 | + != reference_mesh.axis_sizes[:mesh_axis_idx] |
| 88 | + or mesh.axis_sizes[mesh_axis_idx + 1 :] |
| 89 | + != reference_mesh.axis_sizes[mesh_axis_idx + 1 :] |
| 90 | + ): |
| 91 | + raise ValueError( |
| 92 | + "Arrays must have the same mesh axis sizes for all axes except" |
| 93 | + f" {mesh_axis}. Got {mesh} vs. {reference_mesh}." |
| 94 | + ) |
| 95 | + |
| 96 | + # Construct list of the mesh axis section boundaries. |
| 97 | + devices = [] |
| 98 | + mesh_axis_sections = [] |
| 99 | + cur_mesh_axis_size = 0 |
| 100 | + for mesh in meshes_to_concatenate: |
| 101 | + devices.append(mesh.devices) |
| 102 | + cur_mesh_axis_size += mesh.axis_sizes[mesh_axis_idx] |
| 103 | + mesh_axis_sections.append(cur_mesh_axis_size) |
| 104 | + |
| 105 | + concatenated_mesh = jax.sharding.Mesh( |
| 106 | + np.concatenate(devices, mesh_axis_idx), |
| 107 | + axis_names=reference_mesh.axis_names, |
| 108 | + axis_types=reference_mesh.axis_types, |
| 109 | + ) |
| 110 | + |
| 111 | + def _get_output_sharding( |
| 112 | + arrays: Sequence[jax.Array], |
| 113 | + ) -> jax.sharding.NamedSharding: |
| 114 | + reference_sharding = _get_named_sharding(arrays[0]) |
| 115 | + reference_spec = reference_sharding.spec |
| 116 | + return jax.sharding.NamedSharding(concatenated_mesh, reference_spec) |
| 117 | + |
| 118 | + def _sharded_dim_idx_for_sharding( |
| 119 | + sharding: jax.sharding.NamedSharding, |
| 120 | + ) -> int: |
| 121 | + sharded_dim_idx = -1 |
| 122 | + for dim_idx, dim_spec in enumerate(sharding.spec): |
| 123 | + flat_dim_spec, _ = jax.tree_util.tree_flatten(dim_spec) |
| 124 | + if mesh_axis in flat_dim_spec: |
| 125 | + sharded_dim_idx = dim_idx |
| 126 | + break |
| 127 | + return sharded_dim_idx |
| 128 | + |
| 129 | + out_shardings = [_get_output_sharding(arrays) for arrays in input_flat_arrays] |
| 130 | + sharded_dim_idxs = [ |
| 131 | + _sharded_dim_idx_for_sharding(sharding) for sharding in out_shardings |
| 132 | + ] |
| 133 | + |
| 134 | + flat_output_arrays = pw_jax.concatenate_by_mesh_axis( |
| 135 | + arrays=input_flat_arrays, |
| 136 | + sharded_dim_idxs=sharded_dim_idxs, |
| 137 | + mesh_axis_sizes=concatenated_mesh.axis_sizes, |
| 138 | + mesh_axis_idx=mesh_axis_idx, |
| 139 | + mesh_axis_sections=mesh_axis_sections, |
| 140 | + out_shardings=out_shardings, |
| 141 | + donate=True, |
| 142 | + ) |
| 143 | + |
| 144 | + return jax.tree_util.tree_unflatten(input_treedefs, flat_output_arrays) |
0 commit comments