55conformational entropy.
66"""
77
8+ from __future__ import annotations
9+
810import logging
11+ from typing import Any
912
1013import numpy as np
1114from MDAnalysis .analysis .dihedrals import Dihedral
15+ from rich .progress import TaskID
16+
17+ from CodeEntropy .results .reporter import _RichProgressSink
1218
1319logger = logging .getLogger (__name__ )
1420
1824class ConformationStateBuilder :
1925 """Build conformational state labels from dihedral angles."""
2026
21- def __init__ (self , universe_operations = None ) :
27+ def __init__ (self , universe_operations : Any ) -> None :
2228 """Initializes the analysis helper.
2329
2430 Args:
@@ -30,15 +36,15 @@ def __init__(self, universe_operations=None):
3036
3137 def build_conformational_states (
3238 self ,
33- data_container ,
34- levels ,
35- groups ,
39+ data_container : Any ,
40+ levels : dict [ Any , list [ str ]] ,
41+ groups : dict [ int , list [ Any ]] ,
3642 start : int ,
3743 end : int ,
3844 step : int ,
3945 bin_width : float ,
40- progress : object | None = None ,
41- ):
46+ progress : _RichProgressSink | None = None ,
47+ ) -> tuple [ dict [ UAKey , list [ str ]], list [ list [ str ]]] :
4248 """Build conformational state labels from trajectory dihedrals.
4349
4450 This method constructs discrete conformational state descriptors used in
@@ -61,8 +67,9 @@ def build_conformational_states(
6167 step: Frame stride.
6268 bin_width: Histogram bin width in degrees used when identifying peak
6369 dihedral populations.
64- progress: Optional progress sink (e.g., from ResultsReporter.progress()).
65- Must expose add_task(), update(), and advance().
70+ progress: Optional progress sink (e.g., from
71+ ResultsReporter.progress()). Must expose add_task(), update(),
72+ and advance().
6673
6774 Returns:
6875 tuple: (states_ua, states_res)
@@ -79,9 +86,9 @@ def build_conformational_states(
7986 """
8087 number_groups = len (groups )
8188 states_ua : dict [UAKey , list [str ]] = {}
82- states_res : list [list [str ]] = [None ] * number_groups
89+ states_res : list [list [str ]] = [[] for _ in range ( number_groups )]
8390
84- task = None
91+ task : TaskID | None = None
8592 if progress is not None :
8693 total = max (1 , len (groups ))
8794 task = progress .add_task (
@@ -91,20 +98,20 @@ def build_conformational_states(
9198 )
9299
93100 if not groups :
94- if task is not None :
101+ if progress is not None and task is not None :
95102 progress .update (task , title = "No groups" )
96103 progress .advance (task )
97104 return states_ua , states_res
98105
99106 for group_id in groups .keys ():
100107 molecules = groups [group_id ]
101108 if not molecules :
102- if task is not None :
109+ if progress is not None and task is not None :
103110 progress .update (task , title = f"Group { group_id } (empty)" )
104111 progress .advance (task )
105112 continue
106113
107- if task is not None :
114+ if progress is not None and task is not None :
108115 progress .update (task , title = f"Group { group_id } " )
109116
110117 mol = self ._universe_operations .extract_fragment (
@@ -144,12 +151,14 @@ def build_conformational_states(
144151 states_res = states_res ,
145152 )
146153
147- if task is not None :
154+ if progress is not None and task is not None :
148155 progress .advance (task )
149156
150157 return states_ua , states_res
151158
152- def _collect_dihedrals_for_group (self , mol , level_list ):
159+ def _collect_dihedrals_for_group (
160+ self , mol : Any , level_list : list [str ]
161+ ) -> tuple [list [list [Any ]], list [Any ]]:
153162 """Collect UA and residue dihedral AtomGroups for a group.
154163
155164 Args:
@@ -162,8 +171,8 @@ def _collect_dihedrals_for_group(self, mol, level_list):
162171 dihedrals_res: List of residue-level dihedral AtomGroups.
163172 """
164173 num_residues = len (mol .residues )
165- dihedrals_ua : list [list ] = [[] for _ in range (num_residues )]
166- dihedrals_res : list = []
174+ dihedrals_ua : list [list [ Any ] ] = [[] for _ in range (num_residues )]
175+ dihedrals_res : list [ Any ] = []
167176
168177 for level in level_list :
169178 if level == "united_atom" :
@@ -176,7 +185,7 @@ def _collect_dihedrals_for_group(self, mol, level_list):
176185
177186 return dihedrals_ua , dihedrals_res
178187
179- def _select_heavy_residue (self , mol , res_id : int ):
188+ def _select_heavy_residue (self , mol : Any , res_id : int ) -> Any :
180189 """Select heavy atoms in a residue by residue index.
181190
182191 Args:
@@ -194,7 +203,7 @@ def _select_heavy_residue(self, mol, res_id: int):
194203 )
195204 return self ._universe_operations .select_atoms (res_container , "prop mass > 1.1" )
196205
197- def _get_dihedrals (self , data_container , level : str ):
206+ def _get_dihedrals (self , data_container : Any , level : str ) -> list [ Any ] :
198207 """Return dihedral AtomGroups for a container at a given level.
199208
200209 Args:
@@ -204,7 +213,7 @@ def _get_dihedrals(self, data_container, level: str):
204213 Returns:
205214 List of AtomGroups (each representing a dihedral definition).
206215 """
207- atom_groups = []
216+ atom_groups : list [ Any ] = []
208217
209218 if level == "united_atom" :
210219 dihedrals = data_container .dihedrals
@@ -234,16 +243,16 @@ def _get_dihedrals(self, data_container, level: str):
234243
235244 def _collect_peaks_for_group (
236245 self ,
237- data_container ,
238- molecules ,
239- dihedrals_ua ,
240- dihedrals_res ,
241- bin_width ,
242- start ,
243- end ,
244- step ,
245- level_list ,
246- ):
246+ data_container : Any ,
247+ molecules : list [ Any ] ,
248+ dihedrals_ua : list [ list [ Any ]] ,
249+ dihedrals_res : list [ Any ] ,
250+ bin_width : float ,
251+ start : int ,
252+ end : int ,
253+ step : int ,
254+ level_list : list [ str ] ,
255+ ) -> tuple [ list [ list [ Any ]], list [ Any ]] :
247256 """Compute histogram peaks for UA and residue dihedral sets.
248257
249258 Returns:
@@ -252,8 +261,8 @@ def _collect_peaks_for_group(
252261 (each item is list-of-peaks per dihedral)
253262 peaks_res: list-of-peaks per dihedral for residue level (or [])
254263 """
255- peaks_ua = [{} for _ in range (len (dihedrals_ua ))]
256- peaks_res = {}
264+ peaks_ua : list [ list [ Any ]] = [[] for _ in range (len (dihedrals_ua ))]
265+ peaks_res : list [ Any ] = []
257266
258267 for level in level_list :
259268 if level == "united_atom" :
@@ -289,14 +298,14 @@ def _collect_peaks_for_group(
289298
290299 def _identify_peaks (
291300 self ,
292- data_container ,
293- molecules ,
294- dihedrals ,
295- bin_width ,
296- start ,
297- end ,
298- step ,
299- ):
301+ data_container : Any ,
302+ molecules : list [ Any ] ,
303+ dihedrals : list [ Any ] ,
304+ bin_width : float ,
305+ start : int ,
306+ end : int ,
307+ step : int ,
308+ ) -> list [ list [ float ]] :
300309 """Identify histogram peaks ("convex turning points") for each dihedral.
301310
302311 Important:
@@ -316,10 +325,10 @@ def _identify_peaks(
316325 Returns:
317326 List of peaks per dihedral (peak_values[dihedral_index] -> list of peaks).
318327 """
319- peak_values = [] * len ( dihedrals )
328+ peak_values : list [ list [ float ]] = []
320329
321330 for dihedral_index in range (len (dihedrals )):
322- phi = []
331+ phi : list [ float ] = []
323332
324333 for molecule in molecules :
325334 mol = self ._universe_operations .extract_fragment (
@@ -333,7 +342,7 @@ def _identify_peaks(
333342 value = dihedral_results .results .angles [timestep ][dihedral_index ]
334343 if value < 0 :
335344 value += 360
336- phi .append (value )
345+ phi .append (float ( value ) )
337346
338347 number_bins = int (360 / bin_width )
339348 popul , bin_edges = np .histogram (a = phi , bins = number_bins , range = (0 , 360 ))
@@ -349,7 +358,9 @@ def _identify_peaks(
349358 return peak_values
350359
351360 @staticmethod
352- def _find_histogram_peaks (popul , bin_value ):
361+ def _find_histogram_peaks (
362+ popul : np .ndarray [Any , Any ], bin_value : list [float ]
363+ ) -> list [float ]:
353364 """Return convex turning-point peaks from a histogram.
354365
355366 The selection of the population of the right adjacent bin takes into
@@ -364,7 +375,7 @@ def _find_histogram_peaks(popul, bin_value):
364375 peaks: list of values associated with peaks.
365376 """
366377 number_bins = len (popul )
367- peaks = []
378+ peaks : list [ float ] = []
368379
369380 for bin_index in range (number_bins ):
370381 if popul [bin_index ] == 0 :
@@ -380,20 +391,20 @@ def _find_histogram_peaks(popul, bin_value):
380391
381392 def _assign_states_for_group (
382393 self ,
383- data_container ,
384- group_id ,
385- molecules ,
386- dihedrals_ua ,
387- peaks_ua ,
388- dihedrals_res ,
389- peaks_res ,
390- start ,
391- end ,
392- step ,
393- level_list ,
394- states_ua ,
395- states_res ,
396- ):
394+ data_container : Any ,
395+ group_id : int ,
396+ molecules : list [ Any ] ,
397+ dihedrals_ua : list [ list [ Any ]] ,
398+ peaks_ua : list [ list [ Any ]] ,
399+ dihedrals_res : list [ Any ] ,
400+ peaks_res : list [ Any ] ,
401+ start : int ,
402+ end : int ,
403+ step : int ,
404+ level_list : list [ str ] ,
405+ states_ua : dict [ UAKey , list [ str ]] ,
406+ states_res : list [ list [ str ]] ,
407+ ) -> None :
397408 """Assign UA and residue states for a group into output containers."""
398409 for level in level_list :
399410 if level == "united_atom" :
@@ -428,14 +439,14 @@ def _assign_states_for_group(
428439
429440 def _assign_states (
430441 self ,
431- data_container ,
432- molecules ,
433- dihedrals ,
434- peaks ,
435- start ,
436- end ,
437- step ,
438- ):
442+ data_container : Any ,
443+ molecules : list [ Any ] ,
444+ dihedrals : list [ Any ] ,
445+ peaks : list [ list [ Any ]] ,
446+ start : int ,
447+ end : int ,
448+ step : int ,
449+ ) -> list [ str ] :
439450 """Assign discrete state labels for the provided dihedrals.
440451
441452 Important:
@@ -455,17 +466,17 @@ def _assign_states(
455466 Returns:
456467 List of state labels (strings).
457468 """
458- states = None
469+ states : list [ str ] = []
459470
460471 for molecule in molecules :
461- conformations = []
472+ conformations : list [ list [ Any ]] = []
462473 mol = self ._universe_operations .extract_fragment (data_container , molecule )
463474 number_frames = len (mol .trajectory )
464475
465476 dihedral_results = Dihedral (dihedrals ).run ()
466477
467478 for dihedral_index in range (len (dihedrals )):
468- conformation = []
479+ conformation : list [ Any ] = []
469480 for timestep in range (number_frames ):
470481 value = dihedral_results .results .angles [timestep ][dihedral_index ]
471482 if value < 0 :
@@ -487,10 +498,7 @@ def _assign_states(
487498 if state
488499 ]
489500
490- if states is None :
491- states = mol_states
492- else :
493- states .extend (mol_states )
501+ states .extend (mol_states )
494502
495503 logger .debug (f"States: { states } " )
496504 return states
0 commit comments