Skip to content

Commit 8f2b3c9

Browse files
fix(pt, dpmodel): spin model finetune (#5281)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Spin-aware preprocessing for training data, plus bias-adjustment and type-mapping that handle spin variants during finetuning. * **Tests** * Added comprehensive unit and end-to-end tests for spin finetuning, bias updates, transformed spin data, default-parameter propagation, and post-finetune inference validation. * **Refactor** * Centralized, cached sampler wrapper to standardize sampling and default-parameter injection. * **Chores** * Truncated bias logging for more concise diagnostics. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 73bb1b7 commit 8f2b3c9

File tree

8 files changed

+1035
-38
lines changed

8 files changed

+1035
-38
lines changed

deepmd/dpmodel/model/spin_model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
3+
from collections.abc import (
4+
Callable,
5+
)
26
from copy import (
37
deepcopy,
48
)
@@ -332,6 +336,88 @@ def model_output_def(self) -> ModelOutputDef:
332336
backbone_model_atomic_output_def[var_name].magnetic = True
333337
return ModelOutputDef(backbone_model_atomic_output_def)
334338

339+
def _get_spin_sampled_func(
340+
self, sampled_func: Callable[[], list[dict]]
341+
) -> Callable[[], list[dict]]:
342+
"""Get a spin-aware sampled function that transforms spin data for the backbone model.
343+
344+
Parameters
345+
----------
346+
sampled_func
347+
A callable that returns a list of data dicts containing 'coord', 'atype', 'spin', etc.
348+
349+
Returns
350+
-------
351+
Callable
352+
A cached callable that returns spin-preprocessed data dicts.
353+
"""
354+
355+
@functools.lru_cache
356+
def spin_sampled_func() -> list[dict]:
357+
sampled = sampled_func()
358+
spin_sampled = []
359+
for sys in sampled:
360+
coord_updated, atype_updated = self.process_spin_input(
361+
sys["coord"], sys["atype"], sys["spin"]
362+
)
363+
tmp_dict = {
364+
"coord": coord_updated,
365+
"atype": atype_updated,
366+
}
367+
if "natoms" in sys:
368+
natoms = sys["natoms"]
369+
tmp_dict["natoms"] = np.concatenate(
370+
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1
371+
)
372+
for item_key in sys.keys():
373+
if item_key not in ["coord", "atype", "spin", "natoms"]:
374+
tmp_dict[item_key] = sys[item_key]
375+
spin_sampled.append(tmp_dict)
376+
return spin_sampled
377+
378+
return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func)
379+
380+
def change_out_bias(
381+
self,
382+
merged: Callable[[], list[dict]] | list[dict],
383+
bias_adjust_mode: str = "change-by-statistic",
384+
) -> None:
385+
"""Change the output bias of atomic model according to the input data and the pretrained model.
386+
387+
Parameters
388+
----------
389+
merged : Union[Callable[[], list[dict]], list[dict]]
390+
- list[dict]: A list of data samples from various data systems.
391+
Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray`
392+
originating from the `i`-th data system.
393+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
394+
only when needed. Since the sampling process can be slow and memory-intensive,
395+
the lazy function helps by only sampling once.
396+
bias_adjust_mode : str
397+
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
398+
'change-by-statistic' : perform predictions on labels of target dataset,
399+
and do least square on the errors to obtain the target shift as bias.
400+
'set-by-statistic' : directly use the statistic output bias in the target dataset.
401+
"""
402+
spin_sampled_func = self._get_spin_sampled_func(
403+
merged if callable(merged) else lambda: merged
404+
)
405+
self.backbone_model.change_out_bias(
406+
spin_sampled_func,
407+
bias_adjust_mode=bias_adjust_mode,
408+
)
409+
410+
def change_type_map(
411+
self, type_map: list[str], model_with_new_type_stat: Any = None
412+
) -> None:
413+
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
414+
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
415+
"""
416+
type_map_with_spin = type_map + [item + "_spin" for item in type_map]
417+
self.backbone_model.change_type_map(
418+
type_map_with_spin, model_with_new_type_stat
419+
)
420+
335421
def __getattr__(self, name: str) -> Any:
336422
"""Get attribute from the wrapped model."""
337423
if "backbone_model" not in self.__dict__:

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
import functools
34
import logging
45
from collections.abc import (
56
Callable,
@@ -184,6 +185,58 @@ def has_default_fparam(self) -> bool:
184185
"""Check if the model has default frame parameters."""
185186
return False
186187

188+
def get_default_fparam(self) -> torch.Tensor | None:
189+
"""Get the default frame parameters."""
190+
return None
191+
192+
def _make_wrapped_sampler(
193+
self,
194+
sampled_func: Callable[[], list[dict]],
195+
) -> Callable[[], list[dict]]:
196+
"""Wrap the sampled function with exclusion types and default fparam.
197+
198+
The returned callable is cached so that the sampling (which may be
199+
expensive) is performed at most once.
200+
201+
Parameters
202+
----------
203+
sampled_func
204+
The lazy sampled function to get data frames from different data
205+
systems.
206+
207+
Returns
208+
-------
209+
Callable[[], list[dict]]
210+
A cached wrapper around *sampled_func* that additionally sets
211+
``pair_exclude_types``, ``atom_exclude_types`` and default
212+
``fparam`` on every sample dict when applicable.
213+
"""
214+
215+
@functools.lru_cache
216+
def wrapped_sampler() -> list[dict]:
217+
sampled = sampled_func()
218+
if self.pair_excl is not None:
219+
pair_exclude_types = self.pair_excl.get_exclude_types()
220+
for sample in sampled:
221+
sample["pair_exclude_types"] = list(pair_exclude_types)
222+
if self.atom_excl is not None:
223+
atom_exclude_types = self.atom_excl.get_exclude_types()
224+
for sample in sampled:
225+
sample["atom_exclude_types"] = list(atom_exclude_types)
226+
if (
227+
"find_fparam" not in sampled[0]
228+
and "fparam" not in sampled[0]
229+
and self.has_default_fparam()
230+
):
231+
default_fparam = self.get_default_fparam()
232+
if default_fparam is not None:
233+
for sample in sampled:
234+
nframe = sample["atype"].shape[0]
235+
sample["fparam"] = default_fparam.repeat(nframe, 1)
236+
return sampled
237+
238+
return wrapped_sampler
239+
187240
def reinit_atom_exclude(
188241
self,
189242
exclude_types: list[int] = [],

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import functools
32
import logging
43
from collections.abc import (
54
Callable,
@@ -329,28 +328,7 @@ def compute_or_load_stat(
329328
# should not share the same parameters
330329
stat_file_path /= " ".join(self.type_map)
331330

332-
@functools.lru_cache
333-
def wrapped_sampler() -> list[dict]:
334-
sampled = sampled_func()
335-
if self.pair_excl is not None:
336-
pair_exclude_types = self.pair_excl.get_exclude_types()
337-
for sample in sampled:
338-
sample["pair_exclude_types"] = list(pair_exclude_types)
339-
if self.atom_excl is not None:
340-
atom_exclude_types = self.atom_excl.get_exclude_types()
341-
for sample in sampled:
342-
sample["atom_exclude_types"] = list(atom_exclude_types)
343-
if (
344-
"find_fparam" not in sampled[0]
345-
and "fparam" not in sampled[0]
346-
and self.has_default_fparam()
347-
):
348-
default_fparam = self.get_default_fparam()
349-
for sample in sampled:
350-
nframe = sample["atype"].shape[0]
351-
sample["fparam"] = default_fparam.repeat(nframe, 1)
352-
return sampled
353-
331+
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
354332
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
355333
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
356334
if compute_or_load_out_stat:

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import functools
32
from collections.abc import (
43
Callable,
54
)
@@ -518,19 +517,7 @@ def compute_or_load_stat(
518517
# should not share the same parameters
519518
stat_file_path /= " ".join(self.type_map)
520519

521-
@functools.lru_cache
522-
def wrapped_sampler() -> list[dict[str, Any]]:
523-
sampled = sampled_func()
524-
if self.pair_excl is not None:
525-
pair_exclude_types = self.pair_excl.get_exclude_types()
526-
for sample in sampled:
527-
sample["pair_exclude_types"] = list(pair_exclude_types)
528-
if self.atom_excl is not None:
529-
atom_exclude_types = self.atom_excl.get_exclude_types()
530-
for sample in sampled:
531-
sample["atom_exclude_types"] = list(atom_exclude_types)
532-
return sampled
533-
520+
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
534521
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
535522

536523

deepmd/pt/model/model/spin_model.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,75 @@ def __getattr__(self, name: str) -> Any:
388388
else:
389389
return getattr(self.backbone_model, name)
390390

391+
def _get_spin_sampled_func(
392+
self, sampled_func: Callable[[], list[dict]]
393+
) -> Callable[[], list[dict]]:
394+
@functools.lru_cache
395+
def spin_sampled_func() -> list[dict]:
396+
sampled = sampled_func()
397+
spin_sampled = []
398+
for sys in sampled:
399+
coord_updated, atype_updated, _ = self.process_spin_input(
400+
sys["coord"], sys["atype"], sys["spin"]
401+
)
402+
tmp_dict = {
403+
"coord": coord_updated,
404+
"atype": atype_updated,
405+
}
406+
if "natoms" in sys:
407+
natoms = sys["natoms"]
408+
tmp_dict["natoms"] = torch.cat(
409+
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1
410+
)
411+
for item_key in sys.keys():
412+
if item_key not in ["coord", "atype", "spin", "natoms"]:
413+
tmp_dict[item_key] = sys[item_key]
414+
spin_sampled.append(tmp_dict)
415+
return spin_sampled
416+
417+
return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func)
418+
419+
def change_out_bias(
420+
self,
421+
merged: Callable[[], list[dict]] | list[dict],
422+
bias_adjust_mode: str = "change-by-statistic",
423+
) -> None:
424+
"""Change the output bias of atomic model according to the input data and the pretrained model.
425+
426+
Parameters
427+
----------
428+
merged : Union[Callable[[], list[dict]], list[dict]]
429+
- list[dict]: A list of data samples from various data systems.
430+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
431+
originating from the `i`-th data system.
432+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
433+
only when needed. Since the sampling process can be slow and memory-intensive,
434+
the lazy function helps by only sampling once.
435+
bias_adjust_mode : str
436+
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
437+
'change-by-statistic' : perform predictions on labels of target dataset,
438+
and do least square on the errors to obtain the target shift as bias.
439+
'set-by-statistic' : directly use the statistic output bias in the target dataset.
440+
"""
441+
spin_sampled_func = self._get_spin_sampled_func(
442+
merged if callable(merged) else lambda: merged
443+
)
444+
self.backbone_model.change_out_bias(
445+
spin_sampled_func,
446+
bias_adjust_mode=bias_adjust_mode,
447+
)
448+
449+
def change_type_map(
450+
self, type_map: list[str], model_with_new_type_stat: Any = None
451+
) -> None:
452+
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
453+
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
454+
"""
455+
type_map_with_spin = type_map + [item + "_spin" for item in type_map]
456+
self.backbone_model.change_type_map(
457+
type_map_with_spin, model_with_new_type_stat
458+
)
459+
391460
def compute_or_load_stat(
392461
self,
393462
sampled_func: Callable[[], list[dict[str, Any]]],

deepmd/pt/train/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,8 @@ def model_change_out_bias(
18331833

18341834
model_type_map = _model.get_type_map()
18351835
log.info(
1836-
f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}."
1836+
f"Change output bias of {model_type_map!s} "
1837+
f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} "
1838+
f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}."
18371839
)
18381840
return _model

0 commit comments

Comments
 (0)