Skip to content

Commit 87cb6ef

Browse files
authored
fix: remove hessian outdef if not necessary (#5045)
I found that the inference time per atom very weird using small models (both DPA3-L3 and DPA1 attn0) on very large systems (more than 1000 atoms): <img width="1034" height="695" alt="截屏2025-11-11 17 52 32" src="https://github.com/user-attachments/assets/71b12719-ee74-4f2b-bb50-9f5f7031ee16" /> Through profilling, I found some unnecessary memory allocation matters for keys not in the model outputs (such as hessian). After fix, the inference time seems good: <img width="1067" height="693" alt="截屏2025-11-11 17 56 26" src="https://github.com/user-attachments/assets/0fe6d430-3daa-43cd-b245-0889cd1311a8" /> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved internal handling of output definitions in model inference to ensure proper filtering for models without Hessian support. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 877147e commit 87cb6ef

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

deepmd/pt/infer/deep_eval.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,9 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
397397
The requested output definitions.
398398
"""
399399
if atomic:
400-
return list(self.output_def.var_defs.values())
400+
output_defs = list(self.output_def.var_defs.values())
401401
else:
402-
return [
402+
output_defs = [
403403
x
404404
for x in self.output_def.var_defs.values()
405405
if x.category
@@ -411,6 +411,13 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
411411
OutputVariableCategory.DERV_R_DERV_R,
412412
)
413413
]
414+
if not self.get_has_hessian():
415+
output_defs = [
416+
x
417+
for x in output_defs
418+
if x.category != OutputVariableCategory.DERV_R_DERV_R
419+
]
420+
return output_defs
414421

415422
def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable:
416423
"""Wrapper method with auto batch size.

0 commit comments

Comments
 (0)