Skip to content

Commit edade8f

Browse files
feat(language): Opacus DP training with wrap_model=False (#279)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent b011d81 commit edade8f

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

mostlyai/engine/_language/training.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from huggingface_hub import get_safetensors_metadata
3232
from opacus import GradSampleModule, PrivacyEngine
3333
from opacus.accountants import GaussianAccountant, PRVAccountant, RDPAccountant
34-
from opacus.grad_sample import register_grad_sampler
34+
from opacus.grad_sample import GradSampleHooks, register_grad_sampler
3535
from opacus.utils.batch_memory_manager import wrap_data_loader
3636
from peft import LoraConfig, PeftModel
3737
from torch import nn
@@ -626,6 +626,7 @@ def concat_prompt_and_response(x):
626626
# this can help accelerate GPU compute
627627
torch.backends.cudnn.benchmark = True
628628

629+
dp_grad_sample_hooks: GradSampleHooks | None = None
629630
if with_dp:
630631
if isinstance(differential_privacy, DifferentialPrivacyConfig):
631632
dp_config = differential_privacy.model_dump()
@@ -650,18 +651,21 @@ def concat_prompt_and_response(x):
650651
privacy_engine.accountant.load_state_dict(
651652
torch.load(workspace.model_dp_accountant_path, map_location=device, weights_only=True),
652653
)
653-
# Opacus will return the modified objects
654-
# - model: wrapped in GradSampleModule and contains additional hooks for computing per-sample gradients
655-
# - optimizer: wrapped in DPOptimizer and will do different operations during virtual steps and logical steps
656-
# - dataloader: the dataloader with batch_sampler=UniformWithReplacementSampler (for Poisson sampling)
657-
model, optimizer, trn_dataloader = privacy_engine.make_private(
654+
# Opacus returns GradSampleHooks when wrap_model=False: hooks attach to the original module so HF /
655+
# Transformers sees an unwrapped PreTrainedModel (requires Opacus >= 1.6).
656+
# - dp_grad_sample_hooks: must call .cleanup() after training to remove backward hooks and param attrs
657+
# - optimizer: wrapped in DPOptimizer (virtual vs logical steps)
658+
# - dataloader: UniformWithReplacementSampler when poisson_sampling=True
659+
dp_grad_sample_hooks, optimizer, trn_dataloader = privacy_engine.make_private(
658660
module=model,
659661
optimizer=optimizer,
660662
data_loader=trn_dataloader,
661663
noise_multiplier=dp_config.get("noise_multiplier"),
662664
max_grad_norm=dp_config.get("max_grad_norm"),
663665
poisson_sampling=True,
666+
wrap_model=False,
664667
)
668+
model = dp_grad_sample_hooks._module
665669
# this further wraps the dataloader with batch_sampler=BatchSplittingSampler to achieve gradient accumulation
666670
# it will split the sampled logical batches into smaller sub-batches with batch_size
667671
trn_dataloader = wrap_data_loader(
@@ -835,6 +839,9 @@ def concat_prompt_and_response(x):
835839
if total_training_time > max_training_time:
836840
do_stop = True
837841

842+
if dp_grad_sample_hooks is not None:
843+
dp_grad_sample_hooks.cleanup()
844+
838845
# no checkpoint is saved yet because the training stopped before the first epoch ended
839846
if not model_checkpoint.has_saved_once():
840847
_LOG.info("saving model weights, as none were saved so far")

0 commit comments

Comments
 (0)