2121import msgpack
2222import numpy as np
2323
24+ from deepmd .env import (
25+ GLOBAL_ENER_FLOAT_PRECISION ,
26+ GLOBAL_NP_FLOAT_PRECISION ,
27+ )
2428from deepmd .utils .data import (
2529 DataRequirementItem ,
2630)
3741 "virials" : "virial" ,
3842}
3943
44+ # Keys whose high_prec is always True in the standard pipeline
45+ # (energy is set by Loss DataRequirementItem; reduce() also sets high_prec=True)
46+ _HIGH_PREC_KEYS = frozenset ({"energy" })
47+
4048
4149def _open_lmdb (path : str ) -> lmdb .Environment :
4250 """Open LMDB environment readonly."""
@@ -267,6 +275,34 @@ def _compute_natoms_vec(self, atype: np.ndarray) -> np.ndarray:
267275 vec [2 :] = counts
268276 return vec
269277
278+ def _resolve_dtype (self , key : str ) -> np .dtype :
279+ """Resolve the target numpy dtype for a given key.
280+
281+ Priority: DataRequirementItem.dtype > DataRequirementItem.high_prec >
282+ built-in defaults (energy=high, others=normal).
283+ """
284+ if key in self ._data_requirements :
285+ req = self ._data_requirements [key ]
286+ # Support both DataRequirementItem objects and plain dicts
287+ if isinstance (req , dict ):
288+ dtype = req .get ("dtype" )
289+ if dtype is not None :
290+ return dtype
291+ if req .get ("high_prec" , False ):
292+ return GLOBAL_ENER_FLOAT_PRECISION
293+ return GLOBAL_NP_FLOAT_PRECISION
294+ else :
295+ # DataRequirementItem object
296+ if hasattr (req , "dtype" ) and req .dtype is not None :
297+ return req .dtype
298+ if hasattr (req , "high_prec" ) and req .high_prec :
299+ return GLOBAL_ENER_FLOAT_PRECISION
300+ return GLOBAL_NP_FLOAT_PRECISION
301+ # Fall back to built-in defaults
302+ if key in _HIGH_PREC_KEYS :
303+ return GLOBAL_ENER_FLOAT_PRECISION
304+ return GLOBAL_NP_FLOAT_PRECISION
305+
270306 def get_batch_size_for_nloc (self , nloc : int ) -> int :
271307 """Get batch_size for a given nloc. Uses auto rule if configured."""
272308 if self ._auto_rule is not None :
@@ -291,21 +327,29 @@ def __getitem__(self, index: int) -> dict[str, Any]:
291327
292328 # Flatten arrays to match DeePMD convention
293329 if "coord" in frame and isinstance (frame ["coord" ], np .ndarray ):
294- frame ["coord" ] = frame ["coord" ].reshape (- 1 , 3 ).astype (np .float64 )
330+ frame ["coord" ] = (
331+ frame ["coord" ].reshape (- 1 , 3 ).astype (self ._resolve_dtype ("coord" ))
332+ )
295333 if "box" in frame and isinstance (frame ["box" ], np .ndarray ):
296- frame ["box" ] = frame ["box" ].reshape (9 ).astype (np . float64 )
334+ frame ["box" ] = frame ["box" ].reshape (9 ).astype (self . _resolve_dtype ( "box" ) )
297335 if "energy" in frame :
298336 val = frame ["energy" ]
299337 if isinstance (val , np .ndarray ):
300- frame ["energy" ] = val .reshape (1 ).astype (np . float64 )
338+ frame ["energy" ] = val .reshape (1 ).astype (self . _resolve_dtype ( "energy" ) )
301339 else :
302- frame ["energy" ] = np .array ([float (val )], dtype = np .float64 )
340+ frame ["energy" ] = np .array (
341+ [float (val )], dtype = self ._resolve_dtype ("energy" )
342+ )
303343 if "force" in frame and isinstance (frame ["force" ], np .ndarray ):
304- frame ["force" ] = frame ["force" ].reshape (- 1 , 3 ).astype (np .float64 )
344+ frame ["force" ] = (
345+ frame ["force" ].reshape (- 1 , 3 ).astype (self ._resolve_dtype ("force" ))
346+ )
305347 if "atype" in frame and isinstance (frame ["atype" ], np .ndarray ):
306348 frame ["atype" ] = frame ["atype" ].reshape (- 1 ).astype (np .int64 )
307349 if "virial" in frame and isinstance (frame ["virial" ], np .ndarray ):
308- frame ["virial" ] = frame ["virial" ].reshape (9 ).astype (np .float64 )
350+ frame ["virial" ] = (
351+ frame ["virial" ].reshape (9 ).astype (self ._resolve_dtype ("virial" ))
352+ )
309353
310354 # Per-frame natoms_vec from atype
311355 atype = frame .get ("atype" )
@@ -340,14 +384,34 @@ def __getitem__(self, index: int) -> dict[str, Any]:
340384 for req_key , req_item in self ._data_requirements .items ():
341385 if req_key not in frame :
342386 frame [f"find_{ req_key } " ] = np .float32 (0.0 )
343- ndof = req_item ["ndof" ]
344- default = req_item ["default" ]
345- atomic = req_item ["atomic" ]
387+ # Support both dict and DataRequirementItem object
388+ if isinstance (req_item , dict ):
389+ ndof = req_item ["ndof" ]
390+ default = req_item ["default" ]
391+ atomic = req_item ["atomic" ]
392+ req_dtype = req_item .get ("dtype" )
393+ if req_dtype is None :
394+ req_dtype = (
395+ GLOBAL_ENER_FLOAT_PRECISION
396+ if req_item .get ("high_prec" , False )
397+ else GLOBAL_NP_FLOAT_PRECISION
398+ )
399+ else :
400+ ndof = req_item .ndof
401+ default = req_item .default
402+ atomic = req_item .atomic
403+ req_dtype = req_item .dtype
404+ if req_dtype is None :
405+ req_dtype = (
406+ GLOBAL_ENER_FLOAT_PRECISION
407+ if req_item .high_prec
408+ else GLOBAL_NP_FLOAT_PRECISION
409+ )
346410 if atomic :
347411 shape = (frame_natoms , ndof )
348412 else :
349413 shape = (ndof ,)
350- frame [req_key ] = np .full (shape , default , dtype = np . float64 )
414+ frame [req_key ] = np .full (shape , default , dtype = req_dtype )
351415 elif f"find_{ req_key } " not in frame :
352416 frame [f"find_{ req_key } " ] = np .float32 (1.0 )
353417
@@ -679,6 +743,7 @@ def add(
679743 high_prec : bool = False ,
680744 repeat : int = 1 ,
681745 default : float = 0.0 ,
746+ dtype : np .dtype | None = None ,
682747 ** kwargs : Any ,
683748 ) -> None :
684749 """Register a data requirement (mirrors DeepmdData.add)."""
@@ -689,8 +754,23 @@ def add(
689754 "high_prec" : high_prec ,
690755 "repeat" : repeat ,
691756 "default" : default ,
757+ "dtype" : dtype ,
692758 }
693759
760+ def _resolve_dtype (self , key : str ) -> np .dtype :
761+ """Resolve target dtype for a key using registered requirements."""
762+ if key in self ._requirements :
763+ req = self ._requirements [key ]
764+ dtype = req .get ("dtype" )
765+ if dtype is not None :
766+ return dtype
767+ if req .get ("high_prec" , False ):
768+ return GLOBAL_ENER_FLOAT_PRECISION
769+ return GLOBAL_NP_FLOAT_PRECISION
770+ if key in _HIGH_PREC_KEYS :
771+ return GLOBAL_ENER_FLOAT_PRECISION
772+ return GLOBAL_NP_FLOAT_PRECISION
773+
694774 def get_test (self , nloc : int | None = None ) -> dict [str , Any ]:
695775 """Return frames stacked as numpy arrays.
696776
@@ -741,18 +821,28 @@ def _stack_frames(
741821
742822 for frame in frames :
743823 if "coord" in frame and isinstance (frame ["coord" ], np .ndarray ):
744- coords .append (frame ["coord" ].reshape (natoms * 3 ).astype (np .float64 ))
824+ coords .append (
825+ frame ["coord" ]
826+ .reshape (natoms * 3 )
827+ .astype (self ._resolve_dtype ("coord" ))
828+ )
745829 if "box" in frame and isinstance (frame ["box" ], np .ndarray ):
746- boxes .append (frame ["box" ].reshape (9 ).astype (np . float64 ))
830+ boxes .append (frame ["box" ].reshape (9 ).astype (self . _resolve_dtype ( "box" ) ))
747831 else :
748- boxes .append (np .zeros (9 , dtype = np . float64 ))
832+ boxes .append (np .zeros (9 , dtype = self . _resolve_dtype ( "box" ) ))
749833 if "atype" in frame and isinstance (frame ["atype" ], np .ndarray ):
750834 atypes .append (frame ["atype" ].reshape (natoms ).astype (np .int64 ))
751835
752836 result ["coord" ] = (
753- np .stack (coords ) if coords else np .zeros ((0 , natoms * 3 ), dtype = np .float64 )
837+ np .stack (coords )
838+ if coords
839+ else np .zeros ((0 , natoms * 3 ), dtype = self ._resolve_dtype ("coord" ))
840+ )
841+ result ["box" ] = (
842+ np .stack (boxes )
843+ if boxes
844+ else np .zeros ((0 , 9 ), dtype = self ._resolve_dtype ("box" ))
754845 )
755- result ["box" ] = np .stack (boxes ) if boxes else np .zeros ((0 , 9 ), dtype = np .float64 )
756846 result ["type" ] = (
757847 np .stack (atypes ) if atypes else np .zeros ((0 , natoms ), dtype = np .int64 )
758848 )
@@ -787,9 +877,11 @@ def _stack_frames(
787877 for frame in frames :
788878 val = frame .get (key )
789879 if isinstance (val , np .ndarray ):
790- arrays .append (val .astype (np . float64 ).ravel ())
880+ arrays .append (val .astype (self . _resolve_dtype ( key ) ).ravel ())
791881 elif val is not None :
792- arrays .append (np .array ([float (val )], dtype = np .float64 ))
882+ arrays .append (
883+ np .array ([float (val )], dtype = self ._resolve_dtype (key ))
884+ )
793885 else :
794886 ref = next (
795887 (
@@ -800,9 +892,11 @@ def _stack_frames(
800892 None ,
801893 )
802894 if ref is not None :
803- arrays .append (np .zeros (ref .size , dtype = np .float64 ))
895+ arrays .append (
896+ np .zeros (ref .size , dtype = self ._resolve_dtype (key ))
897+ )
804898 else :
805- arrays .append (np .zeros (1 , dtype = np . float64 ))
899+ arrays .append (np .zeros (1 , dtype = self . _resolve_dtype ( key ) ))
806900 result [key ] = np .stack (arrays )
807901 elif key in self ._requirements :
808902 ndof = self ._requirements [key ]["ndof" ]
@@ -812,7 +906,7 @@ def _stack_frames(
812906 shape = (nframes , natoms * ndof )
813907 else :
814908 shape = (nframes , ndof )
815- result [key ] = np .full (shape , default , dtype = np . float64 )
909+ result [key ] = np .full (shape , default , dtype = self . _resolve_dtype ( key ) )
816910
817911 return result
818912
0 commit comments