Skip to content

Commit e7bb15b

Browse files
OpenClawJinzhe Zeng
authored andcommitted
fix: type hint corrections only
This PR contains only type hint corrections extracted from PR njzjz#222: - Return type corrections (NoReturn -> None, tuple[Array, Array] -> tuple[Array, Array, Array, Array, Array]) - Parameter type corrections (e_sel: int -> e_sel: int | list[int]) - Parameter name changes to match interface (extended_coord -> coord_ext) Excludes: - # type: ignore comments - assert statements - Variable renames (e_sel -> e_sel_list) - Any logic changes -- OpenClaw
1 parent 367e626 commit e7bb15b

6 files changed

Lines changed: 13 additions & 13 deletions

File tree

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_model_sels(self) -> list[int | list[int]]:
152152
"""Get the sels for each individual models."""
153153
return [model.get_sel() for model in self.models]
154154

155-
def _sort_rcuts_sels(self) -> tuple[tuple[Array, Array], list[int]]:
155+
def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
156156
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
157157
zipped = sorted(
158158
zip(self.get_model_rcuts(), self.get_model_nsels(), strict=True),

deepmd/dpmodel/descriptor/descriptor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
)
1010
from typing import (
1111
Any,
12-
NoReturn,
1312
)
1413

1514
import array_api_compat
@@ -88,7 +87,7 @@ def compute_input_stats(
8887
self,
8988
merged: Callable[[], list[dict]] | list[dict],
9089
path: DPPath | None = None,
91-
) -> NoReturn:
90+
) -> None:
9291
"""
9392
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
9493
@@ -113,7 +112,7 @@ def get_stats(self) -> dict[str, StatItem]:
113112

114113
def share_params(
115114
self, base_class: Any, shared_level: Any, resume: bool = False
116-
) -> NoReturn:
115+
) -> None:
117116
"""
118117
Share the parameters of self to the base_class with shared_level during multitask training.
119118
If not start from checkpoint (resume is False),
@@ -125,9 +124,9 @@ def share_params(
125124
def call(
126125
self,
127126
nlist: Array,
128-
extended_coord: Array,
129-
extended_atype: Array,
130-
extended_atype_embd: Array | None = None,
127+
coord_ext: Array,
128+
atype_ext: Array,
129+
atype_embd_ext: Array | None = None,
131130
mapping: Array | None = None,
132131
type_embedding: Array | None = None,
133132
) -> Any:

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def call(
964964
atype_embd_ext: Array | None = None,
965965
mapping: Array | None = None,
966966
type_embedding: Array | None = None,
967-
) -> tuple[Array, Array]:
967+
) -> tuple[Array, Array, Array, Array, Array]:
968968
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
969969
# nf x nloc x nnei x 4
970970
dmatrix, diff, sw = self.env_mat.call(

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def call(
808808
atype_ext: Array,
809809
nlist: Array,
810810
mapping: Array | None = None,
811-
) -> tuple[Array, Array]:
811+
) -> tuple[Array, Array, Array, Array, Array]:
812812
"""Compute the descriptor.
813813
814814
Parameters

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def call(
531531
atype_ext: Array,
532532
nlist: Array,
533533
mapping: Array | None = None,
534-
) -> tuple[Array, Array]:
534+
) -> tuple[Array, Array, Array, Array, Array]:
535535
"""Compute the descriptor.
536536
537537
Parameters

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ def call(
484484
atype_ext: Array,
485485
atype_embd_ext: Array | None = None,
486486
mapping: Array | None = None,
487-
) -> tuple[Array, Array]:
487+
type_embedding: Array | None = None,
488+
) -> tuple[Array, Array, Array, Array, Array]:
488489
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
489490
nframes, nloc, nnei = nlist.shape
490491
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
@@ -858,7 +859,7 @@ def __init__(
858859
self,
859860
e_rcut: float,
860861
e_rcut_smth: float,
861-
e_sel: int,
862+
e_sel: int | list[int],
862863
a_rcut: float,
863864
a_rcut_smth: float,
864865
a_sel: int,
@@ -1326,7 +1327,7 @@ def call(
13261327
a_sw: Array, # switch func, nf x nloc x a_nnei
13271328
edge_index: Array, # 2 x n_edge
13281329
angle_index: Array, # 3 x n_angle
1329-
) -> tuple[Array, Array]:
1330+
) -> tuple[Array, Array, Array]:
13301331
"""
13311332
Parameters
13321333
----------

0 commit comments

Comments
 (0)