Skip to content

Commit 58e346a

Browse files
pd(feat): support python inference with DP class (#4987)
Summary: 1. support python inference with `DP(model=paddle_inference.json)` for ase, and add an unitest. 2. update ase document with different backend 3. fix typos in `DeepPotPD.cc` --- This pull request introduces significant enhancements and refactoring to the Paddle backend implementation of `DeepEval` in `deepmd/pd/infer/deep_eval.py`, along with minor improvements to model freezing in `deepmd/pd/entrypoints/main.py`. The most important changes are the addition of support for static models loaded from `.json` files, expanded model type handling, and improved output shape and evaluation logic for various model branches and output variable categories. ### Static model and inference support * Added support for loading and evaluating static models from `.json` files, including Paddle inference engine integration and input/output handle management for efficient prediction. (`deepmd/pd/infer/deep_eval.py`) [[1]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68L117-R181) [[2]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68R470) [[3]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68R481-R508) [[4]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68R532-R547) [[5]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68L420-R566) ### Expanded model type and branch handling * Enhanced model type detection to support additional output types such as DOS, dipole, polar, global polar, WFC, and property models, with corresponding evaluator selection. Also added methods for model branch information retrieval and default parameter checks. (`deepmd/pd/infer/deep_eval.py`) [[1]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68L152-R259) [[2]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68L187-R306) ### Output variable and evaluation improvements * Improved output shape determination for new output variable categories (e.g., `DERV_R_DERV_R`), and refactored evaluation logic to handle both static and dynamic models, including proper output conversion. (`deepmd/pd/infer/deep_eval.py`) [[1]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68R412) [[2]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68R681-R683) [[3]](diffhunk://#diff-8c2ffd525a36d0190726f1aca380b9a4e05e67cd8ba6fa5e3842f810c69e6c68L420-R566) * Implemented the previously unimplemented `_eval_model_spin` method to support spin-dependent model evaluation and output extraction. (`deepmd/pd/infer/deep_eval.py`) ### Type embedding evaluation * Added a new method `eval_typeebd` to extract and concatenate type embedding network outputs from the loaded model. (`deepmd/pd/infer/deep_eval.py`) ### Model freezing improvements * Updated the `freeze` function to support freezing additional model methods and atomic virial computation, and changed input specifications for better compatibility. (`deepmd/pd/entrypoints/main.py`) [[1]](diffhunk://#diff-e3f56cd14511cf86a0db88d6d9ee5b08cf45374edfdef0625a0f519d94c58507R345) [[2]](diffhunk://#diff-e3f56cd14511cf86a0db88d6d9ee5b08cf45374edfdef0625a0f519d94c58507L377-R378) [[3]](diffhunk://#diff-e3f56cd14511cf86a0db88d6d9ee5b08cf45374edfdef0625a0f519d94c58507L393-R400) [[4]](diffhunk://#diff-e3f56cd14511cf86a0db88d6d9ee5b08cf45374edfdef0625a0f519d94c58507R413-R432) Let me know if you'd like a walkthrough of any specific new functionality or code sections! <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Dual evaluation modes: dynamic (.json) and static (.pd) with optional no-JIT; new evaluation branches for additional model types and spin support. - Freeze can include atomic virial; frozen exports expose additional buffer-backed getters (type map, cutoffs, parameter dims, ntypes). - Improvements - Wider static-graph support via persistent buffers across descriptors/models; more consistent inference input handling and batching. - Documentation - ASE guide extended with backend-specific examples. - Tests - New end-to-end training, inference, and frozen-model validation tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: HydrogenSulfate <490868991@qq.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 58f60c4 commit 58e346a

24 files changed

Lines changed: 1036 additions & 154 deletions

deepmd/pd/entrypoints/main.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Path,
88
)
99
from typing import (
10+
Any,
1011
Optional,
1112
Union,
1213
)
@@ -80,15 +81,15 @@
8081

