1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415"""Utilities for partitioning."""
1516
1617import abc
2223import jax
2324import numpy as np
2425
25-
2626PyTree = Any
2727State = Any
2828CreateStateFn = Callable [[PyTree ], State ]
@@ -67,7 +67,8 @@ class DataParallelPartitioner(Partitioner):
6767 """Data parallel partitioner."""
6868
6969 def __init__ (self , data_axis : str = "batch" ):
70- self .mesh = jax .make_mesh ((jax .device_count (),), (data_axis ,))
70+ devices = jax .devices ()
71+ self .mesh = jax .sharding .Mesh (devices , (data_axis ,))
7172 self .data_sharding = jax .sharding .NamedSharding (
7273 self .mesh , jax .sharding .PartitionSpec (data_axis )
7374 )
@@ -107,7 +108,7 @@ def _shard(x: np.ndarray) -> jax.Array:
107108 def partition_init (
108109 self , init_fn : CreateStateFn , * , abstract_batch : PyTree | None = None
109110 ) -> CreateStateFn :
110- with jax . sharding . use_mesh ( self .mesh ) :
111+ with self .mesh :
111112 if abstract_batch is not None :
112113 abstract_state = jax .eval_shape (init_fn , abstract_batch )
113114 specs = nn .get_partition_spec (abstract_state )
@@ -117,7 +118,7 @@ def partition_init(
117118 init_fn = jax .jit (init_fn , out_shardings = self .state_sharding )
118119
119120 def _wrapped_init (batch : PyTree ) -> State :
120- with jax . sharding . use_mesh ( self .mesh ) :
121+ with self .mesh :
121122 state = init_fn (batch )
122123 state = _maybe_unbox_state (state )
123124 return state
@@ -130,15 +131,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130131 jit_kws ["out_shardings" ] = (self .state_sharding , None )
131132 jit_kws ["donate_argnums" ] = (1 ,)
132133
133- with jax . sharding . use_mesh ( self .mesh ) :
134+ with self .mesh :
134135 step_fn = jax .jit (
135136 fn ,
136137 in_shardings = (self .data_sharding , self .state_sharding ),
137138 ** jit_kws ,
138139 )
139140
140141 def _wrapped_step (batch : PyTree , state : State ) -> Any :
141- with jax . sharding . use_mesh ( self .mesh ) :
142+ with self .mesh :
142143 return step_fn (batch , state )
143144
144145 return _wrapped_step
@@ -190,7 +191,7 @@ def __init__(
190191 if axis_sizes [0 ] == - 1 :
191192 axis_sizes [0 ] = len (devices ) // math .prod (axis_sizes [1 :])
192193
193- self .mesh = jax .make_mesh ( axis_sizes , axis_names , devices = devices )
194+ self .mesh = jax .sharding . Mesh ( devices , axis_names )
194195 self .rules = rules
195196 self .aot_compile = aot_compile
196197 self .options = options
@@ -213,12 +214,6 @@ def __init__(
213214 self .abstract_batch = None
214215 self .abstract_state = None
215216
216- @property
217- def mesh_context_manager (
218- self ,
219- ) -> Callable [[jax .sharding .Mesh ], ContextManager [None ]]:
220- return jax .sharding .use_mesh
221-
222217 def shard_inputs (self , inputs : PyTree ) -> PyTree :
223218 def _shard (x : np .ndarray ) -> jax .Array :
224219 return jax .make_array_from_process_local_data (self .data_sharding , x )
@@ -234,7 +229,7 @@ def partition_init(
234229 " model parallel partitioner."
235230 )
236231
237- with self .mesh_context_manager ( self . mesh ) :
232+ with self .mesh :
238233 abstract_state = jax .eval_shape (init_fn , abstract_batch )
239234 specs = nn .get_partition_spec (abstract_state )
240235
@@ -247,7 +242,7 @@ def partition_init(
247242 compiled_init_fn = jax .jit (init_fn , out_shardings = state_sharding )
248243
249244 def _init (batch : PyTree ) -> State :
250- with self .mesh_context_manager ( self . mesh ) :
245+ with self .mesh :
251246 state = compiled_init_fn (batch )
252247 state = _maybe_unbox_state (state )
253248 return state
@@ -265,7 +260,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
265260 else :
266261 jit_kws ["out_shardings" ] = None
267262
268- with self .mesh_context_manager (self .mesh ):
263+
264+ with self .mesh :
269265 step_fn = jax .jit (
270266 fn ,
271267 in_shardings = (self .data_sharding , self .state_sharding ),
@@ -286,7 +282,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
286282 )
287283
288284 def _step (batch : PyTree , state : State ) -> Any :
289- with self .mesh_context_manager ( self . mesh ) :
285+ with self .mesh :
290286 return step_fn (batch , state )
291287
292288 return _step
@@ -302,4 +298,4 @@ def _maybe_unbox(x: Any) -> Any:
302298 _maybe_unbox ,
303299 x ,
304300 is_leaf = lambda k : isinstance (k , nn .Partitioned ),
305- )
301+ )
0 commit comments