Skip to content

Add use_cpu_adam CLI toggle and offload validation for GRPO#1654

Open
Bhavyashah20 wants to merge 5 commits into
allenai:mainfrom
Bhavyashah20:feat/cpu-offload-cli-toggle
Open

Add use_cpu_adam CLI toggle and offload validation for GRPO#1654
Bhavyashah20 wants to merge 5 commits into
allenai:mainfrom
Bhavyashah20:feat/cpu-offload-cli-toggle

Conversation

@Bhavyashah20
Copy link
Copy Markdown

@Bhavyashah20 Bhavyashah20 commented May 3, 2026

Summary

Closes #1031

Running GRPO on a 7B model typically requires 40–80 GB VRAM. DeepSpeed supports CPU offloading to reduce this, but the optimizer was hardcoded to torch.optim.AdamW with no way to switch to DeepSpeedCPUAdam via CLI — forcing researchers on smaller hardware to edit source directly.

Changes

open_instruct/grpo_utils.py

  • Adds use_cpu_adam: bool = False field to GRPOExperimentConfig
  • Adds validation: deepspeed_offload_param requires deepspeed_stage=3 (raises ValueError)
  • Adds validation: use_cpu_adam=True requires deepspeed_stage > 0 (raises ValueError)
  • Adds warning: use_cpu_adam=True without deepspeed_offload_optimizer=True (logs warning, does not raise)

open_instruct/grpo_fast.py

  • Adds DeepSpeedCPUAdam = None sentinel + import inside existing contextlib.suppress block
  • Branches optimizer instantiation on args.use_cpu_adam with a clear RuntimeError if DeepSpeed is not installed
  • Passes weight_decay=args.weight_decay to both optimizer branches for consistency

Usage

python open_instruct/grpo_fast.py \
  --use_cpu_adam \
  --deepspeed_offload_optimizer \
  --deepspeed_stage 2 \
  ...

Test plan

  • use_cpu_adam=False (default): existing torch.optim.AdamW path unchanged
  • use_cpu_adam=True + deepspeed_offload_optimizer=True: DeepSpeedCPUAdam selected
  • use_cpu_adam=True + deepspeed_offload_optimizer=False: warning logged, training proceeds
  • deepspeed_offload_param=True + deepspeed_stage != 3: ValueError raised
  • use_cpu_adam=True + deepspeed_stage=0: ValueError raised
  • DeepSpeed not installed + use_cpu_adam=True: clear RuntimeError with install hint

GPU_TESTS=bypass

Adds a `use_cpu_adam` flag to `GRPOExperimentConfig` that switches the
optimizer to `DeepSpeedCPUAdam` when enabled, making GRPO trainable on
consumer GPUs without editing source code. Also adds validation that
`deepspeed_offload_param` requires `deepspeed_stage=3`, and a warning
when `use_cpu_adam` is set without `deepspeed_offload_optimizer`.

Closes allenai#1031
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the option to use the DeepSpeedCPUAdam optimizer to reduce GPU memory usage, including necessary configuration flags and validation logic. A critical inconsistency was identified regarding how weight_decay is handled between torch.optim.AdamW and DeepSpeedCPUAdam, which could lead to unintended training behavior; a suggestion was provided to explicitly pass the weight decay parameter to both optimizers to ensure consistency.

Comment thread open_instruct/grpo_fast.py Outdated
Comment on lines +378 to +381
if args.use_cpu_adam:
self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate)
else:
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an inconsistency in how weight_decay is handled between torch.optim.AdamW and DeepSpeedCPUAdam when args.set_weight_decay_on_bias_and_norm is False.

In that scenario, optim_params is just the model parameters, and no weight_decay is passed to the optimizer's constructor. This leads to different behaviors:

  • torch.optim.AdamW uses its default weight_decay of 0.01.
  • DeepSpeedCPUAdam uses its default weight_decay of 0.

This means switching to use_cpu_adam could silently change the weight decay and affect training.

To ensure consistent behavior and correctly use the weight_decay from the configuration, I suggest explicitly passing weight_decay=args.weight_decay to both optimizer constructors. This also fixes a latent issue in the original code where args.weight_decay was ignored when set_weight_decay_on_bias_and_norm was False.

Suggested change
if args.use_cpu_adam:
self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate)
else:
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer)
if args.use_cpu_adam:
self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate, weight_decay=args.weight_decay)
else:
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer, weight_decay=args.weight_decay)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in bfa3120 — both optimizer branches now explicitly pass weight_decay=args.weight_decay. Also noted that this fixes the latent issue in the original torch.optim.AdamW path where args.weight_decay was silently ignored when set_weight_decay_on_bias_and_norm=False.

- Initialize DeepSpeedCPUAdam to None before the contextlib.suppress block
  so a missing deepspeed install raises a clear RuntimeError instead of
  a confusing NameError when --use_cpu_adam is passed
- Collapse use_cpu_adam docstring to one line, lowercase to match sibling fields
- Pass weight_decay to DeepSpeedCPUAdam to match AdamW branch behavior
- Raise ValueError when use_cpu_adam=True with deepspeed_stage=0, consistent
  with how deepspeed_offload_param validates its stage requirement
@Bhavyashah20
Copy link
Copy Markdown
Author

Just a note for reviewers: the Gemini bot flagged a weight_decay inconsistency between the two optimizer branches — this was already addressed in the follow-up commit (bfa3120). Both DeepSpeedCPUAdam and torch.optim.AdamW now explicitly pass weight_decay=args.weight_decay.

Let me know if there's anything else needed before this can move forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: Add CLI toggles for CPU offloading in grpo_fast.py

1 participant