Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

base_config: "base.yml"
attention: "vllm_rpa"
model_call_mode: "inference"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this new config used below?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for MoE! Its unrelated but I wanted to squeeze it in to this PR since we should always be setting this in the vLLM codepath.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the entry point for this through train_rl?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# NNX required for vLLM integration
enable_nnx: True
# Avoid re-initializing JAX distributed system when using vLLM
Expand Down
180 changes: 96 additions & 84 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,39 +282,18 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
return rollout_kwargs


def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
"""
Run RL training with the provided configuration.

Args:
trainer_config: MaxText configuration for the trainer.
sampler_config: MaxText configuration for the sampler.
trainer_devices: JAX devices for the trainer.
sampler_devices: JAX devices for the sampler.
"""
if not trainer_config.debug.rl:
# Apply filter to suppress noisy logs
noise_filter = max_logging.NoisyLogFilter()
logging.getLogger().addFilter(noise_filter)
absl_logging.get_absl_logger().addFilter(noise_filter)

max_logging.log("Starting RL Training")
max_logging.log(f"Ensuring TensorBoard log directory exists: {trainer_config.tensorboard_dir}")
if not epath.Path(trainer_config.tensorboard_dir).exists():
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)

if not epath.Path(trainer_config.checkpoint_dir).exists():
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)

# Number of training steps.
max_train_steps = int(
def get_max_train_steps(trainer_config):
"""Calculate the total number of training steps."""
return int(
trainer_config.num_batches
* trainer_config.rl.num_iterations
* trainer_config.train_fraction
* trainer_config.num_epoch
)
# ====== Data ======
# Setup data directories


def prepare_datasets(trainer_config, model_tokenizer):
"""Setup and return train and test datasets."""
home = os.path.expanduser("~") + "/"
train_data_dir = f"{home}/data/train"
test_data_dir = f"{home}/data/test"
Expand All @@ -323,9 +302,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
if not os.path.exists(test_data_dir):
os.makedirs(test_data_dir)

# Create model tokenizer
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)

# Load datasets
if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2":
import datasets # pylint: disable=import-outside-toplevel
Expand All @@ -334,7 +310,6 @@ def prepare_openinstructmath2_dataset(
split: str = "train_1M",
seed: int = 42,
test_size: float = 0.05,
output_key: str = "expected_answer",
):
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
max_logging.log(
Expand Down Expand Up @@ -419,41 +394,16 @@ def _filter_long_prompts(x):
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]

test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
return train_dataset, test_dataset

if trainer_config.debug.rl:
# Let's see how one batch of the dataset looks like!
if trainer_config.debug.rl:
for i, ele in enumerate(train_dataset):
if i >= 5:
break
pprint(ele)
if trainer_config.debug.rl:
for i, ele in enumerate(test_dataset):
if i >= 5:
break
pprint(ele)

# Load reference model

def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
"""Create reference and actor models and their respective meshes."""
max_logging.log("Creating reference model and also meshes for reference and rollout")
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
# if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
# else rollout_mesh uses sampler_devices
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
if trainer_config.debug.rl:
max_logging.log("Reference Model initialized successfully")
nnx.display(reference_model)
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")

# Sanity check that weights are loaded correctly.
_maxtext_state_flatten = nnx.state(reference_model).flat_state()
maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
max_logging.log(
f"maxtext_state_flatten[base.token_embedder.embedding].value=\
{maxtext_state_flatten['base.token_embedder.embedding'][...]}"
)

# TODO: @mazumdera: change this to use lora
if trainer_config.load_checkpoint_only_once:
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
with reference_mesh:
Expand All @@ -466,11 +416,22 @@ def _filter_long_prompts(x):
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)

if trainer_config.debug.rl:
max_logging.log("Policy Model initialized successfully")
nnx.display(actor_model)
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")

return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh


def create_rl_components(
trainer_config,
sampler_config,
sampler_devices,
actor_model,
actor_mesh,
reference_model,
reference_mesh,
rollout_mesh,
model_tokenizer,
max_train_steps,
):
"""Setup RL cluster, trainer, and optimizer."""
# Setup optimizer
optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps)

Expand All @@ -483,7 +444,6 @@ def _filter_long_prompts(x):
micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size

