@@ -110,7 +110,6 @@ def _shard(x: np.ndarray) -> jax.Array:
110110 def partition_init (
111111 self , init_fn : CreateStateFn , * , abstract_batch : PyTree | None = None
112112 ) -> CreateStateFn :
113- # FIXED: Use 'with self.mesh'
114113 with self .mesh :
115114 if abstract_batch is not None :
116115 mesh_context .set_global_mesh (self .mesh )
@@ -122,7 +121,6 @@ def partition_init(
122121 init_fn = jax .jit (init_fn , out_shardings = self .state_sharding )
123122
124123 def _wrapped_init (batch : PyTree ) -> State :
125- # FIXED: Use 'with self.mesh'
126124 with self .mesh :
127125 state = init_fn (batch )
128126 state = _maybe_unbox_state (state )
@@ -136,7 +134,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
136134 jit_kws ["out_shardings" ] = (self .state_sharding , None )
137135 jit_kws ["donate_argnums" ] = (1 ,)
138136
139- # FIXED: Use 'with self.mesh' and legacy bridge
140137 with self .mesh :
141138 mesh_context .set_global_mesh (self .mesh )
142139 step_fn = jax .jit (
@@ -146,7 +143,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
146143 )
147144
148145 def _wrapped_step (batch : PyTree , state : State ) -> Any :
149- # FIXED: Use 'with self.mesh'
150146 with self .mesh :
151147 return step_fn (batch , state )
152148
@@ -238,9 +234,7 @@ def partition_init(
238234 " model parallel partitioner."
239235 )
240236
241- # FIXED: Use 'with self.mesh' directly
242237 with self .mesh :
243- # FIXED: Legacy bridge
244238 mesh_context .set_global_mesh (self .mesh )
245239 abstract_state = jax .eval_shape (init_fn , abstract_batch )
246240 specs = nn .get_partition_spec (abstract_state )
@@ -254,7 +248,6 @@ def partition_init(
254248 compiled_init_fn = jax .jit (init_fn , out_shardings = state_sharding )
255249
256250 def _init (batch : PyTree ) -> State :
257- # FIXED: Use 'with self.mesh' directly
258251 with self .mesh :
259252 state = compiled_init_fn (batch )
260253 state = _maybe_unbox_state (state )
@@ -273,7 +266,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
273266 else :
274267 jit_kws ["out_shardings" ] = None
275268
276- # FIXED: Use 'with self.mesh' directly and legacy bridge
269+
277270 with self .mesh :
278271 mesh_context .set_global_mesh (self .mesh )
279272 step_fn = jax .jit (
@@ -296,7 +289,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
296289 )
297290
298291 def _step (batch : PyTree , state : State ) -> Any :
299- # FIXED: Use 'with self.mesh' directly
300292 with self .mesh :
301293 return step_fn (batch , state )
302294
0 commit comments