Skip to content

Commit 2ce2ea4

Browse files
committed
bug fix for non-atomic data
1 parent 97ddb72 commit 2ce2ea4

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

deepmd/utils/data.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,12 @@ def _load_single_data(
853853
# Branch 2: File exists, use memmap
854854
mmap_obj = self._get_memmap(path)
855855
# Slice the single frame and make an in-memory copy for modification
856-
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)
856+
if mmap_obj.ndim == 0:
857+
# Handle scalar data (0-dimensional array)
858+
data = mmap_obj.copy().astype(dtype, copy=False)
859+
else:
860+
# Handle array data that can be indexed by frame
861+
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)
857862

858863
try:
859864
if vv["atomic"]:
@@ -900,7 +905,9 @@ def _load_single_data(
900905
data = data.reshape([natoms, -1])
901906
data = data[idx_map, :]
902907

903-
return np.float32(1.0), data
908+
# Handle non-atomic data
909+
# For non-atomic data, reshape to (ndof,) shape
910+
return np.float32(1.0), data.reshape([ndof])
904911

905912
except ValueError as err_message:
906913
explanation = "This error may occur when your label mismatch its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`."

0 commit comments

Comments
 (0)