Skip to content

Commit 24b9060

Browse files
committed
fix repeat keys
1 parent 6b275bf commit 24b9060

2 files changed

Lines changed: 238 additions & 545 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 84 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -406,54 +406,75 @@ def __getitem__(self, index: int) -> dict[str, Any]:
406406
frame["natoms"] = fallback
407407
frame["real_natoms_vec"] = fallback
408408

409-
# Add find_* flags for known label keys
410-
label_keys = [
411-
"energy",
412-
"force",
413-
"virial",
414-
"atom_ener",
415-
"atom_pref",
416-
"drdq",
417-
"atom_ener_coeff",
418-
"hessian",
419-
]
420-
for lk in label_keys:
421-
frame[f"find_{lk}"] = np.float32(1.0) if lk in frame else np.float32(0.0)
409+
# Add find_* flags for all data keys present in the frame.
410+
# Core structural keys and metadata are excluded — only label-like
411+
# and auxiliary data keys get find_* flags.
412+
_structural_keys = frozenset(
413+
{
414+
"coord",
415+
"box",
416+
"atype",
417+
"natoms",
418+
"real_natoms_vec",
419+
"fid",
420+
}
421+
)
422+
for fk in list(frame.keys()):
423+
if fk.startswith("find_") or fk in _structural_keys:
424+
continue
425+
# Skip keys handled by data_requirements (processed below)
426+
if fk in self._data_requirements:
427+
continue
428+
if f"find_{fk}" not in frame:
429+
frame[f"find_{fk}"] = np.float32(1.0)
422430

423-
# Handle registered data requirements: fill defaults for missing keys
431+
# Handle registered data requirements: fill defaults for missing keys,
432+
# apply repeat, and cast dtype.
424433
for req_key, req_item in self._data_requirements.items():
434+
# Extract requirement fields (support both dict and object)
435+
if isinstance(req_item, dict):
436+
ndof = req_item["ndof"]
437+
default = req_item["default"]
438+
atomic = req_item["atomic"]
439+
repeat = req_item.get("repeat", 1)
440+
req_dtype = req_item.get("dtype")
441+
if req_dtype is None:
442+
req_dtype = (
443+
GLOBAL_ENER_FLOAT_PRECISION
444+
if req_item.get("high_prec", False)
445+
else GLOBAL_NP_FLOAT_PRECISION
446+
)
447+
else:
448+
ndof = req_item.ndof
449+
default = req_item.default
450+
atomic = req_item.atomic
451+
repeat = getattr(req_item, "repeat", 1)
452+
req_dtype = req_item.dtype
453+
if req_dtype is None:
454+
req_dtype = (
455+
GLOBAL_ENER_FLOAT_PRECISION
456+
if req_item.high_prec
457+
else GLOBAL_NP_FLOAT_PRECISION
458+
)
459+
425460
if req_key not in frame:
426461
frame[f"find_{req_key}"] = np.float32(0.0)
427-
# Support both dict and DataRequirementItem object
428-
if isinstance(req_item, dict):
429-
ndof = req_item["ndof"]
430-
default = req_item["default"]
431-
atomic = req_item["atomic"]
432-
req_dtype = req_item.get("dtype")
433-
if req_dtype is None:
434-
req_dtype = (
435-
GLOBAL_ENER_FLOAT_PRECISION
436-
if req_item.get("high_prec", False)
437-
else GLOBAL_NP_FLOAT_PRECISION
438-
)
439-
else:
440-
ndof = req_item.ndof
441-
default = req_item.default
442-
atomic = req_item.atomic
443-
req_dtype = req_item.dtype
444-
if req_dtype is None:
445-
req_dtype = (
446-
GLOBAL_ENER_FLOAT_PRECISION
447-
if req_item.high_prec
448-
else GLOBAL_NP_FLOAT_PRECISION
449-
)
450462
if atomic:
451463
shape = (frame_natoms, ndof)
452464
else:
453465
shape = (ndof,)
454-
frame[req_key] = np.full(shape, default, dtype=req_dtype)
455-
elif f"find_{req_key}" not in frame:
456-
frame[f"find_{req_key}"] = np.float32(1.0)
466+
data = np.full(shape, default, dtype=req_dtype)
467+
if repeat != 1:
468+
data = np.repeat(data, repeat).reshape(-1)
469+
frame[req_key] = data
470+
else:
471+
if f"find_{req_key}" not in frame:
472+
frame[f"find_{req_key}"] = np.float32(1.0)
473+
# Apply repeat to existing data (e.g. atom_pref repeat=3)
474+
if repeat != 1 and isinstance(frame[req_key], np.ndarray):
475+
frame[req_key] = (
476+
np.repeat(frame[req_key], repeat).reshape(-1).astype(req_dtype)
477+
)
457478

