Skip to content

Commit f91ba7d

Browse files
committed
resolve dtype
1 parent d163b26 commit f91ba7d

1 file changed

Lines changed: 114 additions & 20 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 114 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
import msgpack
2222
import numpy as np
2323

24+
from deepmd.env import (
25+
GLOBAL_ENER_FLOAT_PRECISION,
26+
GLOBAL_NP_FLOAT_PRECISION,
27+
)
2428
from deepmd.utils.data import (
2529
DataRequirementItem,
2630
)
@@ -37,6 +41,10 @@
3741
"virials": "virial",
3842
}
3943

44+
# Keys whose high_prec is always True in the standard pipeline
45+
# (energy is set by Loss DataRequirementItem; reduce() also sets high_prec=True)
46+
_HIGH_PREC_KEYS = frozenset({"energy"})
47+
4048

4149
def _open_lmdb(path: str) -> lmdb.Environment:
4250
"""Open LMDB environment readonly."""
@@ -267,6 +275,34 @@ def _compute_natoms_vec(self, atype: np.ndarray) -> np.ndarray:
267275
vec[2:] = counts
268276
return vec
269277

278+
def _resolve_dtype(self, key: str) -> np.dtype:
279+
"""Resolve the target numpy dtype for a given key.
280+
281+
Priority: DataRequirementItem.dtype > DataRequirementItem.high_prec >
282+
built-in defaults (energy=high, others=normal).
283+
"""
284+
if key in self._data_requirements:
285+
req = self._data_requirements[key]
286+
# Support both DataRequirementItem objects and plain dicts
287+
if isinstance(req, dict):
288+
dtype = req.get("dtype")
289+
if dtype is not None:
290+
return dtype
291+
if req.get("high_prec", False):
292+
return GLOBAL_ENER_FLOAT_PRECISION
293+
return GLOBAL_NP_FLOAT_PRECISION
294+
else:
295+
# DataRequirementItem object
296+
if hasattr(req, "dtype") and req.dtype is not None:
297+
return req.dtype
298+
if hasattr(req, "high_prec") and req.high_prec:
299+
return GLOBAL_ENER_FLOAT_PRECISION
300+
return GLOBAL_NP_FLOAT_PRECISION
301+
# Fall back to built-in defaults
302+
if key in _HIGH_PREC_KEYS:
303+
return GLOBAL_ENER_FLOAT_PRECISION
304+
return GLOBAL_NP_FLOAT_PRECISION
305+
270306
def get_batch_size_for_nloc(self, nloc: int) -> int:
271307
"""Get batch_size for a given nloc. Uses auto rule if configured."""
272308
if self._auto_rule is not None:
@@ -291,21 +327,29 @@ def __getitem__(self, index: int) -> dict[str, Any]:
291327

292328
# Flatten arrays to match DeePMD convention
293329
if "coord" in frame and isinstance(frame["coord"], np.ndarray):
294-
frame["coord"] = frame["coord"].reshape(-1, 3).astype(np.float64)
330+
frame["coord"] = (
331+
frame["coord"].reshape(-1, 3).astype(self._resolve_dtype("coord"))
332+
)
295333
if "box" in frame and isinstance(frame["box"], np.ndarray):
296-
frame["box"] = frame["box"].reshape(9).astype(np.float64)
334+
frame["box"] = frame["box"].reshape(9).astype(self._resolve_dtype("box"))
297335
if "energy" in frame:
298336
val = frame["energy"]
299337
if isinstance(val, np.ndarray):
300-
frame["energy"] = val.reshape(1).astype(np.float64)
338+
frame["energy"] = val.reshape(1).astype(self._resolve_dtype("energy"))
301339
else:
302-
frame["energy"] = np.array([float(val)], dtype=np.float64)
340+
frame["energy"] = np.array(
341+
[float(val)], dtype=self._resolve_dtype("energy")
342+
)
303343
if "force" in frame and isinstance(frame["force"], np.ndarray):
304-
frame["force"] = frame["force"].reshape(-1, 3).astype(np.float64)
344+
frame["force"] = (
345+
frame["force"].reshape(-1, 3).astype(self._resolve_dtype("force"))
346+
)
305347
if "atype" in frame and isinstance(frame["atype"], np.ndarray):
306348
frame["atype"] = frame["atype"].reshape(-1).astype(np.int64)
307349
if "virial" in frame and isinstance(frame["virial"], np.ndarray):
308-
frame["virial"] = frame["virial"].reshape(9).astype(np.float64)
350+
frame["virial"] = (
351+
frame["virial"].reshape(9).astype(self._resolve_dtype("virial"))
352+
)
309353

