Skip to content

Commit 69cb883

Browse files
committed
refactor(levels): introduce dedicated ConformationDAG stage
1 parent 3efc3fa commit 69cb883

4 files changed

Lines changed: 136 additions & 55 deletions

File tree

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""Compute conformational states for configurational entropy calculations."""
1+
"""Conformational-state DAG orchestration.
2+
3+
This module owns the conformational stage between static structural setup and
4+
frame-local covariance/neighbour execution.
5+
"""
26

37
from __future__ import annotations
48

@@ -12,44 +16,34 @@
1216
FlexibleStates = dict[str, Any]
1317

1418

15-
class ComputeConformationalStatesNode:
16-
"""Static node that computes conformational states from selected frames.
19+
class ConformationDAG:
20+
"""Execute conformational-state construction for selected trajectory frames.
1721
18-
Produces:
19-
shared_data["conformational_states"] = {"ua": states_ua, "res": states_res}
20-
shared_data["flexible_dihedrals"] = {"ua": flexible_ua, "res": flexible_res}
22+
The first implementation intentionally preserves the existing serial
23+
ConformationStateBuilder behaviour. Later issues can replace this internal
24+
implementation with chunked map-reduce execution.
2125
"""
2226

2327
def __init__(self, universe_operations: Any | None = None) -> None:
24-
"""Initialise the conformational-state node.
25-
26-
Args:
27-
universe_operations: Optional universe-operation adapter passed to the
28-
underlying conformation-state builder.
29-
"""
3028
self._builder = ConformationStateBuilder(
3129
universe_operations=universe_operations
3230
)
3331

