@@ -37,7 +37,11 @@ def test_sidechannel_reshard_donate(
3737 ):
3838 x = jnp .array ([1 , 2 ])
3939 devices = jax .devices ()
40- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
40+ sharding = getattr (
41+ jax .sharding ,
42+ "make_single_device_sharding" ,
43+ lambda x : jax .sharding .SingleDeviceSharding (x ),
44+ )(devices [0 ])
4145
4246 mock_pe = self .enter_context (
4347 mock .patch .object (plugin_executable , "PluginExecutable" , autospec = True )
@@ -64,7 +68,11 @@ def test_sidechannel_reshard_cache_resharding_plans(
6468 ):
6569 x = jnp .array ([1 , 2 ])
6670 devices = jax .devices ()
67- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
71+ sharding = getattr (
72+ jax .sharding ,
73+ "make_single_device_sharding" ,
74+ lambda x : jax .sharding .SingleDeviceSharding (x ),
75+ )(devices [0 ])
6876
6977 mock_pe = self .enter_context (
7078 mock .patch .object (plugin_executable , "PluginExecutable" )
@@ -92,7 +100,11 @@ def test_sidechannel_reshard_cache_resharding_plans(
92100 def test_sidechannel_reshard_pytree (self ):
93101 x = {"a" : jnp .array ([1 ]), "b" : [jnp .array ([2 ])]}
94102 devices = jax .devices ()
95- sharding = jax .sharding .SingleDeviceSharding (devices [0 ])
103+ sharding = getattr (
104+ jax .sharding ,
105+ "make_single_device_sharding" ,
106+ lambda x : jax .sharding .SingleDeviceSharding (x ),
107+ )(devices [0 ])
96108 # Tree prefix sharding
97109 tree_sharding = {"a" : sharding , "b" : [sharding ]}
98110
0 commit comments