Skip to content

Commit 68f9f3c

Browse files
Copilotnjzjz
andcommitted
fix(paddle): Address code review feedback - Part 2
- Replaced all object types with Any in DPA1, DPA2, DPA3 descriptors - Replaced object with Any in se_t_tebd.py, ener_model.py, make_model.py - Replaced object with Any in task modules (ener.py, fitting.py, invar_fitting.py) - Added Any import to all files that needed it - Changed zip() strict=False to strict=True in se_a.py, transform_output.py, dataloader.py All type annotations now use Any instead of object as requested in code review. Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 4f4eed4 commit 68f9f3c

12 files changed

Lines changed: 52 additions & 24 deletions

File tree

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections.abc import (
33
Callable,
44
)
5+
from typing import (
6+
Any,
7+
)
58

69
import paddle
710

@@ -242,7 +245,7 @@ def __init__(
242245
use_tebd_bias: bool = False,
243246
type_map: list[str] | None = None,
244247
# not implemented
245-
spin: object = None,
248+
spin: Any = None,
246249
type: str | None = None,
247250
) -> None:
248251
super().__init__()
@@ -398,7 +401,7 @@ def get_env_protection(self) -> float:
398401
return self.se_atten.get_env_protection()
399402

400403
def share_params(
401-
self, base_class: object, shared_level: int, resume: bool = False
404+
self, base_class: Any, shared_level: int, resume: bool = False
402405
) -> None:
403406
"""
404407
Share the parameters of self to the base_class with shared_level during multitask training.
@@ -471,7 +474,7 @@ def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]:
471474
return self.se_atten.mean, self.se_atten.stddev
472475

473476
def change_type_map(
474-
self, type_map: list[str], model_with_new_type_stat: object = None
477+
self, type_map: list[str], model_with_new_type_stat: Any = None
475478
) -> None:
476479
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
477480
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -571,7 +574,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
571574
data["use_tebd_bias"] = True
572575
obj = cls(**data)
573576

574-
def t_cvt(xx: object) -> paddle.Tensor:
577+
def t_cvt(xx: Any) -> paddle.Tensor:
575578
return paddle.to_tensor(xx, dtype=obj.se_atten.prec).to(device=env.DEVICE)
576579

577580
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections.abc import (
33
Callable,
44
)
5+
from typing import (
6+
Any,
7+
)
58

69
import paddle
710

@@ -149,7 +152,7 @@ def __init__(
149152
"""
150153
super().__init__()
151154

152-
def init_subclass_params(sub_data: dict | object, sub_class: type) -> object:
155+
def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
153156
if isinstance(sub_data, dict):
154157
return sub_class(**sub_data)
155158
elif isinstance(sub_data, sub_class):
@@ -401,7 +404,7 @@ def get_env_protection(self) -> float:
401404
return self.repinit.get_env_protection()
402405

403406
def share_params(
404-
self, base_class: object, shared_level: int, resume: bool = False
407+
self, base_class: Any, shared_level: int, resume: bool = False
405408
) -> None:
406409
"""
407410
Share the parameters of self to the base_class with shared_level during multitask training.
@@ -438,7 +441,7 @@ def share_params(
438441
raise NotImplementedError
439442

440443
def change_type_map(
441-
self, type_map: list[str], model_with_new_type_stat: object = None
444+
self, type_map: list[str], model_with_new_type_stat: Any = None
442445
) -> None:
443446
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
444447
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -674,7 +677,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
674677
if obj.repinit.dim_out != obj.repformers.dim_in:
675678
obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform)
676679

677-
def t_cvt(xx: object) -> paddle.Tensor:
680+
def t_cvt(xx: Any) -> paddle.Tensor:
678681
return paddle.to_tensor(xx, dtype=obj.repinit.prec, place=env.DEVICE)
679682

680683
# deserialize repinit

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections.abc import (
33
Callable,
44
)
5+
from typing import (
6+
Any,
7+
)
58

69
import paddle
710

@@ -120,7 +123,7 @@ def __init__(
120123
) -> None:
121124
super().__init__()
122125

123-
def init_subclass_params(sub_data: dict | object, sub_class: type) -> object:
126+
def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
124127
if isinstance(sub_data, dict):
125128
return sub_class(**sub_data)
126129
elif isinstance(sub_data, sub_class):
@@ -303,7 +306,7 @@ def get_env_protection(self) -> float:
303306
return self.repflows.get_env_protection()
304307

305308
def share_params(
306-
self, base_class: object, shared_level: int, resume: bool = False
309+
self, base_class: Any, shared_level: int, resume: bool = False
307310
) -> None:
308311
"""
309312
Share the parameters of self to the base_class with shared_level during multitask training.
@@ -332,7 +335,7 @@ def share_params(
332335
raise NotImplementedError
333336

334337
def change_type_map(
335-
self, type_map: list[str], model_with_new_type_stat: object = None
338+
self, type_map: list[str], model_with_new_type_stat: Any = None
336339
) -> None:
337340
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
338341
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
@@ -465,7 +468,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
465468
type_embedding
466469
)
467470

