diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index ffd58daf5e..474fcb03c7 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + from deepmd.jax.common import ( ArrayAPIVariable, to_jax_array, @@ -9,7 +13,7 @@ ) -def base_atomic_model_set_attr(name, value): +def base_atomic_model_set_attr(name: str, value: Any) -> Any: if name in {"out_bias", "out_std"}: value = to_jax_array(value) if value is not None: diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 59f36d11ad..14ae1cad9d 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -70,11 +70,11 @@ def flax_module( metas.add(type(nnx.Module)) class MixedMetaClass(*metas): - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return type(nnx.Module).__call__(self, *args, **kwargs) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: return super().__init_subclass__(**kwargs) def __setattr__(self, name: str, value: Any) -> None: @@ -84,20 +84,22 @@ def __setattr__(self, name: str, value: Any) -> None: class ArrayAPIVariable(nnx.Variable): - def __array__(self, *args, **kwargs): + def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray: return self.value.__array__(*args, **kwargs) - def __array_namespace__(self, *args, **kwargs): + def __array_namespace__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__array_namespace__(*args, **kwargs) - def __dlpack__(self, *args, **kwargs): + def __dlpack__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack__(*args, **kwargs) - def __dlpack_device__(self, *args, **kwargs): + def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack_device__(*args, **kwargs) -def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray: +def scatter_sum( + input: jnp.ndarray, dim: int, index: jnp.ndarray, src: jnp.ndarray +) -> jnp.ndarray: """Reduces all values from the src tensor to the indices specified in the index tensor.""" idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape) new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel() diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 2e74c15fff..92ed78a13e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -301,7 +301,7 @@ def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Calla """ if self.auto_batch_size is not None: - def eval_func(*args, **kwargs): + def eval_func(*args: Any, **kwargs: Any) -> Any: return self.auto_batch_size.execute_all( inner_func, numb_test, natoms, *args, **kwargs ) @@ -335,7 +335,7 @@ def _eval_model( fparam: Optional[np.ndarray], aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], - ): + ) -> tuple[np.ndarray, ...]: model = self.dp nframes = coords.shape[0] @@ -395,7 +395,9 @@ def _eval_model( ) # this is kinda hacky return tuple(results) - def _get_output_shape(self, odef, nframes, natoms): + def _get_output_shape( + self, odef: OutputVariableDef, nframes: int, natoms: int + ) -> list[int]: if odef.category == OutputVariableCategory.DERV_C_REDU: # virial return [nframes, *odef.shape[:-1], 9] diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py index f0c630206f..5cf93610e7 100644 --- a/deepmd/jax/jax2tf/format_nlist.py +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -9,7 +9,7 @@ def format_nlist( nlist: tnp.ndarray, nsel: int, rcut: float, -): +) -> tnp.ndarray: """Format neighbor list. If nnei == nsel, do nothing; diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index 29ed131f8e..341fdf0d1f 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -44,7 +44,7 @@ def model_call_from_call_lower( fparam: tnp.ndarray, aparam: tnp.ndarray, do_atomic_virial: bool = False, -): +) -> dict[str, tnp.ndarray]: """Return model prediction from lower interface. Parameters diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py index 5a0ed58b63..f85526f1e9 100644 --- a/deepmd/jax/jax2tf/nlist.py +++ b/deepmd/jax/jax2tf/nlist.py @@ -115,7 +115,7 @@ def nlist_distinguish_types( nlist: tnp.ndarray, atype: tnp.ndarray, sel: list[int], -): +) -> tnp.ndarray: """Given a nlist that does not distinguish atom types, return a nlist that distinguish atom types. @@ -140,7 +140,7 @@ def nlist_distinguish_types( return ret -def tf_outer(a, b): +def tf_outer(a: tnp.ndarray, b: tnp.ndarray) -> tnp.ndarray: return tf.einsum("i,j->ij", a, b) @@ -150,7 +150,7 @@ def extend_coord_with_ghosts( atype: tnp.ndarray, cell: tnp.ndarray, rcut: float, -): +) -> tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray]: """Extend the coordinates of the atoms by appending peridoc images. The number of images is large enough to ensure all the neighbors within rcut are appended. diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py index 96024bd79a..a90e693478 100644 --- a/deepmd/jax/jax2tf/region.py +++ b/deepmd/jax/jax2tf/region.py @@ -93,7 +93,7 @@ def to_face_distance( return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0)) -def b_to_face_distance(cell): +def b_to_face_distance(cell: tnp.ndarray) -> tnp.ndarray: volume = tf.linalg.det(cell) c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) _h2yz = volume / tf.linalg.norm(c_yz, axis=-1) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index aac022ace9..096fc41e5a 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json from typing import ( + Callable, Optional, ) @@ -38,10 +39,17 @@ def deserialize_to_file(model_file: str, data: dict) -> None: tf_model = tf.Module() - def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms): + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ) -> Callable: def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: return call_lower( coord, atype, @@ -86,8 +94,13 @@ def call_lower_with_fixed_do_atomic_virial( ], ) def call_lower_without_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], @@ -112,7 +125,14 @@ def call_lower_without_atomic_virial( tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], ) - def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + def call_lower_with_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + nlist: tnp.ndarray, + mapping: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ) -> dict[str, tnp.ndarray]: nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], @@ -126,7 +146,7 @@ def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial - def make_call_whether_do_atomic_virial(do_atomic_virial: bool): + def make_call_whether_do_atomic_virial(do_atomic_virial: bool) -> Callable: if do_atomic_virial: call_lower = call_lower_with_atomic_virial else: @@ -138,7 +158,7 @@ def call( box: Optional[tnp.ndarray] = None, fparam: Optional[tnp.ndarray] = None, aparam: Optional[tnp.ndarray] = None, - ): + ) -> dict[str, tnp.ndarray]: """Return model prediction. Parameters @@ -194,7 +214,7 @@ def call_with_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, - ): + ) -> dict[str, tnp.ndarray]: return make_call_whether_do_atomic_virial(do_atomic_virial=True)( coord, atype, box, fparam, aparam ) @@ -217,7 +237,7 @@ def call_without_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, - ): + ) -> dict[str, tnp.ndarray]: return make_call_whether_do_atomic_virial(do_atomic_virial=False)( coord, atype, box, fparam, aparam ) @@ -226,49 +246,49 @@ def call_without_atomic_virial( # set functions to export other attributes @tf.function - def get_type_map(): + def get_type_map() -> tf.Tensor: return tf.constant(model.get_type_map(), dtype=tf.string) tf_model.get_type_map = get_type_map @tf.function - def get_rcut(): + def get_rcut() -> tf.Tensor: return tf.constant(model.get_rcut(), dtype=tf.double) tf_model.get_rcut = get_rcut @tf.function - def get_dim_fparam(): + def get_dim_fparam() -> tf.Tensor: return tf.constant(model.get_dim_fparam(), dtype=tf.int64) tf_model.get_dim_fparam = get_dim_fparam @tf.function - def get_dim_aparam(): + def get_dim_aparam() -> tf.Tensor: return tf.constant(model.get_dim_aparam(), dtype=tf.int64) tf_model.get_dim_aparam = get_dim_aparam @tf.function - def get_sel_type(): + def get_sel_type() -> tf.Tensor: return tf.constant(model.get_sel_type(), dtype=tf.int64) tf_model.get_sel_type = get_sel_type @tf.function - def is_aparam_nall(): + def is_aparam_nall() -> tf.Tensor: return tf.constant(model.is_aparam_nall(), dtype=tf.bool) tf_model.is_aparam_nall = is_aparam_nall @tf.function - def model_output_type(): + def model_output_type() -> tf.Tensor: return tf.constant(model.model_output_type(), dtype=tf.string) tf_model.model_output_type = model_output_type @tf.function - def mixed_types(): + def mixed_types() -> tf.Tensor: return tf.constant(model.mixed_types(), dtype=tf.bool) tf_model.mixed_types = mixed_types @@ -276,19 +296,19 @@ def mixed_types(): if model.get_min_nbor_dist() is not None: @tf.function - def get_min_nbor_dist(): + def get_min_nbor_dist() -> tf.Tensor: return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) tf_model.get_min_nbor_dist = get_min_nbor_dist @tf.function - def get_sel(): + def get_sel() -> tf.Tensor: return tf.constant(model.get_sel(), dtype=tf.int64) tf_model.get_sel = get_sel @tf.function - def get_model_def_script(): + def get_model_def_script() -> tf.Tensor: return tf.constant( json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 0d7b13ba1f..61c83fa028 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -45,7 +45,7 @@ def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: class TFModelWrapper(tf.Module): def __init__( self, - model, + model: str, ) -> None: self.model = tf.saved_model.load(model) self._call_lower = jax2tf.call_tf(self.model.call_lower) @@ -115,7 +115,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: """Return model prediction. Parameters @@ -165,7 +165,7 @@ def call( aparam, ) - def model_output_def(self): + def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) ) @@ -179,7 +179,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: if do_atomic_virial: call_lower = self._call_lower_atomic_virial else: @@ -207,15 +207,15 @@ def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map - def get_rcut(self): + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut - def get_dim_fparam(self): + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam - def get_dim_aparam(self): + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 34ee765459..203da40d07 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -20,7 +20,7 @@ def forward_common_atomic( - self, + self: "BaseModel", extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, @@ -28,7 +28,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, -): +) -> dict[str, jnp.ndarray]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, extended_atype, @@ -60,16 +60,16 @@ def forward_common_atomic( if vdef.r_differentiable: def eval_output( - cc_ext, - extended_atype, - nlist, - mapping, - fparam, - aparam, + cc_ext: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray], + fparam: Optional[jnp.ndarray], + aparam: Optional[jnp.ndarray], *, - _kk=kk, - _atom_axis=atom_axis, - ): + _kk: str = kk, + _atom_axis: int = atom_axis, + ) -> jnp.ndarray: atomic_ret = self.atomic_model.forward_common_atomic( cc_ext[None, ...], extended_atype[None, ...], @@ -117,16 +117,16 @@ def eval_output( if do_atomic_virial: def eval_ce( - cc_ext, - extended_atype, - nlist, - mapping, - fparam, - aparam, + cc_ext: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray], + fparam: Optional[jnp.ndarray], + aparam: Optional[jnp.ndarray], *, - _kk=kk, - _atom_axis=atom_axis - 1, - ): + _kk: str = kk, + _atom_axis: int = atom_axis - 1, + ) -> jnp.ndarray: # atomic_ret[_kk]: [nf, nloc, *def] atomic_ret = self.atomic_model.forward_common_atomic( cc_ext[None, ...], diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py index 436582f22b..ee98a689e4 100644 --- a/deepmd/jax/model/dp_model.py +++ b/deepmd/jax/model/dp_model.py @@ -56,7 +56,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, extended_coord, @@ -74,7 +74,7 @@ def format_nlist( extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, - ): + ) -> jnp.ndarray: return dpmodel_model.format_nlist( self, jax.lax.stop_gradient(extended_coord), diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py index babbc65233..065dbc7aa7 100644 --- a/deepmd/jax/model/dp_zbl_model.py +++ b/deepmd/jax/model/dp_zbl_model.py @@ -38,7 +38,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, extended_coord, @@ -56,7 +56,7 @@ def format_nlist( extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, - ): + ) -> jnp.ndarray: return DPZBLModelDP.format_nlist( self, jax.lax.stop_gradient(extended_coord), diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 4d59957456..cbeb915329 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -44,21 +44,21 @@ class HLO(BaseModel): def __init__( self, - stablehlo, - stablehlo_atomic_virial, - stablehlo_no_ghost, - stablehlo_atomic_virial_no_ghost, - model_def_script, - type_map, - rcut, - dim_fparam, - dim_aparam, - sel_type, - is_aparam_nall, - model_output_type, - mixed_types, - min_nbor_dist, - sel, + stablehlo: bytearray, + stablehlo_atomic_virial: bytearray, + stablehlo_no_ghost: bytearray, + stablehlo_atomic_virial_no_ghost: bytearray, + model_def_script: str, + type_map: list[str], + rcut: float, + dim_fparam: int, + dim_aparam: int, + sel_type: list[int], + is_aparam_nall: bool, + model_output_type: str, + mixed_types: bool, + min_nbor_dist: Optional[float], + sel: list[int], ) -> None: self._call_lower = jax_export.deserialize(stablehlo).call self._call_lower_atomic_virial = jax_export.deserialize( @@ -125,7 +125,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: """Return model prediction. Parameters @@ -165,7 +165,7 @@ def call( do_atomic_virial=do_atomic_virial, ) - def model_output_def(self): + def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) ) @@ -179,7 +179,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, - ): + ) -> dict[str, jnp.ndarray]: if extended_coord.shape[1] > nlist.shape[1]: if do_atomic_virial: call_lower = self._call_lower_atomic_virial @@ -203,15 +203,15 @@ def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map - def get_rcut(self): + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut - def get_dim_fparam(self): + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam - def get_dim_aparam(self): + def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index dc350e968c..321f33b315 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -26,7 +26,7 @@ ) -def get_standard_model(data: dict): +def get_standard_model(data: dict) -> BaseModel: """Get a Model from a dictionary. Parameters @@ -103,7 +103,7 @@ def get_zbl_model(data: dict) -> DPZBLModel: ) -def get_model(data: dict): +def get_model(data: dict) -> BaseModel: """Get a model from a dictionary. Parameters diff --git a/deepmd/jax/utils/neighbor_stat.py b/deepmd/jax/utils/neighbor_stat.py index 6d9bc872e8..ddfc4199a3 100644 --- a/deepmd/jax/utils/neighbor_stat.py +++ b/deepmd/jax/utils/neighbor_stat.py @@ -82,7 +82,7 @@ def _execute( coord: np.ndarray, atype: np.ndarray, cell: Optional[np.ndarray], - ): + ) -> tuple[np.ndarray, np.ndarray]: """Execute the operation. Parameters diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 78da4c96f5..5a42323b90 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -4,6 +4,8 @@ ClassVar, ) +import numpy as np + from deepmd.dpmodel.common import ( NativeOP, ) @@ -26,16 +28,16 @@ class ArrayAPIParam(nnx.Param): - def __array__(self, *args, **kwargs): + def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray: return self.value.__array__(*args, **kwargs) - def __array_namespace__(self, *args, **kwargs): + def __array_namespace__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__array_namespace__(*args, **kwargs) - def __dlpack__(self, *args, **kwargs): + def __dlpack__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack__(*args, **kwargs) - def __dlpack_device__(self, *args, **kwargs): + def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any: return self.value.__dlpack_device__(*args, **kwargs) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d4da49e08..6a3c839608 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -55,10 +55,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial( do_atomic_virial: bool, has_ghost_atoms: bool - ): + ) -> "jax_export.Exported": def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam - ): + coord: jnp.ndarray, + atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray, + fparam: jnp.ndarray, + aparam: jnp.ndarray, + ) -> dict[str, jnp.ndarray]: return call_lower( coord, atype, diff --git a/pyproject.toml b/pyproject.toml index 554a113ca2..3d0b7f1b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -424,7 +424,7 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "data/**" = ["ANN"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253"] -"deepmd/jax/**" = ["TID253", "ANN"] +"deepmd/jax/**" = ["TID253"] "deepmd/pd/**" = ["TID253", "ANN"] "source/**" = ["ANN"]