diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7deba5c4b0..d7ff172ba9 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -35,8 +35,10 @@ from datetime import timedelta os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA +DeepSpeedCPUAdam = None with contextlib.suppress(Exception): import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF from deepspeed.utils import groups @@ -372,7 +374,15 @@ def load(self, path: str, map_location=None): optim_params = get_optimizer_grouped_parameters(self.policy, args.weight_decay) else: optim_params = self.policy.parameters() - self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) + if args.use_cpu_adam: + if DeepSpeedCPUAdam is None: + raise RuntimeError( + "`use_cpu_adam=True` requires DeepSpeed to be installed. " + "Install it with: pip install deepspeed" + ) + self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate, weight_decay=args.weight_decay) + else: + self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer, weight_decay=args.weight_decay) num_scheduler_steps = args.num_training_steps * args.num_epochs * args.num_mini_batches warmup_steps = int(num_scheduler_steps * args.warmup_ratio) scheduler = get_scheduler( diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 73674a86b7..1ce414fdce 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -173,6 +173,8 @@ class GRPOExperimentConfig( """whether to offload parameters to CPU (reduces GPU memory usage)""" deepspeed_offload_optimizer: bool = False """whether to offload optimizer states to CPU (reduces GPU memory usage)""" + use_cpu_adam: bool = False + """whether to use DeepSpeedCPUAdam instead of torch.optim.AdamW; recommended when deepspeed_offload_optimizer=True""" deepspeed_checkpoint_load_universal: bool = False """DeepSpeed checkpoint.load_universal: load checkpoints across different parallel configs""" gather_whole_model: bool = True @@ -322,6 +324,15 @@ def __post_init__(self): ) if self.eval_top_p is not None and not (0.0 < self.eval_top_p <= 1.0): raise ValueError(f"`eval_top_p` must be in (0, 1], got {self.eval_top_p}") + if self.deepspeed_offload_param and self.deepspeed_stage != 3: + raise ValueError("`deepspeed_offload_param` requires `deepspeed_stage` to be 3!") + if self.use_cpu_adam and self.deepspeed_stage == 0: + raise ValueError("`use_cpu_adam` requires a DeepSpeed stage (`deepspeed_stage` > 0)!") + if self.use_cpu_adam and not self.deepspeed_offload_optimizer: + logger.warning( + "`use_cpu_adam` is enabled but `deepspeed_offload_optimizer` is False. " + "Consider enabling `deepspeed_offload_optimizer` to fully benefit from CPU Adam." + ) if self.use_rho_correction: if self.rho_mask_lower_bound > 0.0 and not (0.0 < self.rho_mask_lower_bound < 1.0): raise ValueError(