|
2 | 2 | from collections.abc import ( |
3 | 3 | Callable, |
4 | 4 | ) |
| 5 | +from typing import ( |
| 6 | + Any, |
| 7 | +) |
5 | 8 |
|
6 | 9 | import paddle |
7 | 10 |
|
@@ -149,7 +152,7 @@ def __init__( |
149 | 152 | """ |
150 | 153 | super().__init__() |
151 | 154 |
|
152 | | - def init_subclass_params(sub_data: dict | object, sub_class: type) -> object: |
| 155 | + def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: |
153 | 156 | if isinstance(sub_data, dict): |
154 | 157 | return sub_class(**sub_data) |
155 | 158 | elif isinstance(sub_data, sub_class): |
@@ -401,7 +404,7 @@ def get_env_protection(self) -> float: |
401 | 404 | return self.repinit.get_env_protection() |
402 | 405 |
|
403 | 406 | def share_params( |
404 | | - self, base_class: object, shared_level: int, resume: bool = False |
| 407 | + self, base_class: Any, shared_level: int, resume: bool = False |
405 | 408 | ) -> None: |
406 | 409 | """ |
407 | 410 | Share the parameters of self to the base_class with shared_level during multitask training. |
@@ -438,7 +441,7 @@ def share_params( |
438 | 441 | raise NotImplementedError |
439 | 442 |
|
440 | 443 | def change_type_map( |
441 | | - self, type_map: list[str], model_with_new_type_stat: object = None |
| 444 | + self, type_map: list[str], model_with_new_type_stat: Any = None |
442 | 445 | ) -> None: |
443 | 446 | """Change the type related params to new ones, according to `type_map` and the original one in the model. |
444 | 447 | If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. |
@@ -674,7 +677,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA2": |
674 | 677 | if obj.repinit.dim_out != obj.repformers.dim_in: |
675 | 678 | obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) |
676 | 679 |
|
677 | | - def t_cvt(xx: object) -> paddle.Tensor: |
| 680 | + def t_cvt(xx: Any) -> paddle.Tensor: |
678 | 681 | return paddle.to_tensor(xx, dtype=obj.repinit.prec, place=env.DEVICE) |
679 | 682 |
|
680 | 683 | # deserialize repinit |
|
0 commit comments