8182

8283
def get_trainer(
83-
config,
84-
init_model=None,
85-
restart_model=None,
86-
finetune_model=None,
87-
force_load=False,
88-
init_frz_model=None,
89-
shared_links=None,
90-
finetune_links=None,
91-
):
84+
config: dict[str, Any],
85+
init_model: Optional[str] = None,
86+
restart_model: Optional[str] = None,
87+
finetune_model: Optional[str] = None,
88+
force_load: bool = False,
89+
init_frz_model: Optional[str] = None,
90+
shared_links: Optional[dict[str, Any]] = None,
91+
finetune_links: Optional[dict[str, Any]] = None,
92+
) -> training.Trainer:
9293
multi_task = "model_dict" in config.get("model", {})
9394

9495
# Initialize DDP
@@ -98,17 +99,22 @@ def get_trainer(
9899
fleet.init(is_collective=True)
99100

100101
def prepare_trainer_input_single(
101-
model_params_single, data_dict_single, rank=0, seed=None
102-
):
102+
model_params_single: dict[str, Any],
103+
data_dict_single: dict[str, Any],
104+
rank: int = 0,
105+
seed: Optional[int] = None,
106+
) -> tuple[DpLoaderSet, Optional[DpLoaderSet], Optional[DPPath]]:
103107
training_dataset_params = data_dict_single["training_data"]
104108
validation_dataset_params = data_dict_single.get("validation_data", None)
105109
validation_systems = (
106110
validation_dataset_params["systems"] if validation_dataset_params else None
107111
)
108112
training_systems = training_dataset_params["systems"]
109-
training_systems = process_systems(training_systems)
113+
trn_patterns = training_dataset_params.get("rglob_patterns", None)
114+
training_systems = process_systems(training_systems, patterns=trn_patterns)
110115
if validation_systems is not None:
111-
validation_systems = process_systems(validation_systems)
116+
val_patterns = validation_dataset_params.get("rglob_patterns", None)
117+
validation_systems = process_systems(validation_systems, val_patterns)
112118

113119
# stat files
114120
stat_file_path_single = data_dict_single.get("stat_file", None)
@@ -342,6 +348,7 @@ def freeze(
342348
model: str,
343349
output: str = "frozen_model.json",
344350
head: Optional[str] = None,
351+
do_atomic_virial: bool = False,
345352
) -> None:
346353
paddle.set_flags(
347354
{
@@ -374,7 +381,7 @@ def freeze(
374381
None, # fparam
375382
None, # aparam
376383
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
377-
False, # do_atomic_virial
384+
do_atomic_virial, # do_atomic_virial
378385
],
379386
full_graph=True,
380387
)
@@ -396,7 +403,7 @@ def freeze(
396403
None, # fparam
397404
None, # aparam
398405
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
399-
False, # do_atomic_virial
406+
do_atomic_virial, # do_atomic_virial
400407
(
401408
InputSpec([-1], "int64", name="send_list"),
402409
InputSpec([-1], "int32", name="send_proc"),
@@ -409,6 +416,26 @@ def freeze(
409416
],
410417
full_graph=True,
411418
)
419+
for method_name in [
420+
"get_buffer_rcut",
421+
"get_buffer_type_map",
422+
"get_buffer_dim_fparam",
423+
"get_buffer_dim_aparam",
424+
"get_buffer_intensive",
425+
"get_buffer_sel_type",
426+
"get_buffer_numb_dos",
427+
"get_buffer_task_dim",
428+
]:
429+
if hasattr(model, method_name):
430+
setattr(
431+
model,
432+
method_name,
433+
paddle.jit.to_static(
434+
getattr(model, method_name),
435+
input_spec=[],
436+
full_graph=True,
437+
),
438+
)
412439
if output.endswith(".json"):
413440
output = output[:-5]
414441
paddle.jit.save(

0 commit comments

Comments
 (0)