3131from huggingface_hub import get_safetensors_metadata
3232from opacus import GradSampleModule , PrivacyEngine
3333from opacus .accountants import GaussianAccountant , PRVAccountant , RDPAccountant
34- from opacus .grad_sample import register_grad_sampler
34+ from opacus .grad_sample import GradSampleHooks , register_grad_sampler
3535from opacus .utils .batch_memory_manager import wrap_data_loader
3636from peft import LoraConfig , PeftModel
3737from torch import nn
@@ -626,6 +626,7 @@ def concat_prompt_and_response(x):
626626 # this can help accelerate GPU compute
627627 torch .backends .cudnn .benchmark = True
628628
629+ dp_grad_sample_hooks : GradSampleHooks | None = None
629630 if with_dp :
630631 if isinstance (differential_privacy , DifferentialPrivacyConfig ):
631632 dp_config = differential_privacy .model_dump ()
@@ -650,18 +651,21 @@ def concat_prompt_and_response(x):
650651 privacy_engine .accountant .load_state_dict (
651652 torch .load (workspace .model_dp_accountant_path , map_location = device , weights_only = True ),
652653 )
653- # Opacus will return the modified objects
654- # - model: wrapped in GradSampleModule and contains additional hooks for computing per-sample gradients
655- # - optimizer: wrapped in DPOptimizer and will do different operations during virtual steps and logical steps
656- # - dataloader: the dataloader with batch_sampler=UniformWithReplacementSampler (for Poisson sampling)
657- model , optimizer , trn_dataloader = privacy_engine .make_private (
654+ # Opacus returns GradSampleHooks when wrap_model=False: hooks attach to the original module so HF /
655+ # Transformers sees an unwrapped PreTrainedModel (requires Opacus >= 1.6).
656+ # - dp_grad_sample_hooks: must call .cleanup() after training to remove backward hooks and param attrs
657+ # - optimizer: wrapped in DPOptimizer (virtual vs logical steps)
658+ # - dataloader: UniformWithReplacementSampler when poisson_sampling=True
659+ dp_grad_sample_hooks , optimizer , trn_dataloader = privacy_engine .make_private (
658660 module = model ,
659661 optimizer = optimizer ,
660662 data_loader = trn_dataloader ,
661663 noise_multiplier = dp_config .get ("noise_multiplier" ),
662664 max_grad_norm = dp_config .get ("max_grad_norm" ),
663665 poisson_sampling = True ,
666+ wrap_model = False ,
664667 )
668+ model = dp_grad_sample_hooks ._module
665669 # this further wraps the dataloader with batch_sampler=BatchSplittingSampler to achieve gradient accumulation
666670 # it will split the sampled logical batches into smaller sub-batches with batch_size
667671 trn_dataloader = wrap_data_loader (
@@ -835,6 +839,9 @@ def concat_prompt_and_response(x):
835839 if total_training_time > max_training_time :
836840 do_stop = True
837841
842+ if dp_grad_sample_hooks is not None :
843+ dp_grad_sample_hooks .cleanup ()
844+
838845 # no checkpoint is saved yet because the training stopped before the first epoch ended
839846 if not model_checkpoint .has_saved_once ():
840847 _LOG .info ("saving model weights, as none were saved so far" )
0 commit comments