Skip to content

Commit 5a69b3e

Browse files
Merge pull request #2863 from AI-Hypercomputer:chengnuojin-revert-sft
PiperOrigin-RevId: 847796543
2 parents c2574ab + 96e23e4 commit 5a69b3e

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/MaxText/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_maxtext_model(config, devices=None):
9494
# load_parameters_path=/path/to/your/output/directory/0/items
9595
"""
9696
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
97-
with jax.set_mesh(mesh):
97+
with mesh:
9898
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
9999
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
100100
tunix_model.config = None

src/MaxText/sft/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
160160

161161
def train_model(mt_config, trainer, mesh):
162162
"""Runs the SFT training loop in Tunix."""
163-
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
163+
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
164164
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
165165
return trainer
166166

0 commit comments

Comments
 (0)