Skip to content

Commit 631293f

Browse files
committed
feat(progress): add Rich progress bars to workflow and expand tests:
- Add ResultsReporter progress context manager - Propagate optional progress sink through workflow orchestration - Add progress reporting for: - Conformational state construction (per group) - Frame processing stage (per frame) - Keep entropy graph execution silent due to fast runtime - Update runtime tests to reflect wrapped RuntimeError behavior
1 parent 756a17d commit 631293f

12 files changed

Lines changed: 619 additions & 100 deletions

File tree

CodeEntropy/entropy/graph.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,44 @@ def build(self) -> "EntropyGraph":
8080

8181
return self
8282

83-
def execute(self, shared_data: SharedData) -> Dict[str, Any]:
83+
def execute(
84+
self, shared_data: SharedData, *, progress: object | None = None
85+
) -> Dict[str, Any]:
8486
"""Execute the entropy graph in topological order.
8587
88+
Nodes are executed in dependency order (topological sort). Each node reads
89+
from and may mutate `shared_data`. Dict-like outputs returned by nodes are
90+
merged into a single results dictionary.
91+
92+
This method intentionally does *not* create a progress bar/task for the
93+
entropy graph itself because the graph is typically very fast. If a progress
94+
sink is provided, it is forwarded to nodes that accept it.
95+
8696
Args:
8797
shared_data: Mutable shared data dictionary passed to each node.
98+
progress: Optional progress sink (e.g., from ResultsReporter.progress()).
99+
Forwarded to node `run()` methods that accept a `progress` keyword.
88100
89101
Returns:
90-
Dictionary containing the merged outputs of all nodes (only including
91-
outputs that are dict-like).
102+
Dictionary containing merged dict outputs produced by nodes. On key
103+
collision, later nodes overwrite earlier keys.
92104
93105
Raises:
94106
KeyError: If a node name is missing from the internal node registry.
95107
"""
96108
results: Dict[str, Any] = {}
109+
97110
for node_name in nx.topological_sort(self._graph):
98111
node = self._nodes[node_name]
99-
out = node.run(shared_data)
112+
113+
if progress is not None:
114+
try:
115+
out = node.run(shared_data, progress=progress)
116+
except TypeError:
117+
out = node.run(shared_data)
118+
else:
119+
out = node.run(shared_data)
120+
100121
if isinstance(out, dict):
101122
results.update(out)
102123
return results

CodeEntropy/entropy/workflow.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ def execute(self) -> None:
124124
traj=traj,
125125
)
126126

127-
self._run_level_dag(shared_data)
128-
self._run_entropy_graph(shared_data)
127+
with self._reporter.progress(transient=False) as p:
128+
self._run_level_dag(shared_data, progress=p)
129+
self._run_entropy_graph(shared_data, progress=p)
129130

130131
self._finalize_molecule_results()
131132
self._reporter.log_tables()
@@ -164,21 +165,29 @@ def _build_shared_data(
164165
}
165166
return shared_data
166167

167-
def _run_level_dag(self, shared_data: SharedData) -> None:
168+
def _run_level_dag(
169+
self, shared_data: SharedData, *, progress: object | None = None
170+
) -> None:
168171
"""Execute the structural/level DAG.
169172
170173
Args:
171174
shared_data: Shared data dict that will be mutated by the DAG.
175+
progress: Optional progress sink provided by ResultsReporter.progress().
172176
"""
173-
LevelDAG(self._universe_operations).build().execute(shared_data)
177+
LevelDAG(self._universe_operations).build().execute(
178+
shared_data, progress=progress
179+
)
174180

175-
def _run_entropy_graph(self, shared_data: SharedData) -> None:
181+
def _run_entropy_graph(
182+
self, shared_data: SharedData, *, progress: object | None = None
183+
) -> None:
176184
"""Execute the entropy calculation graph and merge results into shared_data.
177185
178186
Args:
179187
shared_data: Shared data dict that will be mutated by the graph.
188+
progress: Optional progress sink provided by ResultsReporter.progress().
180189
"""
181-
entropy_results = EntropyGraph().build().execute(shared_data)
190+
entropy_results = EntropyGraph().build().execute(shared_data, progress=progress)
182191
shared_data.update(entropy_results)
183192

184193
def _build_trajectory_slice(self) -> TrajectorySlice:

