diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 8663add55f..561f12fbf6 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -44,7 +44,7 @@ """ from __future__ import annotations -from functools import wraps +from functools import wraps, partial from typing import Sequence import collections @@ -78,7 +78,7 @@ from maxtext.trainers.post_train.rl.evaluate_rl import evaluate from maxtext.trainers.post_train.rl import utils_rl from maxtext.input_pipeline.instruction_data_processing import load_template_from_file -from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils, elastic_utils def get_maxtext_model(config, devices=None): @@ -530,7 +530,6 @@ def create_rl_components( rollout_vllm_model_version=trainer_config.tokenizer_path, rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", - rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, rollout_vllm_additional_config=rollout_additional_config, rollout_vllm_init_with_random_weights=True, @@ -770,7 +769,25 @@ def main(argv: Sequence[str]) -> None: max_utils.print_system_information() trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) - rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + + if trainer_config.elastic_enabled: + max_logging.log("Elastic utils: Elastic training enabled.") + + def elastic_train_wrapper(argv: Sequence[str]) -> None: + """Wrapper for elastic training initializes variables and runs the train loop.""" + t_config, s_config, t_devices, s_devices = setup_configs_and_devices(argv) + try: + rl_train(t_config, s_config, t_devices, s_devices) + except jax.errors.JaxRuntimeError as e: + # Workaround for unhandled IFRT proxy disconnection errors + if "Connection to IFRT proxy server was terminated" in str(e) or "UNAVAILABLE" in str(e): + raise jax.errors.JaxRuntimeError(f"INTERNAL: IFRT connection lost: {e}") from e + raise + + train_func = elastic_utils.elastic_retry(trainer_config)(partial(elastic_train_wrapper, argv=argv)) + train_func() + else: + rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) if __name__ == "__main__":