Skip to content

Commit 4e1cbb2

Browse files
committed
feat: allow testing with training data
1 parent 43504d3 commit 4e1cbb2

3 files changed

Lines changed: 178 additions & 5 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from deepmd.common import (
1616
expand_sys_str,
17+
j_loader,
1718
)
1819
from deepmd.infer.deep_dipole import (
1920
DeepDipole,
@@ -41,6 +42,9 @@
4142
from deepmd.utils.data import (
4243
DeepmdData,
4344
)
45+
from deepmd.utils.data_system import (
46+
process_systems,
47+
)
4448
from deepmd.utils.weight_avg import (
4549
weighted_average,
4650
)
@@ -60,6 +64,8 @@ def test(
6064
model: str,
6165
system: str,
6266
datafile: str,
67+
input_json: Optional[str] = None,
68+
use_train: bool = False,
6369
numb_test: int,
6470
rand_seed: Optional[int],
6571
shuffle_test: bool,
@@ -78,6 +84,10 @@ def test(
7884
system directory
7985
datafile : str
8086
the path to the list of systems to test
87+
input_json : Optional[str]
88+
the training input json file. Validation systems in this file will be used.
89+
use_train : bool
90+
use training systems in the input json file instead of validation systems
8191
numb_test : int
8292
munber of tests to do. 0 means all data.
8393
rand_seed : Optional[int]
@@ -101,7 +111,25 @@ def test(
101111
if numb_test == 0:
102112
# only float has inf, but should work for min
103113
numb_test = float("inf")
104-
if datafile is not None:
114+
if input_json is not None:
115+
jdata = j_loader(input_json)
116+
data_key = "training_data" if use_train else "validation_data"
117+
data_params = jdata.get("training", {}).get(data_key, {})
118+
systems = data_params.get("systems")
119+
if not systems:
120+
raise RuntimeError(
121+
f"No {'training' if use_train else 'validation'} data found in input json"
122+
)
123+
root = Path(input_json).parent
124+
if isinstance(systems, str):
125+
systems = str((root / Path(systems)).resolve())
126+
else:
127+
systems = [str((root / Path(ss)).resolve()) for ss in systems]
128+
patterns = data_params.get("rglob_patterns", None)
129+
all_sys = process_systems(systems, patterns=patterns)
130+
elif use_train:
131+
raise RuntimeError("--train-data requires --input-json")
132+
elif datafile is not None:
105133
with open(datafile) as datalist:
106134
all_sys = datalist.read().splitlines()
107135
else:

deepmd/main.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def main_parser() -> argparse.ArgumentParser:
356356
type=str,
357357
help="Frozen model file (prefix) to import. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth.",
358358
)
359-
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group()
359+
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group(required=True)
360360
parser_tst_subgroup.add_argument(
361361
"-s",
362362
"--system",
@@ -371,6 +371,19 @@ def main_parser() -> argparse.ArgumentParser:
371371
type=str,
372372
help="The path to the datafile, each line of which is a path to one data system.",
373373
)
374+
parser_tst_subgroup.add_argument(
375+
"-i",
376+
"--input-json",
377+
default=None,
378+
type=str,
379+
help="The training input json file. Validation data in the training script will be used for testing.",
380+
)
381+
parser_tst.add_argument(
382+
"--train-data",
383+
action="store_true",
384+
dest="use_train",
385+
help="Use training data in the input json file instead of validation data.",
386+
)
374387
parser_tst.add_argument(
375388
"-S",
376389
"--set-prefix",

source/tests/pt/test_dp_test.py

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131

3232
class DPTest:
33-
def test_dp_test_1_frame(self) -> None:
33+
def _run_dp_test(
34+
self, use_input_json: bool, numb_test: int = 0, use_train: bool = False
35+
) -> None:
3436
trainer = get_trainer(deepcopy(self.config))
3537
with torch.device("cpu"):
3638
input_dict, label_dict, _ = trainer.get_data(is_train=False)
@@ -44,12 +46,17 @@ def test_dp_test_1_frame(self) -> None:
4446
model = torch.jit.script(trainer.model)
4547
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
4648
torch.jit.save(model, tmp_model.name)
49+
val_sys = self.config["training"]["validation_data"]["systems"]
50+
if isinstance(val_sys, list):
51+
val_sys = val_sys[0]
4752
dp_test(
4853
model=tmp_model.name,
49-
system=self.config["training"]["validation_data"]["systems"][0],
54+
system=val_sys,
5055
datafile=None,
56+
input_json=self.input_json if use_input_json else None,
57+
use_train=use_train,
5158
set_prefix="set",
52-
numb_test=0,
59+
numb_test=numb_test,
5360
rand_seed=None,
5461
shuffle_test=False,
5562
detail_file=self.detail_file,
@@ -93,6 +100,20 @@ def test_dp_test_1_frame(self) -> None:
93100
).reshape(-1, 3),
94101
)
95102

103+
def test_dp_test_1_frame(self) -> None:
104+
self._run_dp_test(False)
105+
106+
def test_dp_test_input_json(self) -> None:
107+
self._run_dp_test(True)
108+
109+
def test_dp_test_input_json_train(self) -> None:
110+
with open(self.input_json) as f:
111+
cfg = json.load(f)
112+
cfg["training"]["validation_data"]["systems"] = ["non-existent"]
113+
with open(self.input_json, "w") as f:
114+
json.dump(cfg, f, indent=4)
115+
self._run_dp_test(True, use_train=True)
116+
96117
def tearDown(self) -> None:
97118
for f in os.listdir("."):
98119
if f.startswith("model") and f.endswith(".pt"):
@@ -140,6 +161,117 @@ def setUp(self) -> None:
140161
json.dump(self.config, fp, indent=4)
141162

142163

164+
class TestDPTestSeARglob(unittest.TestCase):
165+
def setUp(self) -> None:
166+
self.detail_file = "test_dp_test_ener_rglob_detail"
167+
input_json = str(Path(__file__).parent / "water/se_atten.json")
168+
with open(input_json) as f:
169+
self.config = json.load(f)
170+
self.config["training"]["numb_steps"] = 1
171+
self.config["training"]["save_freq"] = 1
172+
data_file = [str(Path(__file__).parent / "water/data/single")]
173+
self.config["training"]["training_data"]["systems"] = data_file
174+
root_dir = str(Path(__file__).parent)
175+
self.config["training"]["validation_data"]["systems"] = root_dir
176+
self.config["training"]["validation_data"]["rglob_patterns"] = [
177+
"water/data/single"
178+
]
179+
self.config["model"] = deepcopy(model_se_e2_a)
180+
self.input_json = "test_dp_test_rglob.json"
181+
with open(self.input_json, "w") as fp:
182+
json.dump(self.config, fp, indent=4)
183+
184+
def test_dp_test_input_json_rglob(self) -> None:
185+
trainer = get_trainer(deepcopy(self.config))
186+
with torch.device("cpu"):
187+
input_dict, _, _ = trainer.get_data(is_train=False)
188+
input_dict.pop("spin", None)
189+
model = torch.jit.script(trainer.model)
190+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
191+
torch.jit.save(model, tmp_model.name)
192+
dp_test(
193+
model=tmp_model.name,
194+
system=self.config["training"]["validation_data"]["systems"],
195+
datafile=None,
196+
input_json=self.input_json,
197+
set_prefix="set",
198+
numb_test=1,
199+
rand_seed=None,
200+
shuffle_test=False,
201+
detail_file=self.detail_file,
202+
atomic=False,
203+
)
204+
os.unlink(tmp_model.name)
205+
self.assertTrue(os.path.exists(self.detail_file + ".e.out"))
206+
207+
def tearDown(self) -> None:
208+
for f in os.listdir("."):
209+
if f.startswith("model") and f.endswith(".pt"):
210+
os.remove(f)
211+
if f.startswith(self.detail_file):
212+
os.remove(f)
213+
if f in ["lcurve.out", self.input_json]:
214+
os.remove(f)
215+
if f in ["stat_files"]:
216+
shutil.rmtree(f)
217+
218+
219+
class TestDPTestSeARglobTrain(unittest.TestCase):
220+
def setUp(self) -> None:
221+
self.detail_file = "test_dp_test_ener_rglob_train_detail"
222+
input_json = str(Path(__file__).parent / "water/se_atten.json")
223+
with open(input_json) as f:
224+
self.config = json.load(f)
225+
self.config["training"]["numb_steps"] = 1
226+
self.config["training"]["save_freq"] = 1
227+
root_dir = str(Path(__file__).parent)
228+
self.config["training"]["training_data"]["systems"] = root_dir
229+
self.config["training"]["training_data"]["rglob_patterns"] = [
230+
"water/data/single"
231+
]
232+
data_file = [str(Path(__file__).parent / "water/data/single")]
233+
self.config["training"]["validation_data"]["systems"] = data_file
234+
self.config["model"] = deepcopy(model_se_e2_a)
235+
self.input_json = "test_dp_test_rglob_train.json"
236+
with open(self.input_json, "w") as fp:
237+
json.dump(self.config, fp, indent=4)
238+
239+
def test_dp_test_input_json_rglob_train(self) -> None:
240+
trainer = get_trainer(deepcopy(self.config))
241+
with torch.device("cpu"):
242+
input_dict, _, _ = trainer.get_data(is_train=False)
243+
input_dict.pop("spin", None)
244+
model = torch.jit.script(trainer.model)
245+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
246+
torch.jit.save(model, tmp_model.name)
247+
dp_test(
248+
model=tmp_model.name,
249+
system=self.config["training"]["validation_data"]["systems"],
250+
datafile=None,
251+
input_json=self.input_json,
252+
use_train=True,
253+
set_prefix="set",
254+
numb_test=1,
255+
rand_seed=None,
256+
shuffle_test=False,
257+
detail_file=self.detail_file,
258+
atomic=False,
259+
)
260+
os.unlink(tmp_model.name)
261+
self.assertTrue(os.path.exists(self.detail_file + ".e.out"))
262+
263+
def tearDown(self) -> None:
264+
for f in os.listdir("."):
265+
if f.startswith("model") and f.endswith(".pt"):
266+
os.remove(f)
267+
if f.startswith(self.detail_file):
268+
os.remove(f)
269+
if f in ["lcurve.out", self.input_json]:
270+
os.remove(f)
271+
if f in ["stat_files"]:
272+
shutil.rmtree(f)
273+
274+
143275
class TestDPTestPropertySeA(unittest.TestCase):
144276
def setUp(self) -> None:
145277
self.detail_file = "test_dp_test_property_detail"

0 commit comments

Comments
 (0)