458479
# Add find_* for fparam/aparam/spin if not already set
459480
for extra_key in ["fparam", "aparam", "spin"]:
@@ -1268,22 +1289,17 @@ def _stack_frames(
12681289
np.stack(atypes) if atypes else np.zeros((0, natoms), dtype=np.int64)
12691290
)
12701291

1271-
# Label keys and registered requirements
1292+
# Dynamically discover all data keys present in frames, plus
1293+
# any registered requirements. Structural keys (coord, box, type)
1294+
# are excluded — they are already handled above.
1295+
_structural_keys = frozenset({"coord", "box", "atype"})
12721296
all_keys: dict[str, dict[str, Any]] = {}
1273-
for key in [
1274-
"energy",
1275-
"force",
1276-
"virial",
1277-
"atom_ener",
1278-
"atom_pref",
1279-
"force_mag",
1280-
"spin",
1281-
"fparam",
1282-
"aparam",
1283-
"hessian",
1284-
"efield",
1285-
]:
1286-
all_keys[key] = {"ndof": None, "atomic": False, "default": 0.0}
1297+
for f in frames:
1298+
for fk in f:
1299+
if fk in _structural_keys or fk.startswith("find_"):
1300+
continue
1301+
if fk not in all_keys:
1302+
all_keys[fk] = {"ndof": None, "atomic": False, "default": 0.0}
12871303
for key, req in self._requirements.items():
12881304
all_keys[key] = req
12891305

@@ -1293,12 +1309,20 @@ def _stack_frames(
12931309
)
12941310
result[f"find_{key}"] = 1.0 if has_key else 0.0
12951311

1312+
# Get repeat factor from registered requirements
1313+
repeat = 1
1314+
if key in self._requirements:
1315+
repeat = self._requirements[key].get("repeat", 1)
1316+
12961317
if has_key:
12971318
arrays = []
12981319
for frame in frames:
12991320
val = frame.get(key)
13001321
if isinstance(val, np.ndarray):
1301-
arrays.append(val.astype(self._resolve_dtype(key)).ravel())
1322+
arr = val.astype(self._resolve_dtype(key)).ravel()
1323+
if repeat != 1:
1324+
arr = np.repeat(arr, repeat)
1325+
arrays.append(arr)
13021326
elif val is not None:
13031327
arrays.append(
13041328
np.array([float(val)], dtype=self._resolve_dtype(key))
@@ -1313,8 +1337,9 @@ def _stack_frames(
13131337
None,
13141338
)
13151339
if ref is not None:
1340+
size = ref.size * repeat if repeat != 1 else ref.size
13161341
arrays.append(
1317-
np.zeros(ref.size, dtype=self._resolve_dtype(key))
1342+
np.zeros(size, dtype=self._resolve_dtype(key))
13181343
)
13191344
else:
13201345
arrays.append(np.zeros(1, dtype=self._resolve_dtype(key)))
@@ -1324,9 +1349,9 @@ def _stack_frames(
13241349
atomic = self._requirements[key]["atomic"]
13251350
default = self._requirements[key]["default"]
13261351
if atomic:
1327-
shape = (nframes, natoms * ndof)
1352+
shape = (nframes, natoms * ndof * repeat)
13281353
else:
1329-
shape = (nframes, ndof)
1354+
shape = (nframes, ndof * repeat)
13301355
result[key] = np.full(shape, default, dtype=self._resolve_dtype(key))
13311356

13321357
return result

0 commit comments

Comments
 (0)