Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4efd550
Add ixsimpl roundtrip helper and clean up debug prints
Hardcode84 Mar 15, 2026
452de1c
Replace sympy.simplify with ixs_simplify in graph_utils.py
Hardcode84 Mar 15, 2026
a8ed648
Replace sympy.simplify with ixs_simplify in index equality checks
Hardcode84 Mar 15, 2026
bfeb8c3
Replace sympy.simplify with ixs_simplify in preshuffle pattern detection
Hardcode84 Mar 15, 2026
750e381
Replace sympy.simplify with ixs_simplify in ensure_symbols_positive
Hardcode84 Mar 15, 2026
7855a38
Replace sympy.simplify with ixs_simplify for multicast mask in read_w…
Hardcode84 Mar 15, 2026
811a416
Replace sympy.simplify with ixs_simplify in water_emitter.py
Hardcode84 Mar 15, 2026
8eedd61
Replace sympy.simplify with ixs_simplify in multicast mask computation
Hardcode84 Mar 15, 2026
a3260e5
Replace custom simplify() wrapper with ixs_simplify primary path
Hardcode84 Mar 15, 2026
931965b
Update tests for ixsimpl canonical forms
Hardcode84 Mar 15, 2026
93f35bc
Add ixsimpl as git dependency pinned to specific commit
Hardcode84 Mar 19, 2026
1f7abb2
Make ixsimpl context thread-local and differentiate conversion errors
Hardcode84 Mar 19, 2026
e7d37d6
update sha
Hardcode84 Mar 19, 2026
4d41327
update sha
Hardcode84 Mar 20, 2026
190c20b
Fix nested _Rational in Pow codegen for negative exponents
Hardcode84 Mar 20, 2026
3828314
Add sympy-to-ixsimpl migration plan
Hardcode84 Mar 24, 2026
f754280
Update migration plan for new ixsimpl API (lambdify, has, eval, check)
Hardcode84 Mar 25, 2026
82fdcdd
update sha
Hardcode84 Mar 31, 2026
5d0dce1
update sha
Hardcode84 Mar 31, 2026
9317a4d
Remove ixsimpl git dep and update lit tests for new simplification forms
Hardcode84 Apr 7, 2026
6bc4fa1
Use ixsimpl simplify in annotate_iv_strides tests
Hardcode84 Apr 7, 2026
0c04e26
fix test
Hardcode84 Apr 7, 2026
c32eb4e
Add regression test for shifted Mod simplification gap
Hardcode84 Apr 9, 2026
024395c
Vendor newer ixsimpl to tighten Mod bounds
Hardcode84 Apr 9, 2026
7138f64
Update persistent gemm lit check for ixsimpl simplification
Hardcode84 Apr 9, 2026
099e6ac
pre-commit
Hardcode84 Apr 9, 2026
22a9ea1
Revert "Update persistent gemm lit check for ixsimpl simplification"
Hardcode84 Apr 9, 2026
c1f8d92
fix test
Hardcode84 Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
375 changes: 375 additions & 0 deletions docs/sympy-to-ixsimpl-migration.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/gather_to_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
print(scaled_gemm.asm)

# CHECK-LABEL: test_gather_to_shared_scaled_dims
# CHECK-DAG: #[[MAP_COL:.*]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)>
# CHECK-DAG: #[[MAP_COL:.*]] = affine_map<()[s0] -> (s0 floordiv 8 - (s0 floordiv 64) * 8)>
# CHECK-DAG: #[[MAP_ROW:.*]] = affine_map<()[s0] -> (s0 mod 8)>
# CHECK-DAG: #[[MAP_COL_SCALE:.*]] = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)>
# CHECK-DAG: #[[MAP_COL_SCALE:.*]] = affine_map<()[s0] -> (s0 floordiv 2 - (s0 floordiv 4) * 2)>
# CHECK-DAG: #[[MAP_ROW_SCALE:.*]] = affine_map<()[s0] -> (s0 mod 2)>
# CHECK-DAG: #[[MAP_ROW_SWIZZLED:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)>
# CHECK-DAG: #[[MAP_ROW_SWIZZLED_2:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)>
Expand Down
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:

