Skip to content

Commit 5d45f2e

Browse files
Merge pull request #4034 from AI-Hypercomputer:with_mesh_upgrade
PiperOrigin-RevId: 927407231
2 parents 334f936 + b051575 commit 5d45f2e

9 files changed

Lines changed: 14 additions & 11 deletions

File tree

src/maxtext/examples/sft_llama3_demo_gpu.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@
547547
" positions = jnp.arange(seq_len)[None, :]\n",
548548
" attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n",
549549
"\n",
550-
" with mesh:\n",
550+
" with jax.set_mesh(mesh):\n",
551551
" output = model(tokens, positions, None, attention_mask)\n",
552552
" logits = output[0] if isinstance(output, tuple) else output\n",
553553
"\n",

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def generate_and_save_data(config, local_args):
185185

186186
multihost_utils.sync_global_devices("start_generation_loop")
187187

188-
with mesh:
188+
with jax.set_mesh(mesh):
189189
if jax.process_index() == 0:
190190
max_logging.log(f"Starting Distributed Top-K generation loop for {config.steps - start_step} steps...")
191191

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def train_distill(
686686

687687
# Hardware Execution (Safe Context)
688688
max_logging.log("Applying logical axis rules for model initialization and training...")
689-
with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules):
689+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(student_config.logical_axis_rules):
690690
# 2. Load Models
691691
if is_offline:
692692
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")

src/maxtext/trainers/post_train/dpo/train_dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
159159

160160
def train_model(mt_config: MaxTextConfig, trainer, mesh):
161161
"""Runs the DPO training loop in Tunix."""
162-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
162+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
163163
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
164164
return trainer
165165

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
263263

264264
def train_model(mt_config, trainer, mesh):
265265
"""Runs the SFT training loop in Tunix."""
266-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
266+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
267267
trainer.train(
268268
trainer.data_hooks.train_data_iterator,
269269
trainer.data_hooks.eval_data_iterator,

src/maxtext/utils/lora_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def apply_lora_to_model(
539539
)
540540

541541
if mesh is not None:
542-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
542+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
543543
graph_def, state = nnx.split(lora_model)
544544

545545
# We handle explicit replication for LoRA to ensure safety and efficiency.

tests/integration/deepseek_scan_engram_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):
145145

146146
shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)
147147

148-
with mesh, jax.disable_jit():
148+
with jax.set_mesh(mesh), jax.disable_jit():
149149
variables = decoder.init(
150150
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
151151
shared_embedding=shared_embedding,

tests/integration/diloco_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_diloco_training_simulation_with_mesh(self):
7979
]
8080
)
8181

82-
with mesh:
82+
with jax.set_mesh(mesh):
8383
tx = optax.sgd(learning_rate=0.1)
8484
rngs = nnx.Rngs(params=jax.random.key(seed=42))
8585
model = SimpleNNXModel(rngs=rngs)

tests/utils/attention_test_util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,18 @@ def forward_with_context_expert_parallelism(
196196
"inputs_segmentation": decoder_segment_ids,
197197
"inputs_position": decoder_positions,
198198
}
199-
with mesh_cp:
199+
# jax.set_mesh requires all sharding constraints inside the block to reference devices in the context mesh.
200+
with jax.set_mesh(mesh_cp):
201+
replicated = NamedSharding(mesh_cp, P())
202+
replicated_batch = {k: jax.device_put(v, replicated) for k, v in batch.items()}
200203
reordered_batch = maxtext_utils.get_reorder_callable(
201204
context_parallel_size, ShardMode.AUTO, hardware=cfg_cp.hardware
202-
)(batch)
205+
)(replicated_batch)
203206
lnx = reordered_batch["inputs"]
204207
decoder_segment_ids = reordered_batch["inputs_segmentation"]
205208
decoder_positions = reordered_batch["inputs_position"]
206209
# apply attention with sharding
207-
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
210+
with jax.set_mesh(mesh_cp), nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
208211
batch_axis = "activation_batch"
209212
length_axis = "activation_length"
210213
lnx_spec = nn_partitioning.logical_to_mesh_axes(

0 commit comments

Comments
 (0)