Skip to content

Commit 78d5bde

Browse files
author
Han Wang
committed
test(pt_expt): add .pt2 change-bias tests
- test_change_bias_frozen_pt2: change-bias with data on .pt2 model - test_change_bias_frozen_pt2_user_defined: user-defined bias on .pt2 - test_change_bias_pt2_pte_consistency: .pte and .pt2 produce same bias
1 parent 80aa996 commit 78d5bde

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

source/tests/pt_expt/test_change_bias.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,99 @@ def test_change_bias_frozen_pte(self) -> None:
256256
"Bias should have changed after change-bias on frozen model",
257257
)
258258

259+
def test_change_bias_frozen_pt2(self) -> None:
260+
"""Change-bias on a .pt2 frozen model."""
261+
from deepmd.pt_expt.entrypoints.main import (
262+
freeze,
263+
)
264+
from deepmd.pt_expt.model.model import (
265+
BaseModel,
266+
)
267+
from deepmd.pt_expt.utils.serialization import (
268+
serialize_from_file,
269+
)
270+
271+
pt2_path = os.path.join(self.tmpdir, "frozen.pt2")
272+
freeze(model=self.model_path, output=pt2_path)
273+
274+
original_data = serialize_from_file(pt2_path)
275+
original_model = BaseModel.deserialize(original_data["model"])
276+
original_bias = to_numpy(original_model.get_out_bias())
277+
278+
output_pt2 = os.path.join(self.tmpdir, "frozen_updated.pt2")
279+
run_dp(
280+
f"dp --pt-expt change-bias {pt2_path} "
281+
f"-s {self.data_file[0]} -o {output_pt2}"
282+
)
283+
284+
updated_data = serialize_from_file(output_pt2)
285+
updated_model = BaseModel.deserialize(updated_data["model"])
286+
updated_bias = to_numpy(updated_model.get_out_bias())
287+
288+
self.assertFalse(
289+
np.allclose(original_bias, updated_bias),
290+
"Bias should have changed after change-bias on .pt2 model",
291+
)
292+
293+
def test_change_bias_frozen_pt2_user_defined(self) -> None:
294+
"""Change-bias with user-defined values on a .pt2 model."""
295+
from deepmd.pt_expt.entrypoints.main import (
296+
freeze,
297+
)
298+
from deepmd.pt_expt.model.model import (
299+
BaseModel,
300+
)
301+
from deepmd.pt_expt.utils.serialization import (
302+
serialize_from_file,
303+
)
304+
305+
pt2_path = os.path.join(self.tmpdir, "frozen_ud.pt2")
306+
freeze(model=self.model_path, output=pt2_path)
307+
308+
output_pt2 = os.path.join(self.tmpdir, "frozen_ud_updated.pt2")
309+
run_dp(f"dp --pt-expt change-bias {pt2_path} -b 1.0 2.0 -o {output_pt2}")
310+
311+
updated_data = serialize_from_file(output_pt2)
312+
updated_model = BaseModel.deserialize(updated_data["model"])
313+
updated_bias = to_numpy(updated_model.get_out_bias())
314+
315+
np.testing.assert_allclose(updated_bias.flatten()[:2], [1.0, 2.0], atol=1e-10)
316+
317+
def test_change_bias_pt2_pte_consistency(self) -> None:
318+
"""Change-bias on .pte and .pt2 should produce same bias values."""
319+
from deepmd.pt_expt.entrypoints.main import (
320+
freeze,
321+
)
322+
from deepmd.pt_expt.model.model import (
323+
BaseModel,
324+
)
325+
from deepmd.pt_expt.utils.serialization import (
326+
serialize_from_file,
327+
)
328+
329+
pte_path = os.path.join(self.tmpdir, "cons.pte")
330+
pt2_path = os.path.join(self.tmpdir, "cons.pt2")
331+
freeze(model=self.model_path, output=pte_path)
332+
freeze(model=self.model_path, output=pt2_path)
333+
334+
output_pte = os.path.join(self.tmpdir, "cons_updated.pte")
335+
output_pt2 = os.path.join(self.tmpdir, "cons_updated.pt2")
336+
run_dp(
337+
f"dp --pt-expt change-bias {pte_path} "
338+
f"-s {self.data_file[0]} -o {output_pte}"
339+
)
340+
run_dp(
341+
f"dp --pt-expt change-bias {pt2_path} "
342+
f"-s {self.data_file[0]} -o {output_pt2}"
343+
)
344+
345+
pte_data = serialize_from_file(output_pte)
346+
pt2_data = serialize_from_file(output_pt2)
347+
pte_bias = to_numpy(BaseModel.deserialize(pte_data["model"]).get_out_bias())
348+
pt2_bias = to_numpy(BaseModel.deserialize(pt2_data["model"]).get_out_bias())
349+
350+
np.testing.assert_allclose(pte_bias, pt2_bias, atol=1e-10)
351+
259352

260353
class TestChangeBiasFittingStats(unittest.TestCase):
261354
"""Test that model_change_out_bias recomputes fitting stats for set-by-statistic."""

0 commit comments

Comments
 (0)