Skip to content

Commit 7e5278d

Browse files
OutisLiChiahsinChu
authored andcommitted
Fix(pt): add comm_dict for zbl, linear, dipole, dos, polar model to fix bugs mentioned in issue deepmodeling#4906 (deepmodeling#4908)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added optional support to pass a communication dictionary through lower-level model computations across energy, dipole, DOS, polarization, and related models. This enables advanced workflows while remaining fully backward compatible. - Refactor - Standardized internal propagation of the communication dictionary across sub-models to ensure consistent behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent ae51b72 commit 7e5278d

6 files changed

Lines changed: 11 additions & 0 deletions

File tree

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def forward_atomic(
290290
mapping,
291291
fparam,
292292
aparam,
293+
comm_dict=comm_dict,
293294
)["energy"]
294295
)
295296
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

deepmd/pt/model/model/dipole_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def forward_lower(
9898
fparam: Optional[torch.Tensor] = None,
9999
aparam: Optional[torch.Tensor] = None,
100100
do_atomic_virial: bool = False,
101+
comm_dict: Optional[dict[str, torch.Tensor]] = None,
101102
):
102103
model_ret = self.forward_common_lower(
103104
extended_coord,
@@ -107,6 +108,7 @@ def forward_lower(
107108
fparam=fparam,
108109
aparam=aparam,
109110
do_atomic_virial=do_atomic_virial,
111+
comm_dict=comm_dict,
110112
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
111113
)
112114
if self.get_fitting_net() is not None:

deepmd/pt/model/model/dos_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def forward_lower(
8888
fparam: Optional[torch.Tensor] = None,
8989
aparam: Optional[torch.Tensor] = None,
9090
do_atomic_virial: bool = False,
91+
comm_dict: Optional[dict[str, torch.Tensor]] = None,
9192
):
9293
model_ret = self.forward_common_lower(
9394
extended_coord,
@@ -97,6 +98,7 @@ def forward_lower(
9798
fparam=fparam,
9899
aparam=aparam,
99100
do_atomic_virial=do_atomic_virial,
101+
comm_dict=comm_dict,
100102
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
101103
)
102104
if self.get_fitting_net() is not None:

deepmd/pt/model/model/dp_linear_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def forward_lower(
9797
fparam: Optional[torch.Tensor] = None,
9898
aparam: Optional[torch.Tensor] = None,
9999
do_atomic_virial: bool = False,
100+
comm_dict: Optional[dict[str, torch.Tensor]] = None,
100101
):
101102
model_ret = self.forward_common_lower(
102103
extended_coord,
@@ -106,6 +107,7 @@ def forward_lower(
106107
fparam=fparam,
107108
aparam=aparam,
108109
do_atomic_virial=do_atomic_virial,
110+
comm_dict=comm_dict,
109111
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
110112
)
111113

deepmd/pt/model/model/dp_zbl_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def forward_lower(
9797
fparam: Optional[torch.Tensor] = None,
9898
aparam: Optional[torch.Tensor] = None,
9999
do_atomic_virial: bool = False,
100+
comm_dict: Optional[dict[str, torch.Tensor]] = None,
100101
):
101102
model_ret = self.forward_common_lower(
102103
extended_coord,
@@ -106,6 +107,7 @@ def forward_lower(
106107
fparam=fparam,
107108
aparam=aparam,
108109
do_atomic_virial=do_atomic_virial,
110+
comm_dict=comm_dict,
109111
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
110112
)
111113

deepmd/pt/model/model/polar_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def forward_lower(
8282
fparam: Optional[torch.Tensor] = None,
8383
aparam: Optional[torch.Tensor] = None,
8484
do_atomic_virial: bool = False,
85+
comm_dict: Optional[dict[str, torch.Tensor]] = None,
8586
):
8687
model_ret = self.forward_common_lower(
8788
extended_coord,
@@ -91,6 +92,7 @@ def forward_lower(
9192
fparam=fparam,
9293
aparam=aparam,
9394
do_atomic_virial=do_atomic_virial,
95+
comm_dict=comm_dict,
9496
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
9597
)
9698
if self.get_fitting_net() is not None:

0 commit comments

Comments
 (0)