Skip to content

Commit 619f5d1

Browse files
CopilotnjzjzCopilotgithub-advanced-security[bot]pre-commit-ci[bot]
authored andcommitted
style(dpmodel): enforce type annotations (deepmodeling#4953)
This PR fixes critical PyTorch JIT compilation errors that were preventing model serialization and deployment. The main issue was that PyTorch JIT cannot handle certain type annotations, specifically `Union[np.ndarray, Any]` (the `Array` type) and `KeysView[str]` return types. ## Problem PyTorch JIT compilation was failing with errors like: ``` AssertionError: Unsupported annotation typing.Union[numpy.ndarray, typing.Any] could not be resolved because None could not be resolved. ``` This prevented users from using `torch.jit.script()` on trained models, blocking deployment scenarios. ## Solution - **Removed incompatible type annotations** from methods used by PyTorch JIT compilation - **Added `# noqa:ANNxxx` comments** to suppress ruff linting warnings for missing type annotations - **Fixed `KeysView[str]` return types** that caused JIT compilation failures - **Maintained type safety** elsewhere in the codebase where JIT compatibility isn't required ## Files Changed - `deepmd/dpmodel/utils/network.py`: Removed `Array` type annotations from `call` methods and activation functions - `deepmd/dpmodel/output_def.py`: Removed `KeysView[str]` return type annotations from `keys` methods ## Validation All PyTorch JIT tests now pass: - ✅ TestEnergyModelSeA.test_jit - ✅ TestDOSModelSeA.test_jit - ✅ TestEnergyModelDPA1.test_jit - ✅ TestEnergyModelDPA2.test_jit - ✅ TestEnergyModelHybrid.test_jit - ✅ TestEnergyModelHybrid2.test_jit - ✅ TestEnergyModelDPA2IntRcut.test_jit Models can now be successfully JIT compiled for deployment while maintaining full functionality and backward compatibility. <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/deepmodeling/deepmd-kit/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2312b38 commit 619f5d1

61 files changed

Lines changed: 1184 additions & 925 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/array_api.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""Utilities for the array API."""
33

4+
from typing import (
5+
Any,
6+
Callable,
7+
Optional,
8+
Union,
9+
)
10+
411
import array_api_compat
512
import numpy as np
613
from packaging.version import (
714
Version,
815
)
916

17+
# Type alias for array_api compatible arrays
18+
Array = Union[np.ndarray, Any] # Any to support JAX, PyTorch, etc. arrays
19+
1020

11-
def support_array_api(version: str) -> callable:
21+
def support_array_api(version: str) -> Callable:
1222
"""Mark a function as supporting the specific version of the array API.
1323
1424
Parameters
@@ -18,7 +28,7 @@ def support_array_api(version: str) -> callable:
1828
1929
Returns
2030
-------
21-
callable
31+
Callable
2232
The decorated function
2333
2434
Examples
@@ -28,7 +38,7 @@ def support_array_api(version: str) -> callable:
2838
... pass
2939
"""
3040

31-
def set_version(func: callable) -> callable:
41+
def set_version(func: Callable) -> Callable:
3242
func.array_api_version = version
3343
return func
3444

@@ -39,15 +49,15 @@ def set_version(func: callable) -> callable:
3949
# but it hasn't been released yet
4050
# below is a pure Python implementation of take_along_axis
4151
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
42-
def xp_swapaxes(a, axis1, axis2):
52+
def xp_swapaxes(a: Array, axis1: int, axis2: int) -> Array:
4353
xp = array_api_compat.array_namespace(a)
4454
axes = list(range(a.ndim))
4555
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
4656
a = xp.permute_dims(a, axes)
4757
return a
4858

4959

50-
def xp_take_along_axis(arr, indices, axis):
60+
def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
5161
xp = array_api_compat.array_namespace(arr)
5262
if Version(xp.__array_api_version__) >= Version("2024.12"):
5363
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
@@ -76,7 +86,7 @@ def xp_take_along_axis(arr, indices, axis):
7686
return xp_swapaxes(out, axis, -1)
7787

7888

79-
def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
89+
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
8090
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
8191
# jax only
8292
if array_api_compat.is_jax_array(input):
@@ -94,7 +104,7 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
94104
raise NotImplementedError("Only JAX arrays are supported.")
95105

96106

97-
def xp_add_at(x, indices, values):
107+
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
98108
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
99109
xp = array_api_compat.array_namespace(x, indices, values)
100110
if array_api_compat.is_numpy_array(x):
@@ -115,7 +125,7 @@ def xp_add_at(x, indices, values):
115125
return x
116126

117127

118-
def xp_bincount(x, weights=None, minlength=0):
128+
def xp_bincount(x: Array, weights: Optional[Array] = None, minlength: int = 0) -> Array:
119129
"""Counts the number of occurrences of each value in x."""
120130
xp = array_api_compat.array_namespace(x)
121131
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import math
33
from typing import (
4+
Any,
45
Optional,
56
)
67

78
import array_api_compat
89
import numpy as np
910

11+
from deepmd.dpmodel.array_api import (
12+
Array,
13+
)
1014
from deepmd.dpmodel.common import (
1115
NativeOP,
1216
to_numpy_array,
@@ -42,7 +46,7 @@ def __init__(
4246
atom_exclude_types: list[int] = [],
4347
pair_exclude_types: list[tuple[int, int]] = [],
4448
rcond: Optional[float] = None,
45-
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
49+
preset_out_bias: Optional[dict[str, Array]] = None,
4650
) -> None:
4751
super().__init__()
4852
self.type_map = type_map
@@ -68,15 +72,15 @@ def init_out_stat(self) -> None:
6872
self.out_bias = out_bias_data
6973
self.out_std = out_std_data
7074

71-
def __setitem__(self, key, value) -> None:
75+
def __setitem__(self, key: str, value: Array) -> None:
7276
if key in ["out_bias"]:
7377
self.out_bias = value
7478
elif key in ["out_std"]:
7579
self.out_std = value
7680
else:
7781
raise KeyError(key)
7882

79-
def __getitem__(self, key):
83+
def __getitem__(self, key: str) -> Array:
8084
if key in ["out_bias"]:
8185
return self.out_bias
8286
elif key in ["out_std"]:
@@ -129,7 +133,7 @@ def atomic_output_def(self) -> FittingOutputDef:
129133
)
130134

131135
def change_type_map(
132-
self, type_map: list[str], model_with_new_type_stat=None
136+
self, type_map: list[str], model_with_new_type_stat: Optional[Any] = None
133137
) -> None:
134138
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
135139
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -147,13 +151,13 @@ def change_type_map(
147151

148152
def forward_common_atomic(
149153
self,
150-
extended_coord: np.ndarray,
151-
extended_atype: np.ndarray,
152-
nlist: np.ndarray,
153-
mapping: Optional[np.ndarray] = None,
154-
fparam: Optional[np.ndarray] = None,
155-
aparam: Optional[np.ndarray] = None,
156-
) -> dict[str, np.ndarray]:
154+
extended_coord: Array,
155+
extended_atype: Array,
156+
nlist: Array,
157+
mapping: Optional[Array] = None,
158+
fparam: Optional[Array] = None,
159+
aparam: Optional[Array] = None,
160+
) -> dict[str, Array]:
157161
"""Common interface for atomic inference.
158162
159163
This method accept extended coordinates, extended atom typs, neighbor list,
@@ -223,13 +227,13 @@ def forward_common_atomic(
223227

224228
def call(
225229
self,
226-
extended_coord: np.ndarray,
227-
extended_atype: np.ndarray,
228-
nlist: np.ndarray,
229-
mapping: Optional[np.ndarray] = None,
230-
fparam: Optional[np.ndarray] = None,
231-
aparam: Optional[np.ndarray] = None,
232-
) -> dict[str, np.ndarray]:
230+
extended_coord: Array,
231+
extended_atype: Array,
232+
nlist: Array,
233+
mapping: Optional[Array] = None,
234+
fparam: Optional[Array] = None,
235+
aparam: Optional[Array] = None,
236+
) -> dict[str, Array]:
233237
return self.forward_common_atomic(
234238
extended_coord,
235239
extended_atype,
@@ -264,9 +268,9 @@ def deserialize(cls, data: dict) -> "BaseAtomicModel":
264268

265269
def apply_out_stat(
266270
self,
267-
ret: dict[str, np.ndarray],
268-
atype: np.ndarray,
269-
):
271+
ret: dict[str, Array],
272+
atype: Array,
273+
) -> dict[str, Array]:
270274
"""Apply the stat to each atomic output.
271275
The developer may override the method to define how the bias is applied
272276
to the atomic output of the model.
@@ -309,7 +313,7 @@ def _get_bias_index(
309313
def _fetch_out_stat(
310314
self,
311315
keys: list[str],
312-
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
316+
) -> tuple[dict[str, Array], dict[str, Array]]:
313317
ret_bias = {}
314318
ret_std = {}
315319
ntypes = self.get_ntypes()
Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import numpy as np
2+
from typing import (
3+
Any,
4+
)
35

6+
from deepmd.dpmodel.array_api import (
7+
Array,
8+
)
9+
from deepmd.dpmodel.descriptor.base_descriptor import (
10+
BaseDescriptor,
11+
)
12+
from deepmd.dpmodel.fitting.base_fitting import (
13+
BaseFitting,
14+
)
415
from deepmd.dpmodel.fitting.dipole_fitting import (
516
DipoleFitting,
617
)
@@ -11,7 +22,13 @@
1122

1223

1324
class DPDipoleAtomicModel(DPAtomicModel):
14-
def __init__(self, descriptor, fitting, type_map, **kwargs):
25+
def __init__(
26+
self,
27+
descriptor: BaseDescriptor,
28+
fitting: BaseFitting,
29+
type_map: list[str],
30+
**kwargs: Any,
31+
) -> None:
1532
if not isinstance(fitting, DipoleFitting):
1633
raise TypeError(
1734
"fitting must be an instance of DipoleFitting for DPDipoleAtomicModel"
@@ -20,8 +37,8 @@ def __init__(self, descriptor, fitting, type_map, **kwargs):
2037

2138
def apply_out_stat(
2239
self,
23-
ret: dict[str, np.ndarray],
24-
atype: np.ndarray,
25-
):
40+
ret: dict[str, Array],
41+
atype: Array,
42+
) -> dict[str, Array]:
2643
# dipole not applying bias
2744
return ret

deepmd/dpmodel/atomic_model/dos_atomic_model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.base_descriptor import (
7+
BaseDescriptor,
8+
)
9+
from deepmd.dpmodel.fitting.base_fitting import (
10+
BaseFitting,
11+
)
212
from deepmd.dpmodel.fitting.dos_fitting import (
313
DOSFittingNet,
414
)
@@ -9,7 +19,13 @@
919

1020

1121
class DPDOSAtomicModel(DPAtomicModel):
12-
def __init__(self, descriptor, fitting, type_map, **kwargs):
22+
def __init__(
23+
self,
24+
descriptor: BaseDescriptor,
25+
fitting: BaseFitting,
26+
type_map: list[str],
27+
**kwargs: Any,
28+
) -> None:
1329
if not isinstance(fitting, DOSFittingNet):
1430
raise TypeError(
1531
"fitting must be an instance of DOSFittingNet for DPDOSAtomicModel"

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3+
Any,
34
Optional,
45
)
56

6-
import numpy as np
7-
7+
from deepmd.dpmodel.array_api import (
8+
Array,
9+
)
810
from deepmd.dpmodel.descriptor.base_descriptor import (
911
BaseDescriptor,
1012
)
@@ -41,10 +43,10 @@ class DPAtomicModel(BaseAtomicModel):
4143

4244
def __init__(
4345
self,
44-
descriptor,
45-
fitting,
46+
descriptor: BaseDescriptor,
47+
fitting: BaseFitting,
4648
type_map: list[str],
47-
**kwargs,
49+
**kwargs: Any,
4850
) -> None:
4951
super().__init__(type_map, **kwargs)
5052
self.type_map = type_map
@@ -65,7 +67,7 @@ def get_sel(self) -> list[int]:
6567
"""Get the neighbor selection."""
6668
return self.descriptor.get_sel()
6769

68-
def set_case_embd(self, case_idx: int):
70+
def set_case_embd(self, case_idx: int) -> None:
6971
"""
7072
Set the case embedding of this atomic model by the given case_idx,
7173
typically concatenated with the output of the descriptor and fed into the fitting net.
@@ -125,13 +127,13 @@ def enable_compression(
125127

126128
def forward_atomic(
127129
self,
128-
extended_coord: np.ndarray,
129-
extended_atype: np.ndarray,
130-
nlist: np.ndarray,
131-
mapping: Optional[np.ndarray] = None,
132-
fparam: Optional[np.ndarray] = None,
133-
aparam: Optional[np.ndarray] = None,
134-
) -> dict[str, np.ndarray]:
130+
extended_coord: Array,
131+
extended_atype: Array,
132+
nlist: Array,
133+
mapping: Optional[Array] = None,
134+
fparam: Optional[Array] = None,
135+
aparam: Optional[Array] = None,
136+
) -> dict[str, Array]:
135137
"""Models' atomic predictions.
136138
137139
Parameters
@@ -175,7 +177,7 @@ def forward_atomic(
175177
return ret
176178

177179
def change_type_map(
178-
self, type_map: list[str], model_with_new_type_stat=None
180+
self, type_map: list[str], model_with_new_type_stat: Optional[Any] = None
179181
) -> None:
180182
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
181183
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -213,7 +215,7 @@ def serialize(self) -> dict:
213215
"""The base fitting class."""
214216

215217
@classmethod
216-
def deserialize(cls, data) -> "DPAtomicModel":
218+
def deserialize(cls, data: dict[str, Any]) -> "DPAtomicModel":
217219
data = data.copy()
218220
check_version_compatibility(data.pop("@version", 1), 2, 2)
219221
data.pop("@class")

deepmd/dpmodel/atomic_model/energy_atomic_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
26
from deepmd.dpmodel.fitting.ener_fitting import (
37
EnergyFittingNet,
48
InvarFitting,
@@ -10,7 +14,9 @@
1014

1115

1216
class DPEnergyAtomicModel(DPAtomicModel):
13-
def __init__(self, descriptor, fitting, type_map, **kwargs):
17+
def __init__(
18+
self, descriptor: Any, fitting: Any, type_map: list[str], **kwargs: Any
19+
) -> None:
1420
if not (
1521
isinstance(fitting, EnergyFittingNet) or isinstance(fitting, InvarFitting)
1622
):

0 commit comments

Comments
 (0)