Skip to content

Commit 68c72a3

Browse files
author
Han Wang
committed
test(pt_expt): cover with-comm artifact load-failure dispatch guard
Add gtest cases that exercise the explicit ``use_with_comm && !with_comm_loader`` throw added to DeepPotPTExpt::compute and DeepSpinPTExpt::compute. Fixtures: copies of deeppot_dpa3_mpi.pt2 and deeppot_dpa3_spin_mpi.pt2 with the nested ``model/extra/forward_lower_with_comm.pt2`` entry replaced by garbage bytes, produced by gen_corrupt_with_comm.py via zip rewrite (no AOTI recompilation). Each variant asserts: - init() succeeds (catch path keeps regular artifact usable) - single-rank compute (nswap=0) succeeds (uses regular artifact) - multi-rank compute (nswap=1) throws deepmd::deepmd_exception
1 parent 7632db8 commit 68c72a3

2 files changed

Lines changed: 269 additions & 0 deletions

File tree

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
// Tests for the dispatch-site fail-fast guard when the with-comm AOTI
3+
// artifact failed to load at init time. The fixtures are produced by
4+
// source/tests/infer/gen_corrupt_with_comm.py: copies of the valid
5+
// multi-rank .pt2 archives whose nested
6+
// ``model/extra/forward_lower_with_comm.pt2`` entry has been replaced
7+
// with garbage bytes. The outer metadata still claims
8+
// ``has_comm_artifact: true`` so the loader exercises the catch path.
9+
//
10+
// Expectations:
11+
// * init() succeeds (the loader logs and falls back instead of aborting).
12+
// * Single-rank dispatch (nswap == 0) keeps working through the regular
13+
// forward_lower artifact.
14+
// * Multi-rank dispatch (nswap > 0) throws a deepmd::deepmd_exception
15+
// instead of silently dropping the MPI ghost-embedding exchange.
16+
#include <gtest/gtest.h>
17+
18+
#include <fstream>
19+
#include <vector>
20+
21+
#include "DeepPot.h"
22+
// Include the PT_Expt headers so BUILD_PT_EXPT / BUILD_PT_EXPT_SPIN are
23+
// visible to the GTEST_SKIP guard below.
24+
#include "DeepPotPTExpt.h"
25+
#include "DeepSpin.h"
26+
#include "DeepSpinPTExpt.h"
27+
#include "common.h"
28+
#include "neighbor_list.h"
29+
#include "test_utils.h"
30+
31+
namespace {
32+
constexpr const char* kPotCorrupt =
33+
"../../tests/infer/deeppot_dpa3_mpi_corrupt_with_comm.pt2";
34+
constexpr const char* kSpinCorrupt =
35+
"../../tests/infer/deeppot_dpa3_spin_mpi_corrupt_with_comm.pt2";
36+
37+
bool file_exists(const char* path) {
38+
std::ifstream f(path);
39+
return f.good();
40+
}
41+
} // namespace
42+
43+
// ============================================================================
44+
// DeepPot (non-spin) — corrupted with-comm artifact
45+
// ============================================================================
46+
47+
class TestDeepPotPTExptWithCommLoadFailure : public ::testing::Test {
48+
protected:
49+
// Coordinates / atype / box copied from gen_dpa3.py so the regular
50+
// forward_lower artifact has well-formed inputs to evaluate.
51+
std::vector<double> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
52+
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
53+
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
54+
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
55+
std::vector<double> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
56+
57+
deepmd::DeepPot dp;
58+
59+
void SetUp() override {
60+
#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT
61+
GTEST_SKIP() << "Skip because PyTorch / pt_expt support is not enabled.";
62+
#endif
63+
if (!file_exists(kPotCorrupt)) {
64+
GTEST_SKIP() << "Skipping: " << kPotCorrupt
65+
<< " not found. Run source/tests/infer/"
66+
"gen_corrupt_with_comm.py first.";
67+
}
68+
// Init must succeed: the with-comm loader fails internally and the
69+
// catch block keeps the regular single-rank artifact usable.
70+
ASSERT_NO_THROW(dp.init(kPotCorrupt));
71+
}
72+
};
73+
74+
TEST_F(TestDeepPotPTExptWithCommLoadFailure, single_rank_compute_succeeds) {
75+
// nswap == 0 (default InputNlist) routes through the regular
76+
// forward_lower artifact; the broken with-comm artifact is not
77+
// consulted, so compute must succeed.
78+
float rc = dp.cutoff();
79+
int nloc = coord.size() / 3;
80+
std::vector<double> coord_cpy;
81+
std::vector<int> atype_cpy, mapping;
82+
std::vector<std::vector<int>> nlist_data;
83+
_build_nlist<double>(nlist_data, coord_cpy, atype_cpy, mapping, coord, atype,
84+
box, rc);
85+
int nall = coord_cpy.size() / 3;
86+
std::vector<int> ilist(nloc), numneigh(nloc);
87+
std::vector<int*> firstneigh(nloc);
88+
deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(),
89+
firstneigh.data());
90+
convert_nlist(inlist, nlist_data);
91+
inlist.mapping = mapping.data();
92+
ASSERT_EQ(inlist.nswap, 0); // pre-condition: single-rank dispatch
93+
94+
double ener;
95+
std::vector<double> force_, virial;
96+
EXPECT_NO_THROW(dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box,
97+
nall - nloc, inlist, 0));
98+
EXPECT_EQ(force_.size(), nall * 3);
99+
EXPECT_EQ(virial.size(), 9);
100+
}
101+
102+
TEST_F(TestDeepPotPTExptWithCommLoadFailure, multi_rank_compute_throws) {
103+
// nswap > 0 forces the dispatch site to ``run_model_with_comm``; the
104+
// load-failure guard added by PR #5430 must throw rather than silently
105+
// falling back to the single-rank path. The send/recv arrays remain
106+
// null — the guard fires before any of them are dereferenced.
107+
float rc = dp.cutoff();
108+
int nloc = coord.size() / 3;
109+
std::vector<double> coord_cpy;
110+
std::vector<int> atype_cpy, mapping;
111+
std::vector<std::vector<int>> nlist_data;
112+
_build_nlist<double>(nlist_data, coord_cpy, atype_cpy, mapping, coord, atype,
113+
box, rc);
114+
int nall = coord_cpy.size() / 3;
115+
std::vector<int> ilist(nloc), numneigh(nloc);
116+
std::vector<int*> firstneigh(nloc);
117+
deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(),
118+
firstneigh.data());
119+
convert_nlist(inlist, nlist_data);
120+
inlist.mapping = mapping.data();
121+
inlist.nswap = 1; // simulate multi-rank without populating send/recv
122+
123+
double ener;
124+
std::vector<double> force_, virial;
125+
EXPECT_THROW(dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box,
126+
nall - nloc, inlist, 0),
127+
deepmd::deepmd_exception);
128+
}
129+
130+
// ============================================================================
131+
// DeepSpin — corrupted with-comm artifact
132+
// ============================================================================
133+
134+
class TestDeepSpinPTExptWithCommLoadFailure : public ::testing::Test {
135+
protected:
136+
std::vector<double> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
137+
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
138+
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
139+
// Match deeppot_dpa3_spin_mpi.pt2 spin layout (type 0 has spin, types
140+
// 1+ do not) — spin vector packed alongside coord.
141+
std::vector<double> spin = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0.,
142+
0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.};
143+
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
144+
std::vector<double> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
145+
146+
deepmd::DeepSpin dp;
147+
148+
void SetUp() override {
149+
#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN
150+
GTEST_SKIP() << "Skip because PyTorch / pt_expt spin support is not "
151+
"enabled.";
152+
#endif
153+
if (!file_exists(kSpinCorrupt)) {
154+
GTEST_SKIP() << "Skipping: " << kSpinCorrupt
155+
<< " not found. Run source/tests/infer/"
156+
"gen_corrupt_with_comm.py first.";
157+
}
158+
ASSERT_NO_THROW(dp.init(kSpinCorrupt));
159+
}
160+
};
161+
162+
TEST_F(TestDeepSpinPTExptWithCommLoadFailure, single_rank_compute_succeeds) {
163+
// NoPBC + hardcoded all-pairs nlist mirrors the
164+
// ``cpu_lmp_nlist`` pattern in test_deeppot_dpa_ptexpt_spin.cc:
165+
// nloc == natoms == nall, no ghost atoms.
166+
const int natoms = static_cast<int>(atype.size());
167+
std::vector<double> empty_box;
168+
std::vector<std::vector<int>> nlist_data = {{1, 2, 3, 4, 5}, {0, 2, 3, 4, 5},
169+
{0, 1, 3, 4, 5}, {0, 1, 2, 4, 5},
170+
{0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
171+
std::vector<int> ilist(natoms), numneigh(natoms);
172+
std::vector<int*> firstneigh(natoms);
173+
deepmd::InputNlist inlist(natoms, ilist.data(), numneigh.data(),
174+
firstneigh.data());
175+
convert_nlist(inlist, nlist_data);
176+
ASSERT_EQ(inlist.nswap, 0);
177+
178+
double ener;
179+
std::vector<double> force_, force_mag, virial;
180+
EXPECT_NO_THROW(dp.compute(ener, force_, force_mag, virial, coord, spin,
181+
atype, empty_box, 0, inlist, 0));
182+
}
183+
184+
TEST_F(TestDeepSpinPTExptWithCommLoadFailure, multi_rank_compute_throws) {
185+
const int natoms = static_cast<int>(atype.size());
186+
std::vector<double> empty_box;
187+
std::vector<std::vector<int>> nlist_data = {{1, 2, 3, 4, 5}, {0, 2, 3, 4, 5},
188+
{0, 1, 3, 4, 5}, {0, 1, 2, 4, 5},
189+
{0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
190+
std::vector<int> ilist(natoms), numneigh(natoms);
191+
std::vector<int*> firstneigh(natoms);
192+
deepmd::InputNlist inlist(natoms, ilist.data(), numneigh.data(),
193+
firstneigh.data());
194+
convert_nlist(inlist, nlist_data);
195+
inlist.nswap = 1; // simulate multi-rank without populating send/recv
196+
197+
double ener;
198+
std::vector<double> force_, force_mag, virial;
199+
EXPECT_THROW(dp.compute(ener, force_, force_mag, virial, coord, spin, atype,
200+
empty_box, 0, inlist, 0),
201+
deepmd::deepmd_exception);
202+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: LGPL-3.0-or-later
3+
"""Generate ``deeppot_*_corrupt_with_comm.pt2`` fixtures.
4+
5+
The fixtures are copies of the corresponding multi-rank ``.pt2`` archives
6+
in which the nested ``model/extra/forward_lower_with_comm.pt2`` entry has
7+
been overwritten with garbage bytes. The outer metadata still claims
8+
``has_comm_artifact: true``, so:
9+
10+
- ``DeepPotPTExpt::init`` / ``DeepSpinPTExpt::init`` exercise the
11+
try/catch fallback path on the with-comm AOTI loader.
12+
- Single-rank dispatch (``nswap == 0``) keeps working via the regular
13+
artifact.
14+
- Multi-rank dispatch (``nswap > 0``) hits the explicit dispatch-site
15+
throw added in PR #5430, instead of silently dropping the MPI
16+
ghost-embedding exchange.
17+
18+
Consumed by ``source/api_cc/tests/test_with_comm_load_failure_ptexpt.cc``.
19+
"""
20+
21+
import os
22+
import zipfile
23+
24+
WITH_COMM_ENTRY = "model/extra/forward_lower_with_comm.pt2"
25+
GARBAGE = b"NOT_A_VALID_AOTI_ARCHIVE_" * 32
26+
27+
28+
def corrupt_with_comm(src: str, dst: str) -> None:
29+
"""Copy ``src`` to ``dst`` with the nested with-comm entry replaced."""
30+
with (
31+
zipfile.ZipFile(src, "r") as zin,
32+
zipfile.ZipFile(dst, "w", compression=zipfile.ZIP_STORED) as zout,
33+
):
34+
replaced = False
35+
for info in zin.infolist():
36+
data = zin.read(info.filename)
37+
if info.filename == WITH_COMM_ENTRY:
38+
data = GARBAGE
39+
replaced = True
40+
zout.writestr(info, data)
41+
if not replaced:
42+
raise RuntimeError(
43+
f"{src} does not contain {WITH_COMM_ENTRY}; cannot corrupt."
44+
)
45+
46+
47+
def main() -> None:
48+
base_dir = os.path.dirname(__file__)
49+
pairs = [
50+
("deeppot_dpa3_mpi.pt2", "deeppot_dpa3_mpi_corrupt_with_comm.pt2"),
51+
(
52+
"deeppot_dpa3_spin_mpi.pt2",
53+
"deeppot_dpa3_spin_mpi_corrupt_with_comm.pt2",
54+
),
55+
]
56+
for src_name, dst_name in pairs:
57+
src = os.path.join(base_dir, src_name)
58+
dst = os.path.join(base_dir, dst_name)
59+
if not os.path.exists(src):
60+
print(f"Skipping {dst_name}: source {src_name} not found.") # noqa: T201
61+
continue
62+
corrupt_with_comm(src, dst)
63+
print(f"Wrote {dst}") # noqa: T201
64+
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)