Skip to content

Commit 8c356b4

Browse files
author
Han Wang
committed
fix bugs
1 parent c859576 commit 8c356b4

3 files changed

Lines changed: 6 additions & 2 deletions

File tree

deepmd/jax/infer/deep_eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def _eval_model(
388388

389389
results = []
390390
for odef in request_defs:
391-
dp_name = self._OUTDEF_DP2BACKEND[odef.name]
391+
# HLO and TFModelWrapper return raw internal keys (not translated),
392+
# so no key mapping is needed here.
393+
dp_name = odef.name
392394
if dp_name in batch_output:
393395
shape = self._get_output_shape(odef, nframes, natoms)
394396
if batch_output[dp_name] is not None:

source/tests/universal/dpmodel/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class DPTestCase(BackendTestCase):
2121
"""DP module to test."""
2222

2323
def forward_wrapper(self, x):
24+
if not hasattr(x, "forward_lower") and hasattr(x, "call_lower"):
25+
x.forward_lower = x.call_lower
2426
return x
2527

2628
def forward_wrapper_cpu_ref(self, x):

source/tests/universal/dpmodel/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def setUpClass(cls) -> None:
164164
ft,
165165
type_map=cls.expected_type_map,
166166
)
167-
cls.output_def = cls.module.model_output_def().get_data()
167+
cls.output_def = cls.module.translated_output_def()
168168
cls.expected_has_message_passing = ds.has_message_passing()
169169
cls.expected_sel_type = ft.get_sel_type()
170170
cls.expected_dim_fparam = ft.get_dim_fparam()

0 commit comments

Comments
 (0)