# CHECK-LABEL: test_reordered_gemm
# CHECK-DAG: #[[MAP_LIN_A:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 65536 + s2 * 16384 + s3 * 8 + s4 * 32 + (s1 floordiv 4) * 32768 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768 - ((s0 * 8 + s1) floordiv 32) * 262144)>
# CHECK-DAG: #[[MAP_LIN_B:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 32768 + s2 * 16384 + s3 * 8 + s4 * 32 + ((s0 + s1 * 8) floordiv 32) * 131072 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768 - (s0 floordiv 4) * 131072)>
# CHECK-DAG: #[[MAP_LIN_B:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 32768 + s2 * 16384 + s3 * 8 + s4 * 32 + ((s0 + s1 * 8) floordiv 32) * 131072 - (s0 floordiv 4) * 131072 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768)>
# CHECK-DAG: affine.apply #[[MAP_LIN_A]]()[%block_id_y, %block_id_x, %thread_id_y, %thread_id_x, {{.*}}]
# CHECK-DAG: affine.apply #[[MAP_LIN_B]]()[%block_id_x, %block_id_y, %thread_id_y, %thread_id_x, {{.*}}]
# CHECK-DAG: vector.load {{.*}} : memref<{{.*}}xf16, strided<[1]>>, vector<8xf16>
Expand Down
13 changes: 4 additions & 9 deletions lit_tests/kernel/wave/mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,15 +635,10 @@ def mma(
### make DMA base
# CHECK: %[[DMA_BASE0:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW1]][{{.*}}]

# Cluster mask generation
# CHECK: %[[COND0:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index
# CHECK: %[[COND1:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index
# CHECK: %[[COND2:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index
# CHECK: %[[COND3:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index
# CHECK: %[[MASK1:.*]] = arith.select %[[COND3]], %{{.*}}, %[[C0]] : index
# CHECK: %[[MASK2:.*]] = arith.select %[[COND2]], %{{.*}}, %[[MASK1]] : index
# CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index
# CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index
# Cluster mask generation -- ixsimpl flattens the Piecewise into
# individual selects per condition rather than a cascaded chain.
# CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : index
# CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : index

# CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]] globalSize [%{{.*}}, %{{.*}}] globalStride [32, 1] sharedSize [%{{.*}}, %{{.*}}] padShared({{.*}}) workgroupMask %{{.*}}

Expand Down
2 changes: 1 addition & 1 deletion tests/kernel/wave_gemm_mxfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
else:
vgpr_count = 164
vgpr_spill_count = 0
sgpr_count = 60
sgpr_count = 61
sgpr_spill_count = 0
waitcounts = [
"s_waitcnt lgkmcnt(0)",
Expand Down
2 changes: 1 addition & 1 deletion tests/kernel/wave_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_divide_shape_into_chunks():
def test_custom_sympy_simplifications():
a = sympy.Symbol("a", integer=True, nonnegative=True)
mod_expr = (sympy.floor(a) * 4 + 3) % 16
assert str(simplify(mod_expr)) == "4*(Mod(a, 4)) + 3"
assert str(simplify(mod_expr)) == "Mod(4*a, 16) + 3"

floor_expr = sympy.floor(sympy.floor(a) / 3 + sympy.sympify(1) / 6)
assert str(simplify(floor_expr)) == "floor(a/3)"
Expand Down
7 changes: 7 additions & 0 deletions tests/unittests/symbol_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ def test_simplify_floor_of_bounded_fraction(simplify_fn):
assert simplify_fn(expr) == 0


def test_simplify_floor_of_shifted_bounded_fraction():
# floor(Mod(4*x + 3,16)/16) should resolve to 0.
x = _sym("x")
expr = sympy.floor(sympy.Mod(4 * x + 3, 16, evaluate=False) / 16)
assert simplify(expr) == 0


def test_simplify_mod_elimination():
# Mod(Mod(x,8), 16) -> Mod(x,8) since range [0,7] < 16.
x = _sym("x")
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/test_annotate_iv_strides.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_try_with_div_subs,
)
from wave_lang.kernel.wave.assumptions import Assumption, get_divisibility_subs
from wave_lang.kernel.wave.utils.symbol_utils import simplify


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -425,7 +426,7 @@ def test_base_preserves_original_structure(self):
assert result is not None
base, _, _ = result
expected_base = flat.subs({iv: sympy.Integer(0)})
assert sympy.simplify(base - expected_base) == 0
assert simplify(base - expected_base) == 0

def test_method_string_contains_divsubs(self):
"""The returned method string indicates div_subs was used."""
Expand Down Expand Up @@ -580,7 +581,7 @@ def test_new_offset_matches_original(self):
base, stride, _ = result
per_unit = stride // step
new_offset = base + iv * per_unit
assert sympy.simplify(new_offset - flat) == 0
assert simplify(new_offset - flat) == 0

def test_new_offset_matches_preshuffle(self):
"""Verify base + iv * per_unit equals original (preshuffle case)."""
Expand All @@ -591,7 +592,7 @@ def test_new_offset_matches_preshuffle(self):
base, stride, _ = result
per_unit = stride // step
new_offset = base + iv * per_unit
assert sympy.simplify(new_offset - flat) == 0
assert simplify(new_offset - flat) == 0

def test_b_scale_full_pipeline(self):
"""B-scale end-to-end: K=8192, step=2, per_unit = 256."""
Expand All @@ -605,4 +606,4 @@ def test_b_scale_full_pipeline(self):
assert method == "symbolic"
# Rewrite should match original.
new_offset = base + iv * per_unit
assert sympy.simplify(new_offset - flat) == 0
assert simplify(new_offset - flat) == 0
3 changes: 3 additions & 0 deletions third_party/ixsimpl/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,9 @@ This enables rules like:
- `floor(x/64)` where `0 <= x < 64` → `0`
- `floor(x)` → constant when `floor(lo) == floor(hi)` (same for ceiling)
- `Mod(x, m)` bounds tightened to dividend's bounds when `0 <= x < m`
- `Mod(x, m)` upper bound tightened to `m - gcd(d, m)` when `x` is
integer-valued and `d` is the gcd of its top-level integer coefficients
(e.g. `Mod(4*a, 16)` in `[0, 12]` instead of `[0, 15]`)
- `Max(1, expr)` where `expr >= 1` → `expr`

**Congruence-gated rewrites** (requires `Mod(sym, M) == R` assumption):
Expand Down
70 changes: 61 additions & 9 deletions third_party/ixsimpl/ixsimpl_amalg.c
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,34 @@ IXS_STATIC void ixs_bounds_add_assumption(ixs_bounds *b, ixs_node *a) {
}
}

