We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5594c56 commit ac6677eCopy full SHA for ac6677e
1 file changed
deepmd/pt/model/descriptor/repflows.py
@@ -469,12 +469,13 @@ def forward(
469
for idx, ll in enumerate(self.layers):
470
# node_ebd: nb x nloc x n_dim
471
# node_ebd_ext: nb x nall x n_dim
472
- if self.use_ext_ebd or comm_dict is not None:
473
- assert mapping is not None
474
- node_ebd_ext = torch.gather(node_ebd, 1, mapping)
+ if comm_dict is None:
+ if self.use_ext_ebd:
+ assert mapping is not None
475
+ node_ebd_ext = torch.gather(node_ebd, 1, mapping)
476
+ else:
477
+ node_ebd_ext = None
478
else:
- node_ebd_ext = None
- if comm_dict is not None:
479
has_spin = "has_spin" in comm_dict
480
if not has_spin:
481
n_padding = nall - nloc
0 commit comments