Skip to content

Commit 44cf239

Browse files
Enable elastic training for RL
PiperOrigin-RevId: 901000964
1 parent 7f78228 commit 44cf239

1 file changed

Lines changed: 21 additions & 4 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"""
4545

4646
from __future__ import annotations
47-
from functools import wraps
47+
from functools import wraps, partial
4848
from typing import Sequence
4949

5050
import collections
@@ -78,7 +78,7 @@
7878
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
7979
from maxtext.trainers.post_train.rl import utils_rl
8080
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
81-
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
81+
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils, elastic_utils
8282

8383

8484
def get_maxtext_model(config, devices=None):
@@ -530,7 +530,6 @@ def create_rl_components(
530530
rollout_vllm_model_version=trainer_config.tokenizer_path,
531531
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
532532
rollout_vllm_tpu_backend_type="jax",
533-
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
534533
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
535534
rollout_vllm_additional_config=rollout_additional_config,
536535
rollout_vllm_init_with_random_weights=True,
@@ -770,7 +769,25 @@ def main(argv: Sequence[str]) -> None:
770769

771770
max_utils.print_system_information()
772771
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv)
773-
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
772+
773+
if trainer_config.elastic_enabled:
774+
max_logging.log("Elastic utils: Elastic training enabled.")
775+
776+
def elastic_train_wrapper(argv: Sequence[str]) -> None:
777+
"""Wrapper for elastic training initializes variables and runs the train loop."""
778+
t_config, s_config, t_devices, s_devices = setup_configs_and_devices(argv)
779+
try:
780+
rl_train(t_config, s_config, t_devices, s_devices)
781+
except jax.errors.JaxRuntimeError as e:
782+
# Workaround for unhandled IFRT proxy disconnection errors
783+
if "Connection to IFRT proxy server was terminated" in str(e) or "UNAVAILABLE" in str(e):
784+
raise jax.errors.JaxRuntimeError(f"INTERNAL: IFRT connection lost: {e}") from e
785+
raise
786+
787+
train_func = elastic_utils.elastic_retry(trainer_config)(partial(elastic_train_wrapper, argv=argv))
788+
train_func()
789+
else:
790+
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
774791

775792

776793
if __name__ == "__main__":

0 commit comments

Comments
 (0)