468-
def t_cvt(xx: object) -> paddle.Tensor:
471+
def t_cvt(xx: Any) -> paddle.Tensor:
469472
return paddle.to_tensor(xx, dtype=obj.repflows.prec, place=env.DEVICE)
470473

471474
# deserialize repflow

deepmd/pd/model/descriptor/se_a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def forward(
769769
self.filter_layers.networks,
770770
self.compress_data,
771771
self.compress_info,
772-
strict=False,
772+
strict=True,
773773
)
774774
):
775775
if self.type_one_side:

deepmd/pd/model/descriptor/se_t_tebd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections.abc import (
33
Callable,
44
)
5+
from typing import (
6+
Any,
7+
)
58

69
import paddle
710

@@ -331,7 +334,7 @@ def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]:
331334
return self.se_ttebd.mean, self.se_ttebd.stddev
332335

333336
def change_type_map(
334-
self, type_map: list[str], model_with_new_type_stat: object | None = None
337+
self, type_map: list[str], model_with_new_type_stat: Any | None = None
335338
) -> None:
336339
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
337340
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.

deepmd/pd/model/model/ener_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from typing import (
4+
Any,
5+
)
6+
37
import paddle
48

59
from deepmd.pd.model.atomic_model import (
@@ -25,8 +29,8 @@ class EnergyModel(DPModelCommon, DPEnergyModel_):
2529

2630
def __init__(
2731
self,
28-
*args: object,
29-
**kwargs: object,
32+
*args: Any,
33+
**kwargs: Any,
3034
) -> None:
3135
DPModelCommon.__init__(self)
3236
DPEnergyModel_.__init__(self, *args, **kwargs)

deepmd/pd/model/model/make_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections.abc import (
33
Callable,
44
)
5+
from typing import (
6+
Any,
7+
)
58

69
import paddle
710

@@ -65,10 +68,10 @@ def make_model(T_AtomicModel: type[BaseAtomicModel]) -> type[BaseModel]:
6568
class CM(BaseModel):
6669
def __init__(
6770
self,
68-
*args: object,
71+
*args: Any,
6972
# underscore to prevent conflict with normal inputs
7073
atomic_model_: T_AtomicModel | None = None,
71-
**kwargs: object,
74+
**kwargs: Any,
7275
) -> None:
7376
super().__init__(*args, **kwargs)
7477
if atomic_model_ is not None:

deepmd/pd/model/model/transform_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def take_deriv(
122122
split_vv1 = paddle.split(vv1, [1] * size, axis=-1)
123123
split_svv1 = paddle.split(svv1, [1] * size, axis=-1)
124124
split_ff, split_avir = [], []
125-
for vvi, svvi in zip(split_vv1, split_svv1, strict=False):
125+
for vvi, svvi in zip(split_vv1, split_svv1, strict=True):
126126
# nf x nloc x 3, nf x nloc x 9
127127
ffi, aviri = task_deriv_one(
128128
vvi,

deepmd/pd/model/task/ener.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import copy
33
import logging
4+
from typing import (
5+
Any,
6+
)
47

58
import paddle
69

@@ -44,7 +47,7 @@ def __init__(
4447
mixed_types: bool = True,
4548
seed: int | list[int] | None = None,
4649
type_map: list[str] | None = None,
47-
**kwargs: object,
50+
**kwargs: Any,
4851
) -> None:
4952
super().__init__(
5053
"energy",

deepmd/pd/model/task/fitting.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from collections.abc import (
77
Callable,
88
)
9+
from typing import (
10+
Any,
11+
)
912

1013
import numpy as np
1114
import paddle
@@ -51,7 +54,7 @@
5154
class Fitting(paddle.nn.Layer, BaseFitting):
5255
# plugin moved to BaseFitting
5356

54-
def __new__(cls, *args: object, **kwargs: object) -> "Fitting":
57+
def __new__(cls, *args: Any, **kwargs: Any) -> "Fitting":
5558
if cls is Fitting:
5659
return BaseFitting.__new__(BaseFitting, *args, **kwargs)
5760
return super().__new__(cls)
@@ -244,7 +247,7 @@ def __init__(
244247
type_map: list[str] | None = None,
245248
use_aparam_as_mask: bool = False,
246249
default_fparam: list[float] | None = None,
247-
**kwargs: object,
250+
**kwargs: Any,
248251
) -> None:
249252
super().__init__()
250253
self.var_name = var_name

0 commit comments

Comments
 (0)