1515
1616import base64
1717import collections
18+ from collections .abc import Mapping
1819import json
1920import logging
2021import math
2122import operator
2223from typing import Any , Callable , Dict , Mapping , Sequence
23- import warnings
2424
2525import jax
2626from pathwaysutils import jax as pw_jax
2727from pathwaysutils import lru_cache
2828from pathwaysutils import plugin_executable
29+ from pathwaysutils import reshard as pw_reshard
2930from pathwaysutils .experimental import split_by_mesh_axis
3031
3132
@@ -116,78 +117,6 @@ def _get_resharding_plan(
116117_get_resharding_plan_cached = lru_cache .lru_cache ()(_get_resharding_plan )
117118
118119
119- def _reshard (
120- x : Any ,
121- sharding : jax .sharding .Sharding | Any ,
122- * ,
123- donate : bool ,
124- may_alias : bool | None ,
125- jax_array_reshard_fn : Callable [..., Any ],
126- ** kwargs ,
127- ) -> Any :
128- """Reshards `x` to `sharding`."""
129- flat_x , tree_def = jax .tree .flatten (x )
130- flat_sharding = jax .api_util .flatten_axes (
131- "reshard sharding" , tree_def , sharding
132- )
133-
134- # We must split the arrays into two groups:
135- # 1. jax.Array
136- # 2. non jax.Array
137- # For jax.Array, we will use the ifrt client to get the resharding plan and
138- # execute it.
139- # These arrays must be further split into groups based on the device set of
140- # the sharding, since plugin programs only supports execution on the same
141- # device set.
142- # For non jax.Array, we will use jax.device_put to put the array to the
143- # destination devices.
144- #
145- # We need to track what index each array is in the original pytree, so we can
146- # put them back together in the right order.
147- array_info_lambda = lambda : {"arrays" : [], "indices" : [], "dst_shardings" : []}
148- jax_arrays = collections .defaultdict (array_info_lambda )
149- non_reshardable_arrays = array_info_lambda ()
150- for index , (arr , dst_sharding ) in enumerate (zip (flat_x , flat_sharding )):
151- if not isinstance (dst_sharding , jax .sharding .Sharding ):
152- raise ValueError ("`sharding` must contain only `jax.sharding.Sharding`" )
153- if not isinstance (arr , jax .Array ) or (
154- hasattr (arr , "dtype" )
155- and jax .dtypes .issubdtype (arr .dtype , jax .dtypes .prng_key )
156- ):
157- non_reshardable_arrays ["arrays" ].append (arr )
158- non_reshardable_arrays ["indices" ].append (index )
159- non_reshardable_arrays ["dst_shardings" ].append (dst_sharding )
160- else :
161- device_set = frozenset (arr .sharding .device_set )
162- jax_arrays [device_set ]["arrays" ].append (arr )
163- jax_arrays [device_set ]["indices" ].append (index )
164- jax_arrays [device_set ]["dst_shardings" ].append (dst_sharding )
165-
166- if non_reshardable_arrays ["arrays" ]:
167- non_reshardable_arrays ["arrays" ] = jax .device_put (
168- non_reshardable_arrays ["arrays" ],
169- non_reshardable_arrays ["dst_shardings" ],
170- donate = donate ,
171- may_alias = may_alias ,
172- )
173-
174- for array_info in jax_arrays .values ():
175- array_info ["arrays" ] = jax_array_reshard_fn (
176- array_info , donate = donate , ** kwargs
177- )
178-
179- result = [None ] * len (flat_x )
180- for arr , idx in zip (
181- non_reshardable_arrays ["arrays" ], non_reshardable_arrays ["indices" ]
182- ):
183- result [idx ] = arr
184- for array_info in jax_arrays .values ():
185- for arr , idx in zip (array_info ["arrays" ], array_info ["indices" ]):
186- result [idx ] = arr
187-
188- return jax .tree .unflatten (tree_def , result )
189-
190-
191120def _sidechannel_jax_array_reshard (
192121 array_info : Mapping [str , Any ], * , donate : bool , cache_resharding_plans : bool
193122) -> Sequence [jax .Array ]:
@@ -214,61 +143,6 @@ def _ifrt_jax_array_reshard(
214143 )
215144
216145
217- def _reshard_with_sidechannel (
218- x : Any ,
219- sharding : jax .sharding .Sharding | Any ,
220- * ,
221- donate : bool ,
222- may_alias : bool | None ,
223- cache_resharding_plans : bool ,
224- ) -> Any :
225- """Reshards `x` to `sharding` using sidechannel."""
226- return _reshard (
227- x ,
228- sharding ,
229- donate = donate ,
230- may_alias = may_alias ,
231- jax_array_reshard_fn = _sidechannel_jax_array_reshard ,
232- cache_resharding_plans = cache_resharding_plans ,
233- )
234-
235-
236- def _reshard_with_ifrt (
237- x : Any ,
238- sharding : jax .sharding .Sharding | Any ,
239- * ,
240- donate : bool ,
241- may_alias : bool | None ,
242- ) -> Any :
243- """Reshards `x` to `sharding` using IFRT.
244-
245- Note: Resharding plan caching is not applicable to the IFRT implementation
246- and is not supported by this function.
247-
248- Args:
249- x: An array, scalar, or (nested) standard Python container thereof.
250- sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
251- (must be a tree prefix of `x`), representing the device(s) and sharding to
252- which `x` should be sharded to. The result will be committed to the
253- device(s) of the sharding.
254- donate: If `True`, donate all input arrays, which may reduce the amount of
255- memory needed for resharding. Buffers donated to resharding should not be
256- reused.
257- may_alias: If `True`, may alias the input array with the output array. May
258- reduce the amount of memory needed for resharding. Not used at the moment.
259-
260- Returns:
261- A copy of `x` whose sharding is `sharding`.
262- """
263- return _reshard (
264- x ,
265- sharding ,
266- donate = donate ,
267- may_alias = may_alias ,
268- jax_array_reshard_fn = _ifrt_jax_array_reshard ,
269- )
270-
271-
272146def reshard (
273147 x : Any ,
274148 sharding : jax .sharding .Sharding | Any ,
@@ -279,6 +153,9 @@ def reshard(
279153) -> Any :
280154 """Reshards `x` to `sharding`.
281155
156+ This function is an alternative to `pathwaysutils.reshard` that uses the
157+ sidechannel resharding API for the final reshard.
158+
282159 Args:
283160 x: An array, scalar, or (nested) standard Python container thereof.
284161 sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
@@ -291,38 +168,19 @@ def reshard(
291168 may_alias: If `True`, may alias the input array with the output array. May
292169 reduce the amount of memory needed for resharding. Not used at the moment.
293170 cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
294- recreating plans for the same resharding operation. May improve
295- performance for use cases where the same resharding operation is done
296- many times. May degrade performance if most reshardings operations are
297- different, since the cache will cause Pathways Components to remain
298- loaded for each cached plan. `False` by default. This parameter is only
299- used when `pw_jax.ifrt_reshard_available()` is false.
171+ recreating plans for the same resharding operation.
300172
301173 Returns:
302174 A copy of `x` whose sharding is `sharding`.
303175 """
304- if pw_jax .ifrt_reshard_available ():
305- if cache_resharding_plans :
306- warnings .warn (
307- "`cache_resharding_plans` is only applicable when using the"
308- " sidechannel resharding implementation, but IFRT resharding is"
309- " available and will be used. The `cache_resharding_plans` argument"
310- " will be ignored."
311- )
312- return _reshard_with_ifrt (
313- x ,
314- sharding ,
315- donate = donate ,
316- may_alias = may_alias ,
317- )
318- else :
319- return _reshard_with_sidechannel (
320- x ,
321- sharding ,
322- donate = donate ,
323- may_alias = may_alias ,
324- cache_resharding_plans = cache_resharding_plans ,
325- )
176+ return pw_reshard .reshard_generic (
177+ x ,
178+ sharding ,
179+ donate = donate ,
180+ may_alias = may_alias ,
181+ jax_array_reshard_fn = _sidechannel_jax_array_reshard ,
182+ cache_resharding_plans = cache_resharding_plans ,
183+ )
326184
327185
328186class NoIntermediateShardingError (Exception ):
@@ -564,40 +422,38 @@ def find_intermediate_sharding(
564422 return intermediate_sharding , replicated_axes
565423
566424
567- def reshard_with_intermediate_sharding (
425+ def reshard_with_intermediate_sharding_generic (
568426 x : Any ,
569427 in_sharding : jax .sharding .Sharding ,
570428 out_sharding : jax .sharding .Sharding ,
571429 * ,
430+ jax_array_reshard_fn : Callable [..., Sequence [jax .Array ]],
572431 donate : bool = False ,
573432 may_alias : bool | None = None ,
574- cache_resharding_plans : bool = False ,
433+ ** kwargs : Any ,
575434) -> Any :
576435 """Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
577436
578- This function is an alternative to `reshard` that may be faster and sometime
579- essential for certain sharding combinations by using an intermediate sharding
580- to avoid expensive all-gathers. If no beneficial intermediate sharding is
581- found, it falls back to standard resharding. See `find_intermediate_sharding`
582- for more details on when an intermediate sharding is used.
437+ This function is a generic version of `reshard_with_intermediate_sharding`
438+ that allows specifying the `jax_array_reshard_fn` to be used for the final
439+ reshard.
583440
584441 Args:
585442 x: An array, scalar, or (nested) standard Python container thereof.
586443 in_sharding: The source sharding of `x`.
587444 out_sharding: The target sharding for `x`.
445+ jax_array_reshard_fn: The function used for the final reshard of JAX arrays.
588446 donate: If `True`, donate all input arrays, which may reduce the amount of
589447 memory needed for resharding. Buffers donated to resharding should not be
590448 reused.
591449 may_alias: If `True`, may alias the input array with the output array. May
592450 reduce the amount of memory needed for resharding. Not used at the moment.
593- cache_resharding_plans: Only used when resharding with sidechannel. If
594- `True`, uses a resharding plan cache to avoid recreating plans for the
595- same resharding operation.
451+ **kwargs: Additional keyword arguments to be passed to
452+ `jax_array_reshard_fn`.
596453
597454 Returns:
598455 A copy of `x` whose sharding is `out_sharding`.
599456 """
600-
601457 try :
602458 intermediate_sharding , replicated_axes_names = find_intermediate_sharding (
603459 in_sharding , out_sharding
@@ -617,10 +473,98 @@ def reshard_with_intermediate_sharding(
617473 donate = donate ,
618474 )
619475
620- return reshard (
476+ return pw_reshard . reshard_generic (
621477 x_to_reshard ,
622478 out_sharding ,
623479 donate = donate ,
624480 may_alias = may_alias ,
481+ jax_array_reshard_fn = jax_array_reshard_fn ,
482+ ** kwargs ,
483+ )
484+
485+
486+ def reshard_with_intermediate_sharding (
487+ x : Any ,
488+ in_sharding : jax .sharding .Sharding ,
489+ out_sharding : jax .sharding .Sharding ,
490+ * ,
491+ donate : bool = False ,
492+ may_alias : bool | None = None ,
493+ ) -> Any :
494+ """Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
495+
496+ This function is an alternative to `reshard` that may be faster and sometimes
497+ essential for certain sharding combinations by using an intermediate sharding
498+ to avoid expensive all-gathers. If no beneficial intermediate sharding is
499+ found, it falls back to standard resharding. See `find_intermediate_sharding`
500+ for more details on when an intermediate sharding is used.
501+
502+ Uses the IFRT resharding API for the final reshard.
503+
504+ Args:
505+ x: An array, scalar, or (nested) standard Python container thereof.
506+ in_sharding: The source sharding of `x`.
507+ out_sharding: The target sharding for `x`.
508+ donate: If `True`, donate all input arrays, which may reduce the amount of
509+ memory needed for resharding. Buffers donated to resharding should not be
510+ reused.
511+ may_alias: If `True`, may alias the input array with the output array. May
512+ reduce the amount of memory needed for resharding. Not used at the moment.
513+
514+ Returns:
515+ A copy of `x` whose sharding is `out_sharding`.
516+ """
517+
518+ return reshard_with_intermediate_sharding_generic (
519+ x ,
520+ in_sharding ,
521+ out_sharding ,
522+ donate = donate ,
523+ may_alias = may_alias ,
524+ jax_array_reshard_fn = _ifrt_jax_array_reshard ,
525+ )
526+
527+
528+ def sidechannel_reshard_with_intermediate_sharding (
529+ x : Any ,
530+ in_sharding : jax .sharding .Sharding ,
531+ out_sharding : jax .sharding .Sharding ,
532+ * ,
533+ donate : bool = False ,
534+ may_alias : bool | None = None ,
535+ cache_resharding_plans : bool = False ,
536+ ) -> Any :
537+ """Reshards `x` to `out_sharding`, using an intermediate sharding if possible.
538+
539+ This function is an alternative to `reshard` that may be faster and sometimes
540+ essential for certain sharding combinations by using an intermediate sharding
541+ to avoid expensive all-gathers. If no beneficial intermediate sharding is
542+ found, it falls back to standard resharding. See `find_intermediate_sharding`
543+ for more details on when an intermediate sharding is used.
544+
545+ Uses the sidechannel resharding API for the final reshard.
546+
547+ Args:
548+ x: An array, scalar, or (nested) standard Python container thereof.
549+ in_sharding: The source sharding of `x`.
550+ out_sharding: The target sharding for `x`.
551+ donate: If `True`, donate all input arrays, which may reduce the amount of
552+ memory needed for resharding. Buffers donated to resharding should not be
553+ reused.
554+ may_alias: If `True`, may alias the input array with the output array. May
555+ reduce the amount of memory needed for resharding. Not used at the moment.
556+ cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
557+ recreating plans for the same resharding operation.
558+
559+ Returns:
560+ A copy of `x` whose sharding is `out_sharding`.
561+ """
562+ return reshard_with_intermediate_sharding_generic (
563+ x ,
564+ in_sharding ,
565+ out_sharding ,
566+ donate = donate ,
567+ may_alias = may_alias ,
568+ jax_array_reshard_fn = _sidechannel_jax_array_reshard ,
625569 cache_resharding_plans = cache_resharding_plans ,
626570 )
0 commit comments