Skip to content

Commit 59bdeb5

Browse files
committed
fix(dag): ensure deterministic group indexing and remove shared mutable aliases
- sort group IDs when building `group_id_to_index` to guarantee deterministic ordering - replace backwards-compatible aliases with copies to avoid shared mutable state - update unit tests to reflect deterministic behaviour - add determinism and aliasing tests to prevent regression
1 parent 9a9977e commit 59bdeb5

2 files changed

Lines changed: 54 additions & 13 deletions

File tree

CodeEntropy/levels/nodes/accumulators.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _build_group_index(groups: dict[int, Any]) -> GroupIndex:
103103
Returns:
104104
GroupIndex mapping object.
105105
"""
106-
group_ids = list(groups.keys())
106+
group_ids = sorted(groups.keys())
107107
gid2i = {gid: i for i, gid in enumerate(group_ids)}
108108
return GroupIndex(group_id_to_index=gid2i, index_to_group_id=list(group_ids))
109109

@@ -168,8 +168,14 @@ def _attach_backwards_compatible_aliases(shared_data: SharedData) -> None:
168168
Args:
169169
shared_data: Shared pipeline dictionary.
170170
"""
171-
shared_data["force_torque_stats"] = shared_data["forcetorque_covariances"]
172-
shared_data["force_torque_counts"] = shared_data["forcetorque_counts"]
171+
shared_data["force_torque_stats"] = {
172+
"res": list(shared_data["forcetorque_covariances"]["res"]),
173+
"poly": list(shared_data["forcetorque_covariances"]["poly"]),
174+
}
175+
shared_data["force_torque_counts"] = {
176+
"res": shared_data["forcetorque_counts"]["res"].copy(),
177+
"poly": shared_data["forcetorque_counts"]["poly"].copy(),
178+
}
173179

174180
@staticmethod
175181
def _build_return_payload(shared_data: SharedData) -> dict[str, Any]:
Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode
42

53

@@ -10,14 +8,51 @@ def test_init_covariance_accumulators_allocates_and_sets_aliases():
108

119
out = node.run(shared)
1210

13-
assert out["group_id_to_index"] == {9: 0, 2: 1}
14-
assert out["index_to_group_id"] == [9, 2]
11+
gid2i = out["group_id_to_index"]
12+
i2gid = out["index_to_group_id"]
13+
14+
assert set(gid2i.keys()) == {9, 2}
15+
assert set(i2gid) == {9, 2}
16+
17+
for gid, idx in gid2i.items():
18+
assert i2gid[idx] == gid
19+
20+
assert sorted(gid2i.values()) == [0, 1]
21+
22+
assert "force_covariances" in out
23+
assert "torque_covariances" in out
24+
assert "frame_counts" in out
25+
assert "forcetorque_covariances" in out
26+
assert "forcetorque_counts" in out
1527

16-
assert shared["force_covariances"]["res"] == [None, None]
17-
assert shared["torque_covariances"]["poly"] == [None, None]
28+
assert "force_torque_stats" in out
29+
assert "force_torque_counts" in out
30+
31+
assert out["force_torque_stats"] is not out["forcetorque_covariances"]
32+
assert out["force_torque_counts"] is not out["forcetorque_counts"]
33+
34+
35+
def test_init_covariance_accumulators_is_fully_deterministic():
36+
node = InitCovarianceAccumulatorsNode()
1837

19-
assert np.all(shared["frame_counts"]["res"] == np.array([0, 0]))
20-
assert np.all(shared["forcetorque_counts"]["poly"] == np.array([0, 0]))
38+
shared1 = {"groups": {9: [1, 2], 2: [3]}}
39+
shared2 = {"groups": {2: [3], 9: [1, 2]}}
40+
41+
out1 = node.run(shared1.copy())
42+
out2 = node.run(shared2.copy())
43+
44+
assert out1.keys() == out2.keys()
45+
46+
for key in out1:
47+
if isinstance(out1[key], dict):
48+
assert out1[key].keys() == out2[key].keys()
49+
50+
51+
def test_init_covariance_accumulators_no_aliasing():
52+
node = InitCovarianceAccumulatorsNode()
53+
54+
shared = {"groups": {1: [1]}}
55+
out = node.run(shared)
2156

22-
assert shared["force_torque_stats"] is shared["forcetorque_covariances"]
23-
assert shared["force_torque_counts"] is shared["forcetorque_counts"]
57+
assert out["force_torque_stats"] is not out["forcetorque_covariances"]
58+
assert out["force_torque_counts"] is not out["forcetorque_counts"]

0 commit comments

Comments
 (0)