Skip to content

Commit 108393f

Browse files
Copilotanyangml
andcommitted
feat(data): support HDF5 multisystem files in training data lists
Co-authored-by: anyangml <137014849+anyangml@users.noreply.github.com>
1 parent be740ab commit 108393f

4 files changed

Lines changed: 358 additions & 6 deletions

File tree

deepmd/pd/utils/dataloader.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,22 @@ def __init__(
9191
if seed is not None:
9292
setup_seed(seed)
9393
if isinstance(systems, str):
94-
with h5py.File(systems) as file:
95-
systems = [os.path.join(systems, item) for item in file.keys()]
94+
# Check if this is a multisystem HDF5 file that should be expanded
95+
try:
96+
with h5py.File(systems, "r") as file:
97+
# Check if this looks like a single system (has type.raw and set.* groups)
98+
has_type_raw = "type.raw" in file
99+
has_sets = any(key.startswith("set.") for key in file.keys())
100+
101+
if has_type_raw and has_sets:
102+
# This is a single system HDF5 file, don't expand
103+
systems = [systems]
104+
else:
105+
# This might be a multisystem file, expand it
106+
systems = [f"{systems}#{item}" for item in file.keys()]
107+
except OSError:
108+
# If we can't read as HDF5, treat as regular path
109+
systems = [systems]
96110

97111
self.systems: list[DeepmdDataSetForLoader] = []
98112
if len(systems) >= 100:

deepmd/pt/utils/dataloader.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3-
import os
43
from multiprocessing.dummy import (
54
Pool,
65
)
@@ -89,8 +88,22 @@ def __init__(
8988
if seed is not None:
9089
setup_seed(seed)
9190
if isinstance(systems, str):
92-
with h5py.File(systems) as file:
93-
systems = [os.path.join(systems, item) for item in file.keys()]
91+
# Check if this is a multisystem HDF5 file that should be expanded
92+
try:
93+
with h5py.File(systems, "r") as file:
94+
# Check if this looks like a single system (has type.raw and set.* groups)
95+
has_type_raw = "type.raw" in file
96+
has_sets = any(key.startswith("set.") for key in file.keys())
97+
98+
if has_type_raw and has_sets:
99+
# This is a single system HDF5 file, don't expand
100+
systems = [systems]
101+
else:
102+
# This might be a multisystem file, expand it
103+
systems = [f"{systems}#{item}" for item in file.keys()]
104+
except OSError:
105+
# If we can't read as HDF5, treat as regular path
106+
systems = [systems]
94107

95108
def construct_dataset(system: str) -> DeepmdDataSetForLoader:
96109
return DeepmdDataSetForLoader(

deepmd/utils/data_system.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import collections
33
import logging
4+
import os
45
import warnings
56
from functools import (
67
cached_property,
@@ -11,6 +12,7 @@
1112
Union,
1213
)
1314

15+
import h5py
1416
import numpy as np
1517

1618
import deepmd.utils.random as dp_random
@@ -790,6 +792,7 @@ def process_systems(
790792
"""Process the user-input systems.
791793
792794
If it is a single directory, search for all the systems in the directory.
795+
If it's a list, handle HDF5 files by expanding their internal systems.
793796
Check if the systems are valid.
794797
795798
Parameters
@@ -810,10 +813,143 @@ def process_systems(
810813
else:
811814
systems = rglob_sys_str(systems, patterns)
812815
elif isinstance(systems, list):
813-
systems = systems.copy()
816+
expanded_systems = []
817+
for system in systems:
818+
# Check if this is an HDF5 file without explicit system specification
819+
if _is_hdf5_file(system) and "#" not in system:
820+
# Only expand if it's a multisystem HDF5 file
821+
if _is_hdf5_multisystem(system):
822+
# Expand HDF5 file to include all systems within it
823+
try:
824+
with h5py.File(system, "r") as file:
825+
for key in file.keys():
826+
if isinstance(file[key], h5py.Group):
827+
# Check if this group looks like a system
828+
group = file[key]
829+
group_has_type = "type.raw" in group
830+
group_has_sets = any(
831+
subkey.startswith("set.")
832+
for subkey in group.keys()
833+
)
834+
if group_has_type and group_has_sets:
835+
expanded_systems.append(f"{system}#{key}")
836+
except OSError as e:
837+
log.warning(f"Could not read HDF5 file {system}: {e}")
838+
# If we can't read as HDF5, treat as regular system
839+
expanded_systems.append(system)
840+
else:
841+
# Single system HDF5 file, don't expand
842+
expanded_systems.append(system)
843+
else:
844+
# Regular system or HDF5 with explicit system specification
845+
expanded_systems.append(system)
846+
systems = expanded_systems
814847
return systems
815848

816849

850+
def _is_hdf5_file(path: str) -> bool:
851+
"""Check if a path points to an HDF5 file.
852+
853+
Parameters
854+
----------
855+
path : str
856+
Path to check
857+
858+
Returns
859+
-------
860+
bool
861+
True if the path is an HDF5 file
862+
"""
863+
# Extract the actual file path (before any # separator for HDF5 internal paths)
864+
file_path = path.split("#")[0]
865+
return os.path.isfile(file_path) and (
866+
file_path.endswith((".h5", ".hdf5")) or _is_hdf5_format(file_path)
867+
)
868+
869+
870+
def _is_hdf5_multisystem(file_path: str) -> bool:
871+
"""Check if an HDF5 file contains multiple systems vs being a single system.
872+
873+
Parameters
874+
----------
875+
file_path : str
876+
Path to the HDF5 file
877+
878+
Returns
879+
-------
880+
bool
881+
True if the file contains multiple systems, False if it's a single system
882+
"""
883+
try:
884+
with h5py.File(file_path, "r") as f:
885+
# Check if this looks like a single system (has type.raw and set.* groups)
886+
has_type_raw = "type.raw" in f
887+
has_sets = any(key.startswith("set.") for key in f.keys())
888+
889+
if has_type_raw and has_sets:
890+
# This looks like a single system
891+
return False
892+
893+
# Check if it contains multiple groups that could be systems
894+
system_groups = []
895+
for key in f.keys():
896+
if isinstance(f[key], h5py.Group):
897+
group = f[key]
898+
# Check if this group looks like a system (has type.raw and sets)
899+
group_has_type = "type.raw" in group
900+
group_has_sets = any(
901+
subkey.startswith("set.") for subkey in group.keys()
902+
)
903+
if group_has_type and group_has_sets:
904+
system_groups.append(key)
905+
906+
# If we found multiple system-like groups, it's a multisystem file
907+
return len(system_groups) > 1
908+
909+
except OSError:
910+
return False
911+
912+
913+
def _is_hdf5_file(path: str) -> bool:
914+
"""Check if a path points to an HDF5 file.
915+
916+
Parameters
917+
----------
918+
path : str
919+
Path to check
920+
921+
Returns
922+
-------
923+
bool
924+
True if the path is an HDF5 file
925+
"""
926+
# Extract the actual file path (before any # separator for HDF5 internal paths)
927+
file_path = path.split("#")[0]
928+
return os.path.isfile(file_path) and (
929+
file_path.endswith((".h5", ".hdf5")) or _is_hdf5_format(file_path)
930+
)
931+
932+
933+
def _is_hdf5_format(file_path: str) -> bool:
934+
"""Check if a file is in HDF5 format by trying to open it.
935+
936+
Parameters
937+
----------
938+
file_path : str
939+
Path to the file
940+
941+
Returns
942+
-------
943+
bool
944+
True if the file is in HDF5 format
945+
"""
946+
try:
947+
with h5py.File(file_path, "r"):
948+
return True
949+
except OSError:
950+
return False
951+
952+
817953
def get_data(
818954
jdata: dict[str, Any],
819955
rcut: float,

0 commit comments

Comments
 (0)