diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index ae4a484909..26a27c82d7 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -14,6 +14,7 @@ from typing import ( Any, Optional, + Union, ) import numpy as np @@ -135,8 +136,7 @@ def __init__( self.shuffle_test = shuffle_test # set modifier self.modifier = modifier - # calculate prefix sum for get_item method - frames_list = [self._get_nframes(item) for item in self.dirs] + frames_list = [self._get_nframes(set_name) for set_name in self.dirs] self.nframes = np.sum(frames_list) # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() @@ -338,8 +338,10 @@ def get_numb_set(self) -> int: def get_numb_batch(self, batch_size: int, set_idx: int) -> int: """Get the number of batches in a set.""" - data = self._load_set(self.dirs[set_idx]) - ret = data["coord"].shape[0] // batch_size + set_name = self.dirs[set_idx] + # Directly obtain the number of frames to avoid loading the entire dataset + nframes = self._get_nframes(set_name) + ret = nframes // batch_size if ret == 0: ret = 1 return ret @@ -578,18 +580,27 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: ret[kk] = data[kk] return ret, idx - def _get_nframes(self, set_name: DPPath) -> int: - # get nframes + def _get_nframes(self, set_name: Union[DPPath, str]) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) path = set_name / "coord.npy" - if self.data_dict["coord"]["high_prec"]: - coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) + if isinstance(set_name, DPH5Path): + nframes = path.root[path._name].shape[0] else: - coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) - if coord.ndim == 1: - coord = coord.reshape([1, -1]) - nframes = coord.shape[0] + # Read only the header to get shape + with open(str(path), "rb") as f: + version = np.lib.format.read_magic(f) + if version[0] == 1: + shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0( + f + ) + elif version[0] in [2, 3]: + shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0( + f + ) + else: + raise ValueError(f"Unsupported .npy file version: {version}") + nframes = shape[0] if len(shape) > 1 else 1 return nframes def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: