Skip to content

Commit 183cd9e

Browse files
committed
refactor(levels): replace incremental mean with deterministic sum/count reduction and update tests
1 parent a9a3a86 commit 183cd9e

9 files changed

Lines changed: 891 additions & 485 deletions

File tree

CodeEntropy/entropy/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _build_shared_data(
165165
"universe": self._universe,
166166
"reduced_universe": reduced_universe,
167167
"levels": levels,
168-
"groups": dict(groups),
168+
"groups": dict(sorted(groups.items())),
169169
"start": traj.start,
170170
"end": traj.end,
171171
"step": traj.step,

CodeEntropy/levels/level_dag.py

Lines changed: 161 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
1212
2) Frame stage (runs for each trajectory frame):
1313
- Execute the `FrameGraph` to produce frame-local covariance outputs.
14-
- Reduce frame-local outputs into running (incremental) means.
14+
- Reduce frame-local outputs into deterministic sums and counts.
1515
"""
1616

1717
from __future__ import annotations
@@ -41,10 +41,11 @@ class LevelDAG:
4141
The LevelDAG is responsible for:
4242
- Running a static DAG (once) to prepare shared inputs.
4343
- Running a per-frame DAG (for each frame) to compute frame-local outputs.
44-
- Reducing frame-local outputs into shared running means.
44+
- Reducing frame-local outputs into deterministic sums and counts.
4545
46-
The reduction performed here is an incremental mean across frames (and across
47-
molecules within a group when frame nodes average within-frame first).
46+
The reduction performed here is order-independent: frame-local sums and
47+
counts are accumulated across frames and final means are computed once after
48+
all frames have been processed.
4849
"""
4950

