Skip to content
6 changes: 5 additions & 1 deletion deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.jax.common import (
ArrayAPIVariable,
to_jax_array,
Expand All @@ -9,7 +13,7 @@
)


def base_atomic_model_set_attr(name, value):
def base_atomic_model_set_attr(name: str, value: Any) -> Any:
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
if value is not None:
Expand Down
16 changes: 9 additions & 7 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def flax_module(
metas.add(type(nnx.Module))

class MixedMetaClass(*metas):
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return type(nnx.Module).__call__(self, *args, **kwargs)

class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
def __init_subclass__(cls, **kwargs) -> None:
def __init_subclass__(cls, **kwargs: Any) -> None:
return super().__init_subclass__(**kwargs)

def __setattr__(self, name: str, value: Any) -> None:
Expand All @@ -84,20 +84,22 @@ def __setattr__(self, name: str, value: Any) -> None:


class ArrayAPIVariable(nnx.Variable):
def __array__(self, *args, **kwargs):
def __array__(self, *args: Any, **kwargs: Any) -> np.ndarray:
return self.value.__array__(*args, **kwargs)

def __array_namespace__(self, *args, **kwargs):
def __array_namespace__(self, *args: Any, **kwargs: Any) -> Any:
return self.value.__array_namespace__(*args, **kwargs)

def __dlpack__(self, *args, **kwargs):
def __dlpack__(self, *args: Any, **kwargs: Any) -> Any:
return self.value.__dlpack__(*args, **kwargs)

def __dlpack_device__(self, *args, **kwargs):
def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any:
return self.value.__dlpack_device__(*args, **kwargs)


def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
def scatter_sum(
input: jnp.ndarray, dim: int, index: jnp.ndarray, src: jnp.ndarray
) -> jnp.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
Expand Down
8 changes: 5 additions & 3 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Calla
"""
if self.auto_batch_size is not None:

def eval_func(*args, **kwargs):
def eval_func(*args: Any, **kwargs: Any) -> Any:
return self.auto_batch_size.execute_all(
inner_func, numb_test, natoms, *args, **kwargs
)
Expand Down Expand Up @@ -335,7 +335,7 @@ def _eval_model(
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
) -> tuple[np.ndarray, ...]:
model = self.dp

nframes = coords.shape[0]
Expand Down Expand Up @@ -395,7 +395,9 @@ def _eval_model(
) # this is kinda hacky
return tuple(results)

def _get_output_shape(self, odef, nframes, natoms):
def _get_output_shape(
self, odef: OutputVariableDef, nframes: int, natoms: int
) -> list[int]:
if odef.category == OutputVariableCategory.DERV_C_REDU:
# virial
return [nframes, *odef.shape[:-1], 9]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def format_nlist(
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
) -> tnp.ndarray:
"""Format neighbor list.

If nnei == nsel, do nothing;
Expand Down
2 changes: 1 addition & 1 deletion deepmd/jax/jax2tf/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def model_call_from_call_lower(
fparam: tnp.ndarray,
aparam: tnp.ndarray,
do_atomic_virial: bool = False,
):
) -> dict[str, tnp.ndarray]:
"""Return model prediction from lower interface.

Parameters
Expand Down
6 changes: 3 additions & 3 deletions deepmd/jax/jax2tf/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def nlist_distinguish_types(
nlist: tnp.ndarray,
atype: tnp.ndarray,
sel: list[int],
):
) -> tnp.ndarray:
"""Given a nlist that does not distinguish atom types, return a nlist that
distinguish atom types.

Expand All @@ -140,7 +140,7 @@ def nlist_distinguish_types(
return ret


def tf_outer(a, b):
def tf_outer(a: tnp.ndarray, b: tnp.ndarray) -> tnp.ndarray:
return tf.einsum("i,j->ij", a, b)


Expand All @@ -150,7 +150,7 @@ def extend_coord_with_ghosts(
atype: tnp.ndarray,
cell: tnp.ndarray,
rcut: float,
):
) -> tuple[tnp.ndarray, tnp.ndarray, tnp.ndarray]:
"""Extend the coordinates of the atoms by appending peridoc images.
The number of images is large enough to ensure all the neighbors
within rcut are appended.
Expand Down
2 changes: 1 addition & 1 deletion deepmd/jax/jax2tf/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def to_face_distance(
return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0))


def b_to_face_distance(cell):
def b_to_face_distance(cell: tnp.ndarray) -> tnp.ndarray:
volume = tf.linalg.det(cell)
c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...])
_h2yz = volume / tf.linalg.norm(c_yz, axis=-1)
Expand Down
62 changes: 41 additions & 21 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
Callable,
Optional,
)

Expand Down Expand Up @@ -38,10 +39,17 @@ def deserialize_to_file(model_file: str, data: dict) -> None:

tf_model = tf.Module()

def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms):
def exported_whether_do_atomic_virial(
do_atomic_virial: bool, has_ghost_atoms: bool
) -> Callable:
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
coord: tnp.ndarray,
atype: tnp.ndarray,
nlist: tnp.ndarray,
mapping: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
) -> dict[str, tnp.ndarray]:
return call_lower(
coord,
atype,
Expand Down Expand Up @@ -86,8 +94,13 @@ def call_lower_with_fixed_do_atomic_virial(
],
)
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
coord: tnp.ndarray,
atype: tnp.ndarray,
nlist: tnp.ndarray,
mapping: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
) -> dict[str, tnp.ndarray]:
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
Expand All @@ -112,7 +125,14 @@ def call_lower_without_atomic_virial(
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
def call_lower_with_atomic_virial(
coord: tnp.ndarray,
atype: tnp.ndarray,
nlist: tnp.ndarray,
mapping: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
) -> dict[str, tnp.ndarray]:
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
Expand All @@ -126,7 +146,7 @@ def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):

tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial

def make_call_whether_do_atomic_virial(do_atomic_virial: bool):
def make_call_whether_do_atomic_virial(do_atomic_virial: bool) -> Callable:
if do_atomic_virial:
call_lower = call_lower_with_atomic_virial
else:
Expand All @@ -138,7 +158,7 @@ def call(
box: Optional[tnp.ndarray] = None,
fparam: Optional[tnp.ndarray] = None,
aparam: Optional[tnp.ndarray] = None,
):
) -> dict[str, tnp.ndarray]:
"""Return model prediction.

