diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index e5bc106..5f30285 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -86,7 +86,7 @@ def ifrt_reshard_available() -> bool: transfer_to_shardings( [jax.numpy.array([0])], - [jax.sharding.SingleDeviceSharding(jax.devices()[0])], + [jax.sharding.make_single_device_sharding(jax.devices()[0])], ) except (ImportError, NameError, jax.errors.JaxRuntimeError): diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index b4f4378..35bfb26 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -250,7 +250,9 @@ def stop_trace() -> None: and "xprofTraceOptions" in _profile_state.profile_request ): out_avals = [jax.core.ShapedArray((1,), jnp.object_)] - out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])] + out_shardings = [ + jax.sharding.make_single_device_sharding(jax.devices()[0]) + ] else: out_avals = () out_shardings = () diff --git a/pathwaysutils/test/experimental/reshard_test.py b/pathwaysutils/test/experimental/reshard_test.py index 79c0434..f295403 100644 --- a/pathwaysutils/test/experimental/reshard_test.py +++ b/pathwaysutils/test/experimental/reshard_test.py @@ -37,7 +37,7 @@ def test_sidechannel_reshard_donate( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = jax.sharding.make_single_device_sharding(devices[0]) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable", autospec=True) @@ -64,7 +64,7 @@ def test_sidechannel_reshard_cache_resharding_plans( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = jax.sharding.make_single_device_sharding(devices[0]) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable") @@ -92,7 +92,7 @@ def test_sidechannel_reshard_cache_resharding_plans( def test_sidechannel_reshard_pytree(self): x = {"a": jnp.array([1]), "b": [jnp.array([2])]} devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = jax.sharding.make_single_device_sharding(devices[0]) # Tree prefix sharding tree_sharding = {"a": sharding, "b": [sharding]} diff --git a/pathwaysutils/test/reshard_test.py b/pathwaysutils/test/reshard_test.py index 6c80221..ab41aa7 100644 --- a/pathwaysutils/test/reshard_test.py +++ b/pathwaysutils/test/reshard_test.py @@ -41,7 +41,7 @@ def test_ifrt_reshard_donate( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = jax.sharding.make_single_device_sharding(devices[0]) mock_transfer = self.enter_context( mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) @@ -60,7 +60,7 @@ def test_ifrt_reshard_donate( def test_ifrt_reshard_pytree(self): x = {"a": jnp.array([1]), "b": [jnp.array([2])]} devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = jax.sharding.make_single_device_sharding(devices[0]) # Tree prefix sharding tree_sharding = {"a": sharding, "b": [sharding]}