Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,30 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]:
return matches


def rglob_sys_str(root_dir: str, patterns: list[str]) -> list[str]:
"""Recursively iterate over directories taking those that contain `type.raw` file.

Parameters
----------
root_dir : str, Path
starting directory
patterns : list[str]
list of glob patterns to match directories

Returns
-------
list[str]
list of string pointing to system directories
"""
root_dir = Path(root_dir)
matches = []
for pattern in patterns:
matches.extend(
[str(d) for d in root_dir.rglob(pattern) if (d / "type.raw").is_file()]
Comment thread
anyangml marked this conversation as resolved.
)
return list(set(matches)) # remove duplicates
Comment thread
anyangml marked this conversation as resolved.


def get_np_precision(precision: "_PRECISION") -> np.dtype:
"""Get numpy precision constant from string.

Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def prepare_trainer_input_single(
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
trn_patterns = training_dataset_params.get("rglob_patterns", None)
training_systems = process_systems(training_systems, patterns=trn_patterns)
Comment thread
anyangml marked this conversation as resolved.
if validation_systems is not None:
validation_systems = process_systems(validation_systems)
val_patterns = validation_dataset_params.get("rglob_patterns", None)
validation_systems = process_systems(validation_systems, val_patterns)
Comment thread
anyangml marked this conversation as resolved.

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand Down
20 changes: 20 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,9 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
"This key can be provided with a list that specifies the systems, or be provided with a string "
"by which the prefix of all systems are given and the list of the systems is automatically generated."
)
doc_patterns = (
"The customized patterns used in `rglob` to collect all training systems. "
)
doc_batch_size = f'This key can be \n\n\
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
Expand All @@ -2949,6 +2952,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
Argument(
"systems", [list[str], str], optional=False, default=".", doc=doc_systems
),
Argument(
"rglob_patterns",
[list[str]],
optional=True,
default=None,
doc=doc_patterns + doc_only_pt_supported,
),
Argument(
"batch_size",
[list[int], int, str],
Expand Down Expand Up @@ -2995,6 +3005,9 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
"This key can be provided with a list that specifies the systems, or be provided with a string "
"by which the prefix of all systems are given and the list of the systems is automatically generated."
)
doc_patterns = (
"The customized patterns used in `rglob` to collect all validation systems. "
)
doc_batch_size = f'This key can be \n\n\
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
Expand All @@ -3015,6 +3028,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
Argument(
"systems", [list[str], str], optional=False, default=".", doc=doc_systems
),
Argument(
"rglob_patterns",
[list[str]],
optional=True,
default=None,
doc=doc_patterns + doc_only_pt_supported,
),
Argument(
"batch_size",
[list[int], int, str],
Expand Down
15 changes: 12 additions & 3 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from deepmd.common import (
expand_sys_str,
make_default_mesh,
rglob_sys_str,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -730,7 +731,9 @@ def prob_sys_size_ext(keywords, nsystems, nbatch):
return sys_probs


def process_systems(systems: Union[str, list[str]]) -> list[str]:
def process_systems(
systems: Union[str, list[str]], patterns: Optional[list[str]] = None
) -> list[str]:
"""Process the user-input systems.

If it is a single directory, search for all the systems in the directory.
Expand All @@ -740,14 +743,19 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]:
----------
systems : str or list of str
The user-input systems
patterns : list of str, optional
The patterns to match the systems, by default None

Returns
-------
list of str
The valid systems
"""
if isinstance(systems, str):
systems = expand_sys_str(systems)
if patterns is None:
systems = expand_sys_str(systems)
else:
systems = rglob_sys_str(systems, patterns)
elif isinstance(systems, list):
systems = systems.copy()
return systems
Expand Down Expand Up @@ -777,7 +785,8 @@ def get_data(
The data system
"""
systems = jdata["systems"]
systems = process_systems(systems)
rglob_patterns = jdata.get("rglob_patterns", None)
systems = process_systems(systems, patterns=rglob_patterns)

batch_size = jdata["batch_size"]
sys_probs = jdata.get("sys_probs", None)
Expand Down
23 changes: 23 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,5 +516,28 @@ def tearDown(self) -> None:
shutil.rmtree(f)


class TestCustomizedRGLOB(unittest.TestCase, DPTrainTest):
def setUp(self) -> None:
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["training_data"]["rglob_patterns"] = [
"water/data/data_*"
]
self.config["training"]["training_data"]["systems"] = str(Path(__file__).parent)
self.config["training"]["validation_data"]["rglob_patterns"] = [
"water/*/data_0"
]
self.config["training"]["validation_data"]["systems"] = str(
Path(__file__).parent
)
self.config["model"] = deepcopy(model_dpa1)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)


if __name__ == "__main__":
unittest.main()
Loading