Skip to content

Commit c16512c

Browse files
lukebaumanncopybara-github
authored andcommitted
Move ifrt based reshard out of experimental. Leaving intermediate resharding and sidechannel resharding in experimental.
PiperOrigin-RevId: 875850174
1 parent 41666fa commit c16512c

4 files changed

Lines changed: 365 additions & 233 deletions

File tree

pathwaysutils/experimental/reshard.py

Lines changed: 112 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515

1616
import base64
1717
import collections
18+
from collections.abc import Mapping
1819
import json
1920
import logging
2021
import math
2122
import operator
2223
from typing import Any, Callable, Dict, Mapping, Sequence
23-
import warnings
2424

2525
import jax
2626
from pathwaysutils import jax as pw_jax
2727
from pathwaysutils import lru_cache
2828
from pathwaysutils import plugin_executable
29+
from pathwaysutils import reshard as pw_reshard
2930
from pathwaysutils.experimental import split_by_mesh_axis
3031

3132

@@ -116,78 +117,6 @@ def _get_resharding_plan(
116117
_get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan)
117118

118119

119-
def _reshard(
120-
x: Any,
121-
sharding: jax.sharding.Sharding | Any,
122-
*,
123-
donate: bool,
124-
may_alias: bool | None,
125-
jax_array_reshard_fn: Callable[..., Any],
126-
**kwargs,
127-
) -> Any:
128-
"""Reshards `x` to `sharding`."""
129-
flat_x, tree_def = jax.tree.flatten(x)
130-
flat_sharding = jax.api_util.flatten_axes(
131-
"reshard sharding", tree_def, sharding
132-
)
133-
134-
# We must split the arrays into two groups:
135-
# 1. jax.Array
136-
# 2. non jax.Array
137-
# For jax.Array, we will use the ifrt client to get the resharding plan and
138-
# execute it.
139-
# These arrays must be further split into groups based on the device set of
140-
# the sharding, since plugin programs only supports execution on the same
141-
# device set.
142-
# For non jax.Array, we will use jax.device_put to put the array to the
143-
# destination devices.
144-
#
145-
# We need to track what index each array is in the original pytree, so we can
146-
# put them back together in the right order.
147-
array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []}
148-
jax_arrays = collections.defaultdict(array_info_lambda)
149-
non_reshardable_arrays = array_info_lambda()
150-
for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)):
151-
if not isinstance(dst_sharding, jax.sharding.Sharding):
152-
raise ValueError("`sharding` must contain only `jax.sharding.Sharding`")
153-
if not isinstance(arr, jax.Array) or (
154-
hasattr(arr, "dtype")
155-
and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
156-
):
157-
non_reshardable_arrays["arrays"].append(arr)
158-
non_reshardable_arrays["indices"].append(index)
159-
non_reshardable_arrays["dst_shardings"].append(dst_sharding)
160-
else:
161-
device_set = frozenset(arr.sharding.device_set)
162-
jax_arrays[device_set]["arrays"].append(arr)
163-
jax_arrays[device_set]["indices"].append(index)
164-
jax_arrays[device_set]["dst_shardings"].append(dst_sharding)
165-
166-
if non_reshardable_arrays["arrays"]:
167-
non_reshardable_arrays["arrays"] = jax.device_put(
168-
non_reshardable_arrays["arrays"],
169-
non_reshardable_arrays["dst_shardings"],
170-
donate=donate,
171-
may_alias=may_alias,
172-
)
173-
174-
for array_info in jax_arrays.values():
175-
array_info["arrays"] = jax_array_reshard_fn(
176-
array_info, donate=donate, **kwargs
177-
)
178-
179-
result = [None] * len(flat_x)
180-
for arr, idx in zip(
181-
non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"]
182-
):
183-
result[idx] = arr
184-
for array_info in jax_arrays.values():
185-
for arr, idx in zip(array_info["arrays"], array_info["indices"]):
186-
result[idx] = arr
187-
188-
return jax.tree.unflatten(tree_def, result)
189-
190-
191120
def _sidechannel_jax_array_reshard(
192121
array_info: Mapping[str, Any], *, donate: bool, cache_resharding_plans: bool
193122
) -> Sequence[jax.Array]:
@@ -214,61 +143,6 @@ def _ifrt_jax_array_reshard(
214143
)
215144

216145

