-
Notifications
You must be signed in to change notification settings - Fork 612
Expand file tree
/
Copy patheval_desc.py
More file actions
137 lines (115 loc) · 3.37 KB
/
eval_desc.py
File metadata and controls
137 lines (115 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Evaluate descriptors using trained DeePMD model."""
import logging
import os
from pathlib import (
Path,
)
from typing import (
Optional,
)
import numpy as np
from deepmd.common import (
expand_sys_str,
)
from deepmd.infer.deep_eval import (
DeepEval,
)
from deepmd.utils.data import (
DeepmdData,
)
__all__ = ["eval_desc"]
log = logging.getLogger(__name__)
def eval_desc(
*,
model: str,
system: str,
datafile: str,
output: str = "desc",
head: Optional[str] = None,
**kwargs,
) -> None:
"""Evaluate descriptors for given systems.
Parameters
----------
model : str
path where model is stored
system : str
system directory
datafile : str
the path to the list of systems to process
output : str
output directory for descriptor files
head : Optional[str], optional
(Supported backend: PyTorch) Task head if in multi-task mode.
**kwargs
additional arguments
Raises
------
RuntimeError
if no valid system was found
"""
if datafile is not None:
with open(datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
all_sys = expand_sys_str(system)
if len(all_sys) == 0:
raise RuntimeError("Did not find valid system")
# init model
dp = DeepEval(model, head=head)
# create output directory
output_dir = Path(output)
output_dir.mkdir(parents=True, exist_ok=True)
for cc, system_path in enumerate(all_sys):
log.info("# -------output of dp eval_desc------- ")
log.info(f"# processing system : {system_path}")
# create data class
tmap = dp.get_type_map()
data = DeepmdData(
system_path,
set_prefix="set",
shuffle_test=False,
type_map=tmap,
sort_atoms=False,
)
# get test data
test_data = data.get_test()
mixed_type = data.mixed_type
natoms = len(test_data["type"][0])
nframes = test_data["box"].shape[0]
# prepare input data
coord = test_data["coord"].reshape([nframes, -1])
box = test_data["box"]
if not data.pbc:
box = None
if mixed_type:
atype = test_data["type"].reshape([nframes, -1])
else:
atype = test_data["type"][0]
# handle optional parameters
fparam = None
if dp.get_dim_fparam() > 0:
if "fparam" in test_data:
fparam = test_data["fparam"]
aparam = None
if dp.get_dim_aparam() > 0:
if "aparam" in test_data:
aparam = test_data["aparam"]
# evaluate descriptors
log.info(f"# evaluating descriptors for {nframes} frames")
descriptors = dp.eval_descriptor(
coord,
box,
atype,
fparam=fparam,
aparam=aparam,
)
# save descriptors
system_name = os.path.basename(system_path.rstrip('/'))
desc_file = output_dir / f"{system_name}.npy"
np.save(desc_file, descriptors)
log.info(f"# descriptors saved to {desc_file}")
log.info(f"# descriptor shape: {descriptors.shape}")
log.info("# ----------------------------------- ")
log.info("# eval_desc completed successfully")