diff --git a/src/dependencies/github_deps/post_train_base_deps.txt b/src/dependencies/github_deps/post_train_base_deps.txt index 5b4b9ce474..27396d442f 100644 --- a/src/dependencies/github_deps/post_train_base_deps.txt +++ b/src/dependencies/github_deps/post_train_base_deps.txt @@ -1 +1 @@ -google-tunix @ https://github.com/google/tunix/archive/336d102fe32ca0edbe42a8f66ff0fd533cebdf52.zip +google-tunix @ https://github.com/google/tunix/archive/f8d33d28fefdb490a7b0d2279cd906213d401bf4.zip diff --git a/src/dependencies/github_deps/post_train_deps.txt b/src/dependencies/github_deps/post_train_deps.txt index 7bd07a345e..f45ebac15a 100644 --- a/src/dependencies/github_deps/post_train_deps.txt +++ b/src/dependencies/github_deps/post_train_deps.txt @@ -1,5 +1,5 @@ -r post_train_base_deps.txt google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip -tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/0cae84fc9a883ba1bde02d4f07930e6af9e92958.zip -vllm @ git+https://github.com/vllm-project/vllm@ee8a29511fc69e3f0f6291fa6ff1cf6e47f7750d +tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/a99fb338f8810755851d7344b151c509b2719419.zip +vllm @ git+https://github.com/vllm-project/vllm@c012a8c477dd78b4444f22568b2bf1b08f2ad813 diff --git a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt index 364465d935..5b6c06f4d1 100644 --- a/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -51,7 +51,7 @@ fastapi>=0.122.0 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.1 -fonttools>=4.60.1 +fonttools>=4.61.1 frozenlist>=1.8.0 fsspec>=2025.10.0 gast>=0.6.0 diff --git a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt index 60dcb1eb86..093e96e591 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt @@ -76,7 +76,7 @@ fastjsonschema>=2.21.2 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.4 -fonttools>=4.60.1 +fonttools>=4.61.1 frozenlist>=1.8.0 fsspec>=2026.1.0 gast>=0.6.0 diff --git a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt index 08da4a3ab7..3e4dbb326a 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -52,7 +52,7 @@ fastapi>=0.122.0 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.6 -fonttools>=4.60.1 +fonttools>=4.61.1 frozenlist>=1.8.0 fsspec>=2025.10.0 gast>=0.6.0 diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 9cc305d090..119d291b85 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -28,6 +28,8 @@ num_samplers_slices: -1 rollout_data_parallelism: -1 rollout_tensor_parallelism: -1 rollout_expert_parallelism: 1 +rollout_subslice_shape: "" # e.g. '2,2,1' for 4 chips with DP=2, TP=2, EP=1 +rollout_enable_single_controller: False # If True, use a single controller for rollout. This can help with stability when using more than 1 model replica in rollout. # ====== Reproducibility ====== data_shuffle_seed: 42 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9ee1a7bb59..92e0205745 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1622,6 +1622,11 @@ class RLHardware(BaseModel): description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.", ) rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout") + rollout_subslice_shape: str = Field("", description="Subslice shape for rollout in the form of 'x,y,z' for Pathways.") + rollout_enable_single_controller: bool = Field( + False, + description="Whether to enable single-controller mode for rollout. If True, the trainer will also run the rollout and sampling computations instead of launching separate processes. This is only recommended for debugging or if the rollout computation is very small and can be efficiently handled by a single controller.", + ) class VLLM(BaseModel): diff --git a/src/maxtext/trainers/post_train/rl/create_script.py b/src/maxtext/trainers/post_train/rl/create_script.py new file mode 100644 index 0000000000..8a0a73156f --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_script.py @@ -0,0 +1,215 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token, + extra_config, + ici_fsdp_parallelism, + ici_tensor_parallelism, + num_samplers_slices, + num_trainer_slices, + num_slices +): + script_template = """#!/bin/bash +CLUSTER_NAME=next-devx-1 +DEVICE_TYPE=tpu7x-4x4x4 +PROJECT=tpu-prod-env-automated +ZONE=us-central1 +IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + +command="pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ +pip install src/maxtext/integration/vllm && \\ +HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ +python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ +model_name=qwen3-30b-a3b \\ +tokenizer_path=Qwen/Qwen3-30B-A3B \\ +run_name={{ run_name }} \\ +base_output_directory={{ base_output_directory }} \\ +hf_access_token={{ hf_token }} \\ +batch_size={{ batch_size }} \\ +ici_fsdp_parallelism={{ ici_fsdp_parallelism }} \\ +ici_tensor_parallelism={{ ici_tensor_parallelism }} \\ +rl.num_generations=8 \\ +num_batches=10 \\ +rollout_data_parallelism={{ rollout_data_parallelism }} \\ +rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ +rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ +hbm_utilization_vllm=0.2 \\ +scan_layers=True \\ +allow_split_physical_axes=True \\ +vllm_hf_overrides='{architectures: [\\"MaxTextForCausalLM\\"]}' \\ +vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ +trainer_devices_fraction={{ trainer_devices_fraction }} \\ +subslice_shape='{{ subslice_shape }}' \\ +enable_single_controller={{ enable_single_controller }} \\ +sampler_devices_fraction={{ sampler_devices_fraction }} \\ +num_samplers_slices={{ num_samplers_slices }} \\ +num_trainer_slices={{ num_trainer_slices }} {{extra_config}}" + +xpk workload create-pathways --workload {{ metadata_name }} \\ +--docker-image ${IMAGE_DIR} \\ +--cluster ${CLUSTER_NAME} \\ +--tpu-type=${DEVICE_TYPE} \\ +--project=$PROJECT \\ +--zone=$ZONE \\ +--num-slices={{ num_slices }} \\ +--priority=very-high \\ +--custom-pathways-worker-args="--xprof_max_trace_buffers=16384" \\ +--command "${command}" +""" + + t = Template(script_template) + rendered_script = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name, + hf_token=hf_token, + extra_config=extra_config, + ici_fsdp_parallelism=ici_fsdp_parallelism, + ici_tensor_parallelism=ici_tensor_parallelism, + num_samplers_slices=num_samplers_slices, + num_trainer_slices=num_trainer_slices, + num_slices=num_slices + ) + + return rendered_script + +# Example Usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/create_script_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${hf_token}" \ + --store_directory "${store_path}" \ + --enable_tp "${enable_tp}" +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--enable_tp", action="store_true", default=False, help="Enable tensor parallelism") + parser.add_argument("--enable_ep", action="store_true", default=False, help="Enable expert parallelism") + args = parser.parse_args() + print(vars(args)) + + + # for v7x-128 + extra_config = "" + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + if args.trainer_chips >= 64: + ici_tensor_parallelism = 2 + ici_fsdp_parallelism = 64 + assert args.trainer_chips % 64 == 0, "trainer_chips must be a multiple of 64 when using multiple slices" + else: + ici_fsdp_parallelism = -1 + ici_tensor_parallelism = 1 + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas + if args.enable_tp and args.enable_ep: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism + elif args.enable_ep: + rollout_tensor_parallelism = 1 + rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 + elif args.enable_tp: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = 1 + extra_config += " enable_dp_attention=True" if rollout_tensor_parallelism >= 4 else "" + else: + assert False, "At least one of tensor parallelism or expert parallelism must be enabled" + + if args.trainer_chips < number_of_chips: + trainer_devices_fraction = args.trainer_chips / number_of_chips + num_trainer_slices = -1 + else: + trainer_devices_fraction = 1.0 + num_trainer_slices = args.trainer_chips // number_of_chips + assert args.trainer_chips % number_of_chips == 0, "trainer_chips must be a multiple of available chips when trainer_devices_fraction is 1.0" + + if sampler_chips < number_of_chips: + sampler_devices_fraction = sampler_chips / number_of_chips + num_samplers_slices = -1 + else: + sampler_devices_fraction = 1.0 + num_samplers_slices = sampler_chips // number_of_chips + assert sampler_chips % number_of_chips == 0, "Total number of sampler chips must be a multiple of available chips when sampler_devices_fraction is 1.0" + + num_slices = max(1, num_trainer_slices + num_samplers_slices) + + if args.trainer_chips == 4: + enable_single_controller = "true" + else: + enable_single_controller = "false" + + subslice_shape_status = { + 1: "1,1,1", + 2: "2,1,1", + 4: "2,2,1", + 8: "2,2,2", + 16: "2,2,4", + 32: "2,4,4", + 64: "4,4,4", + 128: "4,4,8"} + subslice_shape = subslice_shape_status.get(args.trainer_chips, "") + + output_directory = os.path.join(args.base_output_directory, args.metadata_name) + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=output_directory, + run_name=args.metadata_name, + hf_token=args.hf_token, + extra_config=extra_config, + ici_fsdp_parallelism=ici_fsdp_parallelism, + ici_tensor_parallelism=ici_tensor_parallelism, + num_samplers_slices=num_samplers_slices, + num_trainer_slices=num_trainer_slices, + num_slices=num_slices + ) + # if the script directory does not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_script_path = os.path.join(args.store_directory, f"{args.metadata_name}.sh") + + with open(output_script_path, "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/create_script_235b.py b/src/maxtext/trainers/post_train/rl/create_script_235b.py new file mode 100644 index 0000000000..6d31be84c6 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_script_235b.py @@ -0,0 +1,216 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token, + extra_config, + ici_fsdp_parallelism, + ici_tensor_parallelism, + num_samplers_slices, + num_trainer_slices, + num_slices +): + script_template = """#!/bin/bash +CLUSTER_NAME=next-devx-1 +DEVICE_TYPE=tpu7x-4x4x4 +PROJECT=tpu-prod-env-automated +ZONE=us-central1 +IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + +command="pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ +pip install src/maxtext/integration/vllm && \\ +HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ +python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ +model_name=qwen3-235b-a22b \\ +tokenizer_path=Qwen/Qwen3-235B-A22B \\ +run_name={{ run_name }} \\ +base_output_directory={{ base_output_directory }} \\ +hf_access_token={{ hf_token }} \\ +async_scheduling=False max_target_length=640 \\ +batch_size={{ batch_size }} \\ +ici_fsdp_parallelism={{ ici_fsdp_parallelism }} \\ +ici_tensor_parallelism={{ ici_tensor_parallelism }} \\ +rl.num_generations=16 \\ +num_batches=10 \\ +rollout_data_parallelism={{ rollout_data_parallelism }} \\ +rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ +rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ +hbm_utilization_vllm=0.4 \\ +scan_layers=True \\ +allow_split_physical_axes=True \\ +vllm_hf_overrides='{architectures: [\\"MaxTextForCausalLM\\"]}' \\ +vllm_additional_config='{maxtext_config: {model_name: qwen3-235b-a22b, model_call_mode: inference, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ +trainer_devices_fraction={{ trainer_devices_fraction }} \\ +subslice_shape='{{ subslice_shape }}' \\ +enable_single_controller={{ enable_single_controller }} \\ +sampler_devices_fraction={{ sampler_devices_fraction }} \\ +num_samplers_slices={{ num_samplers_slices }} \\ +num_trainer_slices={{ num_trainer_slices }} {{extra_config}}" + +xpk workload create-pathways --workload {{ metadata_name }} \\ +--docker-image ${IMAGE_DIR} \\ +--cluster ${CLUSTER_NAME} \\ +--tpu-type=${DEVICE_TYPE} \\ +--project=$PROJECT \\ +--zone=$ZONE \\ +--num-slices={{ num_slices }} \\ +--priority=very-high \\ +--custom-pathways-worker-args="--xprof_max_trace_buffers=16384" \\ +--command "${command}" +""" + + t = Template(script_template) + rendered_script = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name, + hf_token=hf_token, + extra_config=extra_config, + ici_fsdp_parallelism=ici_fsdp_parallelism, + ici_tensor_parallelism=ici_tensor_parallelism, + num_samplers_slices=num_samplers_slices, + num_trainer_slices=num_trainer_slices, + num_slices=num_slices + ) + + return rendered_script + +# Example Usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/create_script.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${hf_token}" \ + --store_directory "${store_path}" \ + --enable_tp "${enable_tp}" +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--enable_tp", action="store_true", default=False, help="Enable tensor parallelism") + parser.add_argument("--enable_ep", action="store_true", default=False, help="Enable expert parallelism") + args = parser.parse_args() + print(vars(args)) + + + # for v7x-128 + extra_config = "" + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + if args.trainer_chips >= 64: + ici_tensor_parallelism = 2 + ici_fsdp_parallelism = 64 + assert args.trainer_chips % 64 == 0, "trainer_chips must be a multiple of 64 when using multiple slices" + else: + ici_fsdp_parallelism = -1 + ici_tensor_parallelism = 1 + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas + # assert sampler_chips + args.trainer_chips <= number_of_chips, "Total number of chips used by trainer and sampler must be less than or equal to available chips" + if args.enable_tp and args.enable_ep: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism + elif args.enable_ep: + rollout_tensor_parallelism = 1 + rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 + elif args.enable_tp: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = 1 + extra_config += " enable_dp_attention=True" if rollout_tensor_parallelism >= 4 else "" + else: + assert False, "At least one of tensor parallelism or expert parallelism must be enabled" + if args.trainer_chips < number_of_chips: + trainer_devices_fraction = args.trainer_chips / number_of_chips + num_trainer_slices = -1 + else: + trainer_devices_fraction = 1.0 + num_trainer_slices = args.trainer_chips // number_of_chips + assert args.trainer_chips % number_of_chips == 0, "trainer_chips must be a multiple of available chips when trainer_devices_fraction is 1.0" + + if sampler_chips < number_of_chips: + sampler_devices_fraction = sampler_chips / number_of_chips + num_samplers_slices = -1 + else: + sampler_devices_fraction = 1.0 + num_samplers_slices = sampler_chips // number_of_chips + assert sampler_chips % number_of_chips == 0, "Total number of sampler chips must be a multiple of available chips when sampler_devices_fraction is 1.0" + + num_slices = max(1, num_trainer_slices + num_samplers_slices) + + if args.trainer_chips == 4: + enable_single_controller = "true" + else: + enable_single_controller = "false" + + subslice_shape_status = { + 1: "1,1,1", + 2: "2,1,1", + 4: "2,2,1", + 8: "2,2,2", + 16: "2,2,4", + 32: "2,4,4", + 64: "4,4,4", + 128: "4,4,8"} + subslice_shape = subslice_shape_status.get(args.trainer_chips, "") + + output_directory = os.path.join(args.base_output_directory, args.metadata_name) + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=output_directory, + run_name=args.metadata_name, + hf_token=args.hf_token, + extra_config=extra_config, + ici_fsdp_parallelism=ici_fsdp_parallelism, + ici_tensor_parallelism=ici_tensor_parallelism, + num_samplers_slices=num_samplers_slices, + num_trainer_slices=num_trainer_slices, + num_slices=num_slices + ) + # if the script directory does not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_script_path = os.path.join(args.store_directory, f"{args.metadata_name}.sh") + + with open(output_script_path, "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/create_yaml.py b/src/maxtext/trainers/post_train/rl/create_yaml.py new file mode 100644 index 0000000000..fce72e7047 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_yaml.py @@ -0,0 +1,398 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token, + extra_config +): + yaml_template = """apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + labels: + kueue.x-k8s.io/queue-name: multislice-queue + name: {{ metadata_name }} + namespace: default +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + kueue.x-k8s.io/safe-to-forcefully-terminate: "true" + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + spec: + containers: + - command: + - bash + - -c + - | + echo XPK Start: $(date); + _sigterm() (kill -SIGTERM $! 2>/dev/null;); + trap _sigterm SIGTERM; + + (pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ + pip install src/maxtext/integration/vllm && \\ + HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ + python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ + model_name=qwen3-30b-a3b \\ + tokenizer_path=Qwen/Qwen3-30B-A3B \\ + run_name={{ run_name }} \\ + base_output_directory={{ base_output_directory }} \\ + hf_access_token={{ hf_token }} \\ + batch_size={{ batch_size }} \\ + rl.num_generations=8 \\ + num_batches=10 \\ + rollout_data_parallelism={{ rollout_data_parallelism }} \\ + rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ + rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ + hbm_utilization_vllm=0.6 \\ + scan_layers=True \\ + allow_split_physical_axes=True \\ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \\ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ + trainer_devices_fraction={{ trainer_devices_fraction }} \\ + subslice_shape='{{ subslice_shape }}' \\ + enable_single_controller={{ enable_single_controller }} \\ + sampler_devices_fraction={{ sampler_devices_fraction }} {{extra_config}}) & PID=$!; + + while kill -0 $PID 2>/dev/null; + do sleep 5; + done; + wait $PID; + EXIT_CODE=$?; + + echo XPK End: $(date); + echo EXIT_CODE=$EXIT_CODE; + + exit $EXIT_CODE + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: JAX_PLATFORMS + value: proxy + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: JAX_BACKEND_TARGET + value: grpc://$(PATHWAYS_HEAD):29000 + image: gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + imagePullPolicy: Always + name: jax-tpu + resources: + limits: + cpu: "24" + memory: 100G + securityContext: + privileged: true + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + initContainers: + - args: + - --server_port=29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + - --node_type=resource_manager + - --instance_count=1 + - --instance_type=tpu7x:4x4x4 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-rm + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 32G + restartPolicy: Always + - args: + - --server_port=29000 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + restartPolicy: Always + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + restartPolicy: Never + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + - name: worker + replicas: 1 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + backoffLimit: 32 + completionMode: Indexed + completions: 16 + parallelism: 16 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + tolerations: + - key: "google.com/tpu" + operator: "Equal" + value: "present" + effect: "NoSchedule" + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-tpu-partition-4x4x4-state + operator: In + values: + - HEALTHY + - DEGRADED + containers: + - args: + - --server_port=29005 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-worker + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "4" + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu7x + priorityClassName: medium + restartPolicy: OnFailure + terminationGracePeriodSeconds: 30 + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + startupPolicy: + startupPolicyOrder: InOrder + successPolicy: + operator: All + targetReplicatedJobs: + - pathways-head""" + + t = Template(yaml_template) + rendered_yaml = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name, + hf_token=hf_token, + extra_config=extra_config + ) + + return rendered_yaml + +# Example Usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${hf_token}" \ + --store_directory "${store_path}" \ + --enable_tp "${enable_tp}" +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--enable_tp", action="store_true", default=False, help="Enable tensor parallelism") + parser.add_argument("--enable_ep", action="store_true", default=False, help="Enable expert parallelism") + args = parser.parse_args() + print(vars(args)) + + + # for v7x-128 + extra_config = "" + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + trainer_devices_fraction = args.trainer_chips / number_of_chips + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas + assert sampler_chips + args.trainer_chips <= number_of_chips, "Total number of chips used by trainer and sampler must be less than or equal to available chips" + if args.enable_tp and args.enable_ep: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism + elif args.enable_ep: + rollout_tensor_parallelism = 1 + rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 + elif args.enable_tp: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = 1 + extra_config += " enable_dp_attention=True" if rollout_tensor_parallelism >= 4 else "" + else: + assert False, "At least one of tensor parallelism or expert parallelism must be enabled" + + sampler_devices_fraction = sampler_chips / number_of_chips + if args.trainer_chips == 4: + enable_single_controller = "true" + else: + enable_single_controller = "false" + + subslice_shape_status = { + 1: "1,1,1", + 2: "2,1,1", + 4: "2,2,1", + 8: "2,2,2", + 16: "2,2,4", + 32: "2,4,4", + 64: "4,4,4", + 128: "4,4,8"} + subslice_shape = subslice_shape_status.get(args.trainer_chips, "") + + output_directory = os.path.join(args.base_output_directory, args.metadata_name) + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=output_directory, + run_name=args.metadata_name, + hf_token=args.hf_token, + extra_config=extra_config + ) + # if the yaml directory does not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_yaml_path = os.path.join(args.store_directory, f"{args.metadata_name}.yaml") + + with open(output_yaml_path, "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/create_yaml_235b.py b/src/maxtext/trainers/post_train/rl/create_yaml_235b.py new file mode 100644 index 0000000000..533f4f0147 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/create_yaml_235b.py @@ -0,0 +1,399 @@ +import os +from jinja2 import Template +import argparse + +def generate_rl_config( + metadata_name, + batch_size, + rollout_data_parallelism, + rollout_tensor_parallelism, + rollout_expert_parallelism, + trainer_devices_fraction, + subslice_shape, + enable_single_controller, + sampler_devices_fraction, + base_output_directory, + run_name, + hf_token, + extra_config +): + yaml_template = """apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + labels: + kueue.x-k8s.io/queue-name: multislice-queue + name: {{ metadata_name }} + namespace: default +spec: + coordinator: + replicatedJob: pathways-head + failurePolicy: + maxRestarts: 1 + restartStrategy: Recreate + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: pathways-head + replicas: 1 + template: + metadata: + annotations: + kueue.x-k8s.io/safe-to-forcefully-terminate: "true" + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 1 + parallelism: 1 + template: + spec: + containers: + - command: + - bash + - -c + - | + echo XPK Start: $(date); + _sigterm() (kill -SIGTERM $! 2>/dev/null;); + trap _sigterm SIGTERM; + + (pip install --no-deps git+https://github.com/AI-Hypercomputer/pathways-utils.git@v0.1.4 && \\ + pip install src/maxtext/integration/vllm && \\ + HF_TOKEN={{ hf_token }} JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 \\ + python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \\ + model_name=qwen3-235b-a22b \\ + tokenizer_path=Qwen/Qwen3-235B-A22B \\ + run_name={{ run_name }} \\ + base_output_directory={{ base_output_directory }} \\ + hf_access_token={{ hf_token }} \\ + async_scheduling=False max_target_length=640 \\ + batch_size={{ batch_size }} \\ + rl.num_generations=16 \\ + num_batches=10 \\ + rollout_data_parallelism={{ rollout_data_parallelism }} \\ + rollout_tensor_parallelism={{ rollout_tensor_parallelism }} \\ + rollout_expert_parallelism={{ rollout_expert_parallelism }} \\ + hbm_utilization_vllm=0.4 \\ + scan_layers=True \\ + allow_split_physical_axes=True \\ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \\ + vllm_additional_config='{maxtext_config: {model_name: qwen3-235b-a22b, model_call_mode: inference, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' \\ + trainer_devices_fraction={{ trainer_devices_fraction }} \\ + subslice_shape='{{ subslice_shape }}' \\ + enable_single_controller={{ enable_single_controller }} \\ + sampler_devices_fraction={{ sampler_devices_fraction }} {{extra_config}}) & PID=$!; + + while kill -0 $PID 2>/dev/null; + do sleep 5; + done; + wait $PID; + EXIT_CODE=$?; + + echo XPK End: $(date); + echo EXIT_CODE=$EXIT_CODE; + + exit $EXIT_CODE + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: JAX_PLATFORMS + value: proxy + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: JAX_BACKEND_TARGET + value: grpc://$(PATHWAYS_HEAD):29000 + image: gcr.io/cloud-tpu-multipod-dev/sanbao/maxtext_reshard_image:latest + imagePullPolicy: Always + name: jax-tpu + resources: + limits: + cpu: "24" + memory: 100G + securityContext: + privileged: true + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + initContainers: + - args: + - --server_port=29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + - --node_type=resource_manager + - --instance_count=1 + - --instance_type=tpu7x:4x4x4 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: HOST_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: TPU_SKIP_MDS_QUERY + value: "true" + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-rm + ports: + - containerPort: 29001 + protocol: TCP + - containerPort: 29002 + protocol: TCP + resources: + limits: + cpu: "8" + memory: 32G + restartPolicy: Always + - args: + - --server_port=29000 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + restartPolicy: Always + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + restartPolicy: Never + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + - name: worker + replicas: 1 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + backoffLimit: 32 + completionMode: Indexed + completions: 16 + parallelism: 16 + template: + metadata: + annotations: + cloud.google.com/gke-tpu-slice-topology: 4x4x4 + spec: + tolerations: + - key: "google.com/tpu" + operator: "Equal" + value: "present" + effect: "NoSchedule" + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-tpu-partition-4x4x4-state + operator: In + values: + - HEALTHY + - DEGRADED + containers: + - args: + - --server_port=29005 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://cloud-pathways-staging/tmp + env: + - name: TPU_MIN_LOG_LEVEL + value: "0" + - name: TF_CPP_MIN_LOG_LEVEL + value: "0" + - name: XCLOUD_ENVIRONMENT + value: GCP + - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER + value: "false" + - name: MEGASCALE_NUM_SLICES + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: MEGASCALE_SLICE_ID + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index'] + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: MEGASCALE_COORDINATOR_ADDRESS + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest + imagePullPolicy: Always + name: pathways-worker + ports: + - containerPort: 29005 + protocol: TCP + - containerPort: 29006 + protocol: TCP + - containerPort: 8471 + protocol: TCP + - containerPort: 8080 + protocol: TCP + resources: + limits: + google.com/tpu: "4" + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu7x + priorityClassName: medium + restartPolicy: OnFailure + terminationGracePeriodSeconds: 30 + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + startupPolicy: + startupPolicyOrder: InOrder + successPolicy: + operator: All + targetReplicatedJobs: + - pathways-head""" + + t = Template(yaml_template) + rendered_yaml = t.render( + metadata_name=metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=base_output_directory, + run_name=run_name, + hf_token=hf_token, + extra_config=extra_config + ) + + return rendered_yaml + +# Example Usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${hf_token}" \ + --store_directory "${store_path}" \ + --enable_tp "${enable_tp}" +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_name", type=str, required=True) + parser.add_argument("--trainer_chips", type=int, required=True) + parser.add_argument("--number_of_sampler_chips_per_replica", type=int, required=True) + parser.add_argument("--sampler_replicas", type=int, required=True) + parser.add_argument("--base_output_directory", type=str, required=True) + parser.add_argument("--hf_token", type=str, required=True) + parser.add_argument("--store_directory", type=str, required=True) + parser.add_argument("--enable_tp", action="store_true", default=False, help="Enable tensor parallelism") + parser.add_argument("--enable_ep", action="store_true", default=False, help="Enable expert parallelism") + args = parser.parse_args() + print(vars(args)) + + + # for v7x-128 + extra_config = "" + number_of_chips = 64 + batch_size = args.trainer_chips * 2 + trainer_devices_fraction = args.trainer_chips / number_of_chips + rollout_data_parallelism = args.sampler_replicas + sampler_chips = args.number_of_sampler_chips_per_replica * args.sampler_replicas + assert sampler_chips + args.trainer_chips <= number_of_chips, "Total number of chips used by trainer and sampler must be less than or equal to available chips" + if args.enable_tp and args.enable_ep: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = rollout_tensor_parallelism // 4 if rollout_tensor_parallelism >= 4 else 1 + assert rollout_tensor_parallelism % rollout_expert_parallelism == 0, "rollout_tensor_parallelism must be divisible by rollout_expert_parallelism" + rollout_tensor_parallelism = 4 if rollout_tensor_parallelism >= 4 else rollout_tensor_parallelism + elif args.enable_ep: + rollout_tensor_parallelism = 1 + rollout_expert_parallelism = args.number_of_sampler_chips_per_replica * 2 + elif args.enable_tp: + rollout_tensor_parallelism = args.number_of_sampler_chips_per_replica * 2 + rollout_expert_parallelism = 1 + extra_config += " enable_dp_attention=True" if rollout_tensor_parallelism >= 4 else "" + else: + assert False, "At least one of tensor parallelism or expert parallelism must be enabled" + + sampler_devices_fraction = sampler_chips / number_of_chips + if args.trainer_chips == 4: + enable_single_controller = "true" + else: + enable_single_controller = "false" + + subslice_shape_status = { + 1: "1,1,1", + 2: "2,1,1", + 4: "2,2,1", + 8: "2,2,2", + 16: "2,2,4", + 32: "2,4,4", + 64: "4,4,4", + 128: "4,4,8"} + subslice_shape = subslice_shape_status.get(args.trainer_chips, "") + + output_directory = os.path.join(args.base_output_directory, args.metadata_name) + + result = generate_rl_config( + metadata_name=args.metadata_name, + batch_size=batch_size, + rollout_data_parallelism=rollout_data_parallelism, + rollout_tensor_parallelism=rollout_tensor_parallelism, + rollout_expert_parallelism=rollout_expert_parallelism, + trainer_devices_fraction=trainer_devices_fraction, + subslice_shape=subslice_shape, + enable_single_controller=enable_single_controller, + sampler_devices_fraction=sampler_devices_fraction, + base_output_directory=output_directory, + run_name=args.metadata_name, + hf_token=args.hf_token, + extra_config=extra_config + ) + # if the yaml directory does not exist, create it + if not os.path.exists(args.store_directory): + os.makedirs(args.store_directory) + output_yaml_path = os.path.join(args.store_directory, f"{args.metadata_name}.yaml") + + with open(output_yaml_path, "w") as f: + f.write(result) \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/extract_time.py b/src/maxtext/trainers/post_train/rl/extract_time.py new file mode 100644 index 0000000000..93baa3d31b --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/extract_time.py @@ -0,0 +1,169 @@ +import argparse +import re +import urllib.parse +import pandas as pd +from google.cloud import logging +from google.cloud.logging import DESCENDING +from datetime import datetime, timedelta, timezone +import os +import json + +def get_reshard_data(args): + client = logging.Client(project=args.project_name) + + # 1. Define a narrow time window (last 5 days) + # This prevents the API from searching the entire history of the project + start_time = (datetime.now(timezone.utc) - timedelta(days=5)).strftime('%Y-%m-%dT%H:%M:%SZ') + + # 2. Build the filter to search for both reshard and weight sync times. + # We replace SEARCH() with textPayload: which is the API equivalent + log_filter = ( + f'resource.type="k8s_container" ' + f'resource.labels.location="us-central1" ' + f'resource.labels.project_id="{args.project_name}" ' + f'resource.labels.cluster_name="{args.cluster_name}" ' + f'resource.labels.namespace_name="default" ' + f'resource.labels.pod_name:"{args.pod_name}" ' + f'severity>=DEFAULT ' + f'timestamp >= "{start_time}" ' + f'("Reshard finished in" OR "Weight Syncing Time taken:" OR ("Using" AND "GiB on"))' + ) + + print(f"Querying logs from the last 5 days (Newest first)...") + + # 3. Use order_by=DESCENDING to find recent logs immediately + entries = client.list_entries(filter_=log_filter, order_by=DESCENDING) + + reshard_pattern = r"Reshard finished in (\d+\.?\d*)s" + weight_sync_pattern = r"Weight Syncing Time taken: (\d+\.?\d*)s" + hbm_pattern = r"Using (\d+\.?\d*) GiB on" + reshard_results = [] + weight_sync_results = [] + hbm_results = [] + + try: + for entry in entries: + payload = entry.payload + payload_str = None + if isinstance(payload, dict): + try: + payload_str = json.dumps(payload) + except Exception: + payload_str = str(payload) + else: + payload_str = str(payload) + if payload_str: + reshard_match = re.search(reshard_pattern, payload_str) + if reshard_match: + reshard_results.append({ + "timestamp": entry.timestamp, + "reshard_sec": float(reshard_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + + weight_sync_match = re.search(weight_sync_pattern, payload_str) + if weight_sync_match: + weight_sync_results.append({ + "timestamp": entry.timestamp, + "weight_sync_sec": float(weight_sync_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + + hbm_match = re.search(hbm_pattern, payload_str) + if hbm_match: + hbm_results.append({ + "timestamp": entry.timestamp, + "hbm_gib": float(hbm_match.group(1)), + "pod": entry.resource.labels.get("pod_name") + }) + except Exception as e: + print(f"Error during API call: {e}") + + if not reshard_results and not weight_sync_results and not hbm_results: + print("Still no logs found. Try this final check:") + print(f"1. Run: gcloud logging read '{log_filter}' --limit=1") + print("2. If that returns nothing, your local gcloud credentials don't have permission for this project.") + + mean_reshard_time = float('nan') + mean_weight_sync_time = float('nan') + if reshard_results and weight_sync_results: + reshard_df = pd.DataFrame(reshard_results).sort_values("timestamp") + weight_df = pd.DataFrame(weight_sync_results).sort_values("timestamp") + if reshard_df.shape[0] < 2 or weight_df.shape[0] < 2: + print("Not enough log entries found to compute mean times. Need at least 3 entries for reshard and weight sync each.") + print("Reshard results:") + print(reshard_df) + print("Weight sync results:") + print(weight_df) + else: + print(reshard_df) + print(weight_df) + reshard_df = reshard_df.iloc[2:] + weight_df = weight_df.iloc[2:] + length = min(reshard_df.shape[0], weight_df.shape[0]) + length = min(length, args.max_steps) + selected_reshard_df = reshard_df.iloc[-length:] + selected_weight_df = weight_df.iloc[-length:] + mean_reshard_time = selected_reshard_df["reshard_sec"].mean() + mean_weight_sync_time = selected_weight_df["weight_sync_sec"].mean() + print(selected_reshard_df) + print(selected_weight_df) + + trainer_hbm = float('nan') + sampler_hbm = float('nan') + if hbm_results: + df_hbm = pd.DataFrame(hbm_results).sort_values("timestamp") + if not df_hbm.empty: + trainer_hbm = df_hbm.iloc[0]["hbm_gib"] + sampler_hbm = df_hbm.iloc[-1]["hbm_gib"] + + log_query = ( + f'resource.type="k8s_container" ' + f'resource.labels.project_id="{args.project_name}" ' + f'resource.labels.location="us-central1" ' + f'resource.labels.cluster_name="{args.cluster_name}" ' + f'resource.labels.namespace_name="default" ' + f'resource.labels.pod_name:"{args.pod_name}" ' + f'severity>=DEFAULT ' + f'timestamp >= "{start_time}" ' + f'resource.labels.container_name="jax-tpu"' + ) + log_link = f"https://console.cloud.google.com/logs/query;query={urllib.parse.quote(log_query)}?project={args.project_name}" + + result_df = pd.DataFrame([{ + "pod_name": args.pod_name, + "mean_reshard_time": mean_reshard_time, + "mean_weight_sync_time": mean_weight_sync_time, + "trainer_hbm": trainer_hbm, + "sampler_hbm": sampler_hbm, + "log_link": log_link + }]) + + output_csv_path = args.store_cvs_file + + # If the csv file already exists, append to it instead of overwriting + try: + existing_df = pd.read_csv(output_csv_path) + result_df = pd.concat([existing_df, result_df], ignore_index=True) + except FileNotFoundError: + pass + + # Save the results to a CSV file for later analysis + result_df.to_csv(output_csv_path, index=False) + print(result_df) + return result_df + +# Example usage: +""" +python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py --pod_name sanbao-reshard-10217 --store_cvs_file ./reshard/test.csv +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pod_name", type=str, required=True, help="Pod name") + parser.add_argument("--max_steps", type=int, default=10, help="Max steps") + parser.add_argument("--store_cvs_file", type=str, required=True) + parser.add_argument("--cluster_name", type=str, default="next-devx-1", help="Cluster name") + parser.add_argument("--project_name", type=str, default="tpu-prod-env-automated", help="Project name") + args = parser.parse_args() + get_reshard_data(args) diff --git a/src/maxtext/trainers/post_train/rl/reshard_auto.sh b/src/maxtext/trainers/post_train/rl/reshard_auto.sh new file mode 100644 index 0000000000..5c63e8ff8d --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_auto.sh @@ -0,0 +1,221 @@ +#!/bin/bash + +# Define your configurations: "trainer_chips:number_of_sampler_chips_per_replica" +configs=( + "1:1" + "4:1" + "4:2" + "4:4" + "8:1" + "8:2" + "8:4" + "8:8" + "16:1" + "16:2" + "16:4" + "16:8" + "16:16" + "32:2" + "32:4" + "32:8" + "32:16" + "32:32" +) + +# Global variables +base_output_directory="gs://sanbao-bucket/mlperf_rl/reshard" +store_path="./reshard" +project="cloud-tpu-multipod-dev" +zone="us-central1" +cluster="zxhe-super-xpk-bid" +timestamp="d3" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +store_cvs_file="${store_path}/reshard_stats_tp_ep.csv" + +## For EP + TP +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tpep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp \ + --enable_ep + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 120 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + + +## For EP only +store_cvs_file="${store_path}/reshard_stats_ep.csv" + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-ep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_ep + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 120 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +## For TP only +store_cvs_file="${store_path}/reshard_stats_tp.csv" +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tp${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 120 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_auto_235b.sh b/src/maxtext/trainers/post_train/rl/reshard_auto_235b.sh new file mode 100644 index 0000000000..03d39033e7 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_auto_235b.sh @@ -0,0 +1,213 @@ +#!/bin/bash + +# Define your configurations: "trainer_chips:number_of_sampler_chips_per_replica" +configs=( + "4:4" + "8:4" + "8:8" + "16:4" + "16:8" + "16:16" + "32:4" + "32:8" + "32:16" + "32:32" +) + +# Global variables +base_output_directory="gs://sanbao-bucket/mlperf_rl/reshard" +store_path="./reshard" +project="cloud-tpu-multipod-dev" +zone="us-central1" +cluster="zxhe-super-xpk-bid" +timestamp="d5" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +store_cvs_file="${store_path}/reshard_stats_235b_tp_ep.csv" + +# For EP + TP +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tpep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp \ + --enable_ep + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 30 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + + +## For EP only +store_cvs_file="${store_path}/reshard_stats_235b_ep.csv" + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-ep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_ep + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 30 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +## For TP only +store_cvs_file="${store_path}/reshard_stats_235b_tp.csv" +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tp${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the YAML + python ./maxtext/src/maxtext/trainers/post_train/rl/create_yaml_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp + + # 2. Apply Kubernetes YAML + echo "Applying Kubernetes YAML..." + kubectl apply -f "${store_path}/${workload_name}.yaml" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 30 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_auto_script.sh b/src/maxtext/trainers/post_train/rl/reshard_auto_script.sh new file mode 100644 index 0000000000..7c1ba058d9 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_auto_script.sh @@ -0,0 +1,226 @@ +#!/bin/bash + +# Define your configurations: "trainer_chips:number_of_sampler_chips_per_replica" +configs=( + "1:1" + "4:1" + "4:2" + "4:4" + "8:1" + "8:2" + "8:4" + "8:8" + "16:1" + "16:2" + "16:4" + "16:8" + "16:16" + "32:2" + "32:4" + "32:8" + "32:16" + "32:32" +) + +# Global variables +base_output_directory="gs://sanbao-bucket/mlperf_rl/reshard" +store_path="./reshard" +project="tpu-prod-env-automated" +zone="us-central1" +cluster="next-devx-1" +timestamp="d10" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +## For EP only +store_cvs_file="${store_path}/reshard_stats_ep.csv" + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-ep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_ep + + # 2. Apply Script + echo "Applying Script..." + sh "${store_path}/${workload_name}.sh" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +## For TP only +store_cvs_file="${store_path}/reshard_stats_tp.csv" +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tp${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp + + # 2. Apply Script + echo "Applying Script..." + sh "${store_path}/${workload_name}.sh" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +store_cvs_file="${store_path}/reshard_stats_tp_ep.csv" + +# For EP + TP +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tpep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp \ + --enable_ep + + # 2. Apply Script + echo "Applying Script..." + sh "${store_path}/${workload_name}.sh" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_auto_script_235b.sh b/src/maxtext/trainers/post_train/rl/reshard_auto_script_235b.sh new file mode 100644 index 0000000000..8cb6f5d6a5 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_auto_script_235b.sh @@ -0,0 +1,218 @@ +#!/bin/bash + +# Define your configurations: "trainer_chips:number_of_sampler_chips_per_replica" +configs=( + "4:4" + "8:4" + "8:8" + "16:4" + "16:8" + "16:16" + "32:4" + "32:8" + "32:16" + "32:32" +) + +# Global variables +base_output_directory="gs://sanbao-bucket/mlperf_rl/reshard" +store_path="./reshard" +project="tpu-prod-env-automated" +zone="us-central1" +cluster="next-devx-1" +timestamp="d9" + +mkdir -p ${store_path} + +# Function to handle errors and ensure cleanup +handle_error() { + echo "Error occurred during config ${workload_name}. Cleaning up..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + # Continue to next iteration rather than exiting the whole script +} + +## For EP only +store_cvs_file="${store_path}/reshard_stats_235b_ep.csv" + +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-ep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_ep + + # 2. Apply XPK Script + echo "Applying XPK Script..." + sh "${store_path}/${workload_name}.sh" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +## For TP only +store_cvs_file="${store_path}/reshard_stats_235b_tp.csv" +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tp${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp + + # 2. Apply XPK Script + echo "Applying XPK Script..." + sh "${store_path}/${workload_name}.sh" + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." + +store_cvs_file="${store_path}/reshard_stats_235b_tp_ep.csv" + +# For EP + TP +for config in "${configs[@]}"; do + # Split the config string into variables + IFS=":" read -r trainer_chips sampler_chips <<< "$config" + + # Generate a unique workload name based on config and date + workload_name="sanbao-${trainer_chips}-${sampler_chips}-tpep${timestamp}" + + echo "----------------------------------------------------------" + echo "Running Config: Trainer=${trainer_chips}, Sampler=${sampler_chips}" + echo "Workload Name: ${workload_name}" + echo "----------------------------------------------------------" + + # Trap errors specifically for this iteration + trap 'handle_error' ERR + + # 1. Create the Script + python ./maxtext/src/maxtext/trainers/post_train/rl/create_script_235b.py \ + --metadata_name "${workload_name}" \ + --trainer_chips "${trainer_chips}" \ + --number_of_sampler_chips_per_replica "${sampler_chips}" \ + --sampler_replicas 1 \ + --base_output_directory "${base_output_directory}" \ + --hf_token "${HF_TOKEN}" \ + --store_directory "${store_path}" \ + --enable_tp \ + --enable_ep + + # 2. Apply XPK Script + echo "Applying XPK Script..." + sh ${store_path}/${workload_name}.sh + + # 3. Wait for workload to run + echo "Waiting 10 minutes for workload execution..." + sleep 600 + + # 5. Extract Timing Data + echo "Extracting timing data..." + python ./maxtext/src/maxtext/trainers/post_train/rl/extract_time.py \ + --pod_name "${workload_name}" \ + --store_cvs_file "${store_cvs_file}" \ + --cluster_name "${cluster}" \ + --project_name "${project}" + + echo "Finished: ${workload_name}. Data in ${store_cvs_file}" + + # Small buffer before starting the next config + sleep 10 + + # 4. Cleanup Workload + echo "Deleting workload..." + xpk workload delete --workload "${workload_name}" --cluster "${cluster}" --project "${project}" --zone "${zone}" + + # Clear trap for next iteration + trap - ERR +done + +gcloud storage cp ${store_cvs_file} gs://sanbao-bucket/mlperf_rl/results/ + +echo "All configurations completed." \ No newline at end of file diff --git a/src/maxtext/trainers/post_train/rl/reshard_debug.py b/src/maxtext/trainers/post_train/rl/reshard_debug.py new file mode 100644 index 0000000000..40bca4e176 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/reshard_debug.py @@ -0,0 +1,448 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Resharding Benchmark for the RL Trainer + +This module provides a unified `rl_train` function that consolidates the common +RL training logic. It handles model loading, reward function setup, dataset +processing, and training orchestration. By default, we run Group Relative Policy Optimization (GRPO) on +GSM8K math reasoning benchmark. The script is also flexible enough to run Group Sequence Policy Optimization (GSPO). + +Usage Examples: + +# GRPO on Qwen3-30B +python3 -m src.maxtext.trainers.post_train.rl.reshard_debug src/maxtext/configs/post_train/rl.yml \ + model_name=qwen3-30b-a3b \ + tokenizer_path=Qwen/Qwen3-30B-A3B \ + run_name=sanbao-rl-0310-1 \ + base_output_directory=gs://sanbao-bucket/mlperf_rl/qwen3/sanbao-rl-0310-1 \ + batch_size=16 \ + rl.num_generations=8 \ + num_batches=4 \ + rollout_data_parallelism=4 \ + rollout_tensor_parallelism=1 \ + rollout_expert_parallelism=4 \ + hbm_utilization_vllm=0.4 \ + scan_layers=True \ + allow_split_physical_axes=True \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false, weight_dtype: bfloat16}}' + +""" + +from __future__ import annotations +from typing import Sequence + +import collections +import jax +import json +import time +import logging +import os +import pathwaysutils + +from absl import app +from absl import logging as absl_logging +from etils import epath +from flax import nnx +from jax.sharding import Mesh +from orbax import checkpoint as ocp +from transformers import AutoTokenizer +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger, profiler +from tunix.sft.utils import show_hbm_usage + +# for vLLM we can skip JAX precompilation with this flag, it makes startup faster +os.environ["SKIP_JAX_PRECOMPILE"] = "1" + +from maxtext.configs import pyconfig +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from maxtext.trainers.post_train.rl import utils_rl +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils + + +def get_maxtext_model(config, devices=None): + """ + Load MaxText model with Tunix adapter. + # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. + # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False` + # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # --model_name="gemma2-2b" \ + # --base_output_directory="/path/to/your/output/directory" \ + # --scan_layers=True \ + # --checkpoint_storage_use_ocdbt=False\ + # checkpoint_storage_use_zarr3=False + # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., + # load_parameters_path=/path/to/your/output/directory/0/items + """ + model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) + with mesh: + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + tunix_model.config = None + return tunix_model, mesh + + +def setup_configs_and_devices(argv: list[str]): + """Setup device allocation and configs for training and inference.""" + config = pyconfig.initialize_pydantic(argv) + devices = jax.devices() + if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: + max_logging.log("Running RL on a single slice") + num_vms = len(devices) // config.chips_per_vm + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and config.use_pathways: + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + num_devices = len(devices) + num_trainer_devices = int(num_devices * config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * config.sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices :] + if config.trainer_devices_fraction != 1.0: + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") + if config.sampler_devices_fraction != 1.0: + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") + trainer_config = config.model_copy() + sampler_config = config.model_copy() + elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: + max_logging.log("Running RL with Multislice") + devices_by_slice = collections.defaultdict(list) + for d in devices: + devices_by_slice[d.slice_index].append(d) + slice_indices = sorted(devices_by_slice.keys()) + + if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: + raise ValueError("Not enough slices for trainer and samplers") + + trainer_devices = [] + for i in range(config.num_trainer_slices): + trainer_devices.extend(devices_by_slice[slice_indices[i]]) + + sampler_devices = [] + for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): + sampler_devices.extend(devices_by_slice[slice_indices[i]]) + + trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices + trainer_fsdp = trainer_devices_per_slice + tp = config.ici_tensor_parallelism + if tp > 1: + if trainer_devices_per_slice % tp != 0: + raise ValueError( + f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" + ) + if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: + raise ValueError( + f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " + f"devices_per_slice ({trainer_devices_per_slice})" + ) + trainer_fsdp = trainer_devices_per_slice // tp + + trainer_update = { + "num_slices": config.num_trainer_slices, + "ici_fsdp_parallelism": trainer_fsdp, + "ici_tensor_parallelism": tp, + "dcn_data_parallelism": config.num_trainer_slices, + } + + sampler_update = { + "num_slices": config.num_samplers_slices, + "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, + "ici_tensor_parallelism": -1, + "dcn_data_parallelism": config.num_samplers_slices, + } + + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) + + else: + raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + + sampler_config.subslice_shape = config.rollout_subslice_shape + sampler_config.enable_single_controller = config.rollout_enable_single_controller + return trainer_config, sampler_config, trainer_devices, sampler_devices + + +def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): + """Get rollout kwargs for vLLM rollout when using data parallelism.""" + dp = sampler_config.rollout_data_parallelism + tp = sampler_config.rollout_tensor_parallelism + ep = sampler_config.rollout_expert_parallelism + + # -1 means "auto-derive from the other two". At most one can be -1. + num_auto = sum(1 for x in [tp, dp, ep] if x == -1) + if num_auto > 1: + raise ValueError( + "At most one of rollout_tensor_parallelism, rollout_data_parallelism, " + "rollout_expert_parallelism can be -1 (auto-derived)." + ) + + if dp == -1: + if num_sampler_devices % (tp * ep) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_expert_parallelism({ep}) " + f"when rollout_data_parallelism is -1." + ) + dp = num_sampler_devices // tp // ep + elif tp == -1: + if num_sampler_devices % (dp * ep) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_data_parallelism({dp}) * rollout_expert_parallelism({ep}) " + f"when rollout_tensor_parallelism is -1." + ) + tp = num_sampler_devices // dp // ep + elif ep == -1: + if num_sampler_devices % (tp * dp) != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_tensor_parallelism({tp}) * rollout_data_parallelism({dp}) " + f"when rollout_expert_parallelism is -1." + ) + ep = num_sampler_devices // tp // dp + elif tp * dp * ep != num_sampler_devices: + raise ValueError( + f"rollout_tensor_parallelism({tp}) * " + f"rollout_data_parallelism({dp}) * " + f"rollout_expert_parallelism({ep}) " + f"!= len(sampler_devices)({num_sampler_devices})" + ) + + rollout_kwargs = {} + rollout_kwargs["tensor_parallel_size"] = tp + rollout_kwargs["data_parallel_size"] = dp + rollout_kwargs["expert_parallel_size"] = ep + + return rollout_kwargs + + +def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): + """ + Run RL training with the provided configuration. + + Args: + trainer_config: MaxText configuration for the trainer. + sampler_config: MaxText configuration for the sampler. + trainer_devices: JAX devices for the trainer. + sampler_devices: JAX devices for the sampler. + """ + if not trainer_config.debug.rl: + # Apply filter to suppress noisy logs + noise_filter = max_logging.NoisyLogFilter() + logging.getLogger().addFilter(noise_filter) + absl_logging.get_absl_logger().addFilter(noise_filter) + + max_logging.log("Starting RL Resharding Debug Script") + + # Number of training steps. + max_train_steps = 1 + + # Create model tokenizer + model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + + # Load reference model + max_logging.log("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) + # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh + # else rollout_mesh uses sampler_devices + rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) + if trainer_config.debug.rl: + max_logging.log("Reference Model initialized successfully") + nnx.display(reference_model) + max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") + + # Sanity check that weights are loaded correctly. + _maxtext_state_flatten = nnx.state(reference_model).flat_state() + maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} + max_logging.log( + f"maxtext_state_flatten[base.token_embedder.embedding].value=\ + {maxtext_state_flatten['base.token_embedder.embedding'][...]}" + ) + + # TODO: @mazumdera: change this to use lora + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + + if trainer_config.debug.rl: + max_logging.log("Policy Model initialized successfully") + nnx.display(actor_model) + max_logging.log(f"Policy mesh shape: {actor_mesh.shape}") + + # Setup optimizer + optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps) + + # Setup checkpointing + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=trainer_config.checkpoint_period, max_to_keep=trainer_config.max_num_checkpoints_to_keep + ) + + # Set up micro batching + micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size + + # Parse vllm_additional_config + rollout_additional_config = None + if trainer_config.vllm_additional_config: + if isinstance(trainer_config.vllm_additional_config, dict): + # It's already parsed into a dict + rollout_additional_config = trainer_config.vllm_additional_config + elif isinstance(trainer_config.vllm_additional_config, str): + # It's a string, so we need to parse it + try: + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e + + max_logging.log(f"Parsed additional config: {rollout_additional_config}") + + # We need to parse vLLM config to get the logical axis rules for the sampler config. + vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") + argv_list = ["", str(vllm_config_path), "log_config=False"] + vllm_config = pyconfig.initialize(argv_list) + + # RL Cluster config + # Note that we use vLLM as the rollout engine. + # and we are using Tensor Parallelism for rollout + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: actor_mesh, + rl_cluster_lib.Role.REFERENCE: reference_mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + role_to_logical_axis_rule={ + rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules, + rl_cluster_lib.Role.ROLLOUT: vllm_config.logical_axis_rules, + }, + rollout_engine="vllm", + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=trainer_config.eval_interval, + max_steps=max_train_steps, + # Micro batching + mini_batch_size=trainer_config.batch_size, + train_micro_batch_size=micro_batch_size, + rollout_micro_batch_size=micro_batch_size, + # Checkpoint saving + checkpoint_root_directory=trainer_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=trainer_config.max_target_length - trainer_config.max_prefill_predict_length, + max_prompt_length=trainer_config.max_prefill_predict_length, + kv_cache_size=trainer_config.max_target_length + trainer_config.kv_cache_buffer, + temperature=trainer_config.decode_sampling_temperature, + top_p=trainer_config.decode_sampling_nucleus_p, + top_k=trainer_config.decode_sampling_top_k, + rollout_vllm_model_version=trainer_config.tokenizer_path, + rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, + rollout_vllm_additional_config=rollout_additional_config, + rollout_vllm_init_with_random_weights=True, + rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention, + rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens, + rollout_vllm_max_num_seqs=trainer_config.max_num_seqs, + rollout_vllm_async_scheduling=trainer_config.async_scheduling, + rollout_vllm_kwargs={ + "hf_overrides": trainer_config.vllm_hf_overrides, + "enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1, + }, + rollout_vllm_sampling_kwargs={ + "stop": trainer_config.stop_strings, + "detokenize": trainer_config.stop_strings is not None, + "include_stop_str_in_output": trainer_config.stop_strings is not None, + }, + **get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)), + ), + ) + # Create RL cluster + max_logging.log("Creating RL cluster...") + + rl_cluster = rl_cluster_lib.RLCluster( + actor=actor_model, + reference=reference_model, + tokenizer=model_tokenizer, + cluster_config=cluster_config, + ) + + max_logging.log( + "Calling rl_cluster.sync_weights() to reshard actor weights to rollout mesh..." + ) + + if not epath.Path(trainer_config.tensorboard_dir).exists(): + epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True) + + key = jax.random.PRNGKey(42) + for step in range(trainer_config.num_batches): + if step == trainer_config.skip_first_n_steps_for_profiler: + max_logging.log(f"Starting XProf trace in {trainer_config.tensorboard_dir}") + jax.profiler.start_trace(trainer_config.tensorboard_dir) + + key, subkey = jax.random.split(key) + noise = jax.random.normal(subkey, ()) * 1e-3 + # Update all actor weights to trigger full resharding + state = nnx.state(actor_model, nnx.Param) + new_state = jax.tree_util.tree_map(lambda x: x + noise, state) + nnx.update(actor_model, new_state) + + show_hbm_usage(f"HBM before step {step}:") + start_time = time.time() + rl_cluster.sync_weights() + jax.tree_util.tree_map(jax.block_until_ready, rl_cluster.rollout._sampler.transformer_state) + end_time = time.time() + show_hbm_usage(f"HBM after step {step}:") + max_logging.log(f"Weight Syncing Time taken: {end_time - start_time:.4f}s") + max_logging.log(f"Resharding via sync_weights() completed: step {step}") + + if step == trainer_config.skip_first_n_steps_for_profiler + trainer_config.profiler_steps - 1: + jax.profiler.stop_trace() + max_logging.log("Stopped XProf trace.") + + +def main(argv: Sequence[str]) -> None: + """Main function to run RL training. + + Args: + argv: Command-line arguments. + """ + pathwaysutils.initialize() + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + max_utils.print_system_information() + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) + rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file