Skip to content

Commit 670f346

Browse files
committed
fix ddp
1 parent 448baea commit 670f346

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,10 +693,11 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
693693
# TODO add optimizers for multitask
694694
# author: iProzd
695695
if self.opt_type in ["Adam", "AdamW"]:
696+
base_model = self._get_wrapper_module()
696697
parameter_list = [
697-
{"params": self.wrapper.model.parameters()},
698-
{"params": self.wrapper.loss.parameters(), "lr": 1e-4},
699-
]
698+
{"params": base_model.model.parameters()},
699+
{"params": base_model.loss.parameters(), "lr": 1e-4},
700+
] # set a smaller lr for loss parameters
700701
if self.opt_type == "Adam":
701702
self.optimizer = torch.optim.Adam(
702703
# self.wrapper.parameters(),
@@ -1373,6 +1374,12 @@ def print_on_training(
13731374
fout.write(print_str)
13741375
fout.flush()
13751376

1377+
def _get_wrapper_module(self) -> torch.nn.Module:
1378+
"""Compatible with DDP and normal model."""
1379+
if isinstance(self.wrapper, DDP):
1380+
return self.wrapper.module
1381+
return self.wrapper
1382+
13761383

13771384
def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]:
13781385
additional_data_requirement = []

0 commit comments

Comments
 (0)