Skip to content

Commit 6a80012

Browse files
committed
Surbhi's changes
1 parent 6639254 commit 6a80012

3 files changed

Lines changed: 172 additions & 172 deletions

File tree

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def score_responses(tmvp_config, question, responses, answers):
101101
Tuple of (is_correct, is_partially_correct, has_correct_format)
102102
"""
103103

104-
answers = list(dict.fromkeys(answers))
105104
if tmvp_config.debug.rl:
106105
max_logging.log("========================================")
107106
max_logging.log(f"Evaluation Question: {question}")
@@ -114,7 +113,6 @@ def score_responses(tmvp_config, question, responses, answers):
114113
has_correct_format = False
115114

116115
for response in responses:
117-
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
118116
match_format = utils_rl.get_match_format_regex(tmvp_config)
119117
if match_format.search(response) is not None:
120118
has_correct_format = True
@@ -182,9 +180,8 @@ def evaluate(
182180

183181
# Score each question-answer pair
184182
for question, responses, answer in zip(questions, multiple_call_responses, answers):
185-
answer = (
186-
json.loads(answer) if isinstance(answer, str) else answer
187-
) # decode the json-encoded list of acceptable answers
183+
# decode the json-encoded list of acceptable answers
184+
answer = list(dict.fromkeys(json.loads(answer)))
188185
is_correct, is_partially_correct, has_correct_format = score_responses(
189186
tmvp_config=tmvp_config,
190187
question=question,

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 87 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,23 @@
6464
from orbax import checkpoint as ocp
6565
from pprint import pprint
6666
from transformers import AutoTokenizer
67+
import functools
6768
from tunix.rl import rl_cluster as rl_cluster_lib
6869
from tunix.rl.rollout import base_rollout
6970
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
7071
from tunix.rl.agentic.agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig, GrpoLearner as AgenticGrpoLearner
7172
from tunix.sft import metrics_logger, profiler
7273

7374
# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
74-
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
75+
os.environ["SKIP_JAX_PRECOMPILE"] = "0"
76+
os.environ["PHASED_PROFILING_DIR"] = "gs://mazumdera-bucket-tpu-prod-env-automated/qwen3-30b/0403"
77+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
78+
# os.environ["TUNIX_DEBUG_REWARDS"] = "1"
7579

7680
from maxtext.configs import pyconfig
7781
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
7882
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
83+
from maxtext.integration.vllm.maxtext_vllm_rollout import MaxTextVllmRollout
7984
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
8085
from maxtext.trainers.post_train.rl import utils_rl
8186
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
@@ -86,10 +91,10 @@ def get_maxtext_model(config, devices=None):
8691
"""
8792
Load MaxText model with Tunix adapter.
8893
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
89-
# To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if
94+
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
9095
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
9196
# export USE_PATHWAYS=1
92-
# python src/maxtext/checkpoint_conversion/to_maxtext.py \
97+
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
9398
# --model_name="gemma2-2b" \
9499
# --base_output_directory="/path/to/your/output/directory" \
95100
# --scan_layers=True \
@@ -116,18 +121,29 @@ def get_dataset(
116121
if dataset_name is None:
117122
raise ValueError("dataset_name must be provided")
118123

119-
import datasets # pylint: disable=import-outside-toplevel
124+
if dataset_name.startswith("huggingface:"):
125+
import datasets # pylint: disable=import-outside-toplevel
120126

121-
if data_files is None:
122-
data = datasets.load_dataset(dataset_name, split=split, cache_dir=data_dir)
123-
if tmvp_config.debug.rl:
124-
max_logging.log(f"Loaded Hugging Face dataset {dataset_name} with split {split}. Size: {len(data)}")
125-
else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
126-
data = datasets.load_dataset(
127-
"parquet",
128-
data_files={tmvp_config.train_split: data_files},
127+
if data_files is None:
128+
hf_dataset_name = dataset_name.replace("huggingface:", "")
129+
data = datasets.load_dataset(hf_dataset_name, split=split, cache_dir=data_dir)
130+
if tmvp_config.debug.rl:
131+
max_logging.log(f"Loaded Hugging Face dataset {hf_dataset_name} with split {split}. Size: {len(data)}")
132+
else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
133+
data = datasets.load_dataset(
134+
"parquet",
135+
data_files={tmvp_config.train_split: data_files},
136+
split=split,
137+
cache_dir=data_dir,
138+
)
139+
else:
140+
builder_kwargs = {"file_format": tfds.core.FileFormat.ARRAY_RECORD}
141+
data = tfds.data_source(
142+
dataset_name,
129143
split=split,
130-
cache_dir=data_dir,
144+
data_dir=data_dir,
145+
builder_kwargs=builder_kwargs,
146+
download=True,
131147
)
132148

133149
template_config = load_template_from_file(tmvp_config.chat_template_path)
@@ -284,37 +300,6 @@ def get_max_train_steps(trainer_config):
284300
)
285301

286302

287-
def prepare_train_and_eval_dataset(
288-
trainer_config,
289-
seed: int = 42,
290-
test_size: float = 0.05,
291-
):
292-
"""Load and split the dataset into train and validation sets using HF's train_test_split."""
293-
import datasets # pylint: disable=import-outside-toplevel
294-
295-
max_logging.log(
296-
"WARNING: For reproducible experiments, preprocess the dataset once and "
297-
"define your own HfDataset subclass that directly uses the preprocessed datasets."
298-
)
299-
300-
original_ds = datasets.load_dataset(
301-
"parquet",
302-
data_files={trainer_config.train_split: trainer_config.hf_train_files},
303-
split=trainer_config.train_split,
304-
)
305-
306-
if "OpenMathReasoning" in trainer_config.dataset_name:
307-
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
308-
309-
# Split into train and validation sets using HF's train_test_split
310-
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)
311-
312-
return {
313-
"train": split_ds["train"],
314-
"validation": split_ds["test"],
315-
}
316-
317-
318303
def prepare_datasets(trainer_config, model_tokenizer):
319304
"""Setup and return train and test datasets."""
320305
home = os.path.expanduser("~") + "/"
@@ -326,16 +311,39 @@ def prepare_datasets(trainer_config, model_tokenizer):
326311
os.makedirs(test_data_dir)
327312

