6464from orbax import checkpoint as ocp
6565from pprint import pprint
6666from transformers import AutoTokenizer
67+ import functools
6768from tunix .rl import rl_cluster as rl_cluster_lib
6869from tunix .rl .rollout import base_rollout
6970from tunix .rl .grpo .grpo_learner import GrpoConfig , GrpoLearner
7071from tunix .rl .agentic .agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig , GrpoLearner as AgenticGrpoLearner
7172from tunix .sft import metrics_logger , profiler
7273
7374# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
74- os .environ ["SKIP_JAX_PRECOMPILE" ] = "1"
75+ os .environ ["SKIP_JAX_PRECOMPILE" ] = "0"
76+ os .environ ["PHASED_PROFILING_DIR" ] = "gs://mazumdera-bucket-tpu-prod-env-automated/qwen3-30b/0403"
77+ os .environ ["TOKENIZERS_PARALLELISM" ] = "0"
78+ # os.environ["TUNIX_DEBUG_REWARDS"] = "1"
7579
7680from maxtext .configs import pyconfig
7781from maxtext .utils .globals import MAXTEXT_CONFIGS_DIR
7882from maxtext .integration .tunix .tunix_adapter import TunixMaxTextAdapter
83+ from maxtext .integration .vllm .maxtext_vllm_rollout import MaxTextVllmRollout
7984from maxtext .trainers .post_train .rl .evaluate_rl import evaluate
8085from maxtext .trainers .post_train .rl import utils_rl
8186from maxtext .input_pipeline .instruction_data_processing import load_template_from_file
@@ -86,10 +91,10 @@ def get_maxtext_model(config, devices=None):
8691 """
8792 Load MaxText model with Tunix adapter.
8893 # Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
89- # To create a scanned checkpoint, you can use /maxtext/src/maxtext /checkpoint_conversion/to_maxtext.py and if
94+ # To create a scanned checkpoint, you can use /maxtext/src/MaxText /checkpoint_conversion/to_maxtext.py and if
9095 # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
9196 # export USE_PATHWAYS=1
92- # python src/maxtext /checkpoint_conversion/to_maxtext.py \
97+ # python src/MaxText /checkpoint_conversion/to_maxtext.py \
9398 # --model_name="gemma2-2b" \
9499 # --base_output_directory="/path/to/your/output/directory" \
95100 # --scan_layers=True \
@@ -116,18 +121,29 @@ def get_dataset(
116121 if dataset_name is None :
117122 raise ValueError ("dataset_name must be provided" )
118123
119- import datasets # pylint: disable=import-outside-toplevel
124+ if dataset_name .startswith ("huggingface:" ):
125+ import datasets # pylint: disable=import-outside-toplevel
120126
121- if data_files is None :
122- data = datasets .load_dataset (dataset_name , split = split , cache_dir = data_dir )
123- if tmvp_config .debug .rl :
124- max_logging .log (f"Loaded Hugging Face dataset { dataset_name } with split { split } . Size: { len (data )} " )
125- else : # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
126- data = datasets .load_dataset (
127- "parquet" ,
128- data_files = {tmvp_config .train_split : data_files },
127+ if data_files is None :
128+ hf_dataset_name = dataset_name .replace ("huggingface:" , "" )
129+ data = datasets .load_dataset (hf_dataset_name , split = split , cache_dir = data_dir )
130+ if tmvp_config .debug .rl :
131+ max_logging .log (f"Loaded Hugging Face dataset { hf_dataset_name } with split { split } . Size: { len (data )} " )
132+ else : # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
133+ data = datasets .load_dataset (
134+ "parquet" ,
135+ data_files = {tmvp_config .train_split : data_files },
136+ split = split ,
137+ cache_dir = data_dir ,
138+ )
139+ else :
140+ builder_kwargs = {"file_format" : tfds .core .FileFormat .ARRAY_RECORD }
141+ data = tfds .data_source (
142+ dataset_name ,
129143 split = split ,
130- cache_dir = data_dir ,
144+ data_dir = data_dir ,
145+ builder_kwargs = builder_kwargs ,
146+ download = True ,
131147 )
132148
133149 template_config = load_template_from_file (tmvp_config .chat_template_path )
@@ -284,37 +300,6 @@ def get_max_train_steps(trainer_config):
284300 )
285301
286302
287- def prepare_train_and_eval_dataset (
288- trainer_config ,
289- seed : int = 42 ,
290- test_size : float = 0.05 ,
291- ):
292- """Load and split the dataset into train and validation sets using HF's train_test_split."""
293- import datasets # pylint: disable=import-outside-toplevel
294-
295- max_logging .log (
296- "WARNING: For reproducible experiments, preprocess the dataset once and "
297- "define your own HfDataset subclass that directly uses the preprocessed datasets."
298- )
299-
300- original_ds = datasets .load_dataset (
301- "parquet" ,
302- data_files = {trainer_config .train_split : trainer_config .hf_train_files },
303- split = trainer_config .train_split ,
304- )
305-
306- if "OpenMathReasoning" in trainer_config .dataset_name :
307- original_ds = original_ds .filter (lambda x : x .get ("problem_type" ) == "has_answer_extracted" )
308-
309- # Split into train and validation sets using HF's train_test_split
310- split_ds = original_ds .train_test_split (test_size = test_size , seed = seed )
311-
312- return {
313- "train" : split_ds ["train" ],
314- "validation" : split_ds ["test" ],
315- }
316-
317-
318303def prepare_datasets (trainer_config , model_tokenizer ):
319304 """Setup and return train and test datasets."""
320305 home = os .path .expanduser ("~" ) + "/"
@@ -326,16 +311,39 @@ def prepare_datasets(trainer_config, model_tokenizer):
326311 os .makedirs (test_data_dir )
327312
328313 # Prepare train and test data from training data for certain datasets
329- eval_dataset_name = getattr (trainer_config , "eval_dataset_name" , None )
330- if trainer_config .dataset_name in [
331- "nvidia/OpenMathInstruct-2" ,
332- "nvidia/OpenMathReasoning" ,
333- "open-r1/OpenR1-Math-220k" ,
334- "bethgelab/CuratedThoughts" ,
335- ] and (not eval_dataset_name or eval_dataset_name == trainer_config .dataset_name ):
314+ if trainer_config .dataset_name in ["nvidia/OpenMathInstruct-2" , "nvidia/OpenMathReasoning" , "open-r1/OpenR1-Math-220k" , "bethgelab/CuratedThoughts" ]:
336315 import datasets # pylint: disable=import-outside-toplevel
337-
338- splits = prepare_train_and_eval_dataset (trainer_config )
316+
317+ def prepare_train_and_eval_dataset (
318+ seed : int = 42 ,
319+ test_size : float = 0.05 ,
320+ ):
321+ """Load and split the dataset into train and validation sets using HF's train_test_split."""
322+ max_logging .log (
323+ "WARNING: For reproducible experiments, preprocess the dataset once and "
324+ "define your own HfDataset subclass that directly uses the preprocessed datasets."
325+ )
326+
327+ # Load the original dataset
328+ original_ds = datasets .load_dataset (
329+ "parquet" ,
330+ data_files = {trainer_config .train_split : trainer_config .hf_train_files },
331+ split = trainer_config .train_split ,
332+ )
333+
334+ if "OpenMathReasoning" in trainer_config .dataset_name :
335+ original_ds = original_ds .filter (lambda x : x .get ("problem_type" ) == "has_answer_extracted" )
336+
337+
338+ # Split into train and validation sets using HF's train_test_split
339+ split_ds = original_ds .train_test_split (test_size = test_size , seed = seed )
340+
341+ return {
342+ "train" : split_ds ["train" ],
343+ "validation" : split_ds ["test" ],
344+ }
345+
346+ splits = prepare_train_and_eval_dataset ()
339347 template_config = load_template_from_file (trainer_config .chat_template_path )
340348
341349 train_dataset = (
@@ -398,6 +406,7 @@ def _use_raw_prompt(x):
398406 dataset_size = int (trainer_config .num_batches * trainer_config .batch_size * trainer_config .train_fraction )
399407 train_dataset = train_dataset [:dataset_size ]
400408 train_dataset = train_dataset .repeat (trainer_config .num_epoch )
409+
401410 train_dataset = train_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
402411
403412 test_dataset = test_dataset .filter (_filter_long_prompts )
@@ -493,6 +502,11 @@ def create_rl_components(
493502 argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
494503 vllm_config = pyconfig .initialize (argv_list )
495504
505+ rl_rollout_engine = "vllm"
506+ model_name = trainer_config .model_name
507+ if model_name in ["qwen3-30b-a3b" , "qwen3-30b-a3b-base" , "qwen3-235b-a22b" ]:
508+ rl_rollout_engine = functools .partial (MaxTextVllmRollout , maxtext_config = trainer_config )
509+
496510 cluster_config = rl_cluster_lib .ClusterConfig (
497511 role_to_mesh = {
498512 rl_cluster_lib .Role .ACTOR : actor_mesh ,
@@ -504,7 +518,7 @@ def create_rl_components(
504518 rl_cluster_lib .Role .REFERENCE : trainer_config .logical_axis_rules ,
505519 rl_cluster_lib .Role .ROLLOUT : vllm_config .logical_axis_rules ,
506520 },
507- rollout_engine = "vllm" ,
521+ rollout_engine = rl_rollout_engine ,
508522 offload_to_cpu = False ,
509523 training_config = rl_cluster_lib .RLTrainingConfig (
510524 actor_optimizer = optimizer ,
@@ -708,17 +722,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
708722
709723 # Before we train the model, let's evaluate the model on the test set so we can
710724 # see the improvement post training.
711- if trainer_config .num_test_batches > 0 :
712- max_logging .warning ("Starting evaluation before RL training..." )
713- (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
714- trainer_config ,
715- test_dataset ,
716- rl_cluster = rl_cluster ,
717- num_passes = trainer_config .num_eval_passes ,
718- corr_lst = trainer_config .eval_corr_lst ,
719- make_lst = trainer_config .eval_make_lst ,
720- )
721- max_logging .warning (f"Pre RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
725+ (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
726+ trainer_config ,
727+ test_dataset ,
728+ rl_cluster = rl_cluster ,
729+ num_passes = trainer_config .num_eval_passes ,
730+ corr_lst = trainer_config .eval_corr_lst ,
731+ make_lst = trainer_config .eval_make_lst ,
732+ )
733+ max_logging .warning (f"Pre RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
722734
723735 # Start training
724736 if trainer_config .load_checkpoint_only_once :
@@ -739,17 +751,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
739751 max_logging .warning ("RL Training Completed Successfully!" )
740752
741753 # Let's evaluate our model!
742- if trainer_config .num_test_batches > 0 :
743- max_logging .warning ("Starting evaluation after RL training..." )
744- (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
745- trainer_config ,
746- test_dataset ,
747- rl_cluster = rl_cluster ,
748- num_passes = trainer_config .num_eval_passes ,
749- corr_lst = trainer_config .eval_corr_lst ,
750- make_lst = trainer_config .eval_make_lst ,
751- )
752- max_logging .warning (f"Post RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
754+ (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
755+ trainer_config ,
756+ test_dataset ,
757+ rl_cluster = rl_cluster ,
758+ num_passes = trainer_config .num_eval_passes ,
759+ corr_lst = trainer_config .eval_corr_lst ,
760+ make_lst = trainer_config .eval_make_lst ,
761+ )
762+ max_logging .warning (f"Post RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
753763
754764
755765def main (argv : Sequence [str ]) -> None :
@@ -758,6 +768,7 @@ def main(argv: Sequence[str]) -> None:
758768 Args:
759769 argv: Command-line arguments.
760770 """
771+ jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
761772 pathwaysutils .initialize ()
762773 os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
763774
0 commit comments