Skip to content

Commit b6e7db3

Browse files
author
Han Wang
committed
fix: address reviewer comments on PR #5334
- deep_eval.py: add .contiguous() on expanded default fparam; raise ValueError when no default available instead of silent zero injection - test_deeppot_ptexpt.cc: use distinct coord_alt for frame 1 in multi-frame tests; assert temp file creation in parser error tests - test_model_compression.py: assert force/virial on .pt2 compress path - test_dp_freeze.py: assert virial shape in .pt2 smoke test - test_finetune.py: copy pretrained checkpoint before finetune overwrites
1 parent 78d5bde commit b6e7db3

5 files changed

Lines changed: 45 additions & 22 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,12 @@ def _eval_model(
630630
torch.tensor(default_fp, dtype=torch.float64, device=DEVICE)
631631
.unsqueeze(0)
632632
.expand(nframes, -1)
633+
.contiguous()
633634
)
634635
else:
635-
fparam_t = torch.zeros(
636-
nframes, self.get_dim_fparam(), dtype=torch.float64, device=DEVICE
636+
raise ValueError(
637+
f"fparam is required for this model (dim_fparam={self.get_dim_fparam()}) "
638+
"but was not provided, and no default_fparam is stored in the model."
637639
)
638640
else:
639641
fparam_t = None

source/api_cc/tests/test_deeppot_ptexpt.cc

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class TestInferDeepPotAPtExpt : public ::testing::Test {
1919
std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
2020
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
2121
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
22+
// Alternative coords for multi-frame tests (must give different energy)
23+
std::vector<VALUETYPE> coord_alt = {10.06, 5.71, 11.16, 9.07, 1.22, 12.68,
24+
9.89, 10.22, 1.67, 5.86, 4.82, 12.05,
25+
8.37, 10.70, 5.76, 2.95, 7.21, 0.83};
2226
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
2327
std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
2428
// Same reference values as test_deeppot_pt.cc (model converted from .pth)
@@ -414,6 +418,9 @@ class TestInferDeepPotAPtExptNoPbc : public ::testing::Test {
414418
std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
415419
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
416420
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
421+
std::vector<VALUETYPE> coord_alt = {10.06, 5.71, 11.16, 9.07, 1.22, 12.68,
422+
9.89, 10.22, 1.67, 5.86, 4.82, 12.05,
423+
8.37, 10.70, 5.76, 2.95, 7.21, 0.83};
417424
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
418425
std::vector<VALUETYPE> box = {};
419426
// Same reference values as TestInferDeepPotAPtNoPbc in test_deeppot_pt.cc
@@ -541,6 +548,7 @@ TYPED_TEST(TestInferDeepPotAPtExptNoPbc, cpu_build_nlist_atomic) {
541548
TYPED_TEST(TestInferDeepPotAPtExpt, cpu_build_nlist_nframes) {
542549
using VALUETYPE = TypeParam;
543550
std::vector<VALUETYPE>& coord = this->coord;
551+
std::vector<VALUETYPE>& coord_alt = this->coord_alt;
544552
std::vector<int>& atype = this->atype;
545553
std::vector<VALUETYPE>& box = this->box;
546554
std::vector<VALUETYPE>& expected_f = this->expected_f;
@@ -550,8 +558,9 @@ TYPED_TEST(TestInferDeepPotAPtExpt, cpu_build_nlist_nframes) {
550558
deepmd::DeepPot& dp = this->dp;
551559

552560
int nframes = 2;
561+
// Frame 0: original coords. Frame 1: alternative coords (coord_alt).
553562
std::vector<VALUETYPE> coord_2f(coord);
554-
coord_2f.insert(coord_2f.end(), coord.begin(), coord.end());
563+
coord_2f.insert(coord_2f.end(), coord_alt.begin(), coord_alt.end());
555564
std::vector<int> atype_2f(atype);
556565
atype_2f.insert(atype_2f.end(), atype.begin(), atype.end());
557566
std::vector<VALUETYPE> box_2f(box);
@@ -566,22 +575,21 @@ TYPED_TEST(TestInferDeepPotAPtExpt, cpu_build_nlist_nframes) {
566575
EXPECT_EQ(force.size(), nframes * natoms * 3);
567576
EXPECT_EQ(virial.size(), nframes * 9);
568577

569-
for (int ff = 0; ff < nframes; ++ff) {
570-
EXPECT_LT(fabs(ener[ff] - expected_tot_e), EPSILON);
571-
for (int ii = 0; ii < natoms * 3; ++ii) {
572-
EXPECT_LT(fabs(force[ff * natoms * 3 + ii] - expected_f[ii]), EPSILON);
573-
}
574-
for (int ii = 0; ii < 9; ++ii) {
575-
EXPECT_LT(fabs(virial[ff * 9 + ii] - expected_tot_v[ii]), EPSILON);
576-
}
578+
// Frame 0 should match reference
579+
EXPECT_LT(fabs(ener[0] - expected_tot_e), EPSILON);
580+
for (int ii = 0; ii < natoms * 3; ++ii) {
581+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
577582
}
583+
// Frame 1 should be different (perturbed coords)
584+
EXPECT_GT(fabs(ener[1] - ener[0]), 1e-10);
578585
}
579586

580587
// Multi-frame NoPBC test via compute_mixed_type
581588
TYPED_TEST(TestInferDeepPotAPtExptNoPbc, cpu_build_nlist_nframes) {
582589
using VALUETYPE = TypeParam;
583590
std::vector<VALUETYPE>& coord = this->coord;
584591
std::vector<int>& atype = this->atype;
592+
std::vector<VALUETYPE>& coord_alt = this->coord_alt;
585593
std::vector<VALUETYPE>& box = this->box; // empty
586594
std::vector<VALUETYPE>& expected_f = this->expected_f;
587595
int& natoms = this->natoms;
@@ -591,7 +599,7 @@ TYPED_TEST(TestInferDeepPotAPtExptNoPbc, cpu_build_nlist_nframes) {
591599

592600
int nframes = 2;
593601
std::vector<VALUETYPE> coord_2f(coord);
594-
coord_2f.insert(coord_2f.end(), coord.begin(), coord.end());
602+
coord_2f.insert(coord_2f.end(), coord_alt.begin(), coord_alt.end());
595603
std::vector<int> atype_2f(atype);
596604
atype_2f.insert(atype_2f.end(), atype.begin(), atype.end());
597605

@@ -603,15 +611,13 @@ TYPED_TEST(TestInferDeepPotAPtExptNoPbc, cpu_build_nlist_nframes) {
603611
EXPECT_EQ(force.size(), nframes * natoms * 3);
604612
EXPECT_EQ(virial.size(), nframes * 9);
605613

606-
for (int ff = 0; ff < nframes; ++ff) {
607-
EXPECT_LT(fabs(ener[ff] - expected_tot_e), EPSILON);
608-
for (int ii = 0; ii < natoms * 3; ++ii) {
609-
EXPECT_LT(fabs(force[ff * natoms * 3 + ii] - expected_f[ii]), EPSILON);
610-
}
611-
for (int ii = 0; ii < 9; ++ii) {
612-
EXPECT_LT(fabs(virial[ff * 9 + ii] - expected_tot_v[ii]), EPSILON);
613-
}
614+
// Frame 0 should match reference
615+
EXPECT_LT(fabs(ener[0] - expected_tot_e), EPSILON);
616+
for (int ii = 0; ii < natoms * 3; ++ii) {
617+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
614618
}
619+
// Frame 1 should be different (perturbed coords)
620+
EXPECT_GT(fabs(ener[1] - ener[0]), 1e-10);
615621
}
616622

617623
// ========== Parser / metadata coverage tests ==========
@@ -631,6 +637,7 @@ TEST(TestDeepPotPTExptParser, load_invalid_zip) {
631637
std::string tmpfile = "test_invalid.pt2";
632638
{
633639
std::ofstream ofs(tmpfile, std::ios::binary);
640+
ASSERT_TRUE(ofs.is_open()) << "Failed to create temp file";
634641
ofs << "not a zip file at all";
635642
}
636643
deepmd::DeepPot dp;
@@ -645,6 +652,7 @@ TEST(TestDeepPotPTExptParser, load_tiny_file) {
645652
std::string tmpfile = "test_tiny.pt2";
646653
{
647654
std::ofstream ofs(tmpfile, std::ios::binary);
655+
ASSERT_TRUE(ofs.is_open()) << "Failed to create temp file";
648656
ofs << "abc";
649657
}
650658
deepmd::DeepPot dp;

source/tests/pt_expt/model/test_model_compression.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ def test_freeze_compress_eval_pt2(self) -> None:
341341
np.testing.assert_allclose(
342342
ret_frozen["energy"], e_pt2, atol=1e-7, err_msg="energy"
343343
)
344+
np.testing.assert_allclose(
345+
ret_frozen["force"], f_pt2, atol=1e-7, err_msg="force"
346+
)
347+
np.testing.assert_allclose(
348+
ret_frozen["virial"], v_pt2, atol=1e-7, err_msg="virial"
349+
)
344350
finally:
345351
os.unlink(frozen_path)
346352
if os.path.exists(compressed_pt2_path):

source/tests/pt_expt/test_dp_freeze.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_freeze_pt2(self) -> None:
124124
e, f, v = dp.eval(coord, box, atype)
125125
self.assertEqual(e.shape, (1, 1))
126126
self.assertEqual(f.shape, (1, 3, 3))
127+
self.assertEqual(v.shape, (1, 9))
127128

128129
def test_freeze_pt2_eval_consistency(self) -> None:
129130
"""Verify .pte and .pt2 produce identical results."""

source/tests/pt_expt/test_finetune.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,10 @@ def test_finetune_from_pt2(self) -> None:
899899
freeze(model=ckpt_path, output=pt2_path)
900900
self.assertTrue(os.path.exists(pt2_path))
901901

902+
# Save pretrained weights before finetune overwrites model.ckpt.pt
903+
pre_ckpt_copy = os.path.join(tmpdir, "pretrained_copy.pt")
904+
shutil.copy2(ckpt_path, pre_ckpt_copy)
905+
902906
# Phase 3: finetune from .pt2 via CLI (lr=0 so weights stay unchanged)
903907
ft_config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1)
904908
ft_config["learning_rate"]["start_lr"] = 1e-30
@@ -926,8 +930,10 @@ def test_finetune_from_pt2(self) -> None:
926930
ft_model_state = ft_state["model"] if "model" in ft_state else ft_state
927931
self.assertIn("_extra_state", ft_model_state)
928932

929-
# Load pretrained from .pt for weight comparison
930-
pre_state = torch.load(ckpt_path, map_location=DEVICE, weights_only=True)
933+
# Load pretrained from saved copy (finetune overwrites model.ckpt.pt)
934+
pre_state = torch.load(
935+
pre_ckpt_copy, map_location=DEVICE, weights_only=True
936+
)
931937
pre_model_state = pre_state["model"] if "model" in pre_state else pre_state
932938

933939
# Inherited weights must match pretrained

0 commit comments

Comments
 (0)