Skip to content

Commit 8bbdd1f

Browse files
committed
feat(paddle): Enable ANN rule for additional utility files and complete type annotations
- Enable ANN rule for 5 more utility files: learning_rate.py, dp_random.py, update_sel.py, preprocess.py, spin.py - Add missing type annotations to preprocess.py and spin.py functions - Update gradual enablement tracking: now 7 files fully completed with ANN rule enabled
1 parent 0079e2a commit 8bbdd1f

3 files changed

Lines changed: 10 additions & 5 deletions

File tree

deepmd/pd/utils/preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
log = logging.getLogger(__name__)
77

88

9-
def compute_smooth_weight(distance, rmin: float, rmax: float):
9+
def compute_smooth_weight(distance: paddle.Tensor, rmin: float, rmax: float) -> paddle.Tensor:
1010
"""Compute smooth weight for descriptor elements."""
1111
if rmin >= rmax:
1212
raise ValueError("rmin should be less than rmax.")
@@ -17,7 +17,7 @@ def compute_smooth_weight(distance, rmin: float, rmax: float):
1717
return vv
1818

1919

20-
def compute_exp_sw(distance, rmin: float, rmax: float):
20+
def compute_exp_sw(distance: paddle.Tensor, rmin: float, rmax: float) -> paddle.Tensor:
2121
"""Compute the exponential switch function for neighbor update."""
2222
if rmin >= rmax:
2323
raise ValueError("rmin should be less than rmax.")

deepmd/pd/utils/spin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55

66
def concat_switch_virtual(
7-
extended_tensor,
8-
extended_tensor_virtual,
7+
extended_tensor: paddle.Tensor,
8+
extended_tensor_virtual: paddle.Tensor,
99
nloc: int,
10-
):
10+
) -> paddle.Tensor:
1111
"""
1212
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
1313
- [:, :nloc]: original nloc real atoms.

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ runtime-evaluated-base-classes = ["torch.nn.Module"]
431431
# Completed files with full type annotations:
432432
"deepmd/pd/entrypoints/main.py" = ["TID253"] # ✅ Fully typed
433433
"deepmd/pd/train/wrapper.py" = ["TID253"] # ✅ Fully typed
434+
"deepmd/pd/utils/learning_rate.py" = ["TID253"] # ✅ Fully typed
435+
"deepmd/pd/utils/dp_random.py" = ["TID253"] # ✅ Fully typed
436+
"deepmd/pd/utils/update_sel.py" = ["TID253"] # ✅ Fully typed
437+
"deepmd/pd/utils/preprocess.py" = ["TID253"] # ✅ Fully typed
438+
"deepmd/pd/utils/spin.py" = ["TID253"] # ✅ Fully typed
434439
# TODO: Complete type hints and remove ANN exclusion for remaining files:
435440
"deepmd/pd/train/**" = ["TID253", "ANN"] # 🚧 Partial progress - training.py still needs work
436441
"deepmd/pd/utils/**" = ["TID253", "ANN"] # 🚧 Partial progress - utils.py partially done

0 commit comments

Comments
 (0)