Skip to content

Commit 9d2f577

Browse files
author
Han Wang
committed
test(cc): regression test for atomic_virial fail-fast guard
Closes the coverage gap left by commit c80db58: the 'atomic && !do_atomic_virial' throw branch in DeepPotPTExpt::compute was not exercised by any existing test. Approach: * gen_sea.py: after exporting deeppot_sea.pt2 with do_atomic_virial= True, copy the archive to deeppot_sea_no_atomic_virial.pt2 and patch its metadata.json so do_atomic_virial=False. Cheap (ZIP rewrite, no AOTInductor recompile — adds <1s to gen time). Use ZIP_STORED to match the format expected by the C++ read_zip_entry in commonPTExpt.h. * test_deeppot_ptexpt.cc: new TYPED_TEST cpu_atomic_throws_when_ disabled — load the patched .pt2, call compute() with atomic=true, expect deepmd_exception. Also verify atomic=false on the same model still works (sanity check that the guard fires only when actually requested). Covers the most-impactful gap from PR #5407's review: the change in .pt2 default (do_atomic_virial off) made this code path the typical failure mode for users who'd previously been getting per-atom virial "for free". Without this test, regressions in the throw or in the metadata round-trip would slip through.
1 parent e30206d commit 9d2f577

2 files changed

Lines changed: 88 additions & 0 deletions

File tree

source/api_cc/tests/test_deeppot_ptexpt.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,52 @@ TYPED_TEST(TestInferDeepPotAPtExpt, print_summary) {
581581
dp.print_summary("");
582582
}
583583

584+
// Regression test for the fail-fast guard hoisted in commit c80db58d.
585+
// `deeppot_sea_no_atomic_virial.pt2` is a copy of deeppot_sea.pt2 with
586+
// the do_atomic_virial=false flag patched into its metadata.json.
587+
// Calling compute() with atomic=true on this model must throw before
588+
// any tensors are allocated.
589+
TYPED_TEST(TestInferDeepPotAPtExpt, cpu_atomic_throws_when_disabled) {
590+
using VALUETYPE = TypeParam;
591+
deepmd::DeepPot dp_no_av;
592+
ASSERT_NO_THROW(
593+
dp_no_av.init("../../tests/infer/deeppot_sea_no_atomic_virial.pt2"));
594+
595+
std::vector<VALUETYPE>& coord = this->coord;
596+
std::vector<int>& atype = this->atype;
597+
std::vector<VALUETYPE>& box = this->box;
598+
int& natoms = this->natoms;
599+
600+
// Build an LMP-style nlist so we exercise the nlist-overload of
601+
// compute(); the no-nlist overload has the same guard but is
602+
// covered by symmetry.
603+
float rc = dp_no_av.cutoff();
604+
int nloc = coord.size() / 3;
605+
std::vector<VALUETYPE> coord_cpy;
606+
std::vector<int> atype_cpy, mapping;
607+
std::vector<std::vector<int> > nlist_data;
608+
_build_nlist<VALUETYPE>(nlist_data, coord_cpy, atype_cpy, mapping, coord,
609+
atype, box, rc);
610+
int nall = coord_cpy.size() / 3;
611+
std::vector<int> ilist(nloc), numneigh(nloc);
612+
std::vector<int*> firstneigh(nloc);
613+
deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]);
614+
convert_nlist(inlist, nlist_data);
615+
616+
double ener;
617+
std::vector<VALUETYPE> force(nall * 3, 0.0), virial(9, 0.0), atom_ener,
618+
atom_vir;
619+
// atomic=true => guard must trip and throw deepmd_exception.
620+
EXPECT_THROW(
621+
dp_no_av.compute(ener, force, virial, atom_ener, atom_vir, coord_cpy,
622+
atype_cpy, box, nall - nloc, inlist, 0),
623+
deepmd::deepmd_exception);
624+
// atomic=false on the same model must work normally (sanity check
625+
// that the guard fires only when actually requested).
626+
EXPECT_NO_THROW(dp_no_av.compute(ener, force, virial, coord_cpy, atype_cpy,
627+
box, nall - nloc, inlist, 0));
628+
}
629+
584630
template <class VALUETYPE>
585631
class TestInferDeepPotAPtExptNoPbc : public ::testing::Test {
586632
protected:

source/tests/infer/gen_sea.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,50 @@ def main():
5959
print(f"Exporting to {pt2_path} ...") # noqa: T201
6060
deserialize_to_file(pt2_path, data, do_atomic_virial=True)
6161

62+
# Produce a variant for regression-testing the C++ "atomic &&
63+
# !do_atomic_virial" throw path by copying the .pt2 archive and
64+
# flipping the do_atomic_virial flag in its metadata.json — much
65+
# cheaper than running a second AOTInductor compile. The compiled
66+
# graph itself supports atomic virial; only the C++ guard differs.
67+
import shutil
68+
69+
pt2_no_aviral = os.path.join(base_dir, "deeppot_sea_no_atomic_virial.pt2")
70+
print(f"Patching to {pt2_no_aviral} ...") # noqa: T201
71+
shutil.copyfile(pt2_path, pt2_no_aviral)
72+
_patch_no_atomic_virial(pt2_no_aviral)
73+
6274
print("Done!") # noqa: T201
6375

6476

77+
def _patch_no_atomic_virial(pt2_path: str) -> None:
78+
"""Flip do_atomic_virial=False in the metadata.json of a .pt2 archive.
79+
80+
The .pt2 is a ZIP archive; the metadata blob lives at
81+
``extra/metadata.json``. We rewrite the archive with that one entry
82+
replaced and all other entries preserved verbatim.
83+
"""
84+
import json
85+
import zipfile
86+
87+
metadata_name = "extra/metadata.json"
88+
tmp_path = pt2_path + ".tmp"
89+
# PyTorch .pt2 archives use ZIP_STORED (uncompressed) so that the C++
90+
# reader (read_zip_entry in commonPTExpt.h) and torch's mmap-based
91+
# tensor loader can read entries without decompression. Preserve
92+
# that on rewrite — using ZIP_DEFLATED would yield bytes the C++
93+
# reader treats as raw, resulting in JSON parse errors.
94+
with zipfile.ZipFile(pt2_path, "r") as src:
95+
names = src.namelist()
96+
meta = json.loads(src.read(metadata_name).decode("utf-8"))
97+
meta["do_atomic_virial"] = False
98+
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as dst:
99+
for name in names:
100+
if name == metadata_name:
101+
dst.writestr(name, json.dumps(meta))
102+
else:
103+
dst.writestr(name, src.read(name))
104+
os.replace(tmp_path, pt2_path)
105+
106+
65107
if __name__ == "__main__":
66108
main()

0 commit comments

Comments
 (0)