Skip to content

Commit ac6677e

Browse files
committed
refactor node_ebd_ext ctor
1 parent 5594c56 commit ac6677e

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

deepmd/pt/model/descriptor/repflows.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,13 @@ def forward(
469469
for idx, ll in enumerate(self.layers):
470470
# node_ebd: nb x nloc x n_dim
471471
# node_ebd_ext: nb x nall x n_dim
472-
if self.use_ext_ebd or comm_dict is not None:
473-
assert mapping is not None
474-
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
472+
if comm_dict is None:
473+
if self.use_ext_ebd:
474+
assert mapping is not None
475+
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
476+
else:
477+
node_ebd_ext = None
475478
else:
476-
node_ebd_ext = None
477-
if comm_dict is not None:
478479
has_spin = "has_spin" in comm_dict
479480
if not has_spin:
480481
n_padding = nall - nloc

0 commit comments

Comments
 (0)