Skip to content

Commit 68b7e25

Browse files
committed
fix jit
1 parent 0c81876 commit 68b7e25

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,10 @@ def call_until_last(self, x):
646646
np.ndarray
647647
The output before last layer.
648648
"""
649-
for layer in self.layers[:-1]:
650-
x = layer(x)
649+
# avoid slice (self.layers[:-1]) for jit
650+
for ii, layer in enumerate(self.layers):
651+
if ii < len(self.layers) - 1:
652+
x = layer(x)
651653
return x
652654

653655
def clear(self) -> None:

deepmd/pt/infer/deep_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def __init__(
131131
state_dict = state_dict_head
132132
model = get_model(self.input_param).to(DEVICE)
133133
# TODO fix jit
134-
# if not self.input_param.get("hessian_mode"):
135-
# model = torch.jit.script(model)
134+
if not self.input_param.get("hessian_mode"):
135+
model = torch.jit.script(model)
136136
self.dp = ModelWrapper(model)
137137
self.dp.load_state_dict(state_dict)
138138
elif str(self.model_path).endswith(".pth"):

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def forward_atomic(
271271
)
272272
if self.enable_eval_fitting_last_layer_hook:
273273
assert "middle_output" in fit_ret, (
274-
f"eval_fitting_last_layer not supported for fitting net {type(self.fitting_net.__class__)}!"
274+
"eval_fitting_last_layer not supported for this fitting net!"
275275
)
276276
self.eval_fitting_last_layer_list.append(
277277
fit_ret.pop("middle_output").detach()

0 commit comments

Comments
 (0)