Skip to content

Commit 1ccf225

Browse files
Chengqian-ZhangiProzd
authored andcommitted
delete torch.jit.export of get_default_fparam (#51)
1 parent 20ad044 commit 1ccf225

4 files changed

Lines changed: 7 additions & 8 deletions

File tree

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Optional,
66
)
77

8-
import numpy as np
98
import torch
109

1110
from deepmd.dpmodel import (
@@ -366,7 +365,7 @@ def get_dim_fparam(self) -> int:
366365
def has_default_fparam(self) -> bool:
367366
return self.fitting_net.has_default_fparam()
368367

369-
def get_default_fparam(self) -> Optional[np.array]:
368+
def get_default_fparam(self) -> Optional[torch.Tensor]:
370369
return self.fitting_net.get_default_fparam()
371370

372371
def get_dim_aparam(self) -> int:

deepmd/pt/model/model/make_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
Optional,
44
)
55

6-
import numpy as np
76
import torch
87

98
from deepmd.dpmodel import (
@@ -531,8 +530,7 @@ def get_dim_fparam(self) -> int:
531530
def has_default_fparam(self) -> bool:
532531
return self.atomic_model.has_default_fparam()
533532

534-
@torch.jit.export
535-
def get_default_fparam(self) -> Optional[np.array]:
533+
def get_default_fparam(self) -> Optional[torch.Tensor]:
536534
return self.atomic_model.get_default_fparam()
537535

538536
@torch.jit.export

deepmd/pt/model/task/fitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ def get_dim_fparam(self) -> int:
619619
def has_default_fparam(self) -> bool:
620620
return self.default_fparam is not None
621621

622-
def get_default_fparam(self) -> Optional[np.array]:
623-
return self.default_fparam_tensor.cpu().numpy()
622+
def get_default_fparam(self) -> Optional[torch.Tensor]:
623+
return self.default_fparam_tensor
624624

625625
def get_dim_aparam(self) -> int:
626626
"""Get the number (dimension) of atomic parameters of this atomic model."""

deepmd/pt/train/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,9 @@ def get_additional_data_requirement(_model):
13051305
additional_data_requirement = []
13061306
if _model.get_dim_fparam() > 0:
13071307
_fparam_default = (
1308-
_model.get_default_fparam() if _model.has_default_fparam() else 0.0
1308+
_model.get_default_fparam().cpu().numpy()
1309+
if _model.has_default_fparam()
1310+
else 0.0
13091311
)
13101312
fparam_requirement_items = [
13111313
DataRequirementItem(

0 commit comments

Comments
 (0)