@@ -406,54 +406,75 @@ def __getitem__(self, index: int) -> dict[str, Any]:
406406 frame ["natoms" ] = fallback
407407 frame ["real_natoms_vec" ] = fallback
408408
409- # Add find_* flags for known label keys
410- label_keys = [
411- "energy" ,
412- "force" ,
413- "virial" ,
414- "atom_ener" ,
415- "atom_pref" ,
416- "drdq" ,
417- "atom_ener_coeff" ,
418- "hessian" ,
419- ]
420- for lk in label_keys :
421- frame [f"find_{ lk } " ] = np .float32 (1.0 ) if lk in frame else np .float32 (0.0 )
409+ # Add find_* flags for all data keys present in the frame.
410+ # Core structural keys and metadata are excluded — only label-like
411+ # and auxiliary data keys get find_* flags.
412+ _structural_keys = frozenset (
413+ {
414+ "coord" ,
415+ "box" ,
416+ "atype" ,
417+ "natoms" ,
418+ "real_natoms_vec" ,
419+ "fid" ,
420+ }
421+ )
422+ for fk in list (frame .keys ()):
423+ if fk .startswith ("find_" ) or fk in _structural_keys :
424+ continue
425+ # Skip keys handled by data_requirements (processed below)
426+ if fk in self ._data_requirements :
427+ continue
428+ if f"find_{ fk } " not in frame :
429+ frame [f"find_{ fk } " ] = np .float32 (1.0 )
422430
423- # Handle registered data requirements: fill defaults for missing keys
431+ # Handle registered data requirements: fill defaults for missing keys,
432+ # apply repeat, and cast dtype.
424433 for req_key , req_item in self ._data_requirements .items ():
434+ # Extract requirement fields (support both dict and object)
435+ if isinstance (req_item , dict ):
436+ ndof = req_item ["ndof" ]
437+ default = req_item ["default" ]
438+ atomic = req_item ["atomic" ]
439+ repeat = req_item .get ("repeat" , 1 )
440+ req_dtype = req_item .get ("dtype" )
441+ if req_dtype is None :
442+ req_dtype = (
443+ GLOBAL_ENER_FLOAT_PRECISION
444+ if req_item .get ("high_prec" , False )
445+ else GLOBAL_NP_FLOAT_PRECISION
446+ )
447+ else :
448+ ndof = req_item .ndof
449+ default = req_item .default
450+ atomic = req_item .atomic
451+ repeat = getattr (req_item , "repeat" , 1 )
452+ req_dtype = req_item .dtype
453+ if req_dtype is None :
454+ req_dtype = (
455+ GLOBAL_ENER_FLOAT_PRECISION
456+ if req_item .high_prec
457+ else GLOBAL_NP_FLOAT_PRECISION
458+ )
459+
425460 if req_key not in frame :
426461 frame [f"find_{ req_key } " ] = np .float32 (0.0 )
427- # Support both dict and DataRequirementItem object
428- if isinstance (req_item , dict ):
429- ndof = req_item ["ndof" ]
430- default = req_item ["default" ]
431- atomic = req_item ["atomic" ]
432- req_dtype = req_item .get ("dtype" )
433- if req_dtype is None :
434- req_dtype = (
435- GLOBAL_ENER_FLOAT_PRECISION
436- if req_item .get ("high_prec" , False )
437- else GLOBAL_NP_FLOAT_PRECISION
438- )
439- else :
440- ndof = req_item .ndof
441- default = req_item .default
442- atomic = req_item .atomic
443- req_dtype = req_item .dtype
444- if req_dtype is None :
445- req_dtype = (
446- GLOBAL_ENER_FLOAT_PRECISION
447- if req_item .high_prec
448- else GLOBAL_NP_FLOAT_PRECISION
449- )
450462 if atomic :
451463 shape = (frame_natoms , ndof )
452464 else :
453465 shape = (ndof ,)
454- frame [req_key ] = np .full (shape , default , dtype = req_dtype )
455- elif f"find_{ req_key } " not in frame :
456- frame [f"find_{ req_key } " ] = np .float32 (1.0 )
466+ data = np .full (shape , default , dtype = req_dtype )
467+ if repeat != 1 :
468+ data = np .repeat (data , repeat ).reshape (- 1 )
469+ frame [req_key ] = data
470+ else :
471+ if f"find_{ req_key } " not in frame :
472+ frame [f"find_{ req_key } " ] = np .float32 (1.0 )
473+ # Apply repeat to existing data (e.g. atom_pref repeat=3)
474+ if repeat != 1 and isinstance (frame [req_key ], np .ndarray ):
475+ frame [req_key ] = (
476+ np .repeat (frame [req_key ], repeat ).reshape (- 1 ).astype (req_dtype )
477+ )
457478
458479 # Add find_* for fparam/aparam/spin if not already set
459480 for extra_key in ["fparam" , "aparam" , "spin" ]:
@@ -1268,22 +1289,17 @@ def _stack_frames(
12681289 np .stack (atypes ) if atypes else np .zeros ((0 , natoms ), dtype = np .int64 )
12691290 )
12701291
1271- # Label keys and registered requirements
1292+ # Dynamically discover all data keys present in frames, plus
1293+ # any registered requirements. Structural keys (coord, box, type)
1294+ # are excluded — they are already handled above.
1295+ _structural_keys = frozenset ({"coord" , "box" , "atype" })
12721296 all_keys : dict [str , dict [str , Any ]] = {}
1273- for key in [
1274- "energy" ,
1275- "force" ,
1276- "virial" ,
1277- "atom_ener" ,
1278- "atom_pref" ,
1279- "force_mag" ,
1280- "spin" ,
1281- "fparam" ,
1282- "aparam" ,
1283- "hessian" ,
1284- "efield" ,
1285- ]:
1286- all_keys [key ] = {"ndof" : None , "atomic" : False , "default" : 0.0 }
1297+ for f in frames :
1298+ for fk in f :
1299+ if fk in _structural_keys or fk .startswith ("find_" ):
1300+ continue
1301+ if fk not in all_keys :
1302+ all_keys [fk ] = {"ndof" : None , "atomic" : False , "default" : 0.0 }
12871303 for key , req in self ._requirements .items ():
12881304 all_keys [key ] = req
12891305
@@ -1293,12 +1309,20 @@ def _stack_frames(
12931309 )
12941310 result [f"find_{ key } " ] = 1.0 if has_key else 0.0
12951311
1312+ # Get repeat factor from registered requirements
1313+ repeat = 1
1314+ if key in self ._requirements :
1315+ repeat = self ._requirements [key ].get ("repeat" , 1 )
1316+
12961317 if has_key :
12971318 arrays = []
12981319 for frame in frames :
12991320 val = frame .get (key )
13001321 if isinstance (val , np .ndarray ):
1301- arrays .append (val .astype (self ._resolve_dtype (key )).ravel ())
1322+ arr = val .astype (self ._resolve_dtype (key )).ravel ()
1323+ if repeat != 1 :
1324+ arr = np .repeat (arr , repeat )
1325+ arrays .append (arr )
13021326 elif val is not None :
13031327 arrays .append (
13041328 np .array ([float (val )], dtype = self ._resolve_dtype (key ))
@@ -1313,8 +1337,9 @@ def _stack_frames(
13131337 None ,
13141338 )
13151339 if ref is not None :
1340+ size = ref .size * repeat if repeat != 1 else ref .size
13161341 arrays .append (
1317- np .zeros (ref . size , dtype = self ._resolve_dtype (key ))
1342+ np .zeros (size , dtype = self ._resolve_dtype (key ))
13181343 )
13191344 else :
13201345 arrays .append (np .zeros (1 , dtype = self ._resolve_dtype (key )))
@@ -1324,9 +1349,9 @@ def _stack_frames(
13241349 atomic = self ._requirements [key ]["atomic" ]
13251350 default = self ._requirements [key ]["default" ]
13261351 if atomic :
1327- shape = (nframes , natoms * ndof )
1352+ shape = (nframes , natoms * ndof * repeat )
13281353 else :
1329- shape = (nframes , ndof )
1354+ shape = (nframes , ndof * repeat )
13301355 result [key ] = np .full (shape , default , dtype = self ._resolve_dtype (key ))
13311356
13321357 return result
0 commit comments