Skip to content

Add algebraic simplifications to get rid of 1-based index IR bloat#156

Merged
maleadt merged 4 commits into
mainfrom
tb/algebraic_simplfications
Mar 30, 2026
Merged

Add algebraic simplifications to get rid of 1-based index IR bloat#156
maleadt merged 4 commits into
mainfrom
tb/algebraic_simplfications

Conversation

@maleadt

@maleadt maleadt commented Mar 30, 2026

Copy link
Copy Markdown
Member

This PR reworks the rewriter so that we can implement algebraic simplifcations, with as goal to eliminate the IR bloat from 1-based indices (addi and subi everywhere). As an example, let's look at vadd from the README:

cuda_tile.module @kernels {
  entry @vadd(%arg0: tile<ptr<f32>>, %arg1: tile<i32>, %arg2: tile<i32>, %arg3: tile<ptr<f32>>, %arg4: tile<i32>, %arg5: tile<i32>, %arg6: tile<ptr<f32>>, %arg7: tile<i32>, %arg8: tile<i32>) {
    %cst_16_i64 = constant <i64: 16> : tile<i64>
    %assume = assume div_by<128>, %arg6 : tile<ptr<f32>>
    %assume_0 = assume bounded<0, ?>, %arg7 : tile<i32>
    %assume_assume = assume div_by<32>, %assume_0 : tile<i32>
    %tview = make_tensor_view %assume, shape = [%assume_assume], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %assume_1 = assume div_by<128>, %arg0 : tile<ptr<f32>>
    %assume_2 = assume bounded<0, ?>, %arg1 : tile<i32>
    %assume_assume_3 = assume div_by<32>, %assume_2 : tile<i32>
    %tview_4 = make_tensor_view %assume_1, shape = [%assume_assume_3], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %assume_5 = assume div_by<128>, %arg3 : tile<ptr<f32>>
    %assume_6 = assume bounded<0, ?>, %arg4 : tile<i32>
    %assume_assume_7 = assume div_by<32>, %assume_6 : tile<i32>
    %tview_8 = make_tensor_view %assume_5, shape = [%assume_assume_7], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %0 = make_token : token
    %blockId_x, %blockId_y, %blockId_z = get_tile_block_id : tile<i32>
    %cst_1_i32 = constant <i32: 1> : tile<i32>
    %1 = addi %blockId_x, %cst_1_i32 : tile<i32>
    %pview = make_partition_view %tview_4 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %cst_1_i32_9 = constant <i32: 1> : tile<i32>
    %2 = subi %1, %cst_1_i32_9 : tile<i32>
    %tile, %result_token = load_view_tko weak %pview[%2] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
    %pview_10 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %cst_1_i32_11 = constant <i32: 1> : tile<i32>
    %3 = subi %1, %cst_1_i32_11 : tile<i32>
    %tile_12, %result_token_13 = load_view_tko weak %pview_10[%3] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
    %4 = addf %tile, %tile_12  : tile<16xf32>
    %cst_1_i32_14 = constant <i32: 1> : tile<i32>
    %5 = subi %1, %cst_1_i32_14 : tile<i32>
    %pview_15 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %6 = store_view_tko weak %4, %pview_15[%5] token = %0 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
    return
  }
}

On #155, made it so that the subi happens later, and once per index, reducing the number of redundant operations:

--- /tmp/old    2026-03-30 18:15:00.159164458 +0200
+++ /tmp/normalize      2026-03-30 18:15:27.533894683 +0200
@@ -18,18 +18,12 @@
     %cst_1_i32 = constant <i32: 1> : tile<i32>
     %1 = addi %blockId_x, %cst_1_i32 : tile<i32>
     %pview = make_partition_view %tview_4 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %cst_1_i32_9 = constant <i32: 1> : tile<i32>
-    %2 = subi %1, %cst_1_i32_9 : tile<i32>
-    %tile, %result_token = load_view_tko weak %pview[%2] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
-    %pview_10 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %cst_1_i32_11 = constant <i32: 1> : tile<i32>
-    %3 = subi %1, %cst_1_i32_11 : tile<i32>
-    %tile_12, %result_token_13 = load_view_tko weak %pview_10[%3] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
-    %4 = addf %tile, %tile_12  : tile<16xf32>
-    %cst_1_i32_14 = constant <i32: 1> : tile<i32>
-    %5 = subi %1, %cst_1_i32_14 : tile<i32>
-    %pview_15 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %6 = store_view_tko weak %4, %pview_15[%5] token = %0 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
+    %tile, %result_token = load_view_tko weak %pview[%1] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
+    %pview_9 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
+    %tile_10, %result_token_11 = load_view_tko weak %pview_9[%1] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
+    %2 = addf %tile, %tile_10  : tile<16xf32>
+    %pview_12 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
+    %3 = store_view_tko weak %2, %pview_12[%1] token = %0 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
     return
   }
 }

