Skip to content

Commit cbbce64

Browse files
committed
fix dpa3 spin lmp
1 parent 6496194 commit cbbce64

1 file changed

Lines changed: 3 additions & 8 deletions

File tree

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def forward(
379379
):
380380
if comm_dict is None:
381381
assert mapping is not None
382-
assert extended_atype_embd is not None
383382
nframes, nloc, nnei = nlist.shape
384383
nall = extended_coord.view(nframes, -1).shape[1] // 3
385384
atype = extended_atype[:, :nloc]
@@ -403,13 +402,9 @@ def forward(
403402
sw = sw.masked_fill(~nlist_mask, 0.0)
404403

405404
# [nframes, nloc, tebd_dim]
406-
if comm_dict is None:
407-
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
408-
atype_embd = extended_atype_embd[:, :nloc, :]
409-
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
410-
else:
411-
atype_embd = extended_atype_embd
412-
assert isinstance(atype_embd, torch.Tensor) # for jit
405+
assert extended_atype_embd is not None
406+
atype_embd = extended_atype_embd[:, :nloc, :]
407+
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
413408
node_ebd = self.act(atype_embd)
414409
n_dim = node_ebd.shape[-1]
415410
# nb x nloc x nnei x 1, nb x nloc x nnei x 3

0 commit comments

Comments
 (0)