Skip to content

Commit 478f35e

Browse files
committed
minor update
1 parent ce7fe95 commit 478f35e

1 file changed

Lines changed: 20 additions & 13 deletions

File tree

deepmd/utils/data.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -589,21 +589,28 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]:
589589
ret[kk] = data[kk]
590590
return ret, idx
591591

592-
def _get_nframes(self, set_name: DPPath) -> int:
593-
# get nframes
592+
def _get_nframes(self, set_name: DPPath | str) -> int:
594593
if not isinstance(set_name, DPPath):
595594
set_name = DPPath(set_name)
596-
path = set_name / "coord.npy"
597-
# Read only the header to get shape
598-
with open(str(path), "rb") as f:
599-
version = np.lib.format.read_magic(f)
600-
if version[0] == 1:
601-
shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f)
602-
elif version[0] in [2, 3]:
603-
shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f)
604-
else:
605-
raise ValueError(f"Unsupported .npy file version: {version}")
606-
nframes = shape[0] if (len(shape) if isinstance(shape, tuple) else 0) > 1 else 1
595+
if isinstance(set_name, DPH5Path):
596+
path = set_name / "coord.npy"
597+
nframes = path.root[path._name].shape[0]
598+
else:
599+
path = set_name / "coord.npy"
600+
# Read only the header to get shape
601+
with open(str(path), "rb") as f:
602+
version = np.lib.format.read_magic(f)
603+
if version[0] == 1:
604+
shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0(
605+
f
606+
)
607+
elif version[0] in [2, 3]:
608+
shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0(
609+
f
610+
)
611+
else:
612+
raise ValueError(f"Unsupported .npy file version: {version}")
613+
nframes = shape[0] if len(shape) > 1 else 1
607614
return nframes
608615

609616
def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]:

0 commit comments

Comments
 (0)