1111
12122) Frame stage (runs for each trajectory frame):
1313 - Execute the `FrameGraph` to produce frame-local covariance outputs.
14- - Reduce frame-local outputs into running (incremental) means .
14+ - Reduce frame-local outputs into deterministic sums and counts .
1515"""
1616
1717from __future__ import annotations
@@ -41,10 +41,11 @@ class LevelDAG:
4141 The LevelDAG is responsible for:
4242 - Running a static DAG (once) to prepare shared inputs.
4343 - Running a per-frame DAG (for each frame) to compute frame-local outputs.
44- - Reducing frame-local outputs into shared running means .
44+ - Reducing frame-local outputs into deterministic sums and counts .
4545
46- The reduction performed here is an incremental mean across frames (and across
47- molecules within a group when frame nodes average within-frame first).
46+ The reduction performed here is order-independent: frame-local sums and
47+ counts are accumulated across frames and final means are computed once after
48+ all frames have been processed.
4849 """
4950
5051 def __init__ (self , universe_operations : Any | None = None ) -> None :
@@ -98,7 +99,7 @@ def execute(
9899
99100 This method ensures required shared components exist, runs the static stage
100101 once, then iterates through trajectory frames to run the per-frame stage and
101- reduce outputs into running means .
102+ reduce outputs into deterministic sums and counts .
102103
103104 Args:
104105 shared_data: Shared workflow data dict. This mapping is mutated in-place
@@ -112,6 +113,7 @@ def execute(
112113 shared_data .setdefault ("axes_manager" , AxesCalculator ())
113114 self ._run_static_stage (shared_data , progress = progress )
114115 self ._run_frame_stage (shared_data , progress = progress )
116+ self ._finalize_means (shared_data )
115117 return shared_data
116118
117119 def _run_static_stage (
@@ -220,26 +222,10 @@ def _run_frame_stage(
220222 if progress is not None and task is not None :
221223 progress .advance (task )
222224
223- @staticmethod
224- def _incremental_mean (old : Any , new : Any , n : int ) -> Any :
225- """Compute an incremental mean.
226-
227- Args:
228- old: Previous running mean (or None for first sample).
229- new: New sample to incorporate.
230- n: 1-based sample count after adding `new`.
231-
232- Returns:
233- Updated running mean.
234- """
235- if old is None :
236- return new .copy () if hasattr (new , "copy" ) else new
237- return old + (new - old ) / float (n )
238-
239225 def _reduce_one_frame (
240226 self , shared_data : dict [str , Any ], frame_out : dict [str , Any ]
241227 ) -> None :
242- """Reduce one frame's covariance outputs into shared running means .
228+ """Reduce one frame's covariance outputs into shared sum accumulators .
243229
244230 Args:
245231 shared_data: Shared workflow data dict containing accumulators.
@@ -251,94 +237,191 @@ def _reduce_one_frame(
251237 def _reduce_force_and_torque (
252238 self , shared_data : dict [str , Any ], frame_out : dict [str , Any ]
253239 ) -> None :
254- """Reduce force/torque covariance outputs into shared accumulators.
240+ """Reduce force/torque frame-local sums into shared accumulators.
255241
256242 Args:
257243 shared_data: Shared workflow data dict containing:
258- - "force_covariances ", "torque_covariances ": accumulator structures .
259- - "frame_counts" : running sample counts for each accumulator slot .
244+ - "force_sums ", "torque_sums ": running sum accumulators .
245+ - "force_counts", "torque_counts" : running sample counts.
260246 - "group_id_to_index": mapping from group id to accumulator index.
261- frame_out: Frame-local outputs containing "force" and "torque" sections.
247+ frame_out: Frame-local outputs containing "force", "torque",
248+ "force_counts", and "torque_counts" sections.
262249
263250 Returns:
264- None. Mutates accumulator values and counts in shared_data in-place.
251+ None. Mutates shared accumulators and counts in-place.
265252 """
266- f_cov = shared_data ["force_covariances" ]
267- t_cov = shared_data ["torque_covariances" ]
268- counts = shared_data ["frame_counts" ]
253+ f_sums = shared_data ["force_sums" ]
254+ t_sums = shared_data ["torque_sums" ]
255+ f_counts = shared_data ["force_counts" ]
256+ t_counts = shared_data ["torque_counts" ]
269257 gid2i = shared_data ["group_id_to_index" ]
270258
271259 f_frame = frame_out ["force" ]
272260 t_frame = frame_out ["torque" ]
273-
274- for key , F in f_frame ["ua" ].items ():
275- counts ["ua" ][key ] = counts ["ua" ].get (key , 0 ) + 1
276- n = counts ["ua" ][key ]
277- f_cov ["ua" ][key ] = self ._incremental_mean (f_cov ["ua" ].get (key ), F , n )
278-
279- for key , T in t_frame ["ua" ].items ():
280- if key not in counts ["ua" ]:
281- counts ["ua" ][key ] = counts ["ua" ].get (key , 0 ) + 1
282- n = counts ["ua" ][key ]
283- t_cov ["ua" ][key ] = self ._incremental_mean (t_cov ["ua" ].get (key ), T , n )
284-
285- for gid , F in f_frame ["res" ].items ():
261+ f_frame_counts = frame_out ["force_counts" ]
262+ t_frame_counts = frame_out ["torque_counts" ]
263+
264+ for key in sorted (f_frame ["ua" ].keys ()):
265+ F = f_frame ["ua" ][key ]
266+ c = int (f_frame_counts ["ua" ].get (key , 0 ))
267+ if c <= 0 :
268+ continue
269+ prev = f_sums ["ua" ].get (key )
270+ f_sums ["ua" ][key ] = F .copy () if prev is None else prev + F
271+ f_counts ["ua" ][key ] = f_counts ["ua" ].get (key , 0 ) + c
272+
273+ for key in sorted (t_frame ["ua" ].keys ()):
274+ T = t_frame ["ua" ][key ]
275+ c = int (t_frame_counts ["ua" ].get (key , 0 ))
276+ if c <= 0 :
277+ continue
278+ prev = t_sums ["ua" ].get (key )
279+ t_sums ["ua" ][key ] = T .copy () if prev is None else prev + T
280+ t_counts ["ua" ][key ] = t_counts ["ua" ].get (key , 0 ) + c
281+
282+ for gid in sorted (f_frame ["res" ].keys ()):
283+ F = f_frame ["res" ][gid ]
286284 gi = gid2i [gid ]
287- counts ["res" ][gi ] += 1
288- n = counts ["res" ][gi ]
289- f_cov ["res" ][gi ] = self ._incremental_mean (f_cov ["res" ][gi ], F , n )
290-
291- for gid , T in t_frame ["res" ].items ():
285+ c = int (f_frame_counts ["res" ].get (gid , 0 ))
286+ if c <= 0 :
287+ continue
288+ prev = f_sums ["res" ][gi ]
289+ f_sums ["res" ][gi ] = F .copy () if prev is None else prev + F
290+ f_counts ["res" ][gi ] += c
291+
292+ for gid in sorted (t_frame ["res" ].keys ()):
293+ T = t_frame ["res" ][gid ]
292294 gi = gid2i [gid ]
293- if counts ["res" ][gi ] == 0 :
294- counts ["res" ][gi ] += 1
295- n = counts ["res" ][gi ]
296- t_cov ["res" ][gi ] = self ._incremental_mean (t_cov ["res" ][gi ], T , n )
297-
298- for gid , F in f_frame ["poly" ].items ():
295+ c = int (t_frame_counts ["res" ].get (gid , 0 ))
296+ if c <= 0 :
297+ continue
298+ prev = t_sums ["res" ][gi ]
299+ t_sums ["res" ][gi ] = T .copy () if prev is None else prev + T
300+ t_counts ["res" ][gi ] += c
301+
302+ for gid in sorted (f_frame ["poly" ].keys ()):
303+ F = f_frame ["poly" ][gid ]
299304 gi = gid2i [gid ]
300- counts ["poly" ][gi ] += 1
301- n = counts ["poly" ][gi ]
302- f_cov ["poly" ][gi ] = self ._incremental_mean (f_cov ["poly" ][gi ], F , n )
303-
304- for gid , T in t_frame ["poly" ].items ():
305+ c = int (f_frame_counts ["poly" ].get (gid , 0 ))
306+ if c <= 0 :
307+ continue
308+ prev = f_sums ["poly" ][gi ]
309+ f_sums ["poly" ][gi ] = F .copy () if prev is None else prev + F
310+ f_counts ["poly" ][gi ] += c
311+
312+ for gid in sorted (t_frame ["poly" ].keys ()):
313+ T = t_frame ["poly" ][gid ]
305314 gi = gid2i [gid ]
306- if counts ["poly" ][gi ] == 0 :
307- counts ["poly" ][gi ] += 1
308- n = counts ["poly" ][gi ]
309- t_cov ["poly" ][gi ] = self ._incremental_mean (t_cov ["poly" ][gi ], T , n )
315+ c = int (t_frame_counts ["poly" ].get (gid , 0 ))
316+ if c <= 0 :
317+ continue
318+ prev = t_sums ["poly" ][gi ]
319+ t_sums ["poly" ][gi ] = T .copy () if prev is None else prev + T
320+ t_counts ["poly" ][gi ] += c
310321
311322 def _reduce_forcetorque (
312323 self , shared_data : dict [str , Any ], frame_out : dict [str , Any ]
313324 ) -> None :
314- """Reduce combined force-torque covariance outputs into shared accumulators.
325+ """Reduce combined force-torque frame-local sums into shared accumulators.
315326
316327 Args:
317328 shared_data: Shared workflow data dict containing:
318- - "forcetorque_covariances ": accumulator structures .
319- - "forcetorque_counts": running sample counts for each accumulator slot .
329+ - "forcetorque_sums ": running sum accumulators .
330+ - "forcetorque_counts": running sample counts.
320331 - "group_id_to_index": mapping from group id to accumulator index.
321- frame_out: Frame-local outputs that may include a "forcetorque" section.
332+ frame_out: Frame-local outputs that may include "forcetorque" and
333+ "forcetorque_counts" sections.
322334
323335 Returns:
324- None. Mutates accumulator values and counts in shared_data in-place.
336+ None. Mutates shared accumulators and counts in-place.
325337 """
326338 if "forcetorque" not in frame_out :
327339 return
328340
329- ft_cov = shared_data ["forcetorque_covariances " ]
341+ ft_sums = shared_data ["forcetorque_sums " ]
330342 ft_counts = shared_data ["forcetorque_counts" ]
331343 gid2i = shared_data ["group_id_to_index" ]
344+
332345 ft_frame = frame_out ["forcetorque" ]
346+ ft_frame_counts = frame_out .get ("forcetorque_counts" , {"res" : {}, "poly" : {}})
333347
334- for gid , M in ft_frame .get ("res" , {}).items ():
348+ for gid in sorted (ft_frame .get ("res" , {}).keys ()):
349+ M = ft_frame ["res" ][gid ]
335350 gi = gid2i [gid ]
336- ft_counts ["res" ][gi ] += 1
337- n = ft_counts ["res" ][gi ]
338- ft_cov ["res" ][gi ] = self ._incremental_mean (ft_cov ["res" ][gi ], M , n )
339-
340- for gid , M in ft_frame .get ("poly" , {}).items ():
351+ c = int (ft_frame_counts .get ("res" , {}).get (gid , 0 ))
352+ if c <= 0 :
353+ continue
354+ prev = ft_sums ["res" ][gi ]
355+ ft_sums ["res" ][gi ] = M .copy () if prev is None else prev + M
356+ ft_counts ["res" ][gi ] += c
357+
358+ for gid in sorted (ft_frame .get ("poly" , {}).keys ()):
359+ M = ft_frame ["poly" ][gid ]
341360 gi = gid2i [gid ]
342- ft_counts ["poly" ][gi ] += 1
343- n = ft_counts ["poly" ][gi ]
344- ft_cov ["poly" ][gi ] = self ._incremental_mean (ft_cov ["poly" ][gi ], M , n )
361+ c = int (ft_frame_counts .get ("poly" , {}).get (gid , 0 ))
362+ if c <= 0 :
363+ continue
364+ prev = ft_sums ["poly" ][gi ]
365+ ft_sums ["poly" ][gi ] = M .copy () if prev is None else prev + M
366+ ft_counts ["poly" ][gi ] += c
367+
368+ def _finalize_means (self , shared_data : dict [str , Any ]) -> None :
369+ """Compute finalized mean matrices from accumulated sums and counts.
370+
371+ Args:
372+ shared_data: Shared workflow data dict containing running sums and counts.
373+
374+ Returns:
375+ None. Writes finalized mean matrices back into shared_data.
376+ """
377+
378+ def _compute_means (
379+ sums : dict [str , Any ],
380+ counts : dict [str , Any ],
381+ ) -> dict [str , Any ]:
382+ out : dict [str , Any ] = {}
383+
384+ for domain in sorted (sums .keys ()):
385+ domain_sums = sums [domain ]
386+ domain_counts = counts [domain ]
387+
388+ if isinstance (domain_sums , dict ):
389+ out [domain ] = {}
390+ for key in sorted (domain_sums .keys ()):
391+ total = domain_sums [key ]
392+ count = int (domain_counts .get (key , 0 ))
393+ out [domain ][key ] = total / float (count ) if count > 0 else None
394+ continue
395+
396+ mean_list : list [Any ] = [None ] * len (domain_sums )
397+ for idx , total in enumerate (domain_sums ):
398+ if total is None :
399+ continue
400+ count = int (domain_counts [idx ])
401+ mean_list [idx ] = total / float (count ) if count > 0 else None
402+ out [domain ] = mean_list
403+
404+ return out
405+
406+ shared_data ["force_covariances" ] = _compute_means (
407+ shared_data ["force_sums" ],
408+ shared_data ["force_counts" ],
409+ )
410+ shared_data ["torque_covariances" ] = _compute_means (
411+ shared_data ["torque_sums" ],
412+ shared_data ["torque_counts" ],
413+ )
414+ shared_data ["forcetorque_covariances" ] = _compute_means (
415+ shared_data ["forcetorque_sums" ],
416+ shared_data ["forcetorque_counts" ],
417+ )
418+
419+ shared_data ["frame_counts" ] = shared_data ["force_counts" ]
420+ shared_data ["force_torque_stats" ] = {
421+ "res" : list (shared_data ["forcetorque_covariances" ]["res" ]),
422+ "poly" : list (shared_data ["forcetorque_covariances" ]["poly" ]),
423+ }
424+ shared_data ["force_torque_counts" ] = {
425+ "res" : shared_data ["forcetorque_counts" ]["res" ].copy (),
426+ "poly" : shared_data ["forcetorque_counts" ]["poly" ].copy (),
427+ }
0 commit comments