Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import (
Any,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -135,8 +136,7 @@ def __init__(
self.shuffle_test = shuffle_test
# set modifier
self.modifier = modifier
# calculate prefix sum for get_item method
frames_list = [self._get_nframes(item) for item in self.dirs]
frames_list = [self._get_nframes(set_name) for set_name in self.dirs]
self.nframes = np.sum(frames_list)
# The prefix sum stores the range of indices contained in each directory, which is needed by get_item method
self.prefix_sum = np.cumsum(frames_list).tolist()
Expand Down Expand Up @@ -338,8 +338,10 @@ def get_numb_set(self) -> int:

def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
"""Get the number of batches in a set."""
data = self._load_set(self.dirs[set_idx])
ret = data["coord"].shape[0] // batch_size
set_name = self.dirs[set_idx]
# Directly obtain the number of frames to avoid loading the entire dataset
nframes = self._get_nframes(set_name)
ret = nframes // batch_size
if ret == 0:
ret = 1
return ret
Comment thread
OutisLi marked this conversation as resolved.
Expand Down Expand Up @@ -578,18 +580,27 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]:
ret[kk] = data[kk]
return ret, idx

def _get_nframes(self, set_name: DPPath) -> int:
# get nframes
def _get_nframes(self, set_name: Union[DPPath, str]) -> int:
if not isinstance(set_name, DPPath):
set_name = DPPath(set_name)
path = set_name / "coord.npy"
if self.data_dict["coord"]["high_prec"]:
coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
if isinstance(set_name, DPH5Path):
nframes = path.root[path._name].shape[0]
else:
coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
if coord.ndim == 1:
coord = coord.reshape([1, -1])
nframes = coord.shape[0]
# Read only the header to get shape
with open(str(path), "rb") as f:
version = np.lib.format.read_magic(f)
if version[0] == 1:
shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0(
f
)
elif version[0] in [2, 3]:
shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0(
f
)
else:
raise ValueError(f"Unsupported .npy file version: {version}")
nframes = shape[0] if len(shape) > 1 else 1
return nframes

def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]:
Expand Down