328313
# Prepare train and test data from training data for certain datasets
329-
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
330-
if trainer_config.dataset_name in [
331-
"nvidia/OpenMathInstruct-2",
332-
"nvidia/OpenMathReasoning",
333-
"open-r1/OpenR1-Math-220k",
334-
"bethgelab/CuratedThoughts",
335-
] and (not eval_dataset_name or eval_dataset_name == trainer_config.dataset_name):
314+
if trainer_config.dataset_name in ["nvidia/OpenMathInstruct-2", "nvidia/OpenMathReasoning", "open-r1/OpenR1-Math-220k", "bethgelab/CuratedThoughts"]:
336315
import datasets # pylint: disable=import-outside-toplevel
337-
338-
splits = prepare_train_and_eval_dataset(trainer_config)
316+
317+
def prepare_train_and_eval_dataset(
318+
seed: int = 42,
319+
test_size: float = 0.05,
320+
):
321+
"""Load and split the dataset into train and validation sets using HF's train_test_split."""
322+
max_logging.log(
323+
"WARNING: For reproducible experiments, preprocess the dataset once and "
324+
"define your own HfDataset subclass that directly uses the preprocessed datasets."
325+
)
326+
327+
# Load the original dataset
328+
original_ds = datasets.load_dataset(
329+
"parquet",
330+
data_files={trainer_config.train_split: trainer_config.hf_train_files},
331+
split=trainer_config.train_split,
332+
)
333+
334+
if "OpenMathReasoning" in trainer_config.dataset_name:
335+
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
336+
337+
338+
# Split into train and validation sets using HF's train_test_split
339+
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)
340+
341+
return {
342+
"train": split_ds["train"],
343+
"validation": split_ds["test"],
344+
}
345+
346+
splits = prepare_train_and_eval_dataset()
339347
template_config = load_template_from_file(trainer_config.chat_template_path)
340348

