Skip to content

Commit bf01e3c

Browse files
committed
Update mlp.py
1 parent 7e67cf3 commit bf01e3c

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

deepmd/pt/model/network/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
521521
P_list.append(Plp1)
522522

523523
P = torch.concat(P_list, dim=-1) # (..., L+1)
524-
return P * self.norm
524+
return P * self.norm.type(x.dtype)
525525

526526

527527
def find_normalization(name: str, dim: int | None = None) -> nn.Module | None:

0 commit comments

Comments
 (0)