Skip to content

Commit 7c043f0

Browse files
njzjzpre-commit-ci[bot]Copilot
authored
style: enable B905 to prevent issues with zip (#5136)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enforced strict length consistency when pairing related sequences across model runtime, data utilities, and inference paths to fail fast on mismatches. * **Tests** * Hardened tests to require sequence-length parity and added dtype checks, improving reliability and clearer failures. * **Chores** * Updated lint configuration to include an additional rule identifier for consistent code quality. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 88367a6 commit 7c043f0

38 files changed

Lines changed: 110 additions & 72 deletions

deepmd/calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
self.type_dict = type_dict
104104
else:
105105
self.type_dict = dict(
106-
zip(self.dp.get_type_map(), range(self.dp.get_ntypes()))
106+
zip(self.dp.get_type_map(), range(self.dp.get_ntypes()), strict=True)
107107
)
108108

109109
def calculate(

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_model_sels(self) -> list[int | list[int]]:
155155
def _sort_rcuts_sels(self) -> tuple[tuple[Array, Array], list[int]]:
156156
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
157157
zipped = sorted(
158-
zip(self.get_model_rcuts(), self.get_model_nsels()),
158+
zip(self.get_model_rcuts(), self.get_model_nsels(), strict=True),
159159
key=lambda x: (x[1], x[0]),
160160
)
161161
return [p[0] for p in zipped], [p[1] for p in zipped]
@@ -235,12 +235,14 @@ def forward_atomic(
235235
)
236236
raw_nlists = [
237237
nlists[get_multiple_nlist_key(rcut, sel)]
238-
for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels())
238+
for rcut, sel in zip(
239+
self.get_model_rcuts(), self.get_model_nsels(), strict=True
240+
)
239241
]
240242
nlists_ = [
241243
nl if mt else nlist_distinguish_types(nl, extended_atype, sel)
242244
for mt, nl, sel in zip(
243-
self.mixed_types_list, raw_nlists, self.get_model_sels()
245+
self.mixed_types_list, raw_nlists, self.get_model_sels(), strict=True
244246
)
245247
]
246248
ener_list = []

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1]
102102
end_idx = start_idx + np.array(sub_sel)
103103
cut_idx = np.concatenate(
104-
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
104+
[range(ss, ee) for ss, ee in zip(start_idx, end_idx, strict=True)]
105105
)
106106
nlist_cut_idx.append(cut_idx)
107107
self.nlist_cut_idx = nlist_cut_idx
@@ -310,7 +310,7 @@ def call(
310310
)
311311
else:
312312
nl_distinguish_types = None
313-
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
313+
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx, strict=True):
314314
# cut the nlist to the correct length
315315
if self.mixed_types() == descrpt.mixed_types():
316316
nl = xp.take(nlist, nci, axis=2)

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def eval(
229229
zip(
230230
[x.name for x in request_defs],
231231
out,
232+
strict=True,
232233
)
233234
)
234235

deepmd/dpmodel/utils/nlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def build_multiple_neighbor_list(
255255
rr = xp.where(nlist_mask, xp.full_like(rr, float("inf")), rr)
256256
nlist0 = nlist
257257
ret = {}
258-
for rc, ns in zip(rcuts[::-1], nsels[::-1]):
258+
for rc, ns in zip(rcuts[::-1], nsels[::-1], strict=True):
259259
tnlist_1 = nlist0[:, :, :ns]
260260
tnlist_1 = xp.where(rr[:, :, :ns] > rc, xp.full_like(tnlist_1, -1), tnlist_1)
261261
ret[get_multiple_nlist_key(rc, ns)] = tnlist_1

deepmd/jax/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def eval(
249249
zip(
250250
[x.name for x in request_defs],
251251
out,
252+
strict=True,
252253
)
253254
)
254255

deepmd/tf/descriptor/loc_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def compute_input_stats(
187187
sumn = []
188188
sumv2 = []
189189
for cc, bb, tt, nn, mm in zip(
190-
data_coord, data_box, data_atype, natoms_vec, mesh
190+
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
191191
):
192192
sysv, sysv2, sysn = self._compute_dstats_sys_nonsmth(cc, bb, tt, nn, mm)
193193
sumv.append(sysv)

deepmd/tf/descriptor/se_a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def compute_input_stats(
374374
sumr2 = []
375375
suma2 = []
376376
for cc, bb, tt, nn, mm in zip(
377-
data_coord, data_box, data_atype, natoms_vec, mesh
377+
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
378378
):
379379
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
380380
cc, bb, tt, nn, mm
@@ -1331,7 +1331,7 @@ def init_variables(
13311331
start_index_old[0] = 0
13321332

13331333
for nn, oo, ii, jj in zip(
1334-
n_descpt, n_descpt_old, start_index, start_index_old
1334+
n_descpt, n_descpt_old, start_index, start_index_old, strict=True
13351335
):
13361336
if nn < oo:
13371337
# new size is smaller, copy part of std

deepmd/tf/descriptor/se_a_ef.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,13 @@ def compute_input_stats(
419419
sumr2 = []
420420
suma2 = []
421421
for cc, bb, tt, nn, mm, ee in zip(
422-
data_coord, data_box, data_atype, natoms_vec, mesh, data_efield
422+
data_coord,
423+
data_box,
424+
data_atype,
425+
natoms_vec,
426+
mesh,
427+
data_efield,
428+
strict=True,
423429
):
424430
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
425431
cc, bb, tt, nn, mm, ee

deepmd/tf/descriptor/se_atten.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,13 @@ def compute_input_stats(
379379
if mixed_type:
380380
sys_num = 0
381381
for cc, bb, tt, nn, mm, r_n in zip(
382-
data_coord, data_box, data_atype, natoms_vec, mesh, real_natoms_vec
382+
data_coord,
383+
data_box,
384+
data_atype,
385+
natoms_vec,
386+
mesh,
387+
real_natoms_vec,
388+
strict=True,
383389
):
384390
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
385391
cc, bb, tt, nn, mm, mixed_type, r_n
@@ -392,7 +398,7 @@ def compute_input_stats(
392398
suma2.append(sysa2)
393399
else:
394400
for cc, bb, tt, nn, mm in zip(
395-
data_coord, data_box, data_atype, natoms_vec, mesh
401+
data_coord, data_box, data_atype, natoms_vec, mesh, strict=True
396402
):
397403
sysr, sysr2, sysa, sysa2, sysn = self._compute_dstats_sys_smth(
398404
cc, bb, tt, nn, mm

0 commit comments

Comments
 (0)