Skip to content

Commit 59301c3

Browse files
authored
Merge branch 'master' into dependabot/github_actions/pypa/cibuildwheel-3.4
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
2 parents e8394f4 + 8f2b3c9 commit 59301c3

29 files changed

Lines changed: 3451 additions & 86 deletions

.github/workflows/build_wheel.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
CUDA_VERSION: ${{ matrix.cuda_version }}
6969
DP_PKG_NAME: ${{ matrix.dp_pkg_name }}
7070
CIBW_BUILD_FRONTEND: "uv"
71-
- uses: actions/upload-artifact@v6
71+
- uses: actions/upload-artifact@v7
7272
with:
7373
name: cibw-cp${{ matrix.python }}-${{ matrix.platform_id }}-cu${{ matrix.cuda_version }}-${{ strategy.job-index }}
7474
path: ./wheelhouse/*.whl
@@ -82,7 +82,7 @@ jobs:
8282
- name: Build sdist
8383
run: pipx run uv tool run --with build[uv] --from build python -m build --installer uv --sdist
8484

85-
- uses: actions/upload-artifact@v6
85+
- uses: actions/upload-artifact@v7
8686
with:
8787
name: cibw-sdist
8888
path: dist/*.tar.gz
@@ -95,7 +95,7 @@ jobs:
9595
id-token: write
9696
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
9797
steps:
98-
- uses: actions/download-artifact@v7
98+
- uses: actions/download-artifact@v8
9999
with:
100100
pattern: cibw-*
101101
path: dist
@@ -124,13 +124,13 @@ jobs:
124124
swap-storage: true
125125
docker-images: true
126126
- uses: actions/checkout@v6
127-
- uses: actions/download-artifact@v7
127+
- uses: actions/download-artifact@v8
128128
with:
129129
path: source/install/docker/dist
130130
pattern: cibw-*-manylinux_x86_64-cu${{ matrix.cuda_version }}*
131131
merge-multiple: true
132132
- name: Log in to the Container registry
133-
uses: docker/login-action@v3
133+
uses: docker/login-action@v4
134134
with:
135135
registry: ghcr.io
136136
username: ${{ github.actor }}
@@ -157,7 +157,7 @@ jobs:
157157
needs: [build_wheels, build_sdist]
158158
runs-on: ubuntu-latest
159159
steps:
160-
- uses: actions/download-artifact@v7
160+
- uses: actions/download-artifact@v8
161161
with:
162162
path: dist/packages
163163
pattern: cibw-*

.github/workflows/package_c.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
if: matrix.filename != 'libdeepmd_c.tar.gz'
4848
# for download and debug
4949
- name: Upload artifact
50-
uses: actions/upload-artifact@v6
50+
uses: actions/upload-artifact@v7
5151
with:
5252
name: libdeepmd_c-${{ strategy.job-index }}-${{ matrix.filename }}
5353
path: ${{ matrix.filename }}
@@ -65,7 +65,7 @@ jobs:
6565
steps:
6666
- uses: actions/checkout@v6
6767
- name: Download artifact
68-
uses: actions/download-artifact@v7
68+
uses: actions/download-artifact@v8
6969
with:
7070
pattern: libdeepmd_c-*
7171
merge-multiple: true

.github/workflows/test_python.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
if: matrix.group == 1
7272
- run: mv .test_durations .test_durations_${{ matrix.group }}
7373
- name: Upload partial durations
74-
uses: actions/upload-artifact@v6
74+
uses: actions/upload-artifact@v7
7575
with:
7676
name: split-${{ matrix.python }}-${{ matrix.group }}
7777
path: .test_durations_${{ matrix.group }}
@@ -100,7 +100,7 @@ jobs:
100100
key: test2-durations-combined-${{ matrix.python }}-${{ github.sha }}
101101
restore-keys: test2-durations-combined-${{ matrix.python }}
102102
- name: Download artifacts
103-
uses: actions/download-artifact@v7
103+
uses: actions/download-artifact@v8
104104
with:
105105
pattern: split-${{ matrix.python }}-*
106106
merge-multiple: true

deepmd/backend/pt_expt.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
7676
type[DeepEvalBackend]
7777
The Deep Eval backend of the backend.
7878
"""
79-
raise NotImplementedError
79+
from deepmd.pt_expt.infer.deep_eval import (
80+
DeepEval,
81+
)
82+
83+
return DeepEval
8084

8185
@property
8286
def neighbor_stat(self) -> type["NeighborStat"]:
@@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
8791
type[NeighborStat]
8892
The neighbor statistics of the backend.
8993
"""
90-
raise NotImplementedError
94+
from deepmd.pt_expt.utils.neighbor_stat import (
95+
NeighborStat,
96+
)
97+
98+
return NeighborStat
9199

92100
@property
93101
def serialize_hook(self) -> Callable[[str], dict]:
@@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
98106
Callable[[str], dict]
99107
The serialize hook of the backend.
100108
"""
101-
raise NotImplementedError
109+
from deepmd.pt_expt.utils.serialization import (
110+
serialize_from_file,
111+
)
112+
113+
return serialize_from_file
102114

103115
@property
104116
def deserialize_hook(self) -> Callable[[str, dict], None]:
@@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
109121
Callable[[str, dict], None]
110122
The deserialize hook of the backend.
111123
"""
112-
raise NotImplementedError
124+
from deepmd.pt_expt.utils.serialization import (
125+
deserialize_to_file,
126+
)
127+
128+
return deserialize_to_file

deepmd/dpmodel/array_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
3232
# torch.take_along_dim requires int64 indices
3333
if array_api_compat.is_torch_array(indices):
3434
indices = xp.astype(indices, xp.int64)
35+
if array_api_compat.is_torch_array(arr):
36+
# Use torch.gather directly for torch.export dynamic shape compatibility.
37+
# array_api_compat's take_along_axis / torch.take_along_dim specializes
38+
# the source dimension size to a constant during torch.export tracing,
39+
# breaking dynamic shape export. torch.gather is the underlying
40+
# primitive and handles symbolic shapes correctly.
41+
import torch
42+
43+
return torch.gather(arr, axis, indices)
3544
if Version(xp.__array_api_version__) >= Version("2024.12"):
3645
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
3746
return xp.take_along_axis(arr, indices, axis=axis)
@@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
6271
return xp_swapaxes(out, axis, -1)
6372

6473

74+
def xp_take_first_n(arr: Array, dim: int, n: int) -> Array:
75+
"""Take the first *n* elements along *dim*.
76+
77+
For torch tensors, uses ``torch.index_select`` so that
78+
``torch.export`` does not emit a contiguity guard that would
79+
prevent the ``nall == nloc`` (no-PBC) case from working.
80+
For numpy / jax, uses regular slicing.
81+
"""
82+
if array_api_compat.is_torch_array(arr):
83+
import torch
84+
85+
indices = torch.arange(n, dtype=torch.int64, device=arr.device)
86+
return torch.index_select(arr, dim, indices)
87+
slices = [slice(None)] * arr.ndim
88+
slices[dim] = slice(0, n)
89+
return arr[tuple(slices)]
90+
91+
6592
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
6693
"""Reduces all values from the src tensor to the indices specified in the index tensor.
6794

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from deepmd.dpmodel.array_api import (
1515
Array,
16+
xp_take_first_n,
1617
)
1718
from deepmd.dpmodel.common import (
1819
NativeOP,
@@ -250,7 +251,7 @@ def forward_common_atomic(
250251
"""
251252
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
252253
_, nloc, _ = nlist.shape
253-
atype = extended_atype[:, :nloc]
254+
atype = xp_take_first_n(extended_atype, 1, nloc)
254255
if self.pair_excl is not None:
255256
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
256257
# exclude neighbors in the nlist
@@ -268,7 +269,7 @@ def forward_common_atomic(
268269
ret_dict = self.apply_out_stat(ret_dict, atype)
269270

270271
# nf x nloc
271-
atom_mask = ext_atom_mask[:, :nloc]
272+
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
272273
if self.atom_excl is not None:
273274
atom_mask = xp.logical_and(
274275
atom_mask, self.atom_excl.build_type_exclude_mask(atype)

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from deepmd.dpmodel.array_api import (
1010
Array,
11+
xp_take_first_n,
1112
)
1213
from deepmd.dpmodel.descriptor.base_descriptor import (
1314
BaseDescriptor,
@@ -178,7 +179,7 @@ def forward_atomic(
178179
179180
"""
180181
nframes, nloc, nnei = nlist.shape
181-
atype = extended_atype[:, :nloc]
182+
atype = xp_take_first_n(extended_atype, 1, nloc)
182183
descriptor, rot_mat, g2, h2, sw = self.descriptor(
183184
extended_coord,
184185
extended_atype,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import copy
3+
from typing import (
4+
Any,
5+
)
6+
7+
from deepmd.dpmodel.output_def import (
8+
FittingOutputDef,
9+
)
10+
11+
12+
def make_hessian_model(T_Model: type) -> type:
13+
"""Make a model that can compute Hessian.
14+
15+
With the JAX-mirrored approach, hessian is computed in
16+
``forward_common_atomic`` (in make_model.py) on extended coordinates.
17+
This wrapper only needs to override ``atomic_output_def()`` to set
18+
``r_hessian=True``, and ``communicate_extended_output`` in dpmodel
19+
naturally maps it from nall to nloc.
20+
21+
Parameters
22+
----------
23+
T_Model
24+
The model. Should provide the ``atomic_output_def`` method.
25+
26+
Returns
27+
-------
28+
The model that computes hessian.
29+
30+
"""
31+
32+
class CM(T_Model):
33+
def __init__(
34+
self,
35+
*args: Any,
36+
**kwargs: Any,
37+
) -> None:
38+
super().__init__(
39+
*args,
40+
**kwargs,
41+
)
42+
self.hess_fitting_def = copy.deepcopy(super().atomic_output_def())
43+
44+
def requires_hessian(
45+
self,
46+
keys: str | list[str],
47+
) -> None:
48+
"""Set which output variable(s) requires hessian."""
49+
if isinstance(keys, str):
50+
keys = [keys]
51+
for kk in self.hess_fitting_def.keys():
52+
if kk in keys:
53+
self.hess_fitting_def[kk].r_hessian = True
54+
55+
def atomic_output_def(self) -> FittingOutputDef:
56+
"""Get the fitting output def."""
57+
return self.hess_fitting_def
58+
59+
return CM

deepmd/dpmodel/model/spin_model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
3+
from collections.abc import (
4+
Callable,
5+
)
26
from copy import (
37
deepcopy,
48
)
@@ -332,6 +336,88 @@ def model_output_def(self) -> ModelOutputDef:
332336
backbone_model_atomic_output_def[var_name].magnetic = True
333337
return ModelOutputDef(backbone_model_atomic_output_def)
334338

339+
def _get_spin_sampled_func(
340+
self, sampled_func: Callable[[], list[dict]]
341+
) -> Callable[[], list[dict]]:
342+
"""Get a spin-aware sampled function that transforms spin data for the backbone model.
343+
344+
Parameters
345+
----------
346+
sampled_func
347+
A callable that returns a list of data dicts containing 'coord', 'atype', 'spin', etc.
348+
349+
Returns
350+
-------
351+
Callable
352+
A cached callable that returns spin-preprocessed data dicts.
353+
"""
354+
355+
@functools.lru_cache
356+
def spin_sampled_func() -> list[dict]:
357+
sampled = sampled_func()
358+
spin_sampled = []
359+
for sys in sampled:
360+
coord_updated, atype_updated = self.process_spin_input(
361+
sys["coord"], sys["atype"], sys["spin"]
362+
)
363+
tmp_dict = {
364+
"coord": coord_updated,
365+
"atype": atype_updated,
366+
}
367+
if "natoms" in sys:
368+
natoms = sys["natoms"]
369+
tmp_dict["natoms"] = np.concatenate(
370+
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1
371+
)
372+
for item_key in sys.keys():
373+
if item_key not in ["coord", "atype", "spin", "natoms"]:
374+
tmp_dict[item_key] = sys[item_key]
375+
spin_sampled.append(tmp_dict)
376+
return spin_sampled
377+
378+
return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func)
379+
380+
def change_out_bias(
381+
self,
382+
merged: Callable[[], list[dict]] | list[dict],
383+
bias_adjust_mode: str = "change-by-statistic",
384+
) -> None:
385+
"""Change the output bias of atomic model according to the input data and the pretrained model.
386+
387+
Parameters
388+
----------
389+
merged : Union[Callable[[], list[dict]], list[dict]]
390+
- list[dict]: A list of data samples from various data systems.
391+
Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray`
392+
originating from the `i`-th data system.
393+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
394+
only when needed. Since the sampling process can be slow and memory-intensive,
395+
the lazy function helps by only sampling once.
396+
bias_adjust_mode : str
397+
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
398+
'change-by-statistic' : perform predictions on labels of target dataset,
399+
and do least square on the errors to obtain the target shift as bias.
400+
'set-by-statistic' : directly use the statistic output bias in the target dataset.
401+
"""
402+
spin_sampled_func = self._get_spin_sampled_func(
403+
merged if callable(merged) else lambda: merged
404+
)
405+
self.backbone_model.change_out_bias(
406+
spin_sampled_func,
407+
bias_adjust_mode=bias_adjust_mode,
408+
)
409+
410+
def change_type_map(
411+
self, type_map: list[str], model_with_new_type_stat: Any = None
412+
) -> None:
413+
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
414+
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
415+
"""
416+
type_map_with_spin = type_map + [item + "_spin" for item in type_map]
417+
self.backbone_model.change_type_map(
418+
type_map_with_spin, model_with_new_type_stat
419+
)
420+
335421
def __getattr__(self, name: str) -> Any:
336422
"""Get attribute from the wrapped model."""
337423
if "backbone_model" not in self.__dict__:

0 commit comments

Comments
 (0)