Skip to content

Commit d3a57f2

Browse files
author
Han Wang
committed
test(pt_expt): distinct seeds for multi-task heads + RUF059 cleanup
CodeRabbit flagged that `TestPtExptLoadPtMultiTask` built both heads with the same `GLOBAL_SEED`, so `test_select_head_matches_single_task_forward` would still pass if `_load_pt` accidentally loaded the wrong head's weights. Mirror the spin variant: pass distinct seeds (42/7) to `_build_model_and_params` for the two heads, and add `test_distinct_heads_produce_distinct_outputs` as a sanity guard. Also prefix unused unpack vars with `_` to satisfy RUF059.
1 parent b3bab4b commit d3a57f2

1 file changed

Lines changed: 40 additions & 13 deletions

File tree

source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,14 @@
5252
)
5353

5454

55-
def _build_model_and_params(rcut: float = 4.0) -> tuple[EnergyModel, dict]:
56-
"""Build a small pt_expt EnergyModel and the matching ``model_params`` dict."""
55+
def _build_model_and_params(
56+
rcut: float = 4.0, seed: int = GLOBAL_SEED
57+
) -> tuple[EnergyModel, dict]:
58+
"""Build a small pt_expt EnergyModel and the matching ``model_params`` dict.
59+
60+
The ``seed`` parameter lets callers build distinguishable models when
61+
they need head-selection tests to produce different outputs per head.
62+
"""
5763
type_map = ["foo", "bar"]
5864
sel = [8, 6]
5965
descriptor_args = {
@@ -64,13 +70,13 @@ def _build_model_and_params(rcut: float = 4.0) -> tuple[EnergyModel, dict]:
6470
"neuron": [4, 8],
6571
"axis_neuron": 4,
6672
"type_one_side": True,
67-
"seed": GLOBAL_SEED,
73+
"seed": seed,
6874
}
6975
fitting_args = {
7076
"type": "ener",
7177
"neuron": [8, 8],
7278
"resnet_dt": True,
73-
"seed": GLOBAL_SEED,
79+
"seed": seed,
7480
}
7581

7682
ds = DescrptSeA(
@@ -80,15 +86,15 @@ def _build_model_and_params(rcut: float = 4.0) -> tuple[EnergyModel, dict]:
8086
neuron=[4, 8],
8187
axis_neuron=4,
8288
type_one_side=True,
83-
seed=GLOBAL_SEED,
89+
seed=seed,
8490
)
8591
ft = EnergyFittingNet(
8692
len(type_map),
8793
ds.get_dim_out(),
8894
neuron=[8, 8],
8995
resnet_dt=True,
9096
mixed_types=ds.mixed_types(),
91-
seed=GLOBAL_SEED,
97+
seed=seed,
9298
)
9399
model = EnergyModel(ds, ft, type_map=type_map).to(torch.float64).eval()
94100

@@ -388,9 +394,11 @@ class TestPtExptLoadPtMultiTask(unittest.TestCase):
388394
@classmethod
389395
def setUpClass(cls) -> None:
390396
# Build two single-task models with the same architecture but
391-
# different seeds, then save a multi-task-style checkpoint.
392-
cls.model_a, params_a = _build_model_and_params(rcut=4.0)
393-
cls.model_b, params_b = _build_model_and_params(rcut=4.0)
397+
# different seeds. Distinct seeds matter so that a head-routing
398+
# bug (loading head_b's weights when head_a is requested, or
399+
# vice versa) actually shows up as an assertion failure.
400+
cls.model_a, params_a = _build_model_and_params(rcut=4.0, seed=42)
401+
cls.model_b, params_b = _build_model_and_params(rcut=4.0, seed=7)
394402
cls.models = {"head_a": cls.model_a, "head_b": cls.model_b}
395403
cls.model_params = {"model_dict": {"head_a": params_a, "head_b": params_b}}
396404

@@ -423,7 +431,7 @@ def test_select_head_matches_single_task_forward(self) -> None:
423431
# Build a DeepPot wrapping this DeepEval for end-to-end eval.
424432
dp = DeepPot(self.pt_path, head=head)
425433
de = dp.deep_eval
426-
e, f, v = dp.eval(coords, cells, atom_types, atomic=False)
434+
e, f, _v = dp.eval(coords, cells, atom_types, atomic=False)
427435

428436
coord_t = torch.tensor(
429437
coords, dtype=torch.float64, device=DEVICE
@@ -450,6 +458,25 @@ def test_select_head_matches_single_task_forward(self) -> None:
450458
)
451459
self.assertEqual(de.get_type_map(), src.get_type_map())
452460

461+
def test_distinct_heads_produce_distinct_outputs(self) -> None:
462+
"""Sanity check that head_a and head_b really resolve to different weights."""
463+
rng = np.random.default_rng(GLOBAL_SEED + 2)
464+
natoms = 4
465+
coords = rng.random((1, natoms, 3)) * 8.0
466+
cells = np.eye(3).reshape(1, 9) * 10.0
467+
atom_types = np.array([i % 2 for i in range(natoms)], dtype=np.int32)
468+
e_a = DeepPot(self.pt_path, head="head_a").eval(
469+
coords, cells, atom_types, atomic=False
470+
)[0]
471+
e_b = DeepPot(self.pt_path, head="head_b").eval(
472+
coords, cells, atom_types, atomic=False
473+
)[0]
474+
self.assertFalse(
475+
np.allclose(e_a, e_b),
476+
"head_a and head_b produced identical outputs — head selection "
477+
"may be loading the wrong weights",
478+
)
479+
453480
def test_missing_head_raises(self) -> None:
454481
with self.assertRaisesRegex(ValueError, "Head 'no_such_head' not found"):
455482
DeepPot(self.pt_path, head="no_such_head")
@@ -469,7 +496,7 @@ def test_select_head_compiled_layout_matches(self) -> None:
469496

470497
for head, src in (("head_a", self.model_a), ("head_b", self.model_b)):
471498
dp = DeepPot(self.pt_path_compiled, head=head)
472-
e, f, v = dp.eval(coords, cells, atom_types, atomic=False)
499+
e, f, _v = dp.eval(coords, cells, atom_types, atomic=False)
473500

474501
coord_t = torch.tensor(
475502
coords, dtype=torch.float64, device=DEVICE
@@ -610,7 +637,7 @@ def test_metadata_flags_spin(self) -> None:
610637

611638
def test_eval_pbc_atomic_matches_reference(self) -> None:
612639
dp = DeepPot(self.files[".pt"])
613-
e, f, v, ae, av, fm, mm = dp.eval(
640+
e, f, v, ae, _av, fm, _mm = dp.eval(
614641
self.COORD, self.BOX, self.ATYPE, atomic=True, spin=self.SPIN
615642
)
616643
np.testing.assert_allclose(
@@ -915,7 +942,7 @@ def test_each_head_matches_its_eager_reference(self) -> None:
915942
self.assertEqual(dp.use_spin, [True, False], msg=f"head={head}")
916943

917944
ref = self._eager_ref(src)
918-
e, f, v, ae, av, fm, mm = dp.eval(
945+
e, f, v, _ae, _av, fm, _mm = dp.eval(
919946
self.COORD, self.BOX, self.ATYPE, atomic=True, spin=self.SPIN
920947
)
921948
np.testing.assert_allclose(

0 commit comments

Comments
 (0)