Skip to content

Commit e97967b

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add dp change-bias support (#5330)
## Summary Add `dp --pt-expt change-bias` command to adjust the output bias (energy shift per atom type) of pt_expt models without retraining. This brings the pt_expt backend to parity with pt/tf/pd backends for this feature. ### Supported input/output formats | Input | Output | Notes | |-------|--------|-------| | `.pt` checkpoint | `.pt` checkpoint | Modify bias before freezing | | `.pte` frozen model | `.pte` frozen model | Round-trip: deserialize → modify bias → re-export | ### Bias modes - **Data-based** (`-s <data_dir>` or `-f <data_file>`): compute new bias from data via linear regression (`change-by-statistic` or `set-by-statistic`) - **User-defined** (`-b 0.1 3.2 ...`): set bias values directly ### Implementation details **`deepmd/pt_expt/entrypoints/main.py`** — `change_bias()` function + CLI dispatch: - `.pt` input: `torch.load` → extract `model_params` from `_extra_state` → `get_model()` → `ModelWrapper.load_state_dict()` → apply bias → `torch.save()` - `.pte` input: `serialize_from_file()` → `BaseModel.deserialize()` → apply bias → `model.serialize()` → `deserialize_to_file()` - Data loading uses pt_expt's own pipeline: `DeepmdDataSystem` + `make_stat_input` (numpy-based, from `deepmd.utils.model_stat`) - When `numb_batch=0` (default), uses all available batches via `max(data.get_nbatches())` **`deepmd/pt_expt/train/training.py`** — `model_change_out_bias()` helper: - Logs old/new bias values after calling the dpmodel-inherited `change_out_bias()` - Simpler than pt's version: no `DPModelCommon`/`compute_input_stats` check needed since pt_expt models inherit dpmodel's implementation directly ### Usage ```bash # Change bias using data (checkpoint) dp --pt-expt change-bias model.ckpt.pt -s /path/to/data -o updated.pt # Change bias using data file list dp --pt-expt change-bias model.ckpt.pt -f systems.txt -o updated.pt # Set bias to specific values dp --pt-expt change-bias model.ckpt.pt -b 0.1 3.2 -o updated.pt # Change bias on frozen model dp --pt-expt change-bias frozen.pte -s /path/to/data -o updated.pte ``` ## Test plan - [x] `python -m pytest source/tests/pt_expt/test_change_bias.py -v` — 4 end-to-end CLI tests: - `test_change_bias_with_data` — bias changes when using `-s` flag - `test_change_bias_with_data_sys_file` — bias changes when using `-f` flag - `test_change_bias_with_user_defined` — exact match with user-specified values - `test_change_bias_frozen_pte` — freeze → change-bias on `.pte` → verify bias changed - [x] `python -m pytest source/tests/consistent/model/test_ener.py -k test_change_out_bias -v` — cross-backend consistency passes <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a `change-bias` command to adjust a model's output bias by supplying values or computing statistics from systems; supports checkpoint and frozen model formats while preserving model metadata. * Added a model bias-update helper to apply/statistically derive bias adjustments consistently. * **Tests** * End-to-end tests for data-driven, file-list, user-specified, and frozen-model workflows. * Added tests ensuring fitting-statistics are computed when expected and improved fixtures to clear leaked device contexts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 97e9c9e commit e97967b

4 files changed

Lines changed: 537 additions & 3 deletions

File tree

deepmd/pt_expt/entrypoints/main.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,187 @@ def freeze(
263263
log.info("Saved frozen model to %s", output)
264264

265265

266+
def change_bias(
267+
input_file: str,
268+
mode: str = "change",
269+
bias_value: list | None = None,
270+
datafile: str | None = None,
271+
system: str = ".",
272+
numb_batch: int = 0,
273+
model_branch: str | None = None,
274+
output: str | None = None,
275+
) -> None:
276+
"""Change the output bias of a pt_expt model.
277+
278+
Parameters
279+
----------
280+
input_file : str
281+
Path to the model file (.pt checkpoint or .pte frozen model).
282+
mode : str
283+
``"change"`` or ``"set"``.
284+
bias_value : list or None
285+
User-defined bias values (one per type).
286+
datafile : str or None
287+
File listing data system paths.
288+
system : str
289+
Data system path (used when *datafile* is None).
290+
numb_batch : int
291+
Number of batches for statistics (0 = all).
292+
model_branch : str or None
293+
Branch name for multi-task models.
294+
output : str or None
295+
Output file path.
296+
"""
297+
import torch
298+
299+
from deepmd.common import (
300+
expand_sys_str,
301+
)
302+
from deepmd.dpmodel.common import (
303+
to_numpy_array,
304+
)
305+
from deepmd.pt_expt.model.get_model import (
306+
get_model,
307+
)
308+
from deepmd.pt_expt.train.training import (
309+
get_additional_data_requirement,
310+
get_loss,
311+
model_change_out_bias,
312+
)
313+
from deepmd.pt_expt.train.wrapper import (
314+
ModelWrapper,
315+
)
316+
from deepmd.pt_expt.utils.env import (
317+
DEVICE,
318+
)
319+
from deepmd.pt_expt.utils.serialization import (
320+
deserialize_to_file,
321+
serialize_from_file,
322+
)
323+
from deepmd.pt_expt.utils.stat import (
324+
make_stat_input,
325+
)
326+
327+
if input_file.endswith(".pt"):
328+
old_state_dict = torch.load(input_file, map_location=DEVICE, weights_only=True)
329+
if "model" in old_state_dict:
330+
model_state_dict = old_state_dict["model"]
331+
else:
332+
model_state_dict = old_state_dict
333+
extra_state = model_state_dict.get("_extra_state")
334+
if not isinstance(extra_state, dict) or "model_params" not in extra_state:
335+
raise ValueError(
336+
f"Unsupported checkpoint format at '{input_file}': missing "
337+
"'_extra_state.model_params' in model state dict."
338+
)
339+
model_params = extra_state["model_params"]
340+
elif input_file.endswith((".pte", ".pt2")):
341+
pte_data = serialize_from_file(input_file)
342+
from deepmd.pt_expt.model.model import (
343+
BaseModel,
344+
)
345+
346+
model_to_change = BaseModel.deserialize(pte_data["model"])
347+
model_params = None
348+
else:
349+
raise RuntimeError(
350+
"The model provided must be a checkpoint file with a .pt extension "
351+
"or a frozen model with a .pte/.pt2 extension"
352+
)
353+
354+
if mode == "change":
355+
bias_adjust_mode = "change-by-statistic"
356+
elif mode == "set":
357+
bias_adjust_mode = "set-by-statistic"
358+
else:
359+
raise ValueError(f"Unsupported mode '{mode}'. Expected 'change' or 'set'.")
360+
361+
if input_file.endswith(".pt"):
362+
multi_task = "model_dict" in model_params
363+
if multi_task:
364+
raise NotImplementedError(
365+
"Multi-task change-bias is not yet supported for the pt_expt backend."
366+
)
367+
type_map = model_params["type_map"]
368+
model = get_model(model_params)
369+
wrapper = ModelWrapper(model)
370+
wrapper.load_state_dict(model_state_dict)
371+
model_to_change = model
372+
373+
if input_file.endswith((".pte", ".pt2")):
374+
type_map = model_to_change.get_type_map()
375+
376+
if bias_value is not None:
377+
if "energy" not in model_to_change.model_output_type():
378+
raise ValueError("User-defined bias is only available for energy models!")
379+
if len(bias_value) != len(type_map):
380+
raise ValueError(
381+
f"The number of elements in the bias ({len(bias_value)}) must match "
382+
f"the number of types in type_map ({len(type_map)}): {type_map}."
383+
)
384+
old_bias = model_to_change.get_out_bias()
385+
bias_to_set = torch.tensor(
386+
bias_value, dtype=old_bias.dtype, device=old_bias.device
387+
).view(old_bias.shape)
388+
model_to_change.set_out_bias(bias_to_set)
389+
log.info(
390+
f"Change output bias of {type_map!s} "
391+
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
392+
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
393+
)
394+
else:
395+
if datafile is not None:
396+
with open(datafile) as datalist:
397+
all_sys = datalist.read().splitlines()
398+
else:
399+
all_sys = expand_sys_str(system)
400+
data_systems = process_systems(all_sys)
401+
data = DeepmdDataSystem(
402+
systems=data_systems,
403+
batch_size=1,
404+
test_size=1,
405+
rcut=model_to_change.get_rcut(),
406+
type_map=type_map,
407+
)
408+
mock_loss = get_loss({"inference": True}, 1.0, len(type_map), model_to_change)
409+
data.add_data_requirements(mock_loss.label_requirement)
410+
data.add_data_requirements(get_additional_data_requirement(model_to_change))
411+
if numb_batch != 0:
412+
nbatches = numb_batch
413+
else:
414+
# Cap at the minimum across systems so no system wraps and
415+
# overweights short systems (matching PT behavior).
416+
nbatches = min(data.get_nbatches())
417+
sampled_data = make_stat_input(data, nbatches)
418+
model_to_change = model_change_out_bias(
419+
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
420+
)
421+
422+
if input_file.endswith(".pt"):
423+
output_path = (
424+
output if output is not None else input_file.replace(".pt", "_updated.pt")
425+
)
426+
wrapper = ModelWrapper(model_to_change)
427+
if "model" in old_state_dict:
428+
old_state_dict["model"] = wrapper.state_dict()
429+
old_state_dict["model"]["_extra_state"] = extra_state
430+
else:
431+
old_state_dict = wrapper.state_dict()
432+
old_state_dict["_extra_state"] = extra_state
433+
torch.save(old_state_dict, output_path)
434+
elif input_file.endswith((".pte", ".pt2")):
435+
output_path = (
436+
output
437+
if output is not None
438+
else input_file.replace(".pte", "_updated.pte").replace(
439+
".pt2", "_updated.pt2"
440+
)
441+
)
442+
model_dict = model_to_change.serialize()
443+
deserialize_to_file(output_path, {"model": model_dict})
444+
log.info(f"Saved model to {output_path}")
445+
446+
266447
def main(args: list[str] | argparse.Namespace | None = None) -> None:
267448
"""Entry point for the pt_expt backend CLI.
268449
@@ -323,6 +504,17 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
323504
if not FLAGS.output.endswith((".pte", ".pt2")):
324505
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
325506
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
507+
elif FLAGS.command == "change-bias":
508+
change_bias(
509+
input_file=FLAGS.INPUT,
510+
mode=FLAGS.mode,
511+
bias_value=FLAGS.bias_value,
512+
datafile=FLAGS.datafile,
513+
system=FLAGS.system,
514+
numb_batch=FLAGS.numb_batch,
515+
model_branch=FLAGS.model_branch,
516+
output=FLAGS.output,
517+
)
326518
elif FLAGS.command == "compress":
327519
from deepmd.pt_expt.entrypoints.compress import (
328520
enable_compression,

deepmd/pt_expt/train/training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,14 @@ def model_change_out_bias(
10291029
bias_adjust_mode=_bias_adjust_mode,
10301030
)
10311031
new_bias = deepcopy(_model.get_out_bias())
1032+
1033+
from deepmd.dpmodel.model.dp_model import (
1034+
DPModelCommon,
1035+
)
1036+
1037+
if isinstance(_model, DPModelCommon) and _bias_adjust_mode == "set-by-statistic":
1038+
_model.get_fitting_net().compute_input_stats(_sample_func)
1039+
10321040
model_type_map = _model.get_type_map()
10331041
log.info(
10341042
f"Change output bias of {model_type_map!s} "

source/tests/pt_expt/conftest.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
)
3333

3434

35-
@pytest.fixture(autouse=True)
36-
def _clear_leaked_device_context():
37-
"""Pop any stale ``DeviceContext`` before each test, restore after."""
35+
def _pop_device_contexts() -> list:
36+
"""Pop all stale DeviceContext modes from the torch function mode stack."""
3837
popped = []
3938
while True:
4039
modes = _get_current_function_mode_stack()
@@ -46,6 +45,24 @@ def _clear_leaked_device_context():
4645
popped.append(top)
4746
else:
4847
break
48+
return popped
49+
50+
51+
@pytest.fixture(autouse=True, scope="session")
52+
def _clear_leaked_device_context_session():
53+
"""Pop any stale DeviceContext once at session start.
54+
55+
This runs before any setUpClass, preventing CUDA init errors
56+
in tests that call trainer.run() during class setup.
57+
"""
58+
_pop_device_contexts()
59+
yield
60+
61+
62+
@pytest.fixture(autouse=True)
63+
def _clear_leaked_device_context():
64+
"""Pop any stale ``DeviceContext`` before each test, restore after."""
65+
popped = _pop_device_contexts()
4966
yield
5067
# Restore in reverse order so the stack is back to its original state.
5168
for ctx in reversed(popped):

0 commit comments

Comments
 (0)