/* Conservative positive divisor of expr's value when integer-valued:
* MUL — absolute coefficient; ADD — gcd of constant and all term
* coefficients; everything else — 1. Deliberately ignores nested
* structure; the result is a lower bound on the true step. */
static int64_t mod_dividend_step(ixs_node *expr) {
int64_t p, q, g;
uint32_t i;
switch (expr->tag) {
case IXS_MUL:
ixs_node_get_rat(expr->u.mul.coeff, &p, &q);
return (q == 1) ? ixs_gcd(p, 0) : 1;
case IXS_ADD:
ixs_node_get_rat(expr->u.add.coeff, &p, &q);
if (q != 1)
return 1;
g = ixs_gcd(p, 0);
for (i = 0; i < expr->u.add.nterms; i++) {
ixs_node_get_rat(expr->u.add.terms[i].coeff, &p, &q);
if (q != 1)
return 1;
g = ixs_gcd(g, p);
}
return (g > 0) ? g : 1;
default:
return 1;
}
}

static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) {
uint32_t i;
if (!expr)
Expand Down Expand Up @@ -571,15 +599,21 @@ static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) {
case IXS_MOD: {
/* Mod(x, m) in [0, m-1] only when x is integer-valued and m is a
* positive integer. For non-integer dividends the range is the
* half-open [0, m) which we cannot represent tightly. */
* half-open [0, m) which we cannot represent tightly.
*
* Tighter: if x is always a multiple of d, Mod(x, m) is a multiple
* of gcd(d, m), so the upper bound drops to m - gcd(d, m). */
ixs_node *m = expr->u.binary.rhs;
if (m->tag == IXS_INT && m->u.ival > 0) {
ixs_interval pi = ixs_bounds_get(b, expr->u.binary.lhs);
if (pi.valid && pi.lo_q == 1 && pi.hi_q == 1 && pi.lo_p >= 0 &&
pi.hi_p < m->u.ival)
return pi;
if (ixs_node_is_integer_valued(expr->u.binary.lhs))
return ixs_interval_range(0, 1, m->u.ival - 1, 1);
if (ixs_node_is_integer_valued(expr->u.binary.lhs)) {
int64_t step = mod_dividend_step(expr->u.binary.lhs);
int64_t g = ixs_gcd(step, m->u.ival);
return ixs_interval_range(0, 1, m->u.ival - g, 1);
}
}
return ixs_interval_unknown();
}
Expand Down Expand Up @@ -4428,32 +4462,50 @@ static ixs_node *recognize_mod(ixs_ctx *ctx, ixs_addterm *terms,
int64_t const_q) {
uint32_t i;
int rc1, rc2;
ixs_node *result;
ixs_addterm *snap = NULL;

/* Snapshot terms so we can roll back if a pass hits OOM after
* partially rewriting entries (NULLing matched floor/ceil terms
* and replacing their partners with Mod nodes). */
if (nterms > 0) {
snap =
ixs_arena_alloc(&ctx->scratch, nterms * sizeof(*snap), sizeof(void *));
if (!snap)
return NULL;
memcpy(snap, terms, nterms * sizeof(*terms));
}

rc1 = recognize_mod_const_div(ctx, terms, nterms);
if (rc1 < 0)
return NULL;
goto rollback;
rc2 = recognize_mod_sym_div(ctx, terms, nterms);
if (rc2 < 0)
return NULL;
goto rollback;
if (!rc1 && !rc2)
return NULL;

IXS_STAT_HIT(ctx);
ixs_node *result = make_const(ctx, const_p, const_q);
result = make_const(ctx, const_p, const_q);
if (!result)
return NULL;
goto rollback;
for (i = 0; i < nterms; i++) {
ixs_node *t;
if (!terms[i].term)
continue;
t = simp_mul(ctx, terms[i].coeff, terms[i].term);
if (!t)
return NULL;
goto rollback;
result = simp_add(ctx, result, t);
if (!result)
return NULL;
goto rollback;
}
return result;

rollback:
if (snap)
memcpy(terms, snap, nterms * sizeof(*terms));
return NULL;
}

/*
Expand Down
40 changes: 37 additions & 3 deletions third_party/ixsimpl/src/bounds.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,34 @@ IXS_STATIC void ixs_bounds_add_assumption(ixs_bounds *b, ixs_node *a) {
}
}

/* Conservative positive divisor of expr's value when integer-valued:
* MUL — absolute coefficient; ADD — gcd of constant and all term
* coefficients; everything else — 1. Deliberately ignores nested
* structure; the result is a lower bound on the true step. */
static int64_t mod_dividend_step(ixs_node *expr) {
int64_t p, q, g;
uint32_t i;
switch (expr->tag) {
case IXS_MUL:
ixs_node_get_rat(expr->u.mul.coeff, &p, &q);
return (q == 1) ? ixs_gcd(p, 0) : 1;
case IXS_ADD:
ixs_node_get_rat(expr->u.add.coeff, &p, &q);
if (q != 1)
return 1;
g = ixs_gcd(p, 0);
for (i = 0; i < expr->u.add.nterms; i++) {
ixs_node_get_rat(expr->u.add.terms[i].coeff, &p, &q);
if (q != 1)
return 1;
g = ixs_gcd(g, p);
}
return (g > 0) ? g : 1;
default:
return 1;
}
}

static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) {
uint32_t i;
if (!expr)
Expand Down Expand Up @@ -420,15 +448,21 @@ static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) {
case IXS_MOD: {
/* Mod(x, m) in [0, m-1] only when x is integer-valued and m is a
* positive integer. For non-integer dividends the range is the
* half-open [0, m) which we cannot represent tightly. */
* half-open [0, m) which we cannot represent tightly.
*
* Tighter: if x is always a multiple of d, Mod(x, m) is a multiple
* of gcd(d, m), so the upper bound drops to m - gcd(d, m). */
ixs_node *m = expr->u.binary.rhs;
if (m->tag == IXS_INT && m->u.ival > 0) {
ixs_interval pi = ixs_bounds_get(b, expr->u.binary.lhs);
if (pi.valid && pi.lo_q == 1 && pi.hi_q == 1 && pi.lo_p >= 0 &&
pi.hi_p < m->u.ival)
return pi;
if (ixs_node_is_integer_valued(expr->u.binary.lhs))
return ixs_interval_range(0, 1, m->u.ival - 1, 1);
if (ixs_node_is_integer_valued(expr->u.binary.lhs)) {
int64_t step = mod_dividend_step(expr->u.binary.lhs);
int64_t g = ixs_gcd(step, m->u.ival);
return ixs_interval_range(0, 1, m->u.ival - g, 1);
}
}
return ixs_interval_unknown();
}
Expand Down
30 changes: 24 additions & 6 deletions third_party/ixsimpl/src/simplify.c
Original file line number Diff line number Diff line change
Expand Up @@ -619,32 +619,50 @@ static ixs_node *recognize_mod(ixs_ctx *ctx, ixs_addterm *terms,
int64_t const_q) {
uint32_t i;
int rc1, rc2;
ixs_node *result;
ixs_addterm *snap = NULL;

/* Snapshot terms so we can roll back if a pass hits OOM after
* partially rewriting entries (NULLing matched floor/ceil terms
* and replacing their partners with Mod nodes). */
if (nterms > 0) {
snap =
ixs_arena_alloc(&ctx->scratch, nterms * sizeof(*snap), sizeof(void *));
if (!snap)
return NULL;
memcpy(snap, terms, nterms * sizeof(*terms));
}

rc1 = recognize_mod_const_div(ctx, terms, nterms);
if (rc1 < 0)
return NULL;
goto rollback;
rc2 = recognize_mod_sym_div(ctx, terms, nterms);
if (rc2 < 0)
return NULL;
goto rollback;
if (!rc1 && !rc2)
return NULL;

IXS_STAT_HIT(ctx);
ixs_node *result = make_const(ctx, const_p, const_q);
result = make_const(ctx, const_p, const_q);
if (!result)
return NULL;
goto rollback;
for (i = 0; i < nterms; i++) {
ixs_node *t;
if (!terms[i].term)
continue;
t = simp_mul(ctx, terms[i].coeff, terms[i].term);
if (!t)
return NULL;
goto rollback;
result = simp_add(ctx, result, t);
if (!result)
return NULL;
goto rollback;
}
return result;

rollback:
if (snap)
memcpy(terms, snap, nterms * sizeof(*terms));
return NULL;
}

/*
Expand Down
Loading
Loading