Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/examples/sft_llama3_demo_gpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@
" positions = jnp.arange(seq_len)[None, :]\n",
" attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n",
"\n",
" with mesh:\n",
" with jax.set_mesh(mesh):\n",
" output = model(tokens, positions, None, attention_mask)\n",
" logits = output[0] if isinstance(output, tuple) else output\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def generate_and_save_data(config, local_args):

multihost_utils.sync_global_devices("start_generation_loop")

with mesh:
with jax.set_mesh(mesh):
if jax.process_index() == 0:
max_logging.log(f"Starting Distributed Top-K generation loop for {config.steps - start_step} steps...")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def train_distill(

# Hardware Execution (Safe Context)
max_logging.log("Applying logical axis rules for model initialization and training...")
with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(student_config.logical_axis_rules):
# 2. Load Models
if is_offline:
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/post_train/dpo/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

def train_model(mt_config: MaxTextConfig, trainer, mesh):
"""Runs the DPO training loop in Tunix."""
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
return trainer

Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

def train_model(mt_config, trainer, mesh):
"""Runs the SFT training loop in Tunix."""
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer.train(
trainer.data_hooks.train_data_iterator,
trainer.data_hooks.eval_data_iterator,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def apply_lora_to_model(
)

if mesh is not None:
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
graph_def, state = nnx.split(lora_model)

# We handle explicit replication for LoRA to ensure safety and efficiency.
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/deepseek_scan_engram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):

shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)

with mesh, jax.disable_jit():
with jax.set_mesh(mesh), jax.disable_jit():
variables = decoder.init(
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
shared_embedding=shared_embedding,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/diloco_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_diloco_training_simulation_with_mesh(self):
]
)

with mesh:
with jax.set_mesh(mesh):
tx = optax.sgd(learning_rate=0.1)
rngs = nnx.Rngs(params=jax.random.key(seed=42))
model = SimpleNNXModel(rngs=rngs)
Expand Down
7 changes: 5 additions & 2 deletions tests/utils/attention_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,18 @@ def forward_with_context_expert_parallelism(
"inputs_segmentation": decoder_segment_ids,
"inputs_position": decoder_positions,
}
with mesh_cp:
# jax.set_mesh requires all sharding constraints inside the block to reference devices in the context mesh
with jax.set_mesh(mesh_cp):
replicated = NamedSharding(mesh_cp, P())
batch = {k: jax.device_put(v, replicated) for k, v in batch.items()}
reordered_batch = maxtext_utils.get_reorder_callable(
context_parallel_size, ShardMode.AUTO, hardware=cfg_cp.hardware
)(batch)
lnx = reordered_batch["inputs"]
decoder_segment_ids = reordered_batch["inputs_segmentation"]
decoder_positions = reordered_batch["inputs_position"]
# apply attention with sharding
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
with jax.set_mesh(mesh_cp), nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
batch_axis = "activation_batch"
length_axis = "activation_length"
lnx_spec = nn_partitioning.logical_to_mesh_axes(
Expand Down
Loading