Skip to content

Commit 6c139c3

Browse files
pschuhcopybara-github
authored andcommitted
Automated Code Change
PiperOrigin-RevId: 896750624
1 parent 66e9754 commit 6c139c3

4 files changed

Lines changed: 9 additions & 7 deletions

File tree

pathwaysutils/jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def ifrt_reshard_available() -> bool:
8686

8787
transfer_to_shardings(
8888
[jax.numpy.array([0])],
89-
[jax.sharding.SingleDeviceSharding(jax.devices()[0])],
89+
[jax.sharding.make_single_device_sharding(jax.devices()[0])],
9090
)
9191

9292
except (ImportError, NameError, jax.errors.JaxRuntimeError):

pathwaysutils/profiling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def stop_trace() -> None:
250250
and "xprofTraceOptions" in _profile_state.profile_request
251251
):
252252
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
253-
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
253+
out_shardings = [
254+
jax.sharding.make_single_device_sharding(jax.devices()[0])
255+
]
254256
else:
255257
out_avals = ()
256258
out_shardings = ()

pathwaysutils/test/experimental/reshard_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_sidechannel_reshard_donate(
3737
):
3838
x = jnp.array([1, 2])
3939
devices = jax.devices()
40-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
40+
sharding = jax.sharding.make_single_device_sharding(devices[0])
4141

4242
mock_pe = self.enter_context(
4343
mock.patch.object(plugin_executable, "PluginExecutable", autospec=True)
@@ -64,7 +64,7 @@ def test_sidechannel_reshard_cache_resharding_plans(
6464
):
6565
x = jnp.array([1, 2])
6666
devices = jax.devices()
67-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
67+
sharding = jax.sharding.make_single_device_sharding(devices[0])
6868

6969
mock_pe = self.enter_context(
7070
mock.patch.object(plugin_executable, "PluginExecutable")
@@ -92,7 +92,7 @@ def test_sidechannel_reshard_cache_resharding_plans(
9292
def test_sidechannel_reshard_pytree(self):
9393
x = {"a": jnp.array([1]), "b": [jnp.array([2])]}
9494
devices = jax.devices()
95-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
95+
sharding = jax.sharding.make_single_device_sharding(devices[0])
9696
# Tree prefix sharding
9797
tree_sharding = {"a": sharding, "b": [sharding]}
9898

pathwaysutils/test/reshard_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_ifrt_reshard_donate(
4141
):
4242
x = jnp.array([1, 2])
4343
devices = jax.devices()
44-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
44+
sharding = jax.sharding.make_single_device_sharding(devices[0])
4545

4646
mock_transfer = self.enter_context(
4747
mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True)
@@ -60,7 +60,7 @@ def test_ifrt_reshard_donate(
6060
def test_ifrt_reshard_pytree(self):
6161
x = {"a": jnp.array([1]), "b": [jnp.array([2])]}
6262
devices = jax.devices()
63-
sharding = jax.sharding.SingleDeviceSharding(devices[0])
63+
sharding = jax.sharding.make_single_device_sharding(devices[0])
6464
# Tree prefix sharding
6565
tree_sharding = {"a": sharding, "b": [sharding]}
6666

0 commit comments

Comments
 (0)