Skip to content

Commit 7569f78

Browse files
committed
not extend atype for training
1 parent 1888cd2 commit 7569f78

2 files changed

Lines changed: 11 additions & 15 deletions

File tree

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,18 @@ def forward(
469469
# cast the input to internal precsion
470470
extended_coord = extended_coord.to(dtype=self.prec)
471471
nframes, nloc, nnei = nlist.shape
472-
nall = extended_coord.view(nframes, -1).shape[1] // 3
473-
474-
node_ebd_ext = self.type_embedding(extended_atype)
475-
node_ebd_inp = node_ebd_ext[:, :nloc, :]
472+
# nall = extended_coord.view(nframes, -1).shape[1] // 3
473+
if comm_dict is None:
474+
atype = extended_atype[:, :nloc]
475+
else:
476+
atype = extended_atype
477+
node_ebd_inp = self.type_embedding(atype)
476478
# repflows
477-
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
479+
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward(
478480
nlist,
479481
extended_coord,
480482
extended_atype,
481-
node_ebd_ext,
483+
node_ebd_inp,
482484
mapping,
483485
comm_dict=comm_dict,
484486
)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,13 +373,12 @@ def forward(
373373
nlist: torch.Tensor,
374374
extended_coord: torch.Tensor,
375375
extended_atype: torch.Tensor,
376-
extended_atype_embd: Optional[torch.Tensor] = None,
376+
extended_atype_embd: torch.Tensor,
377377
mapping: Optional[torch.Tensor] = None,
378378
comm_dict: Optional[dict[str, torch.Tensor]] = None,
379379
):
380380
if comm_dict is None:
381381
assert mapping is not None
382-
assert extended_atype_embd is not None
383382
nframes, nloc, nnei = nlist.shape
384383
nall = extended_coord.view(nframes, -1).shape[1] // 3
385384
atype = extended_atype[:, :nloc]
@@ -402,14 +401,9 @@ def forward(
402401
# beyond the cutoff sw should be 0.0
403402
sw = sw.masked_fill(~nlist_mask, 0.0)
404403

404+
atype_embd = extended_atype_embd
405405
# [nframes, nloc, tebd_dim]
406-
if comm_dict is None:
407-
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
408-
atype_embd = extended_atype_embd[:, :nloc, :]
409-
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
410-
else:
411-
atype_embd = extended_atype_embd
412-
assert isinstance(atype_embd, torch.Tensor) # for jit
406+
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
413407
node_ebd = self.act(atype_embd)
414408
n_dim = node_ebd.shape[-1]
415409
# nb x nloc x nnei x 1, nb x nloc x nnei x 3

0 commit comments

Comments
 (0)