Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@
def test(
*,
model: str,
system: str,
datafile: str,
input_json: Optional[str] = None,
use_train: bool = False,
system: Optional[str],
datafile: Optional[str],
train_json: Optional[str] = None,
valid_json: Optional[str] = None,
numb_test: int,
rand_seed: Optional[int],
shuffle_test: bool,
Expand All @@ -83,16 +83,16 @@ def test(
----------
model : str
path where model is stored
system : str
system : str, optional
system directory
datafile : str
datafile : str, optional
the path to the list of systems to test
input_json : Optional[str]
the training input.json file. Validation systems will be used if use_train is False.
use_train : bool
use training systems in the input.json file instead of validation systems
train_json : Optional[str]
Path to the input.json file provided via ``--train-data``. Training systems will be used for testing.
valid_json : Optional[str]
Path to the input.json file provided via ``--valid-data``. Validation systems will be used for testing.
numb_test : int
munber of tests to do. 0 means all data.
number of tests to do. 0 means all data.
rand_seed : Optional[int]
seed for random generator
shuffle_test : bool
Expand All @@ -114,30 +114,41 @@ def test(
if numb_test == 0:
# only float has inf, but should work for min
numb_test = float("inf")
if input_json is not None:
jdata = j_loader(input_json)
if train_json is not None:
jdata = j_loader(train_json)
jdata = update_deepmd_input(jdata)
data_key = "training_data" if use_train else "validation_data"
data_params = jdata.get("training", {}).get(data_key, {})
data_params = jdata.get("training", {}).get("training_data", {})
systems = data_params.get("systems")
if not systems:
raise RuntimeError(
f"No {'training' if use_train else 'validation'} data found in input json"
)
root = Path(input_json).parent
raise RuntimeError("No training data found in input json")
root = Path(train_json).parent
if isinstance(systems, str):
systems = str((root / Path(systems)).resolve())
else:
systems = [str((root / Path(ss)).resolve()) for ss in systems]
patterns = data_params.get("rglob_patterns", None)
all_sys = process_systems(systems, patterns=patterns)
elif valid_json is not None:
jdata = j_loader(valid_json)
jdata = update_deepmd_input(jdata)
data_params = jdata.get("training", {}).get("validation_data", {})
systems = data_params.get("systems")
if not systems:
raise RuntimeError("No validation data found in input json")
root = Path(valid_json).parent
if isinstance(systems, str):
systems = str((root / Path(systems)).resolve())
else:
systems = [str((root / Path(ss)).resolve()) for ss in systems]
patterns = data_params.get("rglob_patterns", None)
all_sys = process_systems(systems, patterns=patterns)
elif use_train:
raise RuntimeError("--train-data requires --input-json")
elif datafile is not None:
with open(datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
elif system is not None:
all_sys = expand_sys_str(system)
else:
raise RuntimeError("No data source specified for testing")

if len(all_sys) == 0:
raise RuntimeError("Did not find valid system")
Expand Down
22 changes: 13 additions & 9 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,22 @@ def main_parser() -> argparse.ArgumentParser:
help="The path to the datafile, each line of which is a path to one data system.",
)
parser_tst_subgroup.add_argument(
"-i",
"--input-json",
"--train-data",
dest="train_json",
default=None,
type=str,
help="The training input json file. Validation data in the training script will be used for testing.",
help=(
"The input json file. Training data in the file will be used for testing."
),
)
parser_tst.add_argument(
"--train-data",
"--train",
action="store_true",
dest="use_train",
help="Use training data in the input json file instead of validation data.",
parser_tst_subgroup.add_argument(
"--valid-data",
dest="valid_json",
default=None,
type=str,
help=(
"The input json file. Validation data in the file will be used for testing."
),
)
parser_tst.add_argument(
"-S",
Expand Down
26 changes: 26 additions & 0 deletions source/tests/common/test_argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,32 @@ def test_parser_test(self) -> None:

self.run_test(command="test", mapping=ARGS)

def test_parser_test_train_data(self) -> None:
"""Test test subparser with train-data."""
ARGS = {
"--model": {"type": str, "value": "MODEL.PB"},
"--train-data": {
"type": (str, type(None)),
"value": "INPUT.JSON",
"dest": "train_json",
},
}

self.run_test(command="test", mapping=ARGS)

def test_parser_test_valid_data(self) -> None:
"""Test test subparser with valid-data."""
ARGS = {
"--model": {"type": str, "value": "MODEL.PB"},
"--valid-data": {
"type": (str, type(None)),
"value": "INPUT.JSON",
"dest": "valid_json",
},
}

self.run_test(command="test", mapping=ARGS)

def test_parser_compress(self) -> None:
"""Test compress subparser."""
ARGS = {
Expand Down
15 changes: 7 additions & 8 deletions source/tests/pt/test_dp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def _run_dp_test(
val_sys = val_sys[0]
dp_test(
model=tmp_model.name,
system=val_sys,
system=None if use_input_json else val_sys,
datafile=None,
input_json=self.input_json if use_input_json else None,
use_train=use_train,
train_json=self.input_json if use_input_json and use_train else None,
valid_json=self.input_json if use_input_json and not use_train else None,
set_prefix="set",
numb_test=numb_test,
rand_seed=None,
Expand Down Expand Up @@ -191,9 +191,9 @@ def test_dp_test_input_json_rglob(self) -> None:
torch.jit.save(model, tmp_model.name)
dp_test(
model=tmp_model.name,
system=self.config["training"]["validation_data"]["systems"],
system=None,
datafile=None,
input_json=self.input_json,
valid_json=self.input_json,
set_prefix="set",
numb_test=1,
rand_seed=None,
Expand Down Expand Up @@ -246,10 +246,9 @@ def test_dp_test_input_json_rglob_train(self) -> None:
torch.jit.save(model, tmp_model.name)
dp_test(
model=tmp_model.name,
system=self.config["training"]["validation_data"]["systems"],
system=None,
datafile=None,
input_json=self.input_json,
use_train=True,
train_json=self.input_json,
set_prefix="set",
numb_test=1,
rand_seed=None,
Expand Down
Loading