34-
def run(
32+
def build(self) -> ConformationDAG:
33+
"""Build the conformational DAG topology.
34+
35+
Returns:
36+
Self, to allow fluent construction.
37+
"""
38+
return self
39+
40+
def execute(
3541
self,
3642
shared_data: SharedData,
3743
*,
3844
progress: object | None = None,
3945
) -> dict[str, ConformationalStates]:
40-
"""Compute conformational states and store them in shared workflow data.
41-
42-
Args:
43-
shared_data: Shared workflow data containing ``reduced_universe``,
44-
``levels``, ``groups``, ``frame_selection``, and ``args.bin_width``.
45-
progress: Optional progress sink forwarded to the conformation builder.
46-
47-
Returns:
48-
A dictionary containing the computed ``conformational_states`` mapping.
49-
50-
Raises:
51-
KeyError: If required entries are missing from ``shared_data``.
52-
"""
46+
"""Compute conformational states and store them in shared workflow data."""
5347
universe = shared_data["reduced_universe"]
5448
levels = shared_data["levels"]
5549
groups = shared_data["groups"]

CodeEntropy/levels/level_dag.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Hierarchy-level DAG orchestration.
22
33
LevelDAG owns hierarchy-level workflow order. Static setup nodes prepare
4-
structural and conformational data, then frame-local covariance and neighbour
5-
observables are executed through deterministic frame map-reduce.
4+
structural data. ConformationDAG computes trajectory-series conformational
5+
states. FrameScheduler executes frame-local covariance and neighbour work.
66
"""
77

88
from __future__ import annotations
@@ -12,14 +12,14 @@
1212
import networkx as nx
1313

1414
from CodeEntropy.levels.axes import AxesCalculator
15+
from CodeEntropy.levels.conformation_dag import ConformationDAG
1516
from CodeEntropy.levels.execution.policy import ExecutionPolicy
1617
from CodeEntropy.levels.execution.reducers import NeighborReducer
1718
from CodeEntropy.levels.execution.scheduler import FrameScheduler
1819
from CodeEntropy.levels.frame_dag import FrameGraph
1920
from CodeEntropy.levels.neighbors import Neighbors
2021
from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode
2122
from CodeEntropy.levels.nodes.beads import BuildBeadsNode
22-
from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode
2323
from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode
2424
from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode
2525
from CodeEntropy.results.reporter import _RichProgressSink
@@ -38,15 +38,14 @@ def __init__(self, universe_operations: Any | None = None) -> None:
3838
self._universe_operations = universe_operations
3939
self._static_graph = nx.DiGraph()
4040
self._static_nodes: dict[str, Any] = {}
41+
self._conformation_dag = ConformationDAG(
42+
universe_operations=universe_operations
43+
)
4144
self._frame_dag = FrameGraph(universe_operations=universe_operations)
4245
self._policy = ExecutionPolicy()
4346

4447
def build(self) -> LevelDAG:
45-
"""Build static and frame-level DAG topology.
46-
47-
Returns:
48-
The current ``LevelDAG`` instance for fluent construction.
49-
"""
48+
"""Build the static, conformation, and frame DAG topology."""
5049
self._add_static("detect_molecules", DetectMoleculesNode())
5150
self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"])
5251
self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"])
@@ -55,12 +54,8 @@ def build(self) -> LevelDAG:
5554
InitCovarianceAccumulatorsNode(),
5655
deps=["detect_levels"],
5756
)
58-
self._add_static(
59-
"compute_conformational_states",
60-
ComputeConformationalStatesNode(self._universe_operations),
61-
deps=["detect_levels"],
62-
)
6357

58+
self._conformation_dag.build()
6459
self._frame_dag.build()
6560
return self
6661

@@ -87,6 +82,8 @@ def execute(
8782
shared_data.setdefault("axes_manager", AxesCalculator())
8883

8984
self._run_static_stage(shared_data, progress=progress)
85+
self._run_conformation_stage(shared_data, progress=progress)
86+
9087
self._initialise_neighbor_metadata(shared_data)
9188
NeighborReducer.initialise(shared_data)
9289
self._run_frame_stage(shared_data, progress=progress)
@@ -137,6 +134,15 @@ def _add_static(
137134
for dep in deps or []:
138135
self._static_graph.add_edge(dep, name)
139136

137+
def _run_conformation_stage(
138+
self,
139+
shared_data: dict[str, Any],
140+
*,
141+
progress: _RichProgressSink | None = None,
142+
) -> None:
143+
"""Run conformational-state construction after static setup."""
144+
self._conformation_dag.execute(shared_data, progress=progress)
145+
140146
def _run_frame_stage(
141147
self,
142148
shared_data: dict[str, Any],

tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py renamed to tests/unit/CodeEntropy/levels/test_conformation_dag.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
"""Unit tests for the conformational-state static node."""
1+
"""Unit tests for the conformational-state DAG stage."""
22

33
from __future__ import annotations
44

55
from types import SimpleNamespace
66

7-
from CodeEntropy.levels.nodes import conformations
8-
from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode
7+
from CodeEntropy.levels import conformation_dag
8+
from CodeEntropy.levels.conformation_dag import ConformationDAG
99

1010

1111
class FakeConformationStateBuilder:
@@ -43,7 +43,13 @@ def build_conformational_states(
4343
)
4444

4545

46-
def test_compute_conformational_states_node_runs_and_writes_shared_data(monkeypatch):
46+
def test_conformation_dag_build_returns_self():
47+
dag = ConformationDAG()
48+
49+
assert dag.build() is dag
50+
51+
52+
def test_conformation_dag_executes_builder_and_writes_shared_data(monkeypatch):
4753
builder_holder = {}
4854

4955
def builder_factory(universe_operations):
@@ -52,13 +58,13 @@ def builder_factory(universe_operations):
5258
return builder
5359

5460
monkeypatch.setattr(
55-
conformations,
61+
conformation_dag,
5662
"ConformationStateBuilder",
5763
builder_factory,
5864
)
5965

6066
universe_operations = object()
61-
node = ComputeConformationalStatesNode(universe_operations)
67+
dag = ConformationDAG(universe_operations=universe_operations)
6268

6369
universe = object()
6470
frame_selection = object()
@@ -72,7 +78,7 @@ def builder_factory(universe_operations):
7278
"args": SimpleNamespace(bin_width=30),
7379
}
7480

75-
result = node.run(shared_data, progress=progress)
81+
result = dag.execute(shared_data, progress=progress)
7682

7783
assert shared_data["conformational_states"] == {
7884
"ua": {"ua_key": ["state_a"]},
@@ -100,20 +106,24 @@ def builder_factory(universe_operations):
100106
]
101107

102108

103-
def test_compute_conformational_states_node_converts_bin_width_to_int(monkeypatch):
109+
def test_conformation_dag_converts_bin_width_to_int(monkeypatch):
104110
captured = {}
105111

106112
class Builder:
107113
def __init__(self, universe_operations):
108-
pass
114+
self.universe_operations = universe_operations
109115

110116
def build_conformational_states(self, **kwargs):
111117
captured.update(kwargs)
112118
return {}, [], {}, []
113119

114-
monkeypatch.setattr(conformations, "ConformationStateBuilder", Builder)
120+
monkeypatch.setattr(
121+
conformation_dag,
122+
"ConformationStateBuilder",
123+
Builder,
124+
)
115125

116-
node = ComputeConformationalStatesNode()
126+
dag = ConformationDAG()
117127
shared_data = {
118128
"reduced_universe": object(),
119129
"levels": [],
@@ -122,6 +132,6 @@ def build_conformational_states(self, **kwargs):
122132
"args": SimpleNamespace(bin_width="45"),
123133
}
124134

125-
node.run(shared_data)
135+
dag.execute(shared_data)
126136

127137
assert captured["bin_width"] == 45

tests/unit/CodeEntropy/levels/test_level_dag.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77
from CodeEntropy.levels.level_dag import LevelDAG
88

99

10-
def test_build_registers_static_nodes_and_builds_frame_dag():
10+
def test_build_registers_static_nodes_and_builds_stage_dags():
1111
with (
1212
patch("CodeEntropy.levels.level_dag.DetectMoleculesNode"),
1313
patch("CodeEntropy.levels.level_dag.DetectLevelsNode"),
1414
patch("CodeEntropy.levels.level_dag.BuildBeadsNode"),
1515
patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode"),
16-
patch("CodeEntropy.levels.level_dag.ComputeConformationalStatesNode"),
16+
patch("CodeEntropy.levels.level_dag.ConformationDAG"),
1717
):
1818
universe_operations = MagicMock()
1919
dag = LevelDAG(universe_operations=universe_operations)
20+
dag._conformation_dag.build = MagicMock()
2021
dag._frame_dag.build = MagicMock()
2122

2223
out = dag.build()
@@ -27,15 +28,19 @@ def test_build_registers_static_nodes_and_builds_frame_dag():
2728
"detect_levels",
2829
"build_beads",
2930
"init_covariance_accumulators",
30-
"compute_conformational_states",
3131
}
3232
assert "find_neighbors" not in dag._static_nodes
33+
assert "compute_conformational_states" not in dag._static_nodes
3334

3435
assert ("detect_molecules", "detect_levels") in dag._static_graph.edges
3536
assert ("detect_levels", "build_beads") in dag._static_graph.edges
3637
assert ("detect_levels", "init_covariance_accumulators") in dag._static_graph.edges
37-
assert ("detect_levels", "compute_conformational_states") in dag._static_graph.edges
38+
assert (
39+
"detect_levels",
40+
"compute_conformational_states",
41+
) not in dag._static_graph.edges
3842

43+
dag._conformation_dag.build.assert_called_once()
3944
dag._frame_dag.build.assert_called_once()
4045

4146

@@ -46,6 +51,7 @@ def test_execute_sets_default_axes_manager_and_runs_workflow_stages():
4651
progress = MagicMock()
4752

4853
dag._run_static_stage = MagicMock()
54+
dag._run_conformation_stage = MagicMock()
4955
dag._initialise_neighbor_metadata = MagicMock()
5056
dag._run_frame_stage = MagicMock()
5157

@@ -59,6 +65,10 @@ def test_execute_sets_default_axes_manager_and_runs_workflow_stages():
5965
assert "axes_manager" in shared_data
6066

6167
dag._run_static_stage.assert_called_once_with(shared_data, progress=progress)
68+
dag._run_conformation_stage.assert_called_once_with(
69+
shared_data,
70+
progress=progress,
71+
)
6272
dag._initialise_neighbor_metadata.assert_called_once_with(shared_data)
6373
initialise.assert_called_once_with(shared_data)
6474
dag._run_frame_stage.assert_called_once_with(shared_data, progress=progress)
@@ -126,6 +136,21 @@ def test_run_static_stage_falls_back_when_node_does_not_accept_progress():
126136
]
127137

128138

139+
def test_run_conformation_stage_delegates_to_conformation_dag():
140+
dag = LevelDAG()
141+
shared_data = {}
142+
progress = MagicMock()
143+
144+
dag._conformation_dag.execute = MagicMock()
145+
146+
dag._run_conformation_stage(shared_data, progress=progress)
147+
148+
dag._conformation_dag.execute.assert_called_once_with(
149+
shared_data,
150+
progress=progress,
151+
)
152+
153+
129154
def test_run_frame_stage_collects_frame_indices_and_delegates_to_scheduler():
130155
universe_operations = MagicMock()
131156
dag = LevelDAG(universe_operations=universe_operations)
@@ -183,3 +208,49 @@ def test_initialise_neighbor_metadata_falls_back_to_universe_key():
183208
LevelDAG._initialise_neighbor_metadata(shared_data)
184209

185210
helper.get_symmetry.assert_called_once_with(universe=universe, groups={0: [0]})
211+
212+
213+
def test_level_dag_runs_static_conformation_then_frame(monkeypatch):
214+
dag = LevelDAG(universe_operations=object())
215+
calls = []
216+
217+
monkeypatch.setattr(
218+
dag,
219+
"_run_static_stage",
220+
lambda shared_data, progress=None: calls.append("static"),
221+
)
222+
monkeypatch.setattr(
223+
dag,
224+
"_run_conformation_stage",
225+
lambda shared_data, progress=None: calls.append("conformation"),
226+
)
227+
monkeypatch.setattr(
228+
dag,
229+
"_initialise_neighbor_metadata",
230+
lambda shared_data: calls.append("neighbor_metadata"),
231+
)
232+
monkeypatch.setattr(
233+
dag,
234+
"_run_frame_stage",
235+
lambda shared_data, progress=None: calls.append("frame"),
236+
)
237+
238+
monkeypatch.setattr(
239+
"CodeEntropy.levels.level_dag.NeighborReducer.initialise",
240+
lambda shared_data: calls.append("neighbor_initialise"),
241+
)
242+
monkeypatch.setattr(
243+
"CodeEntropy.levels.level_dag.NeighborReducer.finalise",
244+
lambda shared_data: calls.append("neighbor_finalise"),
245+
)
246+
247+
dag.execute({})
248+
249+
assert calls == [
250+
"static",
251+
"conformation",
252+
"neighbor_metadata",
253+
"neighbor_initialise",
254+
"frame",
255+
"neighbor_finalise",
256+
]

0 commit comments

Comments
 (0)