Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from datetime import timedelta

os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA
DeepSpeedCPUAdam = None
with contextlib.suppress(Exception):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.utils import groups

Expand Down Expand Up @@ -372,7 +374,15 @@ def load(self, path: str, map_location=None):
optim_params = get_optimizer_grouped_parameters(self.policy, args.weight_decay)
else:
optim_params = self.policy.parameters()
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer)
if args.use_cpu_adam:
if DeepSpeedCPUAdam is None:
raise RuntimeError(
"`use_cpu_adam=True` requires DeepSpeed to be installed. "
"Install it with: pip install deepspeed"
)
self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate, weight_decay=args.weight_decay)
else:
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer, weight_decay=args.weight_decay)
num_scheduler_steps = args.num_training_steps * args.num_epochs * args.num_mini_batches
warmup_steps = int(num_scheduler_steps * args.warmup_ratio)
scheduler = get_scheduler(
Expand Down
11 changes: 11 additions & 0 deletions open_instruct/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class GRPOExperimentConfig(
"""whether to offload parameters to CPU (reduces GPU memory usage)"""
deepspeed_offload_optimizer: bool = False
"""whether to offload optimizer states to CPU (reduces GPU memory usage)"""
use_cpu_adam: bool = False
"""whether to use DeepSpeedCPUAdam instead of torch.optim.AdamW; recommended when deepspeed_offload_optimizer=True"""
deepspeed_checkpoint_load_universal: bool = False
"""DeepSpeed checkpoint.load_universal: load checkpoints across different parallel configs"""
gather_whole_model: bool = True
Expand Down Expand Up @@ -317,6 +319,15 @@ def __post_init__(self):
)
if self.eval_top_p is not None and not (0.0 < self.eval_top_p <= 1.0):
raise ValueError(f"`eval_top_p` must be in (0, 1], got {self.eval_top_p}")
if self.deepspeed_offload_param and self.deepspeed_stage != 3:
raise ValueError("`deepspeed_offload_param` requires `deepspeed_stage` to be 3!")
if self.use_cpu_adam and self.deepspeed_stage == 0:
raise ValueError("`use_cpu_adam` requires a DeepSpeed stage (`deepspeed_stage` > 0)!")
if self.use_cpu_adam and not self.deepspeed_offload_optimizer:
logger.warning(
"`use_cpu_adam` is enabled but `deepspeed_offload_optimizer` is False. "
"Consider enabling `deepspeed_offload_optimizer` to fully benefit from CPU Adam."
)
if self.use_rho_correction:
if self.rho_mask_lower_bound > 0.0 and not (0.0 < self.rho_mask_lower_bound < 1.0):
raise ValueError(
Expand Down