CodeEntropy/levels/dihedrals.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -45,79 +45,114 @@ def build_conformational_states(
4545
end: int,
4646
step: int,
4747
bin_width: float,
48+
progress: object | None = None,
4849
):
49-
"""Build conformational state labels for UA and residue levels.
50+
"""Build conformational state labels from trajectory dihedrals.
51+
52+
This method constructs discrete conformational state descriptors used in
53+
configurational entropy calculations. It supports united-atom (UA) and
54+
residue-level state generation depending on which hierarchy levels are
55+
enabled per molecule.
56+
57+
Progress reporting is optional and UI-agnostic: if a progress sink is
58+
provided, the method will create a single task and advance it once per
59+
molecule group.
5060
5161
Args:
52-
data_container: MDAnalysis universe containing the system.
53-
levels: Mapping of molecule_id -> list of enabled levels.
62+
data_container: MDAnalysis Universe (or compatible container) used to
63+
extract fragments and compute dihedral time series.
64+
levels: Mapping of molecule_id -> iterable of enabled level names
65+
(e.g., ["united_atom", "residue"]).
5466
groups: Mapping of group_id -> list of molecule_ids.
55-
start: Start frame index (currently not applied in legacy sampling).
56-
end: End frame index (currently not applied in legacy sampling).
57-
step: Step size (currently not applied in legacy sampling).
58-
bin_width: Histogram bin width (degrees).
67+
start: Inclusive start frame index.
68+
end: Exclusive end frame index.
69+
step: Frame stride.
70+
bin_width: Histogram bin width in degrees used when identifying peak
71+
dihedral populations.
72+
progress: Optional progress sink (e.g., from ResultsReporter.progress()).
73+
Must expose add_task(), update(), and advance().
5974
6075
Returns:
6176
Tuple of:
62-
states_ua: Dict[(group_id, res_id)] -> list of state labels.
63-
states_res: List indexed by group_id -> list of state labels.
77+
states_ua: Dict mapping (group_id, local_residue_id) -> list of state
78+
labels (strings) across the analyzed trajectory.
79+
states_res: List-like structure indexed by group_id (or equivalent)
80+
containing residue-level state labels (strings) across the
81+
analyzed trajectory.
82+
83+
Notes:
84+
- This function advances progress once per group_id.
85+
- Frame slicing arguments (start/end/step) are forwarded to downstream
86+
helpers as implemented in this module.
6487
"""
6588
number_groups = len(groups)
6689
states_ua: Dict[UAKey, List[str]] = {}
6790
states_res: List[List[str]] = [None] * number_groups
6891

69-
total_items = self._count_total_items(levels=levels, groups=groups)
70-
71-
with self._progress_bar(total_items) as progress:
92+
task = None
93+
if progress is not None:
94+
total = max(1, len(groups))
7295
task = progress.add_task(
73-
"[green]Building Conformational States...",
74-
total=total_items,
75-
title="Starting...",
96+
"[green]Conformational states",
97+
total=total,
98+
title="Initializing",
7699
)
77100

78-
for group_id in groups.keys():
79-
molecules = groups[group_id]
80-
if not molecules:
101+
if not groups:
102+
if task is not None:
103+
progress.update(task, title="No groups")
104+
progress.advance(task)
105+
return states_ua, states_res
106+
107+
for group_id in groups.keys():
108+
molecules = groups[group_id]
109+
if not molecules:
110+
if task is not None:
111+
progress.update(task, title=f"Group {group_id} (empty)")
81112
progress.advance(task)
82-
continue
113+
continue
83114

84-
mol = self._universe_operations.extract_fragment(
85-
data_container, molecules[0]
86-
)
115+
if task is not None:
116+
progress.update(task, title=f"Group {group_id}")
87117

88-
dihedrals_ua, dihedrals_res = self._collect_dihedrals_for_group(
89-
mol=mol,
90-
level_list=levels[molecules[0]],
91-
)
118+
mol = self._universe_operations.extract_fragment(
119+
data_container, molecules[0]
120+
)
92121

93-
peaks_ua, peaks_res = self._collect_peaks_for_group(
94-
data_container=data_container,
95-
molecules=molecules,
96-
dihedrals_ua=dihedrals_ua,
97-
dihedrals_res=dihedrals_res,
98-
bin_width=bin_width,
99-
start=start,
100-
end=end,
101-
step=step,
102-
level_list=levels[molecules[0]],
103-
)
122+
dihedrals_ua, dihedrals_res = self._collect_dihedrals_for_group(
123+
mol=mol,
124+
level_list=levels[molecules[0]],
125+
)
104126

105-
self._assign_states_for_group(
106-
data_container=data_container,
107-
group_id=group_id,
108-
molecules=molecules,
109-
dihedrals_ua=dihedrals_ua,
110-
peaks_ua=peaks_ua,
111-
dihedrals_res=dihedrals_res,
112-
peaks_res=peaks_res,
113-
start=start,
114-
end=end,
115-
step=step,
116-
level_list=levels[molecules[0]],
117-
states_ua=states_ua,
118-
states_res=states_res,
119-
)
127+
peaks_ua, peaks_res = self._collect_peaks_for_group(
128+
data_container=data_container,
129+
molecules=molecules,
130+
dihedrals_ua=dihedrals_ua,
131+
dihedrals_res=dihedrals_res,
132+
bin_width=bin_width,
133+
start=start,
134+
end=end,
135+
step=step,
136+
level_list=levels[molecules[0]],
137+
)
138+
139+
self._assign_states_for_group(
140+
data_container=data_container,
141+
group_id=group_id,
142+
molecules=molecules,
143+
dihedrals_ua=dihedrals_ua,
144+
peaks_ua=peaks_ua,
145+
dihedrals_res=dihedrals_res,
146+
peaks_res=peaks_res,
147+
start=start,
148+
end=end,
149+
step=step,
150+
level_list=levels[molecules[0]],
151+
states_ua=states_ua,
152+
states_res=states_res,
153+
)
120154

155+
if task is not None:
121156
progress.advance(task)
122157

123158
return states_ua, states_res

CodeEntropy/levels/level_dag.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,28 @@ def build(self) -> "LevelDAG":
8282
self._frame_dag.build()
8383
return self
8484

85-
def execute(self, shared_data: Dict[str, Any]) -> Dict[str, Any]:
86-
"""Execute the full hierarchy workflow and mutate shared_data.
87-
88-
Args:
89-
shared_data: Shared workflow data dict that will be mutated by the DAG.
90-
91-
Returns:
92-
The mutated shared_data dict.
93-
"""
85+
def execute(
86+
self, shared_data: Dict[str, Any], *, progress: object | None = None
87+
) -> Dict[str, Any]:
88+
"""Execute the full hierarchy workflow and mutate shared_data."""
9489
shared_data.setdefault("axes_manager", AxesCalculator())
95-
self._run_static_stage(shared_data)
96-
self._run_frame_stage(shared_data)
90+
self._run_static_stage(shared_data, progress=progress)
91+
self._run_frame_stage(shared_data, progress=progress)
9792
return shared_data
9893

99-
def _run_static_stage(self, shared_data: Dict[str, Any]) -> None:
94+
def _run_static_stage(
95+
self, shared_data: Dict[str, Any], *, progress: object | None = None
96+
) -> None:
10097
"""Run all static nodes in dependency order."""
10198
for node_name in nx.topological_sort(self._static_graph):
102-
self._static_nodes[node_name].run(shared_data)
99+
node = self._static_nodes[node_name]
100+
if progress is not None:
101+
try:
102+
node.run(shared_data, progress=progress)
103+
continue
104+
except TypeError:
105+
pass
106+
node.run(shared_data)
103107

104108
def _add_static(
105109
self, name: str, node: Any, deps: Optional[list[str]] = None
@@ -110,16 +114,74 @@ def _add_static(
110114
for dep in deps or []:
111115
self._static_graph.add_edge(dep, name)
112116

113-
def _run_frame_stage(self, shared_data: Dict[str, Any]) -> None:
114-
"""Run the frame DAG for each selected trajectory frame and reduce outputs."""
117+
def _run_frame_stage(
118+
self, shared_data: Dict[str, Any], *, progress: object | None = None
119+
) -> None:
120+
"""Execute the per-frame DAG stage and reduce frame outputs.
121+
122+
This method iterates over the selected trajectory frames, executes the
123+
frame-local DAG for each frame, and reduces the resulting outputs into the
124+
shared accumulators stored in `shared_data`.
125+
126+
Progress reporting is optional. If a progress sink is provided, a task is
127+
always created. When the total number of frames cannot be determined, the
128+
task is created with total=None (indeterminate).
129+
130+
Args:
131+
shared_data: Shared data dictionary. Must contain:
132+
- "reduced_universe": MDAnalysis Universe providing the trajectory.
133+
- "start", "end", "step": frame slicing parameters.
134+
- any additional keys required by the frame DAG and reducer.
135+
progress: Optional progress sink (e.g., from ResultsReporter.progress()).
136+
Must expose add_task(), update(), and advance().
137+
138+
Returns:
139+
None. Mutates `shared_data` in-place via reduction.
140+
141+
Notes:
142+
The task title shows the current frame index being processed.
143+
"""
115144
u = shared_data["reduced_universe"]
116145
start, end, step = shared_data["start"], shared_data["end"], shared_data["step"]
117146

147+
task = None
148+
total_frames = None
149+
150+
if progress is not None:
151+
try:
152+
n_frames = len(u.trajectory)
153+
154+
s = 0 if start is None else int(start)
155+
e = n_frames if end is None else int(end)
156+
157+
if e < 0:
158+
e = n_frames + e
159+
160+
e = max(0, min(e, n_frames))
161+
s = max(0, min(s, e))
162+
163+
st = 1 if step is None else int(step)
164+
if st > 0:
165+
total_frames = max(0, (e - s + st - 1) // st)
166+
except Exception:
167+
total_frames = None
168+
169+
task = progress.add_task(
170+
"[green]Frame processing",
171+
total=total_frames,
172+
title="Initializing",
173+
)
174+
118175
for ts in u.trajectory[start:end:step]:
119-
frame_index = ts.frame
120-
frame_out = self._frame_dag.execute_frame(shared_data, frame_index)
176+
if task is not None:
177+
progress.update(task, title=f"Frame {ts.frame}")
178+
179+
frame_out = self._frame_dag.execute_frame(shared_data, ts.frame)
121180
self._reduce_one_frame(shared_data, frame_out)
122181

182+
if task is not None:
183+
progress.advance(task)
184+
123185
@staticmethod
124186
def _incremental_mean(old: Any, new: Any, n: int) -> Any:
125187
"""Compute an incremental mean.

0 commit comments

Comments
 (0)