Skip to content

Commit 260df88

Browse files
committed
fix type checking within reporter.py and dihedrals
1 parent da86beb commit 260df88

2 files changed

Lines changed: 95 additions & 87 deletions

File tree

CodeEntropy/levels/dihedrals.py

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
conformational entropy.
66
"""
77

8+
from __future__ import annotations
9+
810
import logging
11+
from typing import Any
912

1013
import numpy as np
1114
from MDAnalysis.analysis.dihedrals import Dihedral
15+
from rich.progress import TaskID
16+
17+
from CodeEntropy.results.reporter import _RichProgressSink
1218

1319
logger = logging.getLogger(__name__)
1420

@@ -18,7 +24,7 @@
1824
class ConformationStateBuilder:
1925
"""Build conformational state labels from dihedral angles."""
2026

21-
def __init__(self, universe_operations=None):
27+
def __init__(self, universe_operations: Any) -> None:
2228
"""Initializes the analysis helper.
2329
2430
Args:
@@ -30,15 +36,15 @@ def __init__(self, universe_operations=None):
3036

3137
def build_conformational_states(
3238
self,
33-
data_container,
34-
levels,
35-
groups,
39+
data_container: Any,
40+
levels: dict[Any, list[str]],
41+
groups: dict[int, list[Any]],
3642
start: int,
3743
end: int,
3844
step: int,
3945
bin_width: float,
40-
progress: object | None = None,
41-
):
46+
progress: _RichProgressSink | None = None,
47+
) -> tuple[dict[UAKey, list[str]], list[list[str]]]:
4248
"""Build conformational state labels from trajectory dihedrals.
4349
4450
This method constructs discrete conformational state descriptors used in
@@ -61,8 +67,9 @@ def build_conformational_states(
6167
step: Frame stride.
6268
bin_width: Histogram bin width in degrees used when identifying peak
6369
dihedral populations.
64-
progress: Optional progress sink (e.g., from ResultsReporter.progress()).
65-
Must expose add_task(), update(), and advance().
70+
progress: Optional progress sink (e.g., from
71+
ResultsReporter.progress()). Must expose add_task(), update(),
72+
and advance().
6673
6774
Returns:
6875
tuple: (states_ua, states_res)
@@ -79,9 +86,9 @@ def build_conformational_states(
7986
"""
8087
number_groups = len(groups)
8188
states_ua: dict[UAKey, list[str]] = {}
82-
states_res: list[list[str]] = [None] * number_groups
89+
states_res: list[list[str]] = [[] for _ in range(number_groups)]
8390

