Skip to content

Commit 99a6f3c

Browse files
Copilotnjzjz
andauthored
feat: remove unused support_array_api decorator (#5200)
The `support_array_api` decorator defined in `deepmd/dpmodel/array_api.py` was unused metadata that added no functional value. It set an `array_api_version` attribute on decorated functions, but this attribute was never checked anywhere in the codebase. ## Changes - **Removed decorator definition** from `deepmd/dpmodel/array_api.py` (30 lines) - **Removed decorator usages** from: - `deepmd/dpmodel/utils/env_mat.py` - `compute_smooth_weight()`, `compute_exp_sw()` - `deepmd/dpmodel/utils/network.py` - `NativeLayer.call()`, `get_activation_fn()` - `deepmd/dpmodel/utils/type_embed.py` - `TypeEmbedNet.call()` - **Removed unused test utility** `set_array_api_version()` from `source/tests/common/dpmodel/array_api/utils.py` - **Cleaned up imports** - removed `Callable` and `set_array_api_strict_flags` where no longer needed Functions previously decorated now have unchanged behavior: ```python # Before @support_array_api(version="2023.12") def compute_smooth_weight(distance: Array, rmin: float, rmax: float) -> Array: ... # After def compute_smooth_weight(distance: Array, rmin: float, rmax: float) -> Array: ... ``` Total: 45 lines deleted across 5 files. <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>[Feature Request] Remove unused support_array_api decorator to simplify codebase</issue_title> > <issue_description>### Summary > > Propose to remove the rarely used `support_array_api` decorator defined at [deepmd/dpmodel/array_api.py#L21](https://github.com/deepmodeling/deepmd-kit/blob/ef0be638972739463813609729075bf38973dd9d/deepmd/dpmodel/array_api.py#L21) to simplify and streamline the codebase. > > ### Detailed Description > > The `support_array_api` decorator is present in the codebase but is used in only a few isolated functions and has not contributed to significant or widely used features. Its presence now causes unnecessary clutter and maintenance complexity. > > By removing this unused decorator and its usages, we can improve overall code maintainability and reduce confusion for contributors. If a future need for such a decorator arises, it can be revised and reintroduced when there is a broader adoption scenario. > > ### Further Information, Files, and Links > > _No response_</issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #5199 <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent ef0be63 commit 99a6f3c

5 files changed

Lines changed: 0 additions & 45 deletions

File tree

deepmd/dpmodel/array_api.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""Utilities for the array API."""
33

4-
from collections.abc import (
5-
Callable,
6-
)
74
from typing import (
85
Any,
96
)
@@ -18,33 +15,6 @@
1815
Array = np.ndarray | Any # Any to support JAX, PyTorch, etc. arrays
1916

2017

21-
def support_array_api(version: str) -> Callable:
22-
"""Mark a function as supporting the specific version of the array API.
23-
24-
Parameters
25-
----------
26-
version : str
27-
The version of the array API
28-
29-
Returns
30-
-------
31-
Callable
32-
The decorated function
33-
34-
Examples
35-
--------
36-
>>> @support_array_api(version="2022.12")
37-
... def f(x):
38-
... pass
39-
"""
40-
41-
def set_version(func: Callable) -> Callable:
42-
func.array_api_version = version
43-
return func
44-
45-
return set_version
46-
47-
4818
# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
4919
# but it hasn't been released yet
5020
# below is a pure Python implementation of take_along_axis

deepmd/dpmodel/utils/env_mat.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010
)
1111
from deepmd.dpmodel.array_api import (
1212
Array,
13-
support_array_api,
1413
xp_take_along_axis,
1514
)
1615
from deepmd.dpmodel.utils.safe_gradient import (
1716
safe_for_vector_norm,
1817
)
1918

2019

21-
@support_array_api(version="2023.12")
2220
def compute_smooth_weight(
2321
distance: Array,
2422
rmin: float,
@@ -35,7 +33,6 @@ def compute_smooth_weight(
3533
return vv
3634

3735

38-
@support_array_api(version="2023.12")
3936
def compute_exp_sw(
4037
distance: Array,
4138
rmin: float,

deepmd/dpmodel/utils/network.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
)
2424
from deepmd.dpmodel.array_api import (
2525
Array,
26-
support_array_api,
2726
xp_add_at,
2827
xp_bincount,
2928
)
@@ -259,7 +258,6 @@ def dim_in(self) -> int:
259258
def dim_out(self) -> int:
260259
return self.w.shape[1]
261260

262-
@support_array_api(version="2022.12")
263261
def call(self, x): # noqa: ANN001, ANN201
264262
"""Forward pass.
265263
@@ -296,7 +294,6 @@ def call(self, x): # noqa: ANN001, ANN201
296294
return y
297295

298296

299-
@support_array_api(version="2022.12")
300297
def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
301298
activation_function = activation_function.lower()
302299
if activation_function == "tanh":

deepmd/dpmodel/utils/type_embed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from deepmd.dpmodel.array_api import (
1010
Array,
11-
support_array_api,
1211
)
1312
from deepmd.dpmodel.common import (
1413
PRECISION_DICT,
@@ -96,7 +95,6 @@ def __init__(
9695
trainable=trainable,
9796
)
9897

99-
@support_array_api(version="2022.12")
10098
def call(self) -> Array:
10199
"""Compute the type embedding network."""
102100
sample_array = self.embedding_net[0]["w"]

source/tests/common/dpmodel/array_api/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import array_api_compat
3-
from array_api_strict import (
4-
set_array_api_strict_flags,
5-
)
63

74

85
class ArrayAPITest:
96
"""Utils for array API tests."""
107

11-
def set_array_api_version(self, func) -> None:
12-
"""Set the array API version for a function."""
13-
set_array_api_strict_flags(api_version=func.array_api_version)
14-
158
def assert_namespace_equal(self, a, b) -> None:
169
"""Assert two array has the same namespace."""
1710
self.assertEqual(

0 commit comments

Comments
 (0)