Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"""

from __future__ import annotations
from functools import wraps
from functools import wraps, partial
from typing import Sequence

import collections
Expand Down Expand Up @@ -78,7 +78,7 @@
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
from maxtext.trainers.post_train.rl import utils_rl
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils, elastic_utils


def get_maxtext_model(config, devices=None):
Expand Down Expand Up @@ -530,7 +530,6 @@ def create_rl_components(
rollout_vllm_model_version=trainer_config.tokenizer_path,
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
rollout_vllm_tpu_backend_type="jax",
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
rollout_vllm_additional_config=rollout_additional_config,
rollout_vllm_init_with_random_weights=True,
Expand Down Expand Up @@ -770,7 +769,25 @@ def main(argv: Sequence[str]) -> None:

max_utils.print_system_information()
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv)
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)

if trainer_config.elastic_enabled:
max_logging.log("Elastic utils: Elastic training enabled.")

def elastic_train_wrapper(argv: Sequence[str]) -> None:
"""Wrapper for elastic training initializes variables and runs the train loop."""
t_config, s_config, t_devices, s_devices = setup_configs_and_devices(argv)
try:
rl_train(t_config, s_config, t_devices, s_devices)
except jax.errors.JaxRuntimeError as e:
# Workaround for unhandled IFRT proxy disconnection errors
if "Connection to IFRT proxy server was terminated" in str(e) or "UNAVAILABLE" in str(e):
raise jax.errors.JaxRuntimeError(f"INTERNAL: IFRT connection lost: {e}") from e
raise

train_func = elastic_utils.elastic_retry(trainer_config)(partial(elastic_train_wrapper, argv=argv))
train_func()
else:
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)


if __name__ == "__main__":
Expand Down
Loading