Parameters
Expand Down Expand Up @@ -194,7 +214,7 @@ def call_with_atomic_virial(
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
):
) -> dict[str, tnp.ndarray]:
return make_call_whether_do_atomic_virial(do_atomic_virial=True)(
coord, atype, box, fparam, aparam
)
Expand All @@ -217,7 +237,7 @@ def call_without_atomic_virial(
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
):
) -> dict[str, tnp.ndarray]:
return make_call_whether_do_atomic_virial(do_atomic_virial=False)(
coord, atype, box, fparam, aparam
)
Expand All @@ -226,69 +246,69 @@ def call_without_atomic_virial(

# set functions to export other attributes
@tf.function
def get_type_map():
def get_type_map() -> tf.Tensor:
return tf.constant(model.get_type_map(), dtype=tf.string)

tf_model.get_type_map = get_type_map

@tf.function
def get_rcut():
def get_rcut() -> tf.Tensor:
return tf.constant(model.get_rcut(), dtype=tf.double)

tf_model.get_rcut = get_rcut

@tf.function
def get_dim_fparam():
def get_dim_fparam() -> tf.Tensor:
return tf.constant(model.get_dim_fparam(), dtype=tf.int64)

tf_model.get_dim_fparam = get_dim_fparam

@tf.function
def get_dim_aparam():
def get_dim_aparam() -> tf.Tensor:
return tf.constant(model.get_dim_aparam(), dtype=tf.int64)

tf_model.get_dim_aparam = get_dim_aparam

@tf.function
def get_sel_type():
def get_sel_type() -> tf.Tensor:
return tf.constant(model.get_sel_type(), dtype=tf.int64)

tf_model.get_sel_type = get_sel_type

@tf.function
def is_aparam_nall():
def is_aparam_nall() -> tf.Tensor:
return tf.constant(model.is_aparam_nall(), dtype=tf.bool)

tf_model.is_aparam_nall = is_aparam_nall

@tf.function
def model_output_type():
def model_output_type() -> tf.Tensor:
return tf.constant(model.model_output_type(), dtype=tf.string)

tf_model.model_output_type = model_output_type

@tf.function
def mixed_types():
def mixed_types() -> tf.Tensor:
return tf.constant(model.mixed_types(), dtype=tf.bool)

tf_model.mixed_types = mixed_types

if model.get_min_nbor_dist() is not None:

@tf.function
def get_min_nbor_dist():
def get_min_nbor_dist() -> tf.Tensor:
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double)

tf_model.get_min_nbor_dist = get_min_nbor_dist

@tf.function
def get_sel():
def get_sel() -> tf.Tensor:
return tf.constant(model.get_sel(), dtype=tf.int64)

tf_model.get_sel = get_sel

@tf.function
def get_model_def_script():
def get_model_def_script() -> tf.Tensor:
return tf.constant(
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
)
Expand Down
14 changes: 7 additions & 7 deletions deepmd/jax/jax2tf/tfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]:
class TFModelWrapper(tf.Module):
def __init__(
self,
model,
model: str,
) -> None:
self.model = tf.saved_model.load(model)
self._call_lower = jax2tf.call_tf(self.model.call_lower)
Expand Down Expand Up @@ -115,7 +115,7 @@ def call(
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
) -> dict[str, jnp.ndarray]:
"""Return model prediction.

Parameters
Expand Down Expand Up @@ -165,7 +165,7 @@ def call(
aparam,
)

def model_output_def(self):
def model_output_def(self) -> ModelOutputDef:
return ModelOutputDef(
FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()])
)
Expand All @@ -179,7 +179,7 @@ def call_lower(
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
) -> dict[str, jnp.ndarray]:
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial
else:
Expand Down Expand Up @@ -207,15 +207,15 @@ def get_type_map(self) -> list[str]:
"""Get the type map."""
return self.type_map

def get_rcut(self):
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.rcut

def get_dim_fparam(self):
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.dim_fparam

def get_dim_aparam(self):
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.dim_aparam

Expand Down
Loading
Loading