File tree Expand file tree Collapse file tree
source/tests/universal/dpmodel Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -388,7 +388,9 @@ def _eval_model(
388388
389389 results = []
390390 for odef in request_defs :
391- dp_name = self ._OUTDEF_DP2BACKEND [odef .name ]
391+ # HLO and TFModelWrapper return raw internal keys (not translated),
392+ # so no key mapping is needed here.
393+ dp_name = odef .name
392394 if dp_name in batch_output :
393395 shape = self ._get_output_shape (odef , nframes , natoms )
394396 if batch_output [dp_name ] is not None :
Original file line number Diff line number Diff line change @@ -21,6 +21,8 @@ class DPTestCase(BackendTestCase):
2121 """DP module to test."""
2222
2323 def forward_wrapper (self , x ):
24+ if not hasattr (x , "forward_lower" ) and hasattr (x , "call_lower" ):
25+ x .forward_lower = x .call_lower
2426 return x
2527
2628 def forward_wrapper_cpu_ref (self , x ):
Original file line number Diff line number Diff line change @@ -164,7 +164,7 @@ def setUpClass(cls) -> None:
164164 ft ,
165165 type_map = cls .expected_type_map ,
166166 )
167- cls .output_def = cls .module .model_output_def (). get_data ()
167+ cls .output_def = cls .module .translated_output_def ()
168168 cls .expected_has_message_passing = ds .has_message_passing ()
169169 cls .expected_sel_type = ft .get_sel_type ()
170170 cls .expected_dim_fparam = ft .get_dim_fparam ()
You can’t perform that action at this time.
0 commit comments