However, the addi/subi pairs essentially remained, but essentially were just less numerous. Instead, this PR reworks the IR rewriter to support algebraic simplifications that allows us to eliminate the addi/subi pairs entirely, resulting in much cleaner IR:

@@ -15,21 +15,13 @@
     %tview_8 = make_tensor_view %assume_5, shape = [%assume_assume_7], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
     %0 = make_token : token
     %blockId_x, %blockId_y, %blockId_z = get_tile_block_id : tile<i32>
-    %cst_1_i32 = constant <i32: 1> : tile<i32>
-    %1 = addi %blockId_x, %cst_1_i32 : tile<i32>
     %pview = make_partition_view %tview_4 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %cst_1_i32_9 = constant <i32: 1> : tile<i32>
-    %2 = subi %1, %cst_1_i32_9 : tile<i32>
-    %tile, %result_token = load_view_tko weak %pview[%2] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
-    %pview_10 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %cst_1_i32_11 = constant <i32: 1> : tile<i32>
-    %3 = subi %1, %cst_1_i32_11 : tile<i32>
-    %tile_12, %result_token_13 = load_view_tko weak %pview_10[%3] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
-    %4 = addf %tile, %tile_12  : tile<16xf32>
-    %cst_1_i32_14 = constant <i32: 1> : tile<i32>
-    %5 = subi %1, %cst_1_i32_14 : tile<i32>
-    %pview_15 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
-    %6 = store_view_tko weak %4, %pview_15[%5] token = %0 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
+    %tile, %result_token = load_view_tko weak %pview[%blockId_x] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
+    %pview_9 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
+    %tile_10, %result_token_11 = load_view_tko weak %pview_9[%blockId_x] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
+    %1 = addf %tile, %tile_10  : tile<16xf32>
+    %pview_12 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
+    %2 = store_view_tko weak %1, %pview_12[%blockId_x] token = %0 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
     return
   }
 }

maleadt and others added 4 commits March 30, 2026 15:29
Adds an algebra_pass! with two rewrite rules that cancel inverse
addi/subi pairs (x+c-c → x, x-c+c → x). This eliminates the redundant
subi instructions generated by the 1-based to 0-based index conversion
in load/store operations (e.g. bid() + One() - One()).

Two supporting changes to the rewrite framework:

- DefEntry no longer caches a stale operands copy; pattern matching now
  reads live operands from the instruction via resolve_call, so bindings
  reflect updates from prior rewrites within the same pass.

- RBind consumed tracking only marks the root instruction, leaving shared
  intermediates matchable by subsequent rules (e.g. a single addi used
  by multiple subi sites).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the single-pass linear scan driver with a LIFO worklist that
processes until fixpoint, inspired by MLIR's GreedyPatternRewriteDriver.

Key changes:
- No stale MatchContext: DefEntry reads live operands from the IR via
  resolve_call. Use counts computed on-demand via uses() — always fresh.
- Worklist with notifications: when a rewrite fires, affected instructions
  (users, operand-producers) are re-added to the worklist, enabling
  cascading rewrites across rule sets.
- Unified rule set: all rewrite rules (normalize, algebra, SVE, FMA) run
  in a single fixpoint invocation instead of separate passes.
- Trivial dead-op elimination on worklist pop: keeps use counts accurate
  for one_use patterns (e.g. FMA fusion after SVE removes transparent
  op chains). Full DCE still runs after for complex dead code.
- Safe intermediate deletion: substitution rewrites only delete matched
  intermediates that have no remaining uses, fixing a bug where
  transparent-op tracing could add multi-use intermediates to matched_ssas.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the manual defs-scanning workaround in _add_users_to_worklist!
with the new users(block, val) API from IRStructurizer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The Worklist, DefEntry, defs dict, matched_ssas, and all notification/apply
functions now use SSAValue instead of Int, eliminating the same kind of
type confusion between IR references and literal integers that was fixed
in IRStructurizer's normalize_key.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@maleadt maleadt merged commit 2e0abf9 into main Mar 30, 2026
9 checks passed
@maleadt maleadt deleted the tb/algebraic_simplfications branch March 30, 2026 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant