Skip to content

Commit 7cd6be1

Browse files
committed
refactor: streamline process_systems function and improve input handling
1 parent 583a435 commit 7cd6be1

1 file changed

Lines changed: 21 additions & 17 deletions

File tree

deepmd/utils/data_system.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def process_systems(
790790
"""Process the user-input systems.
791791
792792
If it is a single directory, search for all the systems in the directory.
793-
If it is a list, each item can be either a system path or a directory to search.
793+
If it is a list, each item in the list is treated as a directory to search.
794794
Check if the systems are valid.
795795
796796
Parameters
@@ -802,27 +802,31 @@ def process_systems(
802802
803803
Returns
804804
-------
805-
list of str
805+
result_systems: list of str
806806
The valid systems
807807
"""
808+
# Normalize input to a list of paths to search
808809
if isinstance(systems, str):
809-
if patterns is None:
810-
systems = expand_sys_str(systems)
811-
else:
812-
systems = rglob_sys_str(systems, patterns)
810+
search_paths = [systems]
813811
elif isinstance(systems, list):
814-
result_systems = []
815-
for system in systems:
816-
if isinstance(system, str):
817-
# Try to expand as directory
818-
expanded = expand_sys_str(system)
819-
result_systems.extend(expanded)
820-
else:
821-
result_systems.append(system)
822-
systems = result_systems
812+
search_paths = systems
823813
else:
824-
raise ValueError(f"Invalid systems: {systems}")
825-
return systems
814+
# Handle unsupported input types
815+
raise ValueError(
816+
f"Invalid systems type: {type(systems)}. Must be str or list[str]."
817+
)
818+
819+
# Iterate over the search_paths list and apply
820+
result_systems = []
821+
for path in search_paths:
822+
if patterns is None:
823+
expanded_paths = expand_sys_str(path)
824+
else:
825+
expanded_paths = rglob_sys_str(path, patterns)
826+
827+
result_systems.extend(expanded_paths)
828+
829+
return result_systems
826830

827831

828832
def get_data(

0 commit comments

Comments
 (0)