Skip to content

Commit 4771380

Browse files
pschuhcopybara-github
authored andcommitted
Automated Code Change
PiperOrigin-RevId: 907289393
1 parent 6916a44 commit 4771380

4 files changed

Lines changed: 33 additions & 7 deletions

File tree

pathwaysutils/jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def ifrt_reshard_available() -> bool:
104104

105105
transfer_to_shardings(
106106
[jax.numpy.array([0])],
107-
[jax.sharding.SingleDeviceSharding(jax.devices()[0])],
107+
[jax.sharding.make_single_device_sharding(jax.devices()[0])],
108108
)
109109

110110
except (ImportError, NameError, jax.errors.JaxRuntimeError):

pathwaysutils/profiling.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ def stop_trace() -> None:
250250
and "xprofTraceOptions" in _profile_state.profile_request
251251
):
252252
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
253-
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
253+
out_shardings = [
254+
getattr(
255+
jax.sharding,
256+
"make_single_device_sharding",
257+
lambda x: jax.sharding.SingleDeviceSharding(x),
258+
)(jax.devices()[0])
259+
]
254260
else:
255261
out_avals = ()
256262
out_shardings = ()

pathwaysutils/test/experimental/reshard_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def test_sidechannel_reshard_donate(
3737
):
3838
x = jnp.array([1, 2])
3939
devices = jax.devices()
40-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
40+
sharding = getattr(
41+
jax.sharding,
42+
"make_single_device_sharding",
43+
lambda x: jax.sharding.SingleDeviceSharding(x),
44+
)(devices[0])
4145

4246
mock_pe = self.enter_context(
4347
mock.patch.object(plugin_executable, "PluginExecutable", autospec=True)
@@ -64,7 +68,11 @@ def test_sidechannel_reshard_cache_resharding_plans(
6468
):
6569
x = jnp.array([1, 2])
6670
devices = jax.devices()
67-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
71+
sharding = getattr(
72+
jax.sharding,
73+
"make_single_device_sharding",
74+
lambda x: jax.sharding.SingleDeviceSharding(x),
75+
)(devices[0])
6876

6977
mock_pe = self.enter_context(
7078
mock.patch.object(plugin_executable, "PluginExecutable")
@@ -92,7 +100,11 @@ def test_sidechannel_reshard_cache_resharding_plans(
92100
def test_sidechannel_reshard_pytree(self):
93101
x = {"a": jnp.array([1]), "b": [jnp.array([2])]}
94102
devices = jax.devices()
95-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
103+
sharding = getattr(
104+
jax.sharding,
105+
"make_single_device_sharding",
106+
lambda x: jax.sharding.SingleDeviceSharding(x),
107+
)(devices[0])
96108
# Tree prefix sharding
97109
tree_sharding = {"a": sharding, "b": [sharding]}
98110

pathwaysutils/test/reshard_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def test_ifrt_reshard_donate(
4141
):
4242
x = jnp.array([1, 2])
4343
devices = jax.devices()
44-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
44+
sharding = getattr(
45+
jax.sharding,
46+
"make_single_device_sharding",
47+
lambda x: jax.sharding.SingleDeviceSharding(x),
48+
)(devices[0])
4549

4650
mock_transfer = self.enter_context(
4751
mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True)
@@ -60,7 +64,11 @@ def test_ifrt_reshard_donate(
6064
def test_ifrt_reshard_pytree(self):
6165
x = {"a": jnp.array([1]), "b": [jnp.array([2])]}
6266
devices = jax.devices()
63-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
67+
sharding = getattr(
68+
jax.sharding,
69+
"make_single_device_sharding",
70+
lambda x: jax.sharding.SingleDeviceSharding(x),
71+
)(devices[0])
6472
# Tree prefix sharding
6573
tree_sharding = {"a": sharding, "b": [sharding]}
6674

0 commit comments

Comments
 (0)