Skip to content

Commit e7d15cc

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor resharding logic into helper functions.
This change introduces `_reshard_with_sidechannel` and `_reshard_with_ifrt` to encapsulate the different resharding mechanisms used by the `reshard` function. These are internal APIs and should not be depended on. PiperOrigin-RevId: 857415781
1 parent dad55e8 commit e7d15cc

2 files changed

Lines changed: 218 additions & 10 deletions

File tree

pathwaysutils/experimental/reshard.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import collections
1818
import json
1919
from typing import Any, Callable, Dict, Mapping, Sequence
20+
import warnings
2021

2122
import jax
2223
from pathwaysutils import jax as pw_jax
@@ -104,7 +105,7 @@ def _reshard(
104105
x: Any,
105106
sharding: jax.sharding.Sharding | Any,
106107
*,
107-
donate: bool = False,
108+
donate: bool,
108109
may_alias: bool | None,
109110
jax_array_reshard_fn: Callable[..., Any],
110111
**kwargs,
@@ -198,6 +199,61 @@ def _ifrt_jax_array_reshard(
198199
)
199200

200201

202+
def _reshard_with_sidechannel(
203+
x: Any,
204+
sharding: jax.sharding.Sharding | Any,
205+
*,
206+
donate: bool,
207+
may_alias: bool | None,
208+
cache_resharding_plans: bool,
209+
) -> Any:
210+
"""Reshards `x` to `sharding` using sidechannel."""
211+
return _reshard(
212+
x,
213+
sharding,
214+
donate=donate,
215+
may_alias=may_alias,
216+
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
217+
cache_resharding_plans=cache_resharding_plans,
218+
)
219+
220+
221+
def _reshard_with_ifrt(
222+
x: Any,
223+
sharding: jax.sharding.Sharding | Any,
224+
*,
225+
donate: bool,
226+
may_alias: bool | None,
227+
) -> Any:
228+
"""Reshards `x` to `sharding` using IFRT.
229+
230+
Note: Resharding plan caching is not applicable to the IFRT implementation
231+
and is not supported by this function.
232+
233+
Args:
234+
x: An array, scalar, or (nested) standard Python container thereof.
235+
sharding: A `Sharding` or a (nested) `Sharding` in standard Python container
236+
(must be a tree prefix of `x`), representing the device(s) and sharding to
237+
which `x` should be sharded to. The result will be committed to the
238+
device(s) of the sharding.
239+
donate: If `True`, donate all input arrays, which may reduce the amount of
240+
memory needed for resharding. Buffers donated to resharding should not be
241+
reused.
242+
may_alias: If `True`, may alias the input array with the output array. May
243+
reduce the amount of memory needed for resharding. Not used at the moment.
244+
245+
Returns:
246+
A copy of `x` whose sharding is `sharding`.
247+
"""
248+
return _reshard(
249+
x,
250+
sharding,
251+
donate=donate,
252+
may_alias=may_alias,
253+
jax_array_reshard_fn=_ifrt_jax_array_reshard,
254+
)
255+
256+
201257
def reshard(
202258
x: Any,
203259
sharding: jax.sharding.Sharding | Any,
@@ -221,29 +277,34 @@ def reshard(
221277
reduce the amount of memory needed for resharding. Not used at the moment.
222278
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
223279
recreating plans for the same resharding operation. May improve
224-
performance for use cases where the same resharding operation is done many
225-
times. May degrade performance if most reshardings operations are
226-
different, since the cache will cause Pathways Components to remain loaded
227-
for each cached plan. `False` by default. Only used when IFRT resharding
228-
is not available.
280+
performance for use cases where the same resharding operation is done
281+
many times. May degrade performance if most reshardings operations are
282+
different, since the cache will cause Pathways Components to remain
283+
loaded for each cached plan. `False` by default. This parameter is only
284+
used when `pw_jax.ifrt_reshard_available()` is false.
229285
230286
Returns:
231287
A copy of `x` whose sharding is `sharding`.
232288
"""
233289
if pw_jax.ifrt_reshard_available():
234-
return _reshard(
290+
if cache_resharding_plans:
291+
warnings.warn(
292+
"`cache_resharding_plans` is only applicable when using the"
293+
" sidechannel resharding implementation, but IFRT resharding is"
294+
" available and will be used. The `cache_resharding_plans` argument"
295+
" will be ignored."
296+
)
297+
return _reshard_with_ifrt(
235298
x,
236299
sharding,
237300
donate=donate,
238301
may_alias=may_alias,
239-
jax_array_reshard_fn=_ifrt_jax_array_reshard,
240302
)
241303
else:
242-
return _reshard(
304+
return _reshard_with_sidechannel(
243305
x,
244306
sharding,
245307
donate=donate,
246308
may_alias=may_alias,
247-
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
248309
cache_resharding_plans=cache_resharding_plans,
249310
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections.abc import Mapping
16+
import json
17+
from typing import Any
18+
from unittest import mock
19+
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
import jax
23+
import jax.numpy as jnp
24+
from pathwaysutils import jax as pw_jax
25+
from pathwaysutils import plugin_executable
26+
from pathwaysutils.experimental import reshard
27+
28+
29+
class ReshardTest(parameterized.TestCase):
30+
31+
@parameterized.parameters(
32+
dict(reshard_kwargs={"donate": True}, expected_donate=True),
33+
dict(reshard_kwargs={"donate": False}, expected_donate=False),
34+
dict(reshard_kwargs={}, expected_donate=False),
35+
)
36+
def test_ifrt_reshard_donate(
37+
self, reshard_kwargs: Mapping[str, Any], expected_donate: bool
38+
):
39+
x = jnp.array([1, 2])
40+
devices = jax.devices()
41+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
42+
43+
mock_transfer = self.enter_context(
44+
mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True)
45+
)
46+
self.enter_context(
47+
mock.patch.object(
48+
pw_jax, "ifrt_reshard_available", return_value=True, autospec=True
49+
)
50+
)
51+
52+
reshard.reshard(x, sharding, **reshard_kwargs)
53+
54+
# Signature: transfer_to_shardings(arrays, shardings, donate)
55+
mock_transfer.assert_called_with(mock.ANY, mock.ANY, expected_donate)
56+
57+
@parameterized.parameters(
58+
dict(reshard_kwargs={"donate": True}, expected_donate=True),
59+
dict(reshard_kwargs={"donate": False}, expected_donate=False),
60+
dict(reshard_kwargs={}, expected_donate=False),
61+
)
62+
def test_sidechannel_reshard_donate(
63+
self, reshard_kwargs: Mapping[str, Any], expected_donate: bool
64+
):
65+
x = jnp.array([1, 2])
66+
devices = jax.devices()
67+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
68+
69+
self.enter_context(
70+
mock.patch.object(
71+
pw_jax, "ifrt_reshard_available", return_value=False, autospec=True
72+
)
73+
)
74+
mock_pe = self.enter_context(
75+
mock.patch.object(plugin_executable, "PluginExecutable", autospec=True)
76+
)
77+
mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock())
78+
79+
reshard.reshard(x, sharding, **reshard_kwargs)
80+
81+
mock_pe.assert_called()
82+
(json_request,), _ = mock_pe.call_args
83+
request = json.loads(json_request)
84+
self.assertEqual(request["reshardRequest"]["donateInput"], expected_donate)
85+
86+
@parameterized.parameters(True, False, None)
87+
def test_ifrt_reshard_cache_resharding_plans(self, cache: bool | None):
88+
x = jnp.array([1, 2])
89+
devices = jax.devices()
90+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
91+
92+
mock_transfer = self.enter_context(
93+
mock.patch.object(pw_jax, "transfer_to_shardings")
94+
)
95+
self.enter_context(
96+
mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=True)
97+
)
98+
99+
if cache is None:
100+
reshard.reshard(x, sharding)
101+
elif cache:
102+
with self.assertWarnsRegex(
103+
UserWarning, "cache_resharding_plans` is only applicable"
104+
):
105+
reshard.reshard(x, sharding, cache_resharding_plans=cache)
106+
else:
107+
reshard.reshard(x, sharding, cache_resharding_plans=cache)
108+
109+
mock_transfer.assert_called_once()
110+
111+
@parameterized.parameters(
112+
dict(cache=True, expected_cache=True),
113+
dict(cache=False, expected_cache=False),
114+
dict(cache=None, expected_cache=False),
115+
)
116+
def test_sidechannel_reshard_cache_resharding_plans(
117+
self, cache, expected_cache
118+
):
119+
x = jnp.array([1, 2])
120+
devices = jax.devices()
121+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
122+
123+
self.enter_context(
124+
mock.patch.object(pw_jax, "ifrt_reshard_available", return_value=False)
125+
)
126+
mock_pe = self.enter_context(
127+
mock.patch.object(plugin_executable, "PluginExecutable")
128+
)
129+
mock_pe.return_value.call.return_value = ([mock.Mock()], mock.Mock())
130+
131+
mock_get_resharding_plan_cached = self.enter_context(
132+
mock.patch.object(reshard, "_get_resharding_plan_cached")
133+
)
134+
135+
if cache is None:
136+
reshard.reshard(x, sharding)
137+
else:
138+
reshard.reshard(x, sharding, cache_resharding_plans=cache)
139+
140+
self.assertEqual(mock_pe.call_count, 0 if expected_cache else 1)
141+
142+
self.assertEqual(
143+
mock_get_resharding_plan_cached.call_count,
144+
1 if expected_cache else 0,
145+
)
146+
147+
if __name__ == "__main__": absltest.main()

0 commit comments

Comments
 (0)