Skip to content

Commit 711f4a0

Browse files
committed
feat: support rglob patterns for dp test
1 parent 43504d3 commit 711f4a0

3 files changed

Lines changed: 96 additions & 4 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from deepmd.common import (
1616
expand_sys_str,
17+
j_loader,
1718
)
1819
from deepmd.infer.deep_dipole import (
1920
DeepDipole,
@@ -41,6 +42,9 @@
4142
from deepmd.utils.data import (
4243
DeepmdData,
4344
)
45+
from deepmd.utils.data_system import (
46+
process_systems,
47+
)
4448
from deepmd.utils.weight_avg import (
4549
weighted_average,
4650
)
@@ -60,6 +64,7 @@ def test(
6064
model: str,
6165
system: str,
6266
datafile: str,
67+
input_json: Optional[str] = None,
6368
numb_test: int,
6469
rand_seed: Optional[int],
6570
shuffle_test: bool,
@@ -78,6 +83,8 @@ def test(
7883
system directory
7984
datafile : str
8085
the path to the list of systems to test
86+
input_json : Optional[str]
87+
the training input json file. Validation systems in this file will be used.
8188
numb_test : int
8289
munber of tests to do. 0 means all data.
8390
rand_seed : Optional[int]
@@ -101,7 +108,20 @@ def test(
101108
if numb_test == 0:
102109
# only float has inf, but should work for min
103110
numb_test = float("inf")
104-
if datafile is not None:
111+
if input_json is not None:
112+
jdata = j_loader(input_json)
113+
val_params = jdata.get("training", {}).get("validation_data", {})
114+
validation = val_params.get("systems")
115+
if not validation:
116+
raise RuntimeError("No validation data found in input json")
117+
root = Path(input_json).parent
118+
if isinstance(validation, str):
119+
validation = str((root / Path(validation)).resolve())
120+
else:
121+
validation = [str((root / Path(ss)).resolve()) for ss in validation]
122+
patterns = val_params.get("rglob_patterns", None)
123+
all_sys = process_systems(validation, patterns=patterns)
124+
elif datafile is not None:
105125
with open(datafile) as datalist:
106126
all_sys = datalist.read().splitlines()
107127
else:

deepmd/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ def main_parser() -> argparse.ArgumentParser:
371371
type=str,
372372
help="The path to the datafile, each line of which is a path to one data system.",
373373
)
374+
parser_tst_subgroup.add_argument(
375+
"-i",
376+
"--input-json",
377+
default=None,
378+
type=str,
379+
help="The training input json file. Validation data in the training script will be used for testing.",
380+
)
374381
parser_tst.add_argument(
375382
"-S",
376383
"--set-prefix",

source/tests/pt/test_dp_test.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
class DPTest:
33-
def test_dp_test_1_frame(self) -> None:
33+
def _run_dp_test(self, use_input_json: bool, numb_test: int = 0) -> None:
3434
trainer = get_trainer(deepcopy(self.config))
3535
with torch.device("cpu"):
3636
input_dict, label_dict, _ = trainer.get_data(is_train=False)
@@ -44,12 +44,16 @@ def test_dp_test_1_frame(self) -> None:
4444
model = torch.jit.script(trainer.model)
4545
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
4646
torch.jit.save(model, tmp_model.name)
47+
val_sys = self.config["training"]["validation_data"]["systems"]
48+
if isinstance(val_sys, list):
49+
val_sys = val_sys[0]
4750
dp_test(
4851
model=tmp_model.name,
49-
system=self.config["training"]["validation_data"]["systems"][0],
52+
system=val_sys,
5053
datafile=None,
54+
input_json=self.input_json if use_input_json else None,
5155
set_prefix="set",
52-
numb_test=0,
56+
numb_test=numb_test,
5357
rand_seed=None,
5458
shuffle_test=False,
5559
detail_file=self.detail_file,
@@ -93,6 +97,12 @@ def test_dp_test_1_frame(self) -> None:
9397
).reshape(-1, 3),
9498
)
9599

100+
def test_dp_test_1_frame(self) -> None:
101+
self._run_dp_test(False)
102+
103+
def test_dp_test_input_json(self) -> None:
104+
self._run_dp_test(True)
105+
96106
def tearDown(self) -> None:
97107
for f in os.listdir("."):
98108
if f.startswith("model") and f.endswith(".pt"):
@@ -140,6 +150,61 @@ def setUp(self) -> None:
140150
json.dump(self.config, fp, indent=4)
141151

142152

153+
class TestDPTestSeARglob(unittest.TestCase):
154+
def setUp(self) -> None:
155+
self.detail_file = "test_dp_test_ener_rglob_detail"
156+
input_json = str(Path(__file__).parent / "water/se_atten.json")
157+
with open(input_json) as f:
158+
self.config = json.load(f)
159+
self.config["training"]["numb_steps"] = 1
160+
self.config["training"]["save_freq"] = 1
161+
data_file = [str(Path(__file__).parent / "water/data/single")]
162+
self.config["training"]["training_data"]["systems"] = data_file
163+
root_dir = str(Path(__file__).parent)
164+
self.config["training"]["validation_data"]["systems"] = root_dir
165+
self.config["training"]["validation_data"]["rglob_patterns"] = [
166+
"water/data/single"
167+
]
168+
self.config["model"] = deepcopy(model_se_e2_a)
169+
self.input_json = "test_dp_test_rglob.json"
170+
with open(self.input_json, "w") as fp:
171+
json.dump(self.config, fp, indent=4)
172+
173+
def test_dp_test_input_json_rglob(self) -> None:
174+
trainer = get_trainer(deepcopy(self.config))
175+
with torch.device("cpu"):
176+
input_dict, _, _ = trainer.get_data(is_train=False)
177+
input_dict.pop("spin", None)
178+
model = torch.jit.script(trainer.model)
179+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
180+
torch.jit.save(model, tmp_model.name)
181+
dp_test(
182+
model=tmp_model.name,
183+
system=self.config["training"]["validation_data"]["systems"],
184+
datafile=None,
185+
input_json=self.input_json,
186+
set_prefix="set",
187+
numb_test=1,
188+
rand_seed=None,
189+
shuffle_test=False,
190+
detail_file=self.detail_file,
191+
atomic=False,
192+
)
193+
os.unlink(tmp_model.name)
194+
self.assertTrue(os.path.exists(self.detail_file + ".e.out"))
195+
196+
def tearDown(self) -> None:
197+
for f in os.listdir("."):
198+
if f.startswith("model") and f.endswith(".pt"):
199+
os.remove(f)
200+
if f.startswith(self.detail_file):
201+
os.remove(f)
202+
if f in ["lcurve.out", self.input_json]:
203+
os.remove(f)
204+
if f in ["stat_files"]:
205+
shutil.rmtree(f)
206+
207+
143208
class TestDPTestPropertySeA(unittest.TestCase):
144209
def setUp(self) -> None:
145210
self.detail_file = "test_dp_test_property_detail"

0 commit comments

Comments
 (0)