File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -693,10 +693,11 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
693693 # TODO add optimizers for multitask
694694 # author: iProzd
695695 if self .opt_type in ["Adam" , "AdamW" ]:
696+ base_model = self ._get_wrapper_module ()
696697 parameter_list = [
697- {"params" : self . wrapper .model .parameters ()},
698- {"params" : self . wrapper .loss .parameters (), "lr" : 1e-4 },
699- ]
698+ {"params" : base_model .model .parameters ()},
699+ {"params" : base_model .loss .parameters (), "lr" : 1e-4 },
700+ ] # set a smaller lr for loss parameters
700701 if self .opt_type == "Adam" :
701702 self .optimizer = torch .optim .Adam (
702703 # self.wrapper.parameters(),
@@ -1373,6 +1374,12 @@ def print_on_training(
13731374 fout .write (print_str )
13741375 fout .flush ()
13751376
1377+ def _get_wrapper_module (self ) -> torch .nn .Module :
1378+ """Compatible with DDP and normal model."""
1379+ if isinstance (self .wrapper , DDP ):
1380+ return self .wrapper .module
1381+ return self .wrapper
1382+
13761383
13771384def get_additional_data_requirement (_model : Any ) -> list [DataRequirementItem ]:
13781385 additional_data_requirement = []
You can’t perform that action at this time.
0 commit comments