217-
def _reshard_with_sidechannel(
218-
x: Any,
219-
sharding: jax.sharding.Sharding | Any,
220-
*,
221-
donate: bool,
222-
may_alias: bool | None,
223-
cache_resharding_plans: bool,
224-
) -> Any:
225-
"""Reshards `x` to `sharding` using sidechannel."""
226-
return _reshard(
227-
x,
228-
sharding,
229-
donate=donate,
230-
may_alias=may_alias,
231-
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
232-
cache_resharding_plans=cache_resharding_plans,
233-
)
234-
235-
236-
def _reshard_with_ifrt(
237-
x: Any,
238-
sharding: jax.sharding.Sharding | Any,
239-
*,
240-
donate: bool,
241-
may_alias: bool | None,
242-
) -> Any:
243-
"""Reshards `x` to `sharding` using IFRT.
244-
245-
Note: Resharding plan caching is not applicable to the IFRT implementation
246-
and is not supported by this function.
247-
248-
Args:
249-
x: An array, scalar, or (nested) standard Python container thereof.
250-
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
251-
(must be a tree prefix of `x`), representing the device(s) and sharding to
252-
which `x` should be sharded to. The result will be committed to the
253-
device(s) of the sharding.
254-
donate: If `True`, donate all input arrays, which may reduce the amount of
255-
memory needed for resharding. Buffers donated to resharding should not be
256-
reused.
257-
may_alias: If `True`, may alias the input array with the output array. May
258-
reduce the amount of memory needed for resharding. Not used at the moment.
259-
260-
Returns:
261-
A copy of `x` whose sharding is `sharding`.
262-
"""
263-
return _reshard(
264-
x,
265-
sharding,
266-
donate=donate,
267-
may_alias=may_alias,
268-
jax_array_reshard_fn=_ifrt_jax_array_reshard,
269-
)
270-
271-
272146
def reshard(
273147
x: Any,
274148
sharding: jax.sharding.Sharding | Any,
@@ -279,6 +153,9 @@ def reshard(
279153
) -> Any:
280154
"""Reshards `x` to `sharding`.
281155
156+
This function is an alternative to `pathwaysutils.reshard` that uses the
157+
sidechannel resharding API for the final reshard.
158+
282159
Args:
283160
x: An array, scalar, or (nested) standard Python container thereof.
284161
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
@@ -291,38 +168,19 @@ def reshard(
291168
may_alias: If `True`, may alias the input array with the output array. May
292169
reduce the amount of memory needed for resharding. Not used at the moment.
293170
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
294-
recreating plans for the same resharding operation. May improve
295-
performance for use cases where the same resharding operation is done
296-
many times. May degrade performance if most reshardings operations are
297-
different, since the cache will cause Pathways Components to remain
298-
loaded for each cached plan. `False` by default. This parameter is only
299-
used when `pw_jax.ifrt_reshard_available()` is false.
171+
recreating plans for the same resharding operation.
300172
301173
Returns:
302174
A copy of `x` whose sharding is `sharding`.
303175
"""
304-
if pw_jax.ifrt_reshard_available():
305-
if cache_resharding_plans:
306-
warnings.warn(
307-
"`cache_resharding_plans` is only applicable when using the"
308-
" sidechannel resharding implementation, but IFRT resharding is"
309-
" available and will be used. The `cache_resharding_plans` argument"
310-
" will be ignored."
311-
)
312-
return _reshard_with_ifrt(
313-
x,
314-
sharding,
315-
donate=donate,
316-
may_alias=may_alias,
317-
)
318-
else:
319-
return _reshard_with_sidechannel(
320-
x,
321-
sharding,
322-
donate=donate,
323-
may_alias=may_alias,
324-
cache_resharding_plans=cache_resharding_plans,
325-
)
176+
return pw_reshard.reshard_generic(
177+
x,
178+
sharding,
179+
donate=donate,
180+
may_alias=may_alias,
181+
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
182+
cache_resharding_plans=cache_resharding_plans,
183+
)
326184

327185

328186
class NoIntermediateShardingError(Exception):
@@ -564,40 +422,38 @@ def find_intermediate_sharding(
564422
return intermediate_sharding, replicated_axes
565423

566424

567-
def reshard_with_intermediate_sharding(
425+
def reshard_with_intermediate_sharding_generic(
568426
x: Any,
569427
in_sharding: jax.sharding.Sharding,
570428
out_sharding: jax.sharding.Sharding,
571429
*,
430+
jax_array_reshard_fn: Callable[..., Sequence[jax.Array]],
572431
donate: bool = False,
573432
may_alias: bool | None = None,
574-
cache_resharding_plans: bool = False,
433+
**kwargs: Any,
575434
) -> Any:
576435
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
577436
578-
This function is an alternative to `reshard` that may be faster and sometime
579-
essential for certain sharding combinations by using an intermediate sharding
580-
to avoid expensive all-gathers. If no beneficial intermediate sharding is
581-
found, it falls back to standard resharding. See `find_intermediate_sharding`
582-
for more details on when an intermediate sharding is used.
437+
This function is a generic version of `reshard_with_intermediate_sharding`
438+
that allows specifying the `jax_array_reshard_fn` to be used for the final
439+
reshard.
583440
584441
Args:
585442
x: An array, scalar, or (nested) standard Python container thereof.
586443
in_sharding: The source sharding of `x`.
587444
out_sharding: The target sharding for `x`.
445+
jax_array_reshard_fn: The function used for the final reshard of JAX arrays.
588446
donate: If `True`, donate all input arrays, which may reduce the amount of
589447
memory needed for resharding. Buffers donated to resharding should not be
590448
reused.
591449
may_alias: If `True`, may alias the input array with the output array. May
592450
reduce the amount of memory needed for resharding. Not used at the moment.
593-
cache_resharding_plans: Only used when resharding with sidechannel. If
594-
`True`, uses a resharding plan cache to avoid recreating plans for the
595-
same resharding operation.
451+
**kwargs: Additional keyword arguments to be passed to
452+
`jax_array_reshard_fn`.
596453
597454
Returns:
598455
A copy of `x` whose sharding is `out_sharding`.
599456
"""
600-
601457
try:
602458
intermediate_sharding, replicated_axes_names = find_intermediate_sharding(
603459
in_sharding, out_sharding
@@ -617,10 +473,98 @@ def reshard_with_intermediate_sharding(
617473
donate=donate,
618474
)
619475

620-
return reshard(
476+
return pw_reshard.reshard_generic(
621477
x_to_reshard,
622478
out_sharding,
623479
donate=donate,
624480
may_alias=may_alias,
481+
jax_array_reshard_fn=jax_array_reshard_fn,
482+
**kwargs,
483+
)
484+
485+
486+
def reshard_with_intermediate_sharding(
487+
x: Any,
488+
in_sharding: jax.sharding.Sharding,
489+
out_sharding: jax.sharding.Sharding,
490+
*,
491+
donate: bool = False,
492+
may_alias: bool | None = None,
493+
) -> Any:
494+
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
495+
496+
This function is an alternative to `reshard` that may be faster and sometimes
497+
essential for certain sharding combinations by using an intermediate sharding
498+
to avoid expensive all-gathers. If no beneficial intermediate sharding is
499+
found, it falls back to standard resharding. See `find_intermediate_sharding`
500+
for more details on when an intermediate sharding is used.
501+
502+
Uses the IFRT resharding API for the final reshard.
503+
504+
Args:
505+
x: An array, scalar, or (nested) standard Python container thereof.
506+
in_sharding: The source sharding of `x`.
507+
out_sharding: The target sharding for `x`.
508+
donate: If `True`, donate all input arrays, which may reduce the amount of
509+
memory needed for resharding. Buffers donated to resharding should not be
510+
reused.
511+
may_alias: If `True`, may alias the input array with the output array. May
512+
reduce the amount of memory needed for resharding. Not used at the moment.
513+
514+
Returns:
515+
A copy of `x` whose sharding is `out_sharding`.
516+
"""
517+
518+
return reshard_with_intermediate_sharding_generic(
519+
x,
520+
in_sharding,
521+
out_sharding,
522+
donate=donate,
523+
may_alias=may_alias,
524+
jax_array_reshard_fn=_ifrt_jax_array_reshard,
525+
)
526+
527+
528+
def sidechannel_reshard_with_intermediate_sharding(
529+
x: Any,
530+
in_sharding: jax.sharding.Sharding,
531+
out_sharding: jax.sharding.Sharding,
532+
*,
533+
donate: bool = False,
534+
may_alias: bool | None = None,
535+
cache_resharding_plans: bool = False,
536+
) -> Any:
537+
"""Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
538+
539+
This function is an alternative to `reshard` that may be faster and sometimes
540+
essential for certain sharding combinations by using an intermediate sharding
541+
to avoid expensive all-gathers. If no beneficial intermediate sharding is
542+
found, it falls back to standard resharding. See `find_intermediate_sharding`
543+
for more details on when an intermediate sharding is used.
544+
545+
Uses the sidechannel resharding API for the final reshard.
546+
547+
Args:
548+
x: An array, scalar, or (nested) standard Python container thereof.
549+
in_sharding: The source sharding of `x`.
550+
out_sharding: The target sharding for `x`.
551+
donate: If `True`, donate all input arrays, which may reduce the amount of
552+
memory needed for resharding. Buffers donated to resharding should not be
553+
reused.
554+
may_alias: If `True`, may alias the input array with the output array. May
555+
reduce the amount of memory needed for resharding. Not used at the moment.
556+
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
557+
recreating plans for the same resharding operation.
558+
559+
Returns:
560+
A copy of `x` whose sharding is `out_sharding`.
561+
"""
562+
return reshard_with_intermediate_sharding_generic(
563+
x,
564+
in_sharding,
565+
out_sharding,
566+
donate=donate,
567+
may_alias=may_alias,
568+
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
625569
cache_resharding_plans=cache_resharding_plans,
626570
)

0 commit comments

Comments
 (0)