@@ -256,6 +256,99 @@ def test_change_bias_frozen_pte(self) -> None:
256256 "Bias should have changed after change-bias on frozen model" ,
257257 )
258258
259+ def test_change_bias_frozen_pt2 (self ) -> None :
260+ """Change-bias on a .pt2 frozen model."""
261+ from deepmd .pt_expt .entrypoints .main import (
262+ freeze ,
263+ )
264+ from deepmd .pt_expt .model .model import (
265+ BaseModel ,
266+ )
267+ from deepmd .pt_expt .utils .serialization import (
268+ serialize_from_file ,
269+ )
270+
271+ pt2_path = os .path .join (self .tmpdir , "frozen.pt2" )
272+ freeze (model = self .model_path , output = pt2_path )
273+
274+ original_data = serialize_from_file (pt2_path )
275+ original_model = BaseModel .deserialize (original_data ["model" ])
276+ original_bias = to_numpy (original_model .get_out_bias ())
277+
278+ output_pt2 = os .path .join (self .tmpdir , "frozen_updated.pt2" )
279+ run_dp (
280+ f"dp --pt-expt change-bias { pt2_path } "
281+ f"-s { self .data_file [0 ]} -o { output_pt2 } "
282+ )
283+
284+ updated_data = serialize_from_file (output_pt2 )
285+ updated_model = BaseModel .deserialize (updated_data ["model" ])
286+ updated_bias = to_numpy (updated_model .get_out_bias ())
287+
288+ self .assertFalse (
289+ np .allclose (original_bias , updated_bias ),
290+ "Bias should have changed after change-bias on .pt2 model" ,
291+ )
292+
293+ def test_change_bias_frozen_pt2_user_defined (self ) -> None :
294+ """Change-bias with user-defined values on a .pt2 model."""
295+ from deepmd .pt_expt .entrypoints .main import (
296+ freeze ,
297+ )
298+ from deepmd .pt_expt .model .model import (
299+ BaseModel ,
300+ )
301+ from deepmd .pt_expt .utils .serialization import (
302+ serialize_from_file ,
303+ )
304+
305+ pt2_path = os .path .join (self .tmpdir , "frozen_ud.pt2" )
306+ freeze (model = self .model_path , output = pt2_path )
307+
308+ output_pt2 = os .path .join (self .tmpdir , "frozen_ud_updated.pt2" )
309+ run_dp (f"dp --pt-expt change-bias { pt2_path } -b 1.0 2.0 -o { output_pt2 } " )
310+
311+ updated_data = serialize_from_file (output_pt2 )
312+ updated_model = BaseModel .deserialize (updated_data ["model" ])
313+ updated_bias = to_numpy (updated_model .get_out_bias ())
314+
315+ np .testing .assert_allclose (updated_bias .flatten ()[:2 ], [1.0 , 2.0 ], atol = 1e-10 )
316+
317+ def test_change_bias_pt2_pte_consistency (self ) -> None :
318+ """Change-bias on .pte and .pt2 should produce same bias values."""
319+ from deepmd .pt_expt .entrypoints .main import (
320+ freeze ,
321+ )
322+ from deepmd .pt_expt .model .model import (
323+ BaseModel ,
324+ )
325+ from deepmd .pt_expt .utils .serialization import (
326+ serialize_from_file ,
327+ )
328+
329+ pte_path = os .path .join (self .tmpdir , "cons.pte" )
330+ pt2_path = os .path .join (self .tmpdir , "cons.pt2" )
331+ freeze (model = self .model_path , output = pte_path )
332+ freeze (model = self .model_path , output = pt2_path )
333+
334+ output_pte = os .path .join (self .tmpdir , "cons_updated.pte" )
335+ output_pt2 = os .path .join (self .tmpdir , "cons_updated.pt2" )
336+ run_dp (
337+ f"dp --pt-expt change-bias { pte_path } "
338+ f"-s { self .data_file [0 ]} -o { output_pte } "
339+ )
340+ run_dp (
341+ f"dp --pt-expt change-bias { pt2_path } "
342+ f"-s { self .data_file [0 ]} -o { output_pt2 } "
343+ )
344+
345+ pte_data = serialize_from_file (output_pte )
346+ pt2_data = serialize_from_file (output_pt2 )
347+ pte_bias = to_numpy (BaseModel .deserialize (pte_data ["model" ]).get_out_bias ())
348+ pt2_bias = to_numpy (BaseModel .deserialize (pt2_data ["model" ]).get_out_bias ())
349+
350+ np .testing .assert_allclose (pte_bias , pt2_bias , atol = 1e-10 )
351+
259352
260353class TestChangeBiasFittingStats (unittest .TestCase ):
261354 """Test that model_change_out_bias recomputes fitting stats for set-by-statistic."""
0 commit comments