Skip to content

Commit 769aca9

Browse files
fisherxueclaude
andcommitted
Refactor temporal reuse: fold fill/drain into regular actions with post-hoc correction
Replace the complex in-line temporal reuse detection (parent-named BuffetStats attributes, halo factor, bypass-zone walking) with a simpler two-phase approach: analyze_storage folds fill/drain into regular read/write actions, then _apply_temporal_reuse_corrections walks per-tensor mappings to find contiguous irrelevant temporal loops and divides out inflated stats as a post-processing step. - Remove ~200 lines from symbolic.py (parent attrs, halo logic, _has_temporal_reuse, _compute_overlap_fallback, partial_overlap_info) - Add _apply_temporal_reuse_corrections, _get_parent_buffet, and _compute_buffet_tile_shapes to sparse_adjustment.py - Simplify memory.py and run_model.py action counting - Improve energy.py KeyError message for missing actions - Update temporal reuse test to use structural reuse mapping - Add spatial fanout temporal reuse test and regression comparison tool - Update sparseloop reproduction notebook outputs and regression reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1e958e0 commit 769aca9

18 files changed

Lines changed: 9768 additions & 2601 deletions

accelforge/model/_looptree/energy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,11 @@ def compute_energy_from_actions(
202202
component_obj = components[key.level]
203203
try:
204204
energy_per_ac = component_obj.actions[key.action].energy
205-
except (KeyError, TypeError):
206-
energy_per_ac = 0
205+
except KeyError as e:
206+
raise KeyError(
207+
f"Action {key.action} not found in component {key.level}. Action occurred "
208+
f"{counts.total} times."
209+
) from None
207210
energy_result[key] = counts.total * energy_per_ac
208211

209212
for component_obj in spec.arch.get_nodes_of_type(arch.Component):

accelforge/model/_looptree/latency/memory.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,37 +66,25 @@ def component_latency(
6666
actions[f"{action.name}_actions"] += 0
6767

6868
if isinstance(name2component[component], TensorHolder):
69-
read_actions_val = (
70-
buffet_stats.max_per_unit_read_actions
71-
+ buffet_stats.max_per_parent_drain_read_actions
72-
)
69+
# On main, max_per_unit_read_actions already includes drain reads
70+
# (folded in by analyze_storage).
71+
read_actions_val = buffet_stats.max_per_unit_read_actions
7372
actions["read_actions"] += read_actions_val
7473
per_tensor_reads[component][buffet.tensor] += read_actions_val
75-
# Per-unit computation-path reads only (no fill/drain).
76-
# Use for PE buffer BW where fills go through the parent's port.
74+
# Per-unit computation-path reads (on main, same as read_actions
75+
# since fill/drain are folded in).
7776
actions["pu_read_actions"] += buffet_stats.max_per_unit_read_actions
7877
# Total actions across all spatial instances (for BW throttling
7978
# of shared levels above spatial, e.g. shared_glb)
80-
total_read_actions_val = (
81-
buffet_stats.total_read_actions
82-
+ buffet_stats.total_parent_drain_read_actions
83-
)
84-
actions["total_read_actions"] += total_read_actions_val
79+
actions["total_read_actions"] += buffet_stats.total_read_actions
8580
if not isinstance(name2component[component], arch.Toll):
86-
write_actions_val = (
87-
buffet_stats.max_per_unit_write_actions
88-
+ buffet_stats.max_per_parent_fill_write_actions
89-
)
81+
write_actions_val = buffet_stats.max_per_unit_write_actions
9082
actions["write_actions"] += write_actions_val
9183
per_tensor_writes[component][buffet.tensor] += write_actions_val
9284
actions["pu_write_actions"] += (
9385
buffet_stats.max_per_unit_write_actions
9486
)
95-
total_write_actions_val = (
96-
buffet_stats.total_write_actions
97-
+ buffet_stats.total_parent_fill_write_actions
98-
)
99-
actions["total_write_actions"] += total_write_actions_val
87+
actions["total_write_actions"] += buffet_stats.total_write_actions
10088
elif isinstance(name2component[component], arch.Compute):
10189
pass
10290
else:

0 commit comments

Comments
 (0)