Skip to content

Commit 2e9fd66

Browse files
committed
Replace 'with mesh' with 'with jax.set_mesh(mesh)'
1 parent 0747df9 commit 2e9fd66

9 files changed

Lines changed: 10 additions & 10 deletions

File tree

src/maxtext/examples/sft_llama3_demo_gpu.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@
547547
" positions = jnp.arange(seq_len)[None, :]\n",
548548
" attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n",
549549
"\n",
550-
" with mesh:\n",
550+
" with jax.set_mesh(mesh):\n",
551551
" output = model(tokens, positions, None, attention_mask)\n",
552552
" logits = output[0] if isinstance(output, tuple) else output\n",
553553
"\n",

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def generate_and_save_data(config, local_args):
185185

186186
multihost_utils.sync_global_devices("start_generation_loop")
187187

188-
with mesh:
188+
with jax.set_mesh(mesh):
189189
if jax.process_index() == 0:
190190
max_logging.log(f"Starting Distributed Top-K generation loop for {config.steps - start_step} steps...")
191191

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def train_distill(
687687

688688
# Hardware Execution (Safe Context)
689689
max_logging.log("Applying logical axis rules for model initialization and training...")
690-
with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules):
690+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(student_config.logical_axis_rules):
691691
# 2. Load Models
692692
if is_offline:
693693
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")

src/maxtext/trainers/post_train/dpo/train_dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
145145

146146
def train_model(mt_config: MaxTextConfig, trainer, mesh):
147147
"""Runs the DPO training loop in Tunix."""
148-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
148+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
149149
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
150150
return trainer
151151

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
263263

264264
def train_model(mt_config, trainer, mesh):
265265
"""Runs the SFT training loop in Tunix."""
266-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
266+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
267267
trainer.train(
268268
trainer.data_hooks.train_data_iterator,
269269
trainer.data_hooks.eval_data_iterator,

src/maxtext/utils/lora_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def apply_lora_to_model(
504504
)
505505

506506
if mesh is not None:
507-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
507+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
508508
graph_def, state = nnx.split(lora_model)
509509

510510
# We handle explicit replication for LoRA to ensure safety and efficiency.

tests/integration/deepseek_scan_engram_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):
145145

146146
shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)
147147

148-
with mesh, jax.disable_jit():
148+
with jax.set_mesh(mesh), jax.disable_jit():
149149
variables = decoder.init(
150150
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
151151
shared_embedding=shared_embedding,

tests/integration/diloco_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_diloco_training_simulation_with_mesh(self):
7979
]
8080
)
8181

82-
with mesh:
82+
with jax.set_mesh(mesh):
8383
tx = optax.sgd(learning_rate=0.1)
8484
rngs = nnx.Rngs(params=jax.random.key(seed=42))
8585
model = SimpleNNXModel(rngs=rngs)

tests/utils/attention_test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,15 @@ def forward_with_context_expert_parallelism(
196196
"inputs_segmentation": decoder_segment_ids,
197197
"inputs_position": decoder_positions,
198198
}
199-
with mesh_cp:
199+
with jax.set_mesh(mesh_cp):
200200
reordered_batch = maxtext_utils.get_reorder_callable(
201201
context_parallel_size, ShardMode.AUTO, hardware=cfg_cp.hardware
202202
)(batch)
203203
lnx = reordered_batch["inputs"]
204204
decoder_segment_ids = reordered_batch["inputs_segmentation"]
205205
decoder_positions = reordered_batch["inputs_position"]
206206
# apply attention with sharding
207-
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
207+
with jax.set_mesh(mesh_cp), nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
208208
batch_axis = "activation_batch"
209209
length_axis = "activation_length"
210210
lnx_spec = nn_partitioning.logical_to_mesh_axes(

0 commit comments

Comments
 (0)