341349
train_dataset = (
@@ -398,6 +406,7 @@ def _use_raw_prompt(x):
398406
dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction)
399407
train_dataset = train_dataset[:dataset_size]
400408
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
409+
401410
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
402411

403412
test_dataset = test_dataset.filter(_filter_long_prompts)
@@ -493,6 +502,11 @@ def create_rl_components(
493502
argv_list = ["", str(vllm_config_path), "log_config=False"]
494503
vllm_config = pyconfig.initialize(argv_list)
495504

505+
rl_rollout_engine = "vllm"
506+
model_name = trainer_config.model_name
507+
if model_name in ["qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-235b-a22b"]:
508+
rl_rollout_engine = functools.partial(MaxTextVllmRollout, maxtext_config=trainer_config)
509+
496510
cluster_config = rl_cluster_lib.ClusterConfig(
497511
role_to_mesh={
498512
rl_cluster_lib.Role.ACTOR: actor_mesh,
@@ -504,7 +518,7 @@ def create_rl_components(
504518
rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules,
505519
rl_cluster_lib.Role.ROLLOUT: vllm_config.logical_axis_rules,
506520
},
507-
rollout_engine="vllm",
521+
rollout_engine=rl_rollout_engine,
508522
offload_to_cpu=False,
509523
training_config=rl_cluster_lib.RLTrainingConfig(
510524
actor_optimizer=optimizer,
@@ -708,17 +722,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
708722

709723
# Before we train the model, let's evaluate the model on the test set so we can
710724
# see the improvement post training.
711-
if trainer_config.num_test_batches > 0:
712-
max_logging.warning("Starting evaluation before RL training...")
713-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
714-
trainer_config,
715-
test_dataset,
716-
rl_cluster=rl_cluster,
717-
num_passes=trainer_config.num_eval_passes,
718-
corr_lst=trainer_config.eval_corr_lst,
719-
make_lst=trainer_config.eval_make_lst,
720-
)
721-
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
725+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
726+
trainer_config,
727+
test_dataset,
728+
rl_cluster=rl_cluster,
729+
num_passes=trainer_config.num_eval_passes,
730+
corr_lst=trainer_config.eval_corr_lst,
731+
make_lst=trainer_config.eval_make_lst,
732+
)
733+
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
722734

723735
# Start training
724736
if trainer_config.load_checkpoint_only_once:
@@ -739,17 +751,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
739751
max_logging.warning("RL Training Completed Successfully!")
740752

741753
# Let's evaluate our model!
742-
if trainer_config.num_test_batches > 0:
743-
max_logging.warning("Starting evaluation after RL training...")
744-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
745-
trainer_config,
746-
test_dataset,
747-
rl_cluster=rl_cluster,
748-
num_passes=trainer_config.num_eval_passes,
749-
corr_lst=trainer_config.eval_corr_lst,
750-
make_lst=trainer_config.eval_make_lst,
751-
)
752-
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
754+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
755+
trainer_config,
756+
test_dataset,
757+
rl_cluster=rl_cluster,
758+
num_passes=trainer_config.num_eval_passes,
759+
corr_lst=trainer_config.eval_corr_lst,
760+
make_lst=trainer_config.eval_make_lst,
761+
)
762+
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
753763

754764

755765
def main(argv: Sequence[str]) -> None:
@@ -758,6 +768,7 @@ def main(argv: Sequence[str]) -> None:
758768
Args:
759769
argv: Command-line arguments.
760770
"""
771+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
761772
pathwaysutils.initialize()
762773
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
763774

0 commit comments

Comments
 (0)