Skip to content

Commit b0ff6f6

Browse files
Enable elastic training for RL
PiperOrigin-RevId: 901000964
1 parent e06745f commit b0ff6f6

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"""
4545

4646
from __future__ import annotations
47-
from functools import wraps
47+
from functools import wraps, partial
4848
from typing import Sequence
4949

5050
import collections
@@ -78,7 +78,7 @@
7878
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
7979
from maxtext.trainers.post_train.rl import utils_rl
8080
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
81-
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
81+
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils, elastic_utils
8282

8383

8484
def get_maxtext_model(config, devices=None):
@@ -770,7 +770,19 @@ def main(argv: Sequence[str]) -> None:
770770

771771
max_utils.print_system_information()
772772
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv)
773-
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
773+
774+
if trainer_config.elastic_enabled:
775+
max_logging.log("Elastic utils: Elastic training enabled.")
776+
777+
def elastic_train_wrapper(argv: Sequence[str]) -> None:
778+
"""Wrapper for elastic training initializes variables and runs the train loop."""
779+
t_config, s_config, t_devices, s_devices = setup_configs_and_devices(argv)
780+
rl_train(t_config, s_config, t_devices, s_devices)
781+
782+
train_func = elastic_utils.elastic_retry(trainer_config)(functools.partial(elastic_train_wrapper, argv=argv))
783+
train_func()
784+
else:
785+
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
774786

775787

776788
if __name__ == "__main__":

0 commit comments

Comments
 (0)