5051
def __init__(self, universe_operations: Any | None = None) -> None:
@@ -98,7 +99,7 @@ def execute(
9899
99100
This method ensures required shared components exist, runs the static stage
100101
once, then iterates through trajectory frames to run the per-frame stage and
101-
reduce outputs into running means.
102+
reduce outputs into deterministic sums and counts.
102103
103104
Args:
104105
shared_data: Shared workflow data dict. This mapping is mutated in-place
@@ -112,6 +113,7 @@ def execute(
112113
shared_data.setdefault("axes_manager", AxesCalculator())
113114
self._run_static_stage(shared_data, progress=progress)
114115
self._run_frame_stage(shared_data, progress=progress)
116+
self._finalize_means(shared_data)
115117
return shared_data
116118

117119
def _run_static_stage(
@@ -220,26 +222,10 @@ def _run_frame_stage(
220222
if progress is not None and task is not None:
221223
progress.advance(task)
222224

223-
@staticmethod
224-
def _incremental_mean(old: Any, new: Any, n: int) -> Any:
225-
"""Compute an incremental mean.
226-
227-
Args:
228-
old: Previous running mean (or None for first sample).
229-
new: New sample to incorporate.
230-
n: 1-based sample count after adding `new`.
231-
232-
Returns:
233-
Updated running mean.
234-
"""
235-
if old is None:
236-
return new.copy() if hasattr(new, "copy") else new
237-
return old + (new - old) / float(n)
238-
239225
def _reduce_one_frame(
240226
self, shared_data: dict[str, Any], frame_out: dict[str, Any]
241227
) -> None:
242-
"""Reduce one frame's covariance outputs into shared running means.
228+
"""Reduce one frame's covariance outputs into shared sum accumulators.
243229
244230
Args:
245231
shared_data: Shared workflow data dict containing accumulators.
@@ -251,94 +237,191 @@ def _reduce_one_frame(
251237
def _reduce_force_and_torque(
252238
self, shared_data: dict[str, Any], frame_out: dict[str, Any]
253239
) -> None:
254-
"""Reduce force/torque covariance outputs into shared accumulators.
240+
"""Reduce force/torque frame-local sums into shared accumulators.
255241
256242
Args:
257243
shared_data: Shared workflow data dict containing:
258-
- "force_covariances", "torque_covariances": accumulator structures.
259-
- "frame_counts": running sample counts for each accumulator slot.
244+
- "force_sums", "torque_sums": running sum accumulators.
245+
- "force_counts", "torque_counts": running sample counts.
260246
- "group_id_to_index": mapping from group id to accumulator index.
261-
frame_out: Frame-local outputs containing "force" and "torque" sections.
247+
frame_out: Frame-local outputs containing "force", "torque",
248+
"force_counts", and "torque_counts" sections.
262249
263250
Returns:
264-
None. Mutates accumulator values and counts in shared_data in-place.
251+
None. Mutates shared accumulators and counts in-place.
265252
"""
266-
f_cov = shared_data["force_covariances"]
267-
t_cov = shared_data["torque_covariances"]
268-
counts = shared_data["frame_counts"]
253+
f_sums = shared_data["force_sums"]
254+
t_sums = shared_data["torque_sums"]
255+
f_counts = shared_data["force_counts"]
256+
t_counts = shared_data["torque_counts"]
269257
gid2i = shared_data["group_id_to_index"]
270258

271259
f_frame = frame_out["force"]
272260
t_frame = frame_out["torque"]
273-
274-
for key, F in f_frame["ua"].items():
275-
counts["ua"][key] = counts["ua"].get(key, 0) + 1
276-
n = counts["ua"][key]
277-
f_cov["ua"][key] = self._incremental_mean(f_cov["ua"].get(key), F, n)
278-
279-
for key, T in t_frame["ua"].items():
280-
if key not in counts["ua"]:
281-
counts["ua"][key] = counts["ua"].get(key, 0) + 1
282-
n = counts["ua"][key]
283-
t_cov["ua"][key] = self._incremental_mean(t_cov["ua"].get(key), T, n)
284-
285-
for gid, F in f_frame["res"].items():
261+
f_frame_counts = frame_out["force_counts"]
262+
t_frame_counts = frame_out["torque_counts"]
263+
264+
for key in sorted(f_frame["ua"].keys()):
265+
F = f_frame["ua"][key]
266+
c = int(f_frame_counts["ua"].get(key, 0))
267+
if c <= 0:
268+
continue
269+
prev = f_sums["ua"].get(key)
270+
f_sums["ua"][key] = F.copy() if prev is None else prev + F
271+
f_counts["ua"][key] = f_counts["ua"].get(key, 0) + c
272+
273+
for key in sorted(t_frame["ua"].keys()):
274+
T = t_frame["ua"][key]
275+
c = int(t_frame_counts["ua"].get(key, 0))
276+
if c <= 0:
277+
continue
278+
prev = t_sums["ua"].get(key)
279+
t_sums["ua"][key] = T.copy() if prev is None else prev + T
280+
t_counts["ua"][key] = t_counts["ua"].get(key, 0) + c
281+
282+
for gid in sorted(f_frame["res"].keys()):
283+
F = f_frame["res"][gid]
286284
gi = gid2i[gid]
287-
counts["res"][gi] += 1
288-
n = counts["res"][gi]
289-
f_cov["res"][gi] = self._incremental_mean(f_cov["res"][gi], F, n)
290-
291-
for gid, T in t_frame["res"].items():
285+
c = int(f_frame_counts["res"].get(gid, 0))
286+
if c <= 0:
287+
continue
288+
prev = f_sums["res"][gi]
289+
f_sums["res"][gi] = F.copy() if prev is None else prev + F
290+
f_counts["res"][gi] += c
291+
292+
for gid in sorted(t_frame["res"].keys()):
293+
T = t_frame["res"][gid]
292294
gi = gid2i[gid]
293-
if counts["res"][gi] == 0:
294-
counts["res"][gi] += 1
295-
n = counts["res"][gi]
296-
t_cov["res"][gi] = self._incremental_mean(t_cov["res"][gi], T, n)
297-
298-
for gid, F in f_frame["poly"].items():
295+
c = int(t_frame_counts["res"].get(gid, 0))
296+
if c <= 0:
297+
continue
298+
prev = t_sums["res"][gi]
299+
t_sums["res"][gi] = T.copy() if prev is None else prev + T
300+
t_counts["res"][gi] += c
301+
302+
for gid in sorted(f_frame["poly"].keys()):
303+
F = f_frame["poly"][gid]
299304
gi = gid2i[gid]
300-
counts["poly"][gi] += 1
301-
n = counts["poly"][gi]
302-
f_cov["poly"][gi] = self._incremental_mean(f_cov["poly"][gi], F, n)
303-
304-
for gid, T in t_frame["poly"].items():
305+
c = int(f_frame_counts["poly"].get(gid, 0))
306+
if c <= 0:
307+
continue
308+
prev = f_sums["poly"][gi]
309+
f_sums["poly"][gi] = F.copy() if prev is None else prev + F
310+
f_counts["poly"][gi] += c
311+
312+
for gid in sorted(t_frame["poly"].keys()):
313+
T = t_frame["poly"][gid]
305314
gi = gid2i[gid]
306-
if counts["poly"][gi] == 0:
307-
counts["poly"][gi] += 1
308-
n = counts["poly"][gi]
309-
t_cov["poly"][gi] = self._incremental_mean(t_cov["poly"][gi], T, n)
315+
c = int(t_frame_counts["poly"].get(gid, 0))
316+
if c <= 0:
317+
continue
318+
prev = t_sums["poly"][gi]
319+
t_sums["poly"][gi] = T.copy() if prev is None else prev + T
320+
t_counts["poly"][gi] += c
310321

311322
def _reduce_forcetorque(
312323
self, shared_data: dict[str, Any], frame_out: dict[str, Any]
313324
) -> None:
314-
"""Reduce combined force-torque covariance outputs into shared accumulators.
325+
"""Reduce combined force-torque frame-local sums into shared accumulators.
315326
316327
Args:
317328
shared_data: Shared workflow data dict containing:
318-
- "forcetorque_covariances": accumulator structures.
319-
- "forcetorque_counts": running sample counts for each accumulator slot.
329+
- "forcetorque_sums": running sum accumulators.
330+
- "forcetorque_counts": running sample counts.
320331
- "group_id_to_index": mapping from group id to accumulator index.
321-
frame_out: Frame-local outputs that may include a "forcetorque" section.
332+
frame_out: Frame-local outputs that may include "forcetorque" and
333+
"forcetorque_counts" sections.
322334
323335
Returns:
324-
None. Mutates accumulator values and counts in shared_data in-place.
336+
None. Mutates shared accumulators and counts in-place.
325337
"""
326338
if "forcetorque" not in frame_out:
327339
return
328340

329-
ft_cov = shared_data["forcetorque_covariances"]
341+
ft_sums = shared_data["forcetorque_sums"]
330342
ft_counts = shared_data["forcetorque_counts"]
331343
gid2i = shared_data["group_id_to_index"]
344+
332345
ft_frame = frame_out["forcetorque"]
346+
ft_frame_counts = frame_out.get("forcetorque_counts", {"res": {}, "poly": {}})
333347

334-
for gid, M in ft_frame.get("res", {}).items():
348+
for gid in sorted(ft_frame.get("res", {}).keys()):
349+
M = ft_frame["res"][gid]
335350
gi = gid2i[gid]
336-
ft_counts["res"][gi] += 1
337-
n = ft_counts["res"][gi]
338-
ft_cov["res"][gi] = self._incremental_mean(ft_cov["res"][gi], M, n)
339-
340-
for gid, M in ft_frame.get("poly", {}).items():
351+
c = int(ft_frame_counts.get("res", {}).get(gid, 0))
352+
if c <= 0:
353+
continue
354+
prev = ft_sums["res"][gi]
355+
ft_sums["res"][gi] = M.copy() if prev is None else prev + M
356+
ft_counts["res"][gi] += c
357+
358+
for gid in sorted(ft_frame.get("poly", {}).keys()):
359+
M = ft_frame["poly"][gid]
341360
gi = gid2i[gid]
342-
ft_counts["poly"][gi] += 1
343-
n = ft_counts["poly"][gi]
344-
ft_cov["poly"][gi] = self._incremental_mean(ft_cov["poly"][gi], M, n)
361+
c = int(ft_frame_counts.get("poly", {}).get(gid, 0))
362+
if c <= 0:
363+
continue
364+
prev = ft_sums["poly"][gi]
365+
ft_sums["poly"][gi] = M.copy() if prev is None else prev + M
366+
ft_counts["poly"][gi] += c
367+
368+
def _finalize_means(self, shared_data: dict[str, Any]) -> None:
369+
"""Compute finalized mean matrices from accumulated sums and counts.
370+
371+
Args:
372+
shared_data: Shared workflow data dict containing running sums and counts.
373+
374+
Returns:
375+
None. Writes finalized mean matrices back into shared_data.
376+
"""
377+
378+
def _compute_means(
379+
sums: dict[str, Any],
380+
counts: dict[str, Any],
381+
) -> dict[str, Any]:
382+
out: dict[str, Any] = {}
383+
384+
for domain in sorted(sums.keys()):
385+
domain_sums = sums[domain]
386+
domain_counts = counts[domain]
387+
388+
if isinstance(domain_sums, dict):
389+
out[domain] = {}
390+
for key in sorted(domain_sums.keys()):
391+
total = domain_sums[key]
392+
count = int(domain_counts.get(key, 0))
393+
out[domain][key] = total / float(count) if count > 0 else None
394+
continue
395+
396+
mean_list: list[Any] = [None] * len(domain_sums)
397+
for idx, total in enumerate(domain_sums):
398+
if total is None:
399+
continue
400+
count = int(domain_counts[idx])
401+
mean_list[idx] = total / float(count) if count > 0 else None
402+
out[domain] = mean_list
403+
404+
return out
405+
406+
shared_data["force_covariances"] = _compute_means(
407+
shared_data["force_sums"],
408+
shared_data["force_counts"],
409+
)
410+
shared_data["torque_covariances"] = _compute_means(
411+
shared_data["torque_sums"],
412+
shared_data["torque_counts"],
413+
)
414+
shared_data["forcetorque_covariances"] = _compute_means(
415+
shared_data["forcetorque_sums"],
416+
shared_data["forcetorque_counts"],
417+
)
418+
419+
shared_data["frame_counts"] = shared_data["force_counts"]
420+
shared_data["force_torque_stats"] = {
421+
"res": list(shared_data["forcetorque_covariances"]["res"]),
422+
"poly": list(shared_data["forcetorque_covariances"]["poly"]),
423+
}
424+
shared_data["force_torque_counts"] = {
425+
"res": shared_data["forcetorque_counts"]["res"].copy(),
426+
"poly": shared_data["forcetorque_counts"]["poly"].copy(),
427+
}

0 commit comments

Comments
 (0)