diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index b4c88c5715..ee975e26c3 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -65,10 +65,10 @@ def test( *, model: str, - system: str, - datafile: str, - input_json: Optional[str] = None, - use_train: bool = False, + system: Optional[str], + datafile: Optional[str], + train_json: Optional[str] = None, + valid_json: Optional[str] = None, numb_test: int, rand_seed: Optional[int], shuffle_test: bool, @@ -83,16 +83,16 @@ def test( ---------- model : str path where model is stored - system : str + system : str, optional system directory - datafile : str + datafile : str, optional the path to the list of systems to test - input_json : Optional[str] - the training input.json file. Validation systems will be used if use_train is False. - use_train : bool - use training systems in the input.json file instead of validation systems + train_json : Optional[str] + Path to the input.json file provided via ``--train-data``. Training systems will be used for testing. + valid_json : Optional[str] + Path to the input.json file provided via ``--valid-data``. Validation systems will be used for testing. numb_test : int - munber of tests to do. 0 means all data. + number of tests to do. 0 means all data. rand_seed : Optional[int] seed for random generator shuffle_test : bool @@ -114,30 +114,41 @@ def test( if numb_test == 0: # only float has inf, but should work for min numb_test = float("inf") - if input_json is not None: - jdata = j_loader(input_json) + if train_json is not None: + jdata = j_loader(train_json) jdata = update_deepmd_input(jdata) - data_key = "training_data" if use_train else "validation_data" - data_params = jdata.get("training", {}).get(data_key, {}) + data_params = jdata.get("training", {}).get("training_data", {}) systems = data_params.get("systems") if not systems: - raise RuntimeError( - f"No {'training' if use_train else 'validation'} data found in input json" - ) - root = Path(input_json).parent + raise RuntimeError("No training data found in input json") + root = Path(train_json).parent + if isinstance(systems, str): + systems = str((root / Path(systems)).resolve()) + else: + systems = [str((root / Path(ss)).resolve()) for ss in systems] + patterns = data_params.get("rglob_patterns", None) + all_sys = process_systems(systems, patterns=patterns) + elif valid_json is not None: + jdata = j_loader(valid_json) + jdata = update_deepmd_input(jdata) + data_params = jdata.get("training", {}).get("validation_data", {}) + systems = data_params.get("systems") + if not systems: + raise RuntimeError("No validation data found in input json") + root = Path(valid_json).parent if isinstance(systems, str): systems = str((root / Path(systems)).resolve()) else: systems = [str((root / Path(ss)).resolve()) for ss in systems] patterns = data_params.get("rglob_patterns", None) all_sys = process_systems(systems, patterns=patterns) - elif use_train: - raise RuntimeError("--train-data requires --input-json") elif datafile is not None: with open(datafile) as datalist: all_sys = datalist.read().splitlines() - else: + elif system is not None: all_sys = expand_sys_str(system) + else: + raise RuntimeError("No data source specified for testing") if len(all_sys) == 0: raise RuntimeError("Did not find valid system") diff --git a/deepmd/main.py b/deepmd/main.py index 4a748a0f35..389ee10dfc 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -372,18 +372,22 @@ def main_parser() -> argparse.ArgumentParser: help="The path to the datafile, each line of which is a path to one data system.", ) parser_tst_subgroup.add_argument( - "-i", - "--input-json", + "--train-data", + dest="train_json", default=None, type=str, - help="The training input json file. Validation data in the training script will be used for testing.", + help=( + "The input json file. Training data in the file will be used for testing." + ), ) - parser_tst.add_argument( - "--train-data", - "--train", - action="store_true", - dest="use_train", - help="Use training data in the input json file instead of validation data.", + parser_tst_subgroup.add_argument( + "--valid-data", + dest="valid_json", + default=None, + type=str, + help=( + "The input json file. Validation data in the file will be used for testing." + ), ) parser_tst.add_argument( "-S", diff --git a/source/tests/common/test_argument_parser.py b/source/tests/common/test_argument_parser.py index 4e39df8659..4aebb7dafc 100644 --- a/source/tests/common/test_argument_parser.py +++ b/source/tests/common/test_argument_parser.py @@ -322,6 +322,32 @@ def test_parser_test(self) -> None: self.run_test(command="test", mapping=ARGS) + def test_parser_test_train_data(self) -> None: + """Test test subparser with train-data.""" + ARGS = { + "--model": {"type": str, "value": "MODEL.PB"}, + "--train-data": { + "type": (str, type(None)), + "value": "INPUT.JSON", + "dest": "train_json", + }, + } + + self.run_test(command="test", mapping=ARGS) + + def test_parser_test_valid_data(self) -> None: + """Test test subparser with valid-data.""" + ARGS = { + "--model": {"type": str, "value": "MODEL.PB"}, + "--valid-data": { + "type": (str, type(None)), + "value": "INPUT.JSON", + "dest": "valid_json", + }, + } + + self.run_test(command="test", mapping=ARGS) + def test_parser_compress(self) -> None: """Test compress subparser.""" ARGS = { diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index ca042fc5a3..5a8674e276 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -51,10 +51,10 @@ def _run_dp_test( val_sys = val_sys[0] dp_test( model=tmp_model.name, - system=val_sys, + system=None if use_input_json else val_sys, datafile=None, - input_json=self.input_json if use_input_json else None, - use_train=use_train, + train_json=self.input_json if use_input_json and use_train else None, + valid_json=self.input_json if use_input_json and not use_train else None, set_prefix="set", numb_test=numb_test, rand_seed=None, @@ -191,9 +191,9 @@ def test_dp_test_input_json_rglob(self) -> None: torch.jit.save(model, tmp_model.name) dp_test( model=tmp_model.name, - system=self.config["training"]["validation_data"]["systems"], + system=None, datafile=None, - input_json=self.input_json, + valid_json=self.input_json, set_prefix="set", numb_test=1, rand_seed=None, @@ -246,10 +246,9 @@ def test_dp_test_input_json_rglob_train(self) -> None: torch.jit.save(model, tmp_model.name) dp_test( model=tmp_model.name, - system=self.config["training"]["validation_data"]["systems"], + system=None, datafile=None, - input_json=self.input_json, - use_train=True, + train_json=self.input_json, set_prefix="set", numb_test=1, rand_seed=None,