Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,27 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]:
matches.append(str(root_dir))
return matches

def rglob_sys_str(root_dir: str, patterns: list[str]) -> list[str]:
"""Recursively iterate over directories taking those that contain `type.raw` file.

Parameters
----------
root_dir : str, Path
starting directory
patterns : list[str]
list of glob patterns to match directories

Returns
-------
list[str]
list of string pointing to system directories
"""
root_dir = Path(root_dir)
matches = []
for pattern in patterns:
matches.extend([str(d) for d in root_dir.rglob(pattern) if (d / "type.raw").is_file()])
return matches
Comment thread
anyangml marked this conversation as resolved.
Outdated


def get_np_precision(precision: "_PRECISION") -> np.dtype:
"""Get numpy precision constant from string.
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def prepare_trainer_input_single(
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
trn_patterns = training_dataset_params.get("rglob_patterns", None)
training_systems = process_systems(training_systems, patterns=trn_patterns)
Comment thread
anyangml marked this conversation as resolved.
if validation_systems is not None:
validation_systems = process_systems(validation_systems)
val_patterns = validation_dataset_params.get("rglob_patterns", None)
validation_systems = process_systems(validation_systems, val_patterns)
Comment thread
anyangml marked this conversation as resolved.

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand Down
10 changes: 8 additions & 2 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import deepmd.utils.random as dp_random
from deepmd.common import (
expand_sys_str,
rglob_sys_str,
make_default_mesh,
)
from deepmd.env import (
Expand Down Expand Up @@ -730,7 +731,7 @@ def prob_sys_size_ext(keywords, nsystems, nbatch):
return sys_probs


def process_systems(systems: Union[str, list[str]]) -> list[str]:
def process_systems(systems: Union[str, list[str]], patterns: Optional[list[str]]=None) -> list[str]:
"""Process the user-input systems.

If it is a single directory, search for all the systems in the directory.
Expand All @@ -740,14 +741,19 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]:
----------
systems : str or list of str
The user-input systems
patterns : list of str, optional
The patterns to match the systems, by default None

Returns
-------
list of str
The valid systems
"""
if isinstance(systems, str):
systems = expand_sys_str(systems)
if patterns is None:
systems = expand_sys_str(systems)
else:
systems = rglob_sys_str(systems, patterns)
elif isinstance(systems, list):
systems = systems.copy()
return systems
Expand Down
Loading