Skip to content

Commit abf9575

Browse files
wanghan-iapcmHan Wang
andauthored
fix: unify compute_output_stats across dp, pt, and pd backends (#5267)
- Fix elif → if for index classification: A system can have both global and atomic labels simultaneously. The old elif logic meant a system with both find_atom_energy and find_energy would only be indexed for atomic, silently dropping its global label. Changed to two independent if checks in all three backends. - Add global_sampled_idx/atomic_sampled_idx parameters: compute_output_stats_global and compute_output_stats_atomic in pt and pd backends now accept precomputed index dicts (matching dpmodel's signature) instead of re-scanning systems internally. - Support mixed type in dpmodel: compute_output_stats_global in dpmodel now checks for real_natoms_vec (previously hardcoded natoms_key = "natoms"). - Apply atom_exclude_types mask to natoms in dpmodel: dpmodel was missing the exclude-type mask on natoms that pt/pd already had. - Fix in-place mutation of input data: All three backends were mutating sampled[i]["natoms"] (or real_natoms_vec) in-place when atom_exclude_types was present. Now the mask is applied to a local copy, leaving the caller's data untouched. - Add cross-backend consistency tests: New test file source/tests/consistent/utils/test_stat.py with 48 tests covering dp-vs-pt and dp-vs-pd consistency for compute_output_stats_global, compute_output_stats_atomic, and the top-level compute_output_stats, plus no-mutation verification. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Support excluding atom types from global natoms calculations. * Allow a system to be counted in both global and atomic sampling simultaneously. * **Refactor** * Switched statistics assembly to index-based gathering for robust mixed-type handling. * Standardized numeric assembly/reshaping across backends for consistent merging. * **Tests** * Added comprehensive cross-backend consistency tests (global, atomic, mixed types, exclusions). * **Chores** * Minor unpacking/cleanup to remove unused fields. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent dfb8c3f commit abf9575

7 files changed

Lines changed: 642 additions & 122 deletions

File tree

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,10 @@ def iter(
128128
device=array_api_compat.device(data[0]["coord"]),
129129
)
130130
for system in data:
131-
coord, atype, box, natoms = (
131+
coord, atype, box = (
132132
system["coord"],
133133
system["atype"],
134134
system["box"],
135-
system["natoms"],
136135
)
137136
(
138137
extended_coord,

deepmd/dpmodel/utils/stat.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from deepmd.dpmodel.common import (
1515
to_numpy_array,
1616
)
17+
from deepmd.dpmodel.utils.exclude_mask import (
18+
AtomExcludeMask,
19+
)
1720
from deepmd.utils.out_stat import (
1821
compute_stats_do_not_distinguish_types,
1922
compute_stats_from_atomic,
@@ -245,10 +248,8 @@ def compute_output_stats(
245248
system["find_atom_" + kk] > 0.0
246249
):
247250
atomic_sampled_idx[kk].append(idx)
248-
elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0):
251+
if (("find_" + kk) in system) and (system["find_" + kk] > 0.0):
249252
global_sampled_idx[kk].append(idx)
250-
else:
251-
continue
252253

253254
# use index to gather model predictions for the corresponding systems.
254255
model_pred_g = (
@@ -291,7 +292,7 @@ def compute_output_stats(
291292
)
292293

293294
# compute stat
294-
bias_atom_g, std_atom_g = compute_output_stats_global(
295+
bias_atom_g, std_atom_g = _compute_output_stats_global(
295296
sampled,
296297
ntypes,
297298
keys,
@@ -302,7 +303,7 @@ def compute_output_stats(
302303
intensive,
303304
model_pred_g,
304305
)
305-
bias_atom_a, std_atom_a = compute_output_stats_atomic(
306+
bias_atom_a, std_atom_a = _compute_output_stats_atomic(
306307
sampled,
307308
ntypes,
308309
keys,
@@ -335,7 +336,7 @@ def compute_output_stats(
335336
return bias_atom_e, std_atom_e
336337

337338

338-
def compute_output_stats_global(
339+
def _compute_output_stats_global(
339340
sampled: list[dict],
340341
ntypes: int,
341342
keys: list[str],
@@ -359,14 +360,21 @@ def compute_output_stats_global(
359360
for kk in keys
360361
}
361362

362-
natoms_key = "natoms"
363-
input_natoms = {
364-
kk: [
365-
to_numpy_array(sampled[idx][natoms_key])
366-
for idx in global_sampled_idx.get(kk, [])
367-
]
368-
for kk in keys
369-
}
363+
data_mixed_type = "real_natoms_vec" in sampled[0]
364+
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
365+
input_natoms = {}
366+
for kk in keys:
367+
kk_natoms = []
368+
for idx in global_sampled_idx.get(kk, []):
369+
nn = to_numpy_array(sampled[idx][natoms_key])
370+
if "atom_exclude_types" in sampled[idx]:
371+
nn = nn.copy()
372+
type_mask = AtomExcludeMask(
373+
ntypes, sampled[idx]["atom_exclude_types"]
374+
).get_type_mask()
375+
nn[:, 2:] *= type_mask.reshape(1, -1)
376+
kk_natoms.append(nn)
377+
input_natoms[kk] = kk_natoms
370378

371379
# shape: (nframes, ndim)
372380
merged_output = {
@@ -453,7 +461,7 @@ def rmse(x: np.ndarray) -> float:
453461
return bias_atom_e, std_atom_e
454462

455463

456-
def compute_output_stats_atomic(
464+
def _compute_output_stats_atomic(
457465
sampled: list[dict],
458466
ntypes: int,
459467
keys: list[str],

deepmd/pd/utils/env_mat_stat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,10 @@ def iter(
107107
"last_dim should be 1 for raial-only or 4 for full descriptor."
108108
)
109109
for system in data:
110-
coord, atype, box, natoms = (
110+
coord, atype, box = (
111111
system["coord"],
112112
system["atype"],
113113
system["box"],
114-
system["natoms"],
115114
)
116115
(
117116
extended_coord,

deepmd/pd/utils/stat.py

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,10 @@ def _compute_model_predict(
167167
model_predict = {kk: [] for kk in keys}
168168
for system in sampled:
169169
nframes = system["coord"].shape[0]
170-
coord, atype, box, natoms = (
170+
coord, atype, box = (
171171
system["coord"],
172172
system["atype"],
173173
system["box"],
174-
system["natoms"],
175174
)
176175
fparam = system.get("fparam", None)
177176
aparam = system.get("aparam", None)
@@ -324,12 +323,9 @@ def compute_output_stats(
324323
system["find_atom_" + kk] > 0.0
325324
):
326325
atomic_sampled_idx[kk].append(idx)
327-
elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0):
326+
if (("find_" + kk) in system) and (system["find_" + kk] > 0.0):
328327
global_sampled_idx[kk].append(idx)
329328

330-
else:
331-
continue
332-
333329
# use index to gather model predictions for the corresponding systems.
334330

335331
model_pred_g = (
@@ -372,20 +368,22 @@ def compute_output_stats(
372368
)
373369

374370
# compute stat
375-
bias_atom_g, std_atom_g = compute_output_stats_global(
371+
bias_atom_g, std_atom_g = _compute_output_stats_global(
376372
sampled,
377373
ntypes,
378374
keys,
379375
rcond,
380376
preset_bias,
381-
model_pred_g,
377+
global_sampled_idx,
382378
stats_distinguish_types,
383379
intensive,
380+
model_pred_g,
384381
)
385-
bias_atom_a, std_atom_a = compute_output_stats_atomic(
382+
bias_atom_a, std_atom_a = _compute_output_stats_atomic(
386383
sampled,
387384
ntypes,
388385
keys,
386+
atomic_sampled_idx,
389387
model_pred_a,
390388
)
391389

@@ -416,58 +414,52 @@ def compute_output_stats(
416414
return bias_atom_e, std_atom_e
417415

418416

419-
def compute_output_stats_global(
417+
def _compute_output_stats_global(
420418
sampled: list[dict],
421419
ntypes: int,
422420
keys: list[str],
423421
rcond: float | None = None,
424422
preset_bias: dict[str, list[paddle.Tensor | None]] | None = None,
425-
model_pred: dict[str, np.ndarray] | None = None,
423+
global_sampled_idx: dict | None = None,
426424
stats_distinguish_types: bool = True,
427425
intensive: bool = False,
426+
model_pred: dict[str, np.ndarray] | None = None,
428427
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
429428
"""This function only handle stat computation from reduced global labels."""
430-
# return directly if model predict is empty for global
431-
if model_pred == {}:
429+
# return directly if no global samples
430+
if global_sampled_idx is None or all(
431+
len(v) == 0 for v in global_sampled_idx.values()
432+
):
432433
return {}, {}
433434

434435
# get label dict from sample; for each key, only picking the system with global labels.
435436
outputs = {
436-
kk: [
437-
system[kk]
438-
for system in sampled
439-
if kk in system and system.get(f"find_{kk}", 0) > 0
440-
]
437+
kk: [to_numpy_array(sampled[idx][kk]) for idx in global_sampled_idx.get(kk, [])]
441438
for kk in keys
442439
}
443440

444441
data_mixed_type = "real_natoms_vec" in sampled[0]
445442
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
446-
for system in sampled:
447-
if "atom_exclude_types" in system:
448-
type_mask = AtomExcludeMask(
449-
ntypes, system["atom_exclude_types"]
450-
).get_type_mask()
451-
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
452-
453-
input_natoms = {
454-
kk: [
455-
item[natoms_key]
456-
for item in sampled
457-
if kk in item and item.get(f"find_{kk}", 0) > 0
458-
]
459-
for kk in keys
460-
}
443+
input_natoms = {}
444+
for kk in keys:
445+
kk_natoms = []
446+
for idx in global_sampled_idx.get(kk, []):
447+
nn = to_numpy_array(sampled[idx][natoms_key])
448+
if "atom_exclude_types" in sampled[idx]:
449+
nn = nn.copy()
450+
type_mask = AtomExcludeMask(
451+
ntypes, sampled[idx]["atom_exclude_types"]
452+
).get_type_mask()
453+
nn[:, 2:] *= to_numpy_array(type_mask).reshape(1, -1)
454+
kk_natoms.append(nn)
455+
input_natoms[kk] = kk_natoms
461456
# shape: (nframes, ndim)
462457
merged_output = {
463-
kk: to_numpy_array(paddle.concat(outputs[kk]))
464-
for kk in keys
465-
if len(outputs[kk]) > 0
458+
kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0
466459
}
467460
# shape: (nframes, ntypes)
468-
469461
merged_natoms = {
470-
kk: to_numpy_array(paddle.concat(input_natoms[kk])[:, 2:])
462+
kk: np.concatenate(input_natoms[kk])[:, 2:]
471463
for kk in keys
472464
if len(input_natoms[kk]) > 0
473465
}
@@ -550,53 +542,55 @@ def rmse(x: np.ndarray) -> float:
550542
return bias_atom_e, std_atom_e
551543

552544

553-
def compute_output_stats_atomic(
545+
def _compute_output_stats_atomic(
554546
sampled: list[dict],
555547
ntypes: int,
556548
keys: list[str],
549+
atomic_sampled_idx: dict | None = None,
557550
model_pred: dict[str, np.ndarray] | None = None,
558551
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
552+
"""Compute output statistics from atomic labels."""
553+
# return directly if no atomic samples
554+
if atomic_sampled_idx is None or all(
555+
len(v) == 0 for v in atomic_sampled_idx.values()
556+
):
557+
return {}, {}
558+
559559
# get label dict from sample; for each key, only picking the system with atomic labels.
560560
outputs = {
561561
kk: [
562-
system["atom_" + kk]
563-
for system in sampled
564-
if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0
562+
to_numpy_array(sampled[idx]["atom_" + kk])
563+
for idx in atomic_sampled_idx.get(kk, [])
565564
]
566565
for kk in keys
567566
}
568567
natoms = {
569568
kk: [
570-
system["atype"]
571-
for system in sampled
572-
if ("atom_" + kk) in system and system.get(f"find_atom_{kk}", 0) > 0
569+
to_numpy_array(sampled[idx]["atype"])
570+
for idx in atomic_sampled_idx.get(kk, [])
573571
]
574572
for kk in keys
575573
}
576574
# reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation
577575
# reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation
578-
natoms = {k: [sys_v.reshape([-1, 1]) for sys_v in v] for k, v in natoms.items()}
576+
natoms = {k: [sys_v.reshape(-1, 1) for sys_v in v] for k, v in natoms.items()}
579577
outputs = {
580578
k: [
581-
sys.reshape([natoms[k][sys_idx].shape[0], 1, -1])
579+
sys.reshape(natoms[k][sys_idx].shape[0], 1, -1)
582580
for sys_idx, sys in enumerate(v)
583581
]
584582
for k, v in outputs.items()
585583
}
586584

587585
merged_output = {
588-
kk: to_numpy_array(paddle.concat(outputs[kk]))
589-
for kk in keys
590-
if len(outputs[kk]) > 0
586+
kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0
591587
}
592588
merged_natoms = {
593-
kk: to_numpy_array(paddle.concat(natoms[kk]))
594-
for kk in keys
595-
if len(natoms[kk]) > 0
589+
kk: np.concatenate(natoms[kk]) for kk in keys if len(natoms[kk]) > 0
596590
}
597591
# reshape merged data to [nf, nloc, ndim]
598592
merged_output = {
599-
kk: merged_output[kk].reshape([*merged_natoms[kk].shape, -1])
593+
kk: merged_output[kk].reshape((*merged_natoms[kk].shape, -1))
600594
for kk in merged_output
601595
}
602596

0 commit comments

Comments
 (0)