310354
# Per-frame natoms_vec from atype
311355
atype = frame.get("atype")
@@ -340,14 +384,34 @@ def __getitem__(self, index: int) -> dict[str, Any]:
340384
for req_key, req_item in self._data_requirements.items():
341385
if req_key not in frame:
342386
frame[f"find_{req_key}"] = np.float32(0.0)
343-
ndof = req_item["ndof"]
344-
default = req_item["default"]
345-
atomic = req_item["atomic"]
387+
# Support both dict and DataRequirementItem object
388+
if isinstance(req_item, dict):
389+
ndof = req_item["ndof"]
390+
default = req_item["default"]
391+
atomic = req_item["atomic"]
392+
req_dtype = req_item.get("dtype")
393+
if req_dtype is None:
394+
req_dtype = (
395+
GLOBAL_ENER_FLOAT_PRECISION
396+
if req_item.get("high_prec", False)
397+
else GLOBAL_NP_FLOAT_PRECISION
398+
)
399+
else:
400+
ndof = req_item.ndof
401+
default = req_item.default
402+
atomic = req_item.atomic
403+
req_dtype = req_item.dtype
404+
if req_dtype is None:
405+
req_dtype = (
406+
GLOBAL_ENER_FLOAT_PRECISION
407+
if req_item.high_prec
408+
else GLOBAL_NP_FLOAT_PRECISION
409+
)
346410
if atomic:
347411
shape = (frame_natoms, ndof)
348412
else:
349413
shape = (ndof,)
350-
frame[req_key] = np.full(shape, default, dtype=np.float64)
414+
frame[req_key] = np.full(shape, default, dtype=req_dtype)
351415
elif f"find_{req_key}" not in frame:
352416
frame[f"find_{req_key}"] = np.float32(1.0)
353417

@@ -679,6 +743,7 @@ def add(
679743
high_prec: bool = False,
680744
repeat: int = 1,
681745
default: float = 0.0,
746+
dtype: np.dtype | None = None,
682747
**kwargs: Any,
683748
) -> None:
684749
"""Register a data requirement (mirrors DeepmdData.add)."""
@@ -689,8 +754,23 @@ def add(
689754
"high_prec": high_prec,
690755
"repeat": repeat,
691756
"default": default,
757+
"dtype": dtype,
692758
}
693759

760+
def _resolve_dtype(self, key: str) -> np.dtype:
761+
"""Resolve target dtype for a key using registered requirements."""
762+
if key in self._requirements:
763+
req = self._requirements[key]
764+
dtype = req.get("dtype")
765+
if dtype is not None:
766+
return dtype
767+
if req.get("high_prec", False):
768+
return GLOBAL_ENER_FLOAT_PRECISION
769+
return GLOBAL_NP_FLOAT_PRECISION
770+
if key in _HIGH_PREC_KEYS:
771+
return GLOBAL_ENER_FLOAT_PRECISION
772+
return GLOBAL_NP_FLOAT_PRECISION
773+
694774
def get_test(self, nloc: int | None = None) -> dict[str, Any]:
695775
"""Return frames stacked as numpy arrays.
696776
@@ -741,18 +821,28 @@ def _stack_frames(
741821

742822
for frame in frames:
743823
if "coord" in frame and isinstance(frame["coord"], np.ndarray):
744-
coords.append(frame["coord"].reshape(natoms * 3).astype(np.float64))
824+
coords.append(
825+
frame["coord"]
826+
.reshape(natoms * 3)
827+
.astype(self._resolve_dtype("coord"))
828+
)
745829
if "box" in frame and isinstance(frame["box"], np.ndarray):
746-
boxes.append(frame["box"].reshape(9).astype(np.float64))
830+
boxes.append(frame["box"].reshape(9).astype(self._resolve_dtype("box")))
747831
else:
748-
boxes.append(np.zeros(9, dtype=np.float64))
832+
boxes.append(np.zeros(9, dtype=self._resolve_dtype("box")))
749833
if "atype" in frame and isinstance(frame["atype"], np.ndarray):
750834
atypes.append(frame["atype"].reshape(natoms).astype(np.int64))
751835

752836
result["coord"] = (
753-
np.stack(coords) if coords else np.zeros((0, natoms * 3), dtype=np.float64)
837+
np.stack(coords)
838+
if coords
839+
else np.zeros((0, natoms * 3), dtype=self._resolve_dtype("coord"))
840+
)
841+
result["box"] = (
842+
np.stack(boxes)
843+
if boxes
844+
else np.zeros((0, 9), dtype=self._resolve_dtype("box"))
754845
)
755-
result["box"] = np.stack(boxes) if boxes else np.zeros((0, 9), dtype=np.float64)
756846
result["type"] = (
757847
np.stack(atypes) if atypes else np.zeros((0, natoms), dtype=np.int64)
758848
)
@@ -787,9 +877,11 @@ def _stack_frames(
787877
for frame in frames:
788878
val = frame.get(key)
789879
if isinstance(val, np.ndarray):
790-
arrays.append(val.astype(np.float64).ravel())
880+
arrays.append(val.astype(self._resolve_dtype(key)).ravel())
791881
elif val is not None:
792-
arrays.append(np.array([float(val)], dtype=np.float64))
882+
arrays.append(
883+
np.array([float(val)], dtype=self._resolve_dtype(key))
884+
)
793885
else:
794886
ref = next(
795887
(
@@ -800,9 +892,11 @@ def _stack_frames(
800892
None,
801893
)
802894
if ref is not None:
803-
arrays.append(np.zeros(ref.size, dtype=np.float64))
895+
arrays.append(
896+
np.zeros(ref.size, dtype=self._resolve_dtype(key))
897+
)
804898
else:
805-
arrays.append(np.zeros(1, dtype=np.float64))
899+
arrays.append(np.zeros(1, dtype=self._resolve_dtype(key)))
806900
result[key] = np.stack(arrays)
807901
elif key in self._requirements:
808902
ndof = self._requirements[key]["ndof"]
@@ -812,7 +906,7 @@ def _stack_frames(
812906
shape = (nframes, natoms * ndof)
813907
else:
814908
shape = (nframes, ndof)
815-
result[key] = np.full(shape, default, dtype=np.float64)
909+
result[key] = np.full(shape, default, dtype=self._resolve_dtype(key))
816910

817911
return result
818912

0 commit comments

Comments
 (0)