|
44 | 44 | """ |
45 | 45 |
|
46 | 46 | from __future__ import annotations |
47 | | -from functools import wraps |
| 47 | +from functools import wraps, partial |
48 | 48 | from typing import Sequence |
49 | 49 |
|
50 | 50 | import collections |
|
78 | 78 | from maxtext.trainers.post_train.rl.evaluate_rl import evaluate |
79 | 79 | from maxtext.trainers.post_train.rl import utils_rl |
80 | 80 | from maxtext.input_pipeline.instruction_data_processing import load_template_from_file |
81 | | -from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils |
| 81 | +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils, elastic_utils |
82 | 82 |
|
83 | 83 |
|
84 | 84 | def get_maxtext_model(config, devices=None): |
@@ -530,7 +530,6 @@ def create_rl_components( |
530 | 530 | rollout_vllm_model_version=trainer_config.tokenizer_path, |
531 | 531 | rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, |
532 | 532 | rollout_vllm_tpu_backend_type="jax", |
533 | | - rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, |
534 | 533 | rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, |
535 | 534 | rollout_vllm_additional_config=rollout_additional_config, |
536 | 535 | rollout_vllm_init_with_random_weights=True, |
@@ -770,7 +769,19 @@ def main(argv: Sequence[str]) -> None: |
770 | 769 |
|
771 | 770 | max_utils.print_system_information() |
772 | 771 | trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) |
773 | | - rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) |
| 772 | + |
| 773 | + if trainer_config.elastic_enabled: |
| 774 | + max_logging.log("Elastic utils: Elastic training enabled.") |
| 775 | + |
| 776 | + def elastic_train_wrapper(argv: Sequence[str]) -> None: |
| 777 | + """Wrapper for elastic training initializes variables and runs the train loop.""" |
| 778 | + t_config, s_config, t_devices, s_devices = setup_configs_and_devices(argv) |
| 779 | + rl_train(t_config, s_config, t_devices, s_devices) |
| 780 | + |
| 781 | + train_func = elastic_utils.elastic_retry(trainer_config)(partial(elastic_train_wrapper, argv=argv)) |
| 782 | + train_func() |
| 783 | + else: |
| 784 | + rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) |
774 | 785 |
|
775 | 786 |
|
776 | 787 | if __name__ == "__main__": |
|
0 commit comments