# Setup metrics logging
max_logging.log(f"Tensorboard logs directory: {trainer_config.tensorboard_dir}")
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir=trainer_config.tensorboard_dir, flush_every_n_steps=trainer_config.log_period
)
Expand All @@ -501,25 +461,18 @@ def _filter_long_prompts(x):
rollout_additional_config = None
if trainer_config.vllm_additional_config:
if isinstance(trainer_config.vllm_additional_config, dict):
# It's already parsed into a dict
rollout_additional_config = trainer_config.vllm_additional_config
elif isinstance(trainer_config.vllm_additional_config, str):
# It's a string, so we need to parse it
try:
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e

max_logging.log(f"Parsed additional config: {rollout_additional_config}")

# We need to parse vLLM config to get the logical axis rules for the sampler config.
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(vllm_config_path), "log_config=False"]
vllm_config = pyconfig.initialize(argv_list)

# RL Cluster config
# Note that we use vLLM as the rollout engine.
# and we are using Tensor Parallelism for rollout
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: actor_mesh,
Expand All @@ -537,15 +490,11 @@ def _filter_long_prompts(x):
actor_optimizer=optimizer,
eval_every_n_steps=trainer_config.eval_interval,
max_steps=max_train_steps,
# Micro batching
mini_batch_size=trainer_config.batch_size,
train_micro_batch_size=micro_batch_size,
rollout_micro_batch_size=micro_batch_size,
# Metrics logging
metrics_logging_options=metrics_logging_options,
# Profiling
profiler_options=profiler_options,
# Checkpoint saving
checkpoint_root_directory=trainer_config.checkpoint_dir,
checkpointing_options=checkpointing_options,
),
Expand Down Expand Up @@ -579,6 +528,7 @@ def _filter_long_prompts(x):
**get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)),
),
)

grpo_config = GrpoConfig(
num_generations=trainer_config.rl.num_generations,
num_iterations=trainer_config.rl.num_iterations,
Expand All @@ -595,9 +545,6 @@ def _filter_long_prompts(x):
from tunix.perf import export as perf_export # pylint: disable=import-outside-toplevel
from tunix.perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel

max_logging.log(
"enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
)
perf_config = perf_metrics.PerfMetricsConfig()
perf_config.custom_export_fn = perf_export.PerfMetricsExport.create_metrics_export_fn(cluster_config)
rl_cluster_kwargs["perf_config"] = perf_config
Expand Down Expand Up @@ -627,9 +574,76 @@ def _filter_long_prompts(x):
algo_config=grpo_config,
)

return rl_cluster, rl_trainer, optimizer


def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
"""
Run RL training with the provided configuration.

Args:
trainer_config: MaxText configuration for the trainer.
sampler_config: MaxText configuration for the sampler.
trainer_devices: JAX devices for the trainer.
sampler_devices: JAX devices for the sampler.
"""
if not trainer_config.debug.rl:
# Apply filter to suppress noisy logs
noise_filter = max_logging.NoisyLogFilter()
logging.getLogger().addFilter(noise_filter)
absl_logging.get_absl_logger().addFilter(noise_filter)

max_logging.log("Starting RL Training")
if not epath.Path(trainer_config.tensorboard_dir).exists():
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)

if not epath.Path(trainer_config.checkpoint_dir).exists():
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)

max_train_steps = get_max_train_steps(trainer_config)

# Create model tokenizer
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)

train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)

if trainer_config.debug.rl:
for i, ele in enumerate(train_dataset):
if i >= 5:
break
pprint(ele)
for i, ele in enumerate(test_dataset):
if i >= 5:
break
pprint(ele)

reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes(
trainer_config, sampler_config, trainer_devices, sampler_devices
)

if trainer_config.debug.rl:
max_logging.log("Reference Model initialized successfully")
nnx.display(reference_model)
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
max_logging.log("Policy Model initialized successfully")
nnx.display(actor_model)
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")

rl_cluster, rl_trainer, _ = create_rl_components(
trainer_config,
sampler_config,
sampler_devices,
actor_model,
actor_mesh,
reference_model,
reference_mesh,
rollout_mesh,
model_tokenizer,
max_train_steps,
)

# Before we train the model, let's evaluate the model on the test set so we can
# see the improvement post training.
#
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
trainer_config,
test_dataset,
Expand All @@ -638,11 +652,9 @@ def _filter_long_prompts(x):
corr_lst=trainer_config.eval_corr_lst,
make_lst=trainer_config.eval_make_lst,
)
# TODO: @mazumdera: Change this to max_logging.log once b/473703277 is resolved
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")

# Start training

if trainer_config.load_checkpoint_only_once:
max_logging.log("Capturing reference model state before training.")
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
Expand Down
Loading
Loading