5252)
5353
5454
55- def _build_model_and_params (rcut : float = 4.0 ) -> tuple [EnergyModel , dict ]:
56- """Build a small pt_expt EnergyModel and the matching ``model_params`` dict."""
55+ def _build_model_and_params (
56+ rcut : float = 4.0 , seed : int = GLOBAL_SEED
57+ ) -> tuple [EnergyModel , dict ]:
58+ """Build a small pt_expt EnergyModel and the matching ``model_params`` dict.
59+
60+ The ``seed`` parameter lets callers build distinguishable models when
61+ they need head-selection tests to produce different outputs per head.
62+ """
5763 type_map = ["foo" , "bar" ]
5864 sel = [8 , 6 ]
5965 descriptor_args = {
@@ -64,13 +70,13 @@ def _build_model_and_params(rcut: float = 4.0) -> tuple[EnergyModel, dict]:
6470 "neuron" : [4 , 8 ],
6571 "axis_neuron" : 4 ,
6672 "type_one_side" : True ,
67- "seed" : GLOBAL_SEED ,
73+ "seed" : seed ,
6874 }
6975 fitting_args = {
7076 "type" : "ener" ,
7177 "neuron" : [8 , 8 ],
7278 "resnet_dt" : True ,
73- "seed" : GLOBAL_SEED ,
79+ "seed" : seed ,
7480 }
7581
7682 ds = DescrptSeA (
@@ -80,15 +86,15 @@ def _build_model_and_params(rcut: float = 4.0) -> tuple[EnergyModel, dict]:
8086 neuron = [4 , 8 ],
8187 axis_neuron = 4 ,
8288 type_one_side = True ,
83- seed = GLOBAL_SEED ,
89+ seed = seed ,
8490 )
8591 ft = EnergyFittingNet (
8692 len (type_map ),
8793 ds .get_dim_out (),
8894 neuron = [8 , 8 ],
8995 resnet_dt = True ,
9096 mixed_types = ds .mixed_types (),
91- seed = GLOBAL_SEED ,
97+ seed = seed ,
9298 )
9399 model = EnergyModel (ds , ft , type_map = type_map ).to (torch .float64 ).eval ()
94100
@@ -388,9 +394,11 @@ class TestPtExptLoadPtMultiTask(unittest.TestCase):
388394 @classmethod
389395 def setUpClass (cls ) -> None :
390396 # Build two single-task models with the same architecture but
391- # different seeds, then save a multi-task-style checkpoint.
392- cls .model_a , params_a = _build_model_and_params (rcut = 4.0 )
393- cls .model_b , params_b = _build_model_and_params (rcut = 4.0 )
397+ # different seeds. Distinct seeds matter so that a head-routing
398+ # bug (loading head_b's weights when head_a is requested, or
399+ # vice versa) actually shows up as an assertion failure.
400+ cls .model_a , params_a = _build_model_and_params (rcut = 4.0 , seed = 42 )
401+ cls .model_b , params_b = _build_model_and_params (rcut = 4.0 , seed = 7 )
394402 cls .models = {"head_a" : cls .model_a , "head_b" : cls .model_b }
395403 cls .model_params = {"model_dict" : {"head_a" : params_a , "head_b" : params_b }}
396404
@@ -423,7 +431,7 @@ def test_select_head_matches_single_task_forward(self) -> None:
423431 # Build a DeepPot wrapping this DeepEval for end-to-end eval.
424432 dp = DeepPot (self .pt_path , head = head )
425433 de = dp .deep_eval
426- e , f , v = dp .eval (coords , cells , atom_types , atomic = False )
434+ e , f , _v = dp .eval (coords , cells , atom_types , atomic = False )
427435
428436 coord_t = torch .tensor (
429437 coords , dtype = torch .float64 , device = DEVICE
@@ -450,6 +458,25 @@ def test_select_head_matches_single_task_forward(self) -> None:
450458 )
451459 self .assertEqual (de .get_type_map (), src .get_type_map ())
452460
461+ def test_distinct_heads_produce_distinct_outputs (self ) -> None :
462+ """Sanity check that head_a and head_b really resolve to different weights."""
463+ rng = np .random .default_rng (GLOBAL_SEED + 2 )
464+ natoms = 4
465+ coords = rng .random ((1 , natoms , 3 )) * 8.0
466+ cells = np .eye (3 ).reshape (1 , 9 ) * 10.0
467+ atom_types = np .array ([i % 2 for i in range (natoms )], dtype = np .int32 )
468+ e_a = DeepPot (self .pt_path , head = "head_a" ).eval (
469+ coords , cells , atom_types , atomic = False
470+ )[0 ]
471+ e_b = DeepPot (self .pt_path , head = "head_b" ).eval (
472+ coords , cells , atom_types , atomic = False
473+ )[0 ]
474+ self .assertFalse (
475+ np .allclose (e_a , e_b ),
476+ "head_a and head_b produced identical outputs — head selection "
477+ "may be loading the wrong weights" ,
478+ )
479+
453480 def test_missing_head_raises (self ) -> None :
454481 with self .assertRaisesRegex (ValueError , "Head 'no_such_head' not found" ):
455482 DeepPot (self .pt_path , head = "no_such_head" )
@@ -469,7 +496,7 @@ def test_select_head_compiled_layout_matches(self) -> None:
469496
470497 for head , src in (("head_a" , self .model_a ), ("head_b" , self .model_b )):
471498 dp = DeepPot (self .pt_path_compiled , head = head )
472- e , f , v = dp .eval (coords , cells , atom_types , atomic = False )
499+ e , f , _v = dp .eval (coords , cells , atom_types , atomic = False )
473500
474501 coord_t = torch .tensor (
475502 coords , dtype = torch .float64 , device = DEVICE
@@ -610,7 +637,7 @@ def test_metadata_flags_spin(self) -> None:
610637
611638 def test_eval_pbc_atomic_matches_reference (self ) -> None :
612639 dp = DeepPot (self .files [".pt" ])
613- e , f , v , ae , av , fm , mm = dp .eval (
640+ e , f , v , ae , _av , fm , _mm = dp .eval (
614641 self .COORD , self .BOX , self .ATYPE , atomic = True , spin = self .SPIN
615642 )
616643 np .testing .assert_allclose (
@@ -915,7 +942,7 @@ def test_each_head_matches_its_eager_reference(self) -> None:
915942 self .assertEqual (dp .use_spin , [True , False ], msg = f"head={ head } " )
916943
917944 ref = self ._eager_ref (src )
918- e , f , v , ae , av , fm , mm = dp .eval (
945+ e , f , v , _ae , _av , fm , _mm = dp .eval (
919946 self .COORD , self .BOX , self .ATYPE , atomic = True , spin = self .SPIN
920947 )
921948 np .testing .assert_allclose (
0 commit comments