Add use_cpu_adam CLI toggle and offload validation for GRPO#1654
Add use_cpu_adam CLI toggle and offload validation for GRPO#1654Bhavyashah20 wants to merge 5 commits into
Conversation
Adds a `use_cpu_adam` flag to `GRPOExperimentConfig` that switches the optimizer to `DeepSpeedCPUAdam` when enabled, making GRPO trainable on consumer GPUs without editing source code. Also adds validation that `deepspeed_offload_param` requires `deepspeed_stage=3`, and a warning when `use_cpu_adam` is set without `deepspeed_offload_optimizer`. Closes allenai#1031
There was a problem hiding this comment.
Code Review
This pull request introduces the option to use the DeepSpeedCPUAdam optimizer to reduce GPU memory usage, including necessary configuration flags and validation logic. A critical inconsistency was identified regarding how weight_decay is handled between torch.optim.AdamW and DeepSpeedCPUAdam, which could lead to unintended training behavior; a suggestion was provided to explicitly pass the weight decay parameter to both optimizers to ensure consistency.
| if args.use_cpu_adam: | ||
| self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate) | ||
| else: | ||
| self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) |
There was a problem hiding this comment.
There's an inconsistency in how weight_decay is handled between torch.optim.AdamW and DeepSpeedCPUAdam when args.set_weight_decay_on_bias_and_norm is False.
In that scenario, optim_params is just the model parameters, and no weight_decay is passed to the optimizer's constructor. This leads to different behaviors:
torch.optim.AdamWuses its defaultweight_decayof0.01.DeepSpeedCPUAdamuses its defaultweight_decayof0.
This means switching to use_cpu_adam could silently change the weight decay and affect training.
To ensure consistent behavior and correctly use the weight_decay from the configuration, I suggest explicitly passing weight_decay=args.weight_decay to both optimizer constructors. This also fixes a latent issue in the original code where args.weight_decay was ignored when set_weight_decay_on_bias_and_norm was False.
| if args.use_cpu_adam: | |
| self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate) | |
| else: | |
| self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) | |
| if args.use_cpu_adam: | |
| self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate, weight_decay=args.weight_decay) | |
| else: | |
| self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer, weight_decay=args.weight_decay) |
There was a problem hiding this comment.
Fixed in bfa3120 — both optimizer branches now explicitly pass weight_decay=args.weight_decay. Also noted that this fixes the latent issue in the original torch.optim.AdamW path where args.weight_decay was silently ignored when set_weight_decay_on_bias_and_norm=False.
- Initialize DeepSpeedCPUAdam to None before the contextlib.suppress block so a missing deepspeed install raises a clear RuntimeError instead of a confusing NameError when --use_cpu_adam is passed - Collapse use_cpu_adam docstring to one line, lowercase to match sibling fields
- Pass weight_decay to DeepSpeedCPUAdam to match AdamW branch behavior - Raise ValueError when use_cpu_adam=True with deepspeed_stage=0, consistent with how deepspeed_offload_param validates its stage requirement
|
Just a note for reviewers: the Gemini bot flagged a Let me know if there's anything else needed before this can move forward. |
Summary
Closes #1031
Running GRPO on a 7B model typically requires 40–80 GB VRAM. DeepSpeed supports CPU offloading to reduce this, but the optimizer was hardcoded to
torch.optim.AdamWwith no way to switch toDeepSpeedCPUAdamvia CLI — forcing researchers on smaller hardware to edit source directly.Changes
open_instruct/grpo_utils.pyuse_cpu_adam: bool = Falsefield toGRPOExperimentConfigdeepspeed_offload_paramrequiresdeepspeed_stage=3(raisesValueError)use_cpu_adam=Truerequiresdeepspeed_stage > 0(raisesValueError)use_cpu_adam=Truewithoutdeepspeed_offload_optimizer=True(logs warning, does not raise)open_instruct/grpo_fast.pyDeepSpeedCPUAdam = Nonesentinel + import inside existingcontextlib.suppressblockargs.use_cpu_adamwith a clearRuntimeErrorif DeepSpeed is not installedweight_decay=args.weight_decayto both optimizer branches for consistencyUsage
Test plan
use_cpu_adam=False(default): existingtorch.optim.AdamWpath unchangeduse_cpu_adam=True+deepspeed_offload_optimizer=True:DeepSpeedCPUAdamselecteduse_cpu_adam=True+deepspeed_offload_optimizer=False: warning logged, training proceedsdeepspeed_offload_param=True+deepspeed_stage != 3:ValueErrorraiseduse_cpu_adam=True+deepspeed_stage=0:ValueErrorraiseduse_cpu_adam=True: clearRuntimeErrorwith install hintGPU_TESTS=bypass