84-
task = None
91+
task: TaskID | None = None
8592
if progress is not None:
8693
total = max(1, len(groups))
8794
task = progress.add_task(
@@ -91,20 +98,20 @@ def build_conformational_states(
9198
)
9299

93100
if not groups:
94-
if task is not None:
101+
if progress is not None and task is not None:
95102
progress.update(task, title="No groups")
96103
progress.advance(task)
97104
return states_ua, states_res
98105

99106
for group_id in groups.keys():
100107
molecules = groups[group_id]
101108
if not molecules:
102-
if task is not None:
109+
if progress is not None and task is not None:
103110
progress.update(task, title=f"Group {group_id} (empty)")
104111
progress.advance(task)
105112
continue
106113

107-
if task is not None:
114+
if progress is not None and task is not None:
108115
progress.update(task, title=f"Group {group_id}")
109116

110117
mol = self._universe_operations.extract_fragment(
@@ -144,12 +151,14 @@ def build_conformational_states(
144151
states_res=states_res,
145152
)
146153

147-
if task is not None:
154+
if progress is not None and task is not None:
148155
progress.advance(task)
149156

150157
return states_ua, states_res
151158

152-
def _collect_dihedrals_for_group(self, mol, level_list):
159+
def _collect_dihedrals_for_group(
160+
self, mol: Any, level_list: list[str]
161+
) -> tuple[list[list[Any]], list[Any]]:
153162
"""Collect UA and residue dihedral AtomGroups for a group.
154163
155164
Args:
@@ -162,8 +171,8 @@ def _collect_dihedrals_for_group(self, mol, level_list):
162171
dihedrals_res: List of residue-level dihedral AtomGroups.
163172
"""
164173
num_residues = len(mol.residues)
165-
dihedrals_ua: list[list] = [[] for _ in range(num_residues)]
166-
dihedrals_res: list = []
174+
dihedrals_ua: list[list[Any]] = [[] for _ in range(num_residues)]
175+
dihedrals_res: list[Any] = []
167176

168177
for level in level_list:
169178
if level == "united_atom":
@@ -176,7 +185,7 @@ def _collect_dihedrals_for_group(self, mol, level_list):
176185

177186
return dihedrals_ua, dihedrals_res
178187

179-
def _select_heavy_residue(self, mol, res_id: int):
188+
def _select_heavy_residue(self, mol: Any, res_id: int) -> Any:
180189
"""Select heavy atoms in a residue by residue index.
181190
182191
Args:
@@ -194,7 +203,7 @@ def _select_heavy_residue(self, mol, res_id: int):
194203
)
195204
return self._universe_operations.select_atoms(res_container, "prop mass > 1.1")
196205

197-
def _get_dihedrals(self, data_container, level: str):
206+
def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]:
198207
"""Return dihedral AtomGroups for a container at a given level.
199208
200209
Args:
@@ -204,7 +213,7 @@ def _get_dihedrals(self, data_container, level: str):
204213
Returns:
205214
List of AtomGroups (each representing a dihedral definition).
206215
"""
207-
atom_groups = []
216+
atom_groups: list[Any] = []
208217

209218
if level == "united_atom":
210219
dihedrals = data_container.dihedrals
@@ -234,16 +243,16 @@ def _get_dihedrals(self, data_container, level: str):
234243

235244
def _collect_peaks_for_group(
236245
self,
237-
data_container,
238-
molecules,
239-
dihedrals_ua,
240-
dihedrals_res,
241-
bin_width,
242-
start,
243-
end,
244-
step,
245-
level_list,
246-
):
246+
data_container: Any,
247+
molecules: list[Any],
248+
dihedrals_ua: list[list[Any]],
249+
dihedrals_res: list[Any],
250+
bin_width: float,
251+
start: int,
252+
end: int,
253+
step: int,
254+
level_list: list[str],
255+
) -> tuple[list[list[Any]], list[Any]]:
247256
"""Compute histogram peaks for UA and residue dihedral sets.
248257
249258
Returns:
@@ -252,8 +261,8 @@ def _collect_peaks_for_group(
252261
(each item is list-of-peaks per dihedral)
253262
peaks_res: list-of-peaks per dihedral for residue level (or [])
254263
"""
255-
peaks_ua = [{} for _ in range(len(dihedrals_ua))]
256-
peaks_res = {}
264+
peaks_ua: list[list[Any]] = [[] for _ in range(len(dihedrals_ua))]
265+
peaks_res: list[Any] = []
257266

258267
for level in level_list:
259268
if level == "united_atom":
@@ -289,14 +298,14 @@ def _collect_peaks_for_group(
289298

290299
def _identify_peaks(
291300
self,
292-
data_container,
293-
molecules,
294-
dihedrals,
295-
bin_width,
296-
start,
297-
end,
298-
step,
299-
):
301+
data_container: Any,
302+
molecules: list[Any],
303+
dihedrals: list[Any],
304+
bin_width: float,
305+
start: int,
306+
end: int,
307+
step: int,
308+
) -> list[list[float]]:
300309
"""Identify histogram peaks ("convex turning points") for each dihedral.
301310
302311
Important:
@@ -316,10 +325,10 @@ def _identify_peaks(
316325
Returns:
317326
List of peaks per dihedral (peak_values[dihedral_index] -> list of peaks).
318327
"""
319-
peak_values = [] * len(dihedrals)
328+
peak_values: list[list[float]] = []
320329

321330
for dihedral_index in range(len(dihedrals)):
322-
phi = []
331+
phi: list[float] = []
323332

324333
for molecule in molecules:
325334
mol = self._universe_operations.extract_fragment(
@@ -333,7 +342,7 @@ def _identify_peaks(
333342
value = dihedral_results.results.angles[timestep][dihedral_index]
334343
if value < 0:
335344
value += 360
336-
phi.append(value)
345+
phi.append(float(value))
337346

338347
number_bins = int(360 / bin_width)
339348
popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360))
@@ -349,7 +358,9 @@ def _identify_peaks(
349358
return peak_values
350359

351360
@staticmethod
352-
def _find_histogram_peaks(popul, bin_value):
361+
def _find_histogram_peaks(
362+
popul: np.ndarray[Any, Any], bin_value: list[float]
363+
) -> list[float]:
353364
"""Return convex turning-point peaks from a histogram.
354365
355366
The selection of the population of the right adjacent bin takes into
@@ -364,7 +375,7 @@ def _find_histogram_peaks(popul, bin_value):
364375
peaks: list of values associated with peaks.
365376
"""
366377
number_bins = len(popul)
367-
peaks = []
378+
peaks: list[float] = []
368379

369380
for bin_index in range(number_bins):
370381
if popul[bin_index] == 0:
@@ -380,20 +391,20 @@ def _find_histogram_peaks(popul, bin_value):
380391

381392
def _assign_states_for_group(
382393
self,
383-
data_container,
384-
group_id,
385-
molecules,
386-
dihedrals_ua,
387-
peaks_ua,
388-
dihedrals_res,
389-
peaks_res,
390-
start,
391-
end,
392-
step,
393-
level_list,
394-
states_ua,
395-
states_res,
396-
):
394+
data_container: Any,
395+
group_id: int,
396+
molecules: list[Any],
397+
dihedrals_ua: list[list[Any]],
398+
peaks_ua: list[list[Any]],
399+
dihedrals_res: list[Any],
400+
peaks_res: list[Any],
401+
start: int,
402+
end: int,
403+
step: int,
404+
level_list: list[str],
405+
states_ua: dict[UAKey, list[str]],
406+
states_res: list[list[str]],
407+
) -> None:
397408
"""Assign UA and residue states for a group into output containers."""
398409
for level in level_list:
399410
if level == "united_atom":
@@ -428,14 +439,14 @@ def _assign_states_for_group(
428439

429440
def _assign_states(
430441
self,
431-
data_container,
432-
molecules,
433-
dihedrals,
434-
peaks,
435-
start,
436-
end,
437-
step,
438-
):
442+
data_container: Any,
443+
molecules: list[Any],
444+
dihedrals: list[Any],
445+
peaks: list[list[Any]],
446+
start: int,
447+
end: int,
448+
step: int,
449+
) -> list[str]:
439450
"""Assign discrete state labels for the provided dihedrals.
440451
441452
Important:
@@ -455,17 +466,17 @@ def _assign_states(
455466
Returns:
456467
List of state labels (strings).
457468
"""
458-
states = None
469+
states: list[str] = []
459470

460471
for molecule in molecules:
461-
conformations = []
472+
conformations: list[list[Any]] = []
462473
mol = self._universe_operations.extract_fragment(data_container, molecule)
463474
number_frames = len(mol.trajectory)
464475

465476
dihedral_results = Dihedral(dihedrals).run()
466477

467478
for dihedral_index in range(len(dihedrals)):
468-
conformation = []
479+
conformation: list[Any] = []
469480
for timestep in range(number_frames):
470481
value = dihedral_results.results.angles[timestep][dihedral_index]
471482
if value < 0:
@@ -487,10 +498,7 @@ def _assign_states(
487498
if state
488499
]
489500

490-
if states is None:
491-
states = mol_states
492-
else:
493-
states.extend(mol_states)
501+
states.extend(mol_states)
494502

495503
logger.debug(f"States: {states}")
496504
return states

0 commit comments

Comments
 (0)