File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -86,7 +86,7 @@ def ifrt_reshard_available() -> bool:
8686
8787 transfer_to_shardings (
8888 [jax .numpy .array ([0 ])],
89- [jax .sharding .SingleDeviceSharding (jax .devices ()[0 ])],
89+ [jax .sharding .make_single_device_sharding (jax .devices ()[0 ])],
9090 )
9191
9292 except (ImportError , NameError , jax .errors .JaxRuntimeError ):
Original file line number Diff line number Diff line change @@ -250,7 +250,9 @@ def stop_trace() -> None:
250250 and "xprofTraceOptions" in _profile_state .profile_request
251251 ):
252252 out_avals = [jax .core .ShapedArray ((1 ,), jnp .object_ )]
253- out_shardings = [jax .sharding .SingleDeviceSharding (jax .devices ()[0 ])]
253+ out_shardings = [
254+ jax .sharding .make_single_device_sharding (jax .devices ()[0 ])
255+ ]
254256 else :
255257 out_avals = ()
256258 out_shardings = ()
Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ def test_sidechannel_reshard_donate(
3737 ):
3838 x = jnp .array ([1 , 2 ])
3939 devices = jax .devices ()
40- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
40+ sharding = jax .sharding .make_single_device_sharding (devices [0 ])
4141
4242 mock_pe = self .enter_context (
4343 mock .patch .object (plugin_executable , "PluginExecutable" , autospec = True )
@@ -64,7 +64,7 @@ def test_sidechannel_reshard_cache_resharding_plans(
6464 ):
6565 x = jnp .array ([1 , 2 ])
6666 devices = jax .devices ()
67- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
67+ sharding = jax .sharding .make_single_device_sharding (devices [0 ])
6868
6969 mock_pe = self .enter_context (
7070 mock .patch .object (plugin_executable , "PluginExecutable" )
@@ -92,7 +92,7 @@ def test_sidechannel_reshard_cache_resharding_plans(
9292 def test_sidechannel_reshard_pytree (self ):
9393 x = {"a" : jnp .array ([1 ]), "b" : [jnp .array ([2 ])]}
9494 devices = jax .devices ()
95- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
95+ sharding = jax .sharding .make_single_device_sharding (devices [0 ])
9696 # Tree prefix sharding
9797 tree_sharding = {"a" : sharding , "b" : [sharding ]}
9898
Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ def test_ifrt_reshard_donate(
4141 ):
4242 x = jnp .array ([1 , 2 ])
4343 devices = jax .devices ()
44- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
44+ sharding = jax .sharding .make_single_device_sharding (devices [0 ])
4545
4646 mock_transfer = self .enter_context (
4747 mock .patch .object (pw_jax , "transfer_to_shardings" , autospec = True )
@@ -60,7 +60,7 @@ def test_ifrt_reshard_donate(
6060 def test_ifrt_reshard_pytree (self ):
6161 x = {"a" : jnp .array ([1 ]), "b" : [jnp .array ([2 ])]}
6262 devices = jax .devices ()
63- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
63+ sharding = jax .sharding .make_single_device_sharding (devices [0 ])
6464 # Tree prefix sharding
6565 tree_sharding = {"a" : sharding , "b" : [sharding ]}
6666
You can’t perform that action at this time.
0 commit comments