diff --git a/docs/sympy-to-ixsimpl-migration.md b/docs/sympy-to-ixsimpl-migration.md new file mode 100644 index 0000000000..c3867be7f0 --- /dev/null +++ b/docs/sympy-to-ixsimpl-migration.md @@ -0,0 +1,375 @@ +# Sympy to ixsimpl Migration Plan + +Status: **In Progress** (Phase 1 partially complete) + +## Current State + +**48 files** under `wave_lang/` import sympy directly. ixsimpl currently serves +as the simplification backend via roundtrip conversion in `symbol_utils.py`. +Everything else -- symbol creation, expression construction, type checking, +substitution, codegen, Piecewise logic, numeric probing -- still goes through +sympy despite ixsimpl having native support for most of these operations. + +The integration point is `wave_lang/kernel/wave/utils/symbol_utils.py`: +- `simplify(expr)` tries ixsimpl first, falls back to sympy expand/cancel loop. +- `ixs_simplify(expr)` does the roundtrip, returns original on conversion error. +- All 12 call sites go through these two functions. No code calls ixsimpl directly. + +## ixsimpl API Surface + +Source: `ixsimpl/_ixsimpl.pyi`, `ixsimpl/__init__.py`, `ixsimpl/sympy_conv.py`. + +### Expression types (tag-based dispatch) +`INT`, `RAT`, `SYM`, `ADD`, `MUL`, `FLOOR`, `CEIL`, `MOD`, `PIECEWISE`, `MAX`, +`MIN`, `XOR`, `CMP`, `AND`, `OR`, `NOT`, `TRUE`, `FALSE` + +### Expr methods +| Method | Purpose | +|--------|---------| +| `.tag` | Node type discriminator (replaces `isinstance(expr, sympy.X)`) | +| `.nchildren` / `.children` / `.child(i)` | Generic structural access (replaces `.args`) | +| `.sym_name` | Symbol name (replaces `str(sym)`) | +| `.rat_num` / `.rat_den` | Rational numerator/denominator (replaces `.p`/`.q`) | +| `.add_coeff` / `.add_nterms` / `.add_term(i)` / `.add_term_coeff(i)` | Add decomposition | +| `.mul_coeff` / `.mul_nfactors` / `.mul_factor_base(i)` / `.mul_factor_exp(i)` | Mul decomposition | +| `.pw_ncases` / `.pw_value(i)` / `.pw_cond(i)` | Piecewise branch access | +| `.cmp_op` | Comparison operator kind | +| `.subs(target, replacement)` / `.subs(mapping)` | Substitution (replaces `.subs()`/`.xreplace()`) | +| `.simplify(assumptions=)` | Simplification with assumption constraints | +| `.expand()` | Expression expansion | +| `.to_c()` | C code generation | +| `.free_symbols` | Free symbol set (cached property on `Expr` subclass) | +| `.has(sym)` | True if `sym` appears in the expression tree (replaces `sympy_expr.has()`) | +| `.eval(env)` | Evaluate with concrete values (replaces `expr.subs(dict)` + `int()`) | +| `.is_error` / `.is_parse_error` / `.is_domain_error` | Error detection | +| Arithmetic operators | `+`, `-`, `*`, `/`, `>=`, `>`, `<=`, `<`, `==`, `!=` | + +### Context methods +| Method | Purpose | +|--------|---------| +| `ctx.sym(name)` | Create symbol | +| `ctx.int_(val)` / `ctx.rat(p, q)` | Create constants | +| `ctx.true_()` / `ctx.false_()` | Boolean constants | +| `ctx.eq(a, b)` / `ctx.ne(a, b)` | Equality/inequality nodes | +| `ctx.parse(input)` | Parse expression from string | +| `ctx.check(expr, assumptions=)` | Entailment query: returns `True`/`False`/`None` (replaces some `sympy.solve` uses) | +| `ctx.simplify_batch(exprs, assumptions=)` | Batch simplification | +| `ctx.errors` / `ctx.clear_errors()` | Error reporting | +| `ctx.stats()` / `ctx.stats_reset()` | Performance stats | + +### Free functions +`floor`, `ceil`, `mod`, `max_`, `min_`, `abs_`, `xor_`, `and_`, `or_`, `not_`, +`pw`, `same_node`, `lambdify` + +### Sympy conversion layer (`ixsimpl.sympy_conv`) +`from_sympy(ctx, expr)`, `to_sympy(expr, symbols=, xor_fn=)`, +`extract_assumptions(ctx, expr)` + +Handles: all arithmetic, floor/ceil/mod, min/max, xor, piecewise, all +comparisons, and/or/not, true/false. Custom `xor` Function subclass matched by +name. `Abs` is the only notable missing conversion (emulable via piecewise). + +## Sympy Usage Categories + +### Already migrated to ixsimpl + +| Category | Notes | +|----------|-------| +| **Simplification** | `simplify()` and `ixs_simplify()` delegate to ixsimpl. Fallback to sympy `expand`/`cancel` only on conversion failure. | + +### Not yet migrated (gap analysis) + +| Category | Sympy API Used | ixsimpl Equivalent | Actual Gap | +|----------|---------------|-------------------|------------| +| **Symbol creation** | `sympy.Symbol(name, integer=True, nonneg=True)` | `ctx.sym(name)` + `extract_assumptions` | **Soft.** ixsimpl symbols carry no assumptions intrinsically; assumptions are passed separately to `simplify()`. Wave would need to track assumptions in a side table or always extract from sympy symbols during conversion. | +| **Expression construction** | `Add`, `Mul`, `Pow`, `Integer`, `Rational`, `Mod(evaluate=False)`, `floor`, `ceiling`, `Min`, `Max`, `Abs`, `Piecewise` | Arithmetic ops, `ctx.int_()`, `ctx.rat()`, `floor`, `ceil`, `mod`, `max_`, `min_`, `abs_`, `pw` | **None.** `Abs` via `abs_()` (piecewise under the hood). `evaluate=False` not needed (ixsimpl does not eagerly evaluate). `Pow` expanded to repeated multiplication (already done in `from_sympy`). | +| **Piecewise / conditionals** | `sympy.Piecewise((val, cond), ...)` | `pw(*branches)`, `.pw_ncases`, `.pw_value(i)`, `.pw_cond(i)` | **None.** Full piecewise support. The `piecewise_aware_subs()` workaround in indexing.py exists only because sympy's `.subs()` triggers expensive boolean simplification on Piecewise -- ixsimpl's `.subs()` does not have this problem. | +| **Structural inspection** | `isinstance(expr, sympy.X)`, `.is_number`, `.is_Atom`, `.args`, `.func` | `.tag` (tag-based dispatch), `.nchildren`, `.children`, `.child(i)`, `.has()`, specialized accessors | **None.** `isinstance` dispatch maps to `expr.tag == ixsimpl.ADD` etc. `.has()` is native. `.is_number` maps to `tag in (INT, RAT)`. Specialized accessors (`.add_nterms`, `.mul_nfactors`) give richer decomposition than sympy's flat `.args`. | +| **Substitution** | `.subs()`, `.xreplace()`, `.replace(pred, fn)` | `.subs(target, repl)`, `.subs(mapping)` | **Soft.** Direct substitution is supported. Missing: `.replace(pred, fn)` (predicate-based bottom-up rewrite). Used in `_custom_simplify_once` for 4 transform passes. Would need a Python-level tree walker. | +| **Free symbol inspection** | `.free_symbols`, `.has(sym)` | `.free_symbols` (cached property), `.has(sym)` | **None.** Both are native on the `Expr` class. | +| **Expression decomposition** | `.as_ordered_terms()`, `.as_numer_denom()` | `.add_nterms`/`.add_term(i)`/`.add_term_coeff(i)`, `.rat_num`/`.rat_den` | **Soft.** Add decomposition is richer in ixsimpl (coefficient + terms). Numer/denom split is only used for Rational nodes (`.rat_num`/`.rat_den`). | +| **Numeric probing** | `sympy.lambdify()` with custom modules | `ixsimpl.lambdify()`, `Expr.eval(env)` | **None.** `lambdify` is a near drop-in replacement (uses `.subs()` + constant folding). `Expr.eval(env)` handles single-point evaluation. No custom `modules` parameter needed -- ixsimpl handles floor/Mod/etc natively. | +| **Affine conversion** | Sympy -> MLIR AffineExpr pipeline | None (but `.tag`-based walk is possible) | **Soft.** The converter does a structural walk over the expression tree. This maps directly to ixsimpl's `.tag` + `.child(i)` API. The walk structure would be nearly identical. | +| **Constraint solving** | `sympy.solve()`, `sympy.Eq()` | `ctx.eq()`, `ctx.check(expr, assumptions=)` | **None.** The sole `sympy.solve()` call is `evaluate_with_assumptions()` in `general_utils.py`, which checks whether an inequality is entailed/contradicted by a set of constraints -- returning `True`/`False`/`None`. This maps directly to `ctx.check(expr, assumptions=)`. | +| **Codegen / printing** | `lambdastr()` for grid dim lambdas | `ixsimpl.lambdify()`, `.to_c()` | **None.** `ixsimpl.lambdify()` replaces the `lambdastr()` + `eval()` pattern directly. `.to_c()` available for C codegen if needed. | +| **Type system** | `sympy.Integer`, `sympy.Rational`, `.is_Integer`, `.is_Rational` | `tag == INT`, `tag == RAT`, `.rat_num`, `.rat_den` | **None.** Tag-based dispatch replaces isinstance checks. | + +### Summary of gaps + +**Hard gaps** (no ixsimpl equivalent): +1. `.replace(pred, fn)` -- predicate-based bottom-up rewriting (4 transforms in + `_custom_simplify_once`, but these exist only as fallback when ixsimpl + conversion fails, so they may become dead code as coverage improves). + +**Soft gaps** (ixsimpl has the feature, integration work needed): +2. Symbol assumptions -- tracked separately, not on the symbol node. +3. Affine converter -- structural walk needs porting from sympy isinstance to + ixsimpl tag dispatch. + +**No gap** (ready to use): +5. Piecewise construction and inspection (`pw`, `.pw_*` accessors). +6. Free symbol inspection (`.free_symbols`, `.has()`). +7. Substitution (`.subs()`; predicate-based `.replace` excluded). +8. Structural inspection (`.tag`, `.nchildren`, `.children`, `.child(i)`). +9. Expression decomposition (`.add_*`, `.mul_*`, `.rat_*` accessors). +10. Numeric probing and evaluation (`lambdify()`, `.eval()`). +11. Absolute value (`abs_()`). +12. Entailment queries (`ctx.check()` -- replaces `sympy.solve` in + `evaluate_with_assumptions`). +13. Codegen (`lambdify()`, `.to_c()`). + +## Migration Strategy + +ixsimpl covers every sympy feature used in Wave: symbol creation, expression +construction, structural inspection (`.tag` + accessors), substitution, free +symbol queries, piecewise, evaluation (`lambdify`, `.eval()`), entailment +checking (`ctx.check()`), expansion, and C codegen. The sole `sympy.solve()` +call is an entailment check that maps directly to `ctx.check()`. The only +remaining hard gap is `.replace(pred, fn)` (predicate-based rewriting), which +lives in the sympy fallback path and may become dead code. + +**Recommended end state: ixsimpl as the sole expression IR. Sympy removed +entirely as a runtime dependency.** + +### Incremental phases + +## Phase 1: Complete simplification migration (CURRENT) + +**Goal:** All simplification goes through ixsimpl. Zero calls to +`sympy.simplify()`, `sympy.expand()`, `sympy.cancel()` outside the fallback path +in `symbol_utils.simplify()`. + +**Status:** Mostly done. Remaining direct sympy simplification calls: + +| File | Call | Action | +|------|------|--------| +| `symbol_utils.py:498,502` | `sympy.expand()`, `sympy.cancel()` in fallback | Keep -- intentional fallback for unconvertible expressions. | +| `symbol_utils.py:303` | `sympy.cancel(t / divisor)` in `split_sum_by_divisibility` | Route through ixsimpl if possible, else keep. | +| `index_mapping_simplify.py:177` | `sympy.cancel(numer - mod_arg)` | Same -- targeted cancellation. | +| `read_write.py:118` | `sympy.expand(diff)` for non-Piecewise | Replace with `simplify(diff)`. | +| `schedule.py:622` | `expr.simplify()` (sympy native method) | Replace with `simplify(expr)` from symbol_utils. | + +**Work items:** +1. Replace `expr.simplify()` call in `schedule.py` with `simplify(expr)`. +2. Replace `sympy.expand(diff)` in `read_write.py:118` with `simplify(diff)`. +3. Audit: ensure no other file calls sympy simplification directly. +4. Track fallback frequency to identify remaining conversion gaps. + +## Phase 2: Centralize sympy imports behind Wave wrappers + +**Goal:** No production file imports sympy directly except foundation modules. + +**Allowed import sites:** +- `wave_lang/support/indexing.py` -- type aliases, symbol creation +- `wave_lang/kernel/wave/utils/symbol_utils.py` -- simplification, bounds, probing +- `wave_lang/kernel/wave/mlir_converter/attr_type_converter.py` -- affine conversion +- `wave_lang/kernel/wave/water_mlir/.../sympy_to_affine_converter.py` -- affine conversion +- `wave_lang/kernel/compiler/wave_codegen/emitter.py` -- codegen pattern matching + +**Strategy:** +Re-export needed sympy names from `indexing.py` and `symbol_utils.py`. +Downstream files import from Wave modules, not sympy. + +```python +# indexing.py additions +from sympy import ( + Integer, Rational, Mod, Piecewise, Eq, + floor, ceiling, Min, Max, +) +``` + +**Work items:** +1. Add re-exports to `indexing.py` for expression constructors. +2. Add re-exports to `symbol_utils.py` for analysis utilities. +3. File-by-file: replace `import sympy` with imports from Wave modules. +4. Add a ruff rule or pre-commit check to flag direct `import sympy`. + +**Risk:** Low. Mechanical refactor, no behavior change. + +## Phase 3: Dual-IR wrapper layer + +**Goal:** Introduce a thin `IndexExpr` wrapper that can hold either a sympy +expression or an ixsimpl `Expr`, exposing a unified API. This allows incremental +migration without big-bang rewrites. + +**Key insight:** ixsimpl's `.tag`-based dispatch maps cleanly to sympy's +`isinstance` dispatch. The wrapper translates between them: + +```python +# Sketch -- not final API +def expr_tag(expr) -> int: + """Unified tag for both sympy and ixsimpl expressions.""" + if isinstance(expr, ixsimpl.Expr): + return expr.tag + if isinstance(expr, sympy.Add): + return ixsimpl.ADD + if isinstance(expr, sympy.Mul): + return ixsimpl.MUL + ... + +def expr_free_symbols(expr) -> set: + """Free symbols from either IR.""" + return expr.free_symbols # both have this + +def expr_subs(expr, mapping: dict): + """Substitution on either IR.""" + if isinstance(expr, ixsimpl.Expr): + return expr.subs(mapping) + return piecewise_aware_subs(expr, mapping) +``` + +**Work items:** +1. Define wrapper functions in a new `expr_api.py` or extend `symbol_utils.py`. +2. Migrate callers one pass at a time (one PR per compiler pass). +3. Each pass can be tested independently. + +**Risk:** Medium. Need to ensure sympy/ixsimpl semantic equivalence at each call +site. The conversion layer (`sympy_conv.py`) already validates this for the +simplification path. + +## Phase 4: Port affine converter to ixsimpl + +**Goal:** `sympy_to_affine_converter.py` works directly on ixsimpl `Expr` nodes +instead of sympy expressions. + +This phase has **no ordering dependency** on Phases 3 or 5. The affine converter +is a self-contained structural tree walk -- it can be ported to ixsimpl tags +while the rest of the compiler still uses sympy. During transition, the converter +can accept ixsimpl `Expr` natively and fall back to `from_sympy()` conversion +for callers still passing sympy expressions. + +ixsimpl's `.tag` + accessor API maps 1:1 to the current isinstance dispatch: + +| Current (sympy) | New (ixsimpl) | +|-----------------|---------------| +| `isinstance(expr, sympy.Integer)` | `expr.tag == INT` | +| `isinstance(expr, sympy.Rational)` | `expr.tag == RAT` | +| `isinstance(expr, sympy.Symbol)` | `expr.tag == SYM` | +| `isinstance(expr, sympy.Add)` | `expr.tag == ADD` | +| `isinstance(expr, sympy.Mul)` | `expr.tag == MUL` | +| `isinstance(expr, sympy.floor)` | `expr.tag == FLOOR` | +| `isinstance(expr, sympy.Mod)` | `expr.tag == MOD` | +| `isinstance(expr, sympy.Piecewise)` | `expr.tag == PIECEWISE` | +| `expr.args[0]` | `expr.child(0)` | +| `int(expr)` | `int(expr)` | +| `expr.p, expr.q` | `expr.rat_num, expr.rat_den` | + +**Work items:** +1. Create `ixsimpl_to_affine_converter.py` alongside the existing converter. +2. Port case-by-case, sharing the `AffineFraction` infrastructure. +3. Add a `from_sympy()` shim at the entry point so callers passing sympy + expressions still work during transition. +4. Test both paths produce identical AffineExpr output. +5. Once validated, swap the default and deprecate the sympy path. + +**Risk:** Medium. The converter has subtle Rational/fraction handling that needs +careful porting. + +## Phase 5: Port emitter to tag-based dispatch + +**Goal:** `emitter.py` pattern-matches on ixsimpl tags instead of sympy types. + +The emitter currently does: +```python +match expr: + case sympy.Add(): ... + case sympy.Mul(): ... + case sympy.Mod(): ... + ... +``` + +This becomes: +```python +tag = expr.tag +if tag == ADD: + ... +elif tag == MUL: + ... +elif tag == MOD: + ... +``` + +With ixsimpl's specialized accessors, the emitter can also be cleaner -- e.g. +`.add_nterms` / `.add_term(i)` instead of iterating `.args` and guessing +structure. + +**Work items:** +1. Port emitter dispatch to ixsimpl tags. +2. Use `.add_*` / `.mul_*` / `.pw_*` accessors for structured decomposition. +3. Keep sympy conversion as fallback during transition. + +**Risk:** Medium. The emitter is well-tested via LIT tests. Changes should be +caught by FileCheck. + +## Phase 6: Remove sympy from hot paths + +**Goal:** Compiler passes operate on ixsimpl `Expr` natively. Sympy is removed +as a runtime dependency. + +**What this means:** +- `IndexExpr` type alias changes from `sympy.Expr` to `ixsimpl.Expr`. +- Symbol creation via `ctx.sym()` instead of `sympy.Symbol()`. +- Expression construction via ixsimpl operators and free functions. +- Substitution via `.subs()`. +- Numeric probing via `ixsimpl.lambdify()` and `Expr.eval()`. +- Entailment queries via `ctx.check()` (replaces `evaluate_with_assumptions`). +- No more `piecewise_aware_subs()` workaround. +- No more `evaluate=False` on Mod/floor (ixsimpl does not eagerly evaluate). +- No more sympy bug workarounds (#28744, floor/ceil evaluation bugs). + +**Risk:** High but contained. This is the "flip the switch" phase. Every +preceding phase must be complete and validated. + +## What NOT to do + +1. **Do not try to eliminate sympy in one shot.** The 6 phases exist for a + reason. Each phase is independently testable and revertible. + +2. **Do not add ixsimpl calls outside the centralized wrappers** (until Phase 6 + flips the default IR). All ixsimpl usage should go through `symbol_utils.py` + so fallback behavior is consistent. + +3. **Do not remove the sympy fallback path prematurely.** Until ixsimpl handles + 100% of expressions the compiler produces, the fallback is a safety net. + +4. **Do not port the emitter and affine converter simultaneously.** These are + independent subsystems; port one at a time with full test coverage between. + +## Dependency risks + +- **ixsimpl is pinned to a specific git SHA.** API changes require updating the + pin and potentially the conversion layer. +- **Thread safety:** ixsimpl context is not thread-safe (handled via + thread-local storage). Works fine for multiprocessing. For async, verify + context isolation. +- **sympy version sensitivity:** The codebase works around sympy bugs (#28744 Mod + auto-evaluation, floor/ceil evaluation on Max/Min arguments). Moving to ixsimpl + as primary IR eliminates these workarounds entirely. +- **ixsimpl `.subs()` semantics:** Need to verify it matches sympy's substitution + semantics exactly, particularly for Piecewise conditions and nested + replacements. The `piecewise_aware_subs` workaround exists because sympy's + `.subs()` triggers boolean simplification -- if ixsimpl's `.subs()` does not + have this problem, the workaround can be dropped. + +## Metrics to track + +- Number of files with direct `import sympy` (target: 5 -> 2 -> 0 non-test). +- Fallback rate: how often `simplify()` hits the sympy fallback (indicates + conversion coverage gaps). +- sympy-to-ixsimpl conversion errors by type (identifies which expression + patterns still need `from_sympy` support). + +## Priority order + +| Phase | Effort | Risk | Depends on | Impact | +|-------|--------|------|------------|--------| +| 1. Complete simplification migration | Low | Low | -- | Eliminates stray `sympy.simplify` calls | +| 2. Centralize imports | Low | Low | -- | Creates migration chokepoint | +| 3. Dual-IR wrapper layer | Medium | Medium | 2 | Enables incremental pass migration | +| 4. Port affine converter | Medium | Medium | -- | Removes sympy from MLIR lowering | +| 5. Port emitter | Medium | Medium | 3 | Removes sympy from codegen | +| 6. Remove sympy from hot paths | High | High | 3, 4, 5 | ixsimpl becomes primary IR | + +Phases 1, 2, and 4 can run in parallel. Phase 4 uses the existing sympy +roundtrip (`from_sympy`) as a shim, so it does not need to wait for the rest +of the compiler to migrate. diff --git a/lit_tests/kernel/wave/gather_to_shared.py b/lit_tests/kernel/wave/gather_to_shared.py index 2fc8658f3b..fadf292b28 100644 --- a/lit_tests/kernel/wave/gather_to_shared.py +++ b/lit_tests/kernel/wave/gather_to_shared.py @@ -329,9 +329,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(scaled_gemm.asm) # CHECK-LABEL: test_gather_to_shared_scaled_dims - # CHECK-DAG: #[[MAP_COL:.*]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + # CHECK-DAG: #[[MAP_COL:.*]] = affine_map<()[s0] -> (s0 floordiv 8 - (s0 floordiv 64) * 8)> # CHECK-DAG: #[[MAP_ROW:.*]] = affine_map<()[s0] -> (s0 mod 8)> - # CHECK-DAG: #[[MAP_COL_SCALE:.*]] = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + # CHECK-DAG: #[[MAP_COL_SCALE:.*]] = affine_map<()[s0] -> (s0 floordiv 2 - (s0 floordiv 4) * 2)> # CHECK-DAG: #[[MAP_ROW_SCALE:.*]] = affine_map<()[s0] -> (s0 mod 2)> # CHECK-DAG: #[[MAP_ROW_SWIZZLED:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> # CHECK-DAG: #[[MAP_ROW_SWIZZLED_2:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 96f386af65..e11b8221b9 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -280,7 +280,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: test_reordered_gemm # CHECK-DAG: #[[MAP_LIN_A:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 65536 + s2 * 16384 + s3 * 8 + s4 * 32 + (s1 floordiv 4) * 32768 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768 - ((s0 * 8 + s1) floordiv 32) * 262144)> - # CHECK-DAG: #[[MAP_LIN_B:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 32768 + s2 * 16384 + s3 * 8 + s4 * 32 + ((s0 + s1 * 8) floordiv 32) * 131072 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768 - (s0 floordiv 4) * 131072)> + # CHECK-DAG: #[[MAP_LIN_B:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 32768 + s2 * 16384 + s3 * 8 + s4 * 32 + ((s0 + s1 * 8) floordiv 32) * 131072 - (s0 floordiv 4) * 131072 + (s3 floordiv 4) * 480 - ((s2 * 32 + s3 floordiv 4) floordiv 64) * 32768)> # CHECK-DAG: affine.apply #[[MAP_LIN_A]]()[%block_id_y, %block_id_x, %thread_id_y, %thread_id_x, {{.*}}] # CHECK-DAG: affine.apply #[[MAP_LIN_B]]()[%block_id_x, %block_id_y, %thread_id_y, %thread_id_x, {{.*}}] # CHECK-DAG: vector.load {{.*}} : memref<{{.*}}xf16, strided<[1]>>, vector<8xf16> diff --git a/lit_tests/kernel/wave/mma.py b/lit_tests/kernel/wave/mma.py index f27477a528..6760092ac2 100644 --- a/lit_tests/kernel/wave/mma.py +++ b/lit_tests/kernel/wave/mma.py @@ -635,15 +635,10 @@ def mma( ### make DMA base # CHECK: %[[DMA_BASE0:.+]] = amdgpu.make_dma_base {{.*}}, %[[VIEW1]][{{.*}}] - # Cluster mask generation - # CHECK: %[[COND0:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index - # CHECK: %[[COND1:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index - # CHECK: %[[COND2:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index - # CHECK: %[[COND3:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : index - # CHECK: %[[MASK1:.*]] = arith.select %[[COND3]], %{{.*}}, %[[C0]] : index - # CHECK: %[[MASK2:.*]] = arith.select %[[COND2]], %{{.*}}, %[[MASK1]] : index - # CHECK: %[[MASK3:.*]] = arith.select %[[COND1]], %{{.*}}, %[[MASK2]] : index - # CHECK: %[[MASK4:.*]] = arith.select %[[COND0]], %{{.*}}, %[[MASK3]] : index + # Cluster mask generation -- ixsimpl flattens the Piecewise into + # individual selects per condition rather than a cascaded chain. + # CHECK: arith.cmpi eq, %{{.*}}, %{{.*}} : index + # CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : index # CHECK: %[[TENSOR_DESC_0:.*]] = amdgpu.make_dma_descriptor %[[DMA_BASE0:.+]] globalSize [%{{.*}}, %{{.*}}] globalStride [32, 1] sharedSize [%{{.*}}, %{{.*}}] padShared({{.*}}) workgroupMask %{{.*}} diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index dcff48b2ad..e71d7d717f 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -380,7 +380,7 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): else: vgpr_count = 164 vgpr_spill_count = 0 - sgpr_count = 60 + sgpr_count = 61 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", diff --git a/tests/kernel/wave_utils_test.py b/tests/kernel/wave_utils_test.py index 8078bdea7e..ccdad28412 100644 --- a/tests/kernel/wave_utils_test.py +++ b/tests/kernel/wave_utils_test.py @@ -109,7 +109,7 @@ def test_divide_shape_into_chunks(): def test_custom_sympy_simplifications(): a = sympy.Symbol("a", integer=True, nonnegative=True) mod_expr = (sympy.floor(a) * 4 + 3) % 16 - assert str(simplify(mod_expr)) == "4*(Mod(a, 4)) + 3" + assert str(simplify(mod_expr)) == "Mod(4*a, 16) + 3" floor_expr = sympy.floor(sympy.floor(a) / 3 + sympy.sympify(1) / 6) assert str(simplify(floor_expr)) == "floor(a/3)" diff --git a/tests/unittests/symbol_utils_test.py b/tests/unittests/symbol_utils_test.py index dc2d80b553..32914a7183 100644 --- a/tests/unittests/symbol_utils_test.py +++ b/tests/unittests/symbol_utils_test.py @@ -127,6 +127,13 @@ def test_simplify_floor_of_bounded_fraction(simplify_fn): assert simplify_fn(expr) == 0 +def test_simplify_floor_of_shifted_bounded_fraction(): + # floor(Mod(4*x + 3,16)/16) should resolve to 0. + x = _sym("x") + expr = sympy.floor(sympy.Mod(4 * x + 3, 16, evaluate=False) / 16) + assert simplify(expr) == 0 + + def test_simplify_mod_elimination(): # Mod(Mod(x,8), 16) -> Mod(x,8) since range [0,7] < 16. x = _sym("x") diff --git a/tests/unittests/test_annotate_iv_strides.py b/tests/unittests/test_annotate_iv_strides.py index f00c79e7ff..0dec84373b 100644 --- a/tests/unittests/test_annotate_iv_strides.py +++ b/tests/unittests/test_annotate_iv_strides.py @@ -36,6 +36,7 @@ _try_with_div_subs, ) from wave_lang.kernel.wave.assumptions import Assumption, get_divisibility_subs +from wave_lang.kernel.wave.utils.symbol_utils import simplify # --------------------------------------------------------------------------- @@ -425,7 +426,7 @@ def test_base_preserves_original_structure(self): assert result is not None base, _, _ = result expected_base = flat.subs({iv: sympy.Integer(0)}) - assert sympy.simplify(base - expected_base) == 0 + assert simplify(base - expected_base) == 0 def test_method_string_contains_divsubs(self): """The returned method string indicates div_subs was used.""" @@ -580,7 +581,7 @@ def test_new_offset_matches_original(self): base, stride, _ = result per_unit = stride // step new_offset = base + iv * per_unit - assert sympy.simplify(new_offset - flat) == 0 + assert simplify(new_offset - flat) == 0 def test_new_offset_matches_preshuffle(self): """Verify base + iv * per_unit equals original (preshuffle case).""" @@ -591,7 +592,7 @@ def test_new_offset_matches_preshuffle(self): base, stride, _ = result per_unit = stride // step new_offset = base + iv * per_unit - assert sympy.simplify(new_offset - flat) == 0 + assert simplify(new_offset - flat) == 0 def test_b_scale_full_pipeline(self): """B-scale end-to-end: K=8192, step=2, per_unit = 256.""" @@ -605,4 +606,4 @@ def test_b_scale_full_pipeline(self): assert method == "symbolic" # Rewrite should match original. new_offset = base + iv * per_unit - assert sympy.simplify(new_offset - flat) == 0 + assert simplify(new_offset - flat) == 0 diff --git a/third_party/ixsimpl/DESIGN.md b/third_party/ixsimpl/DESIGN.md index 39580155b3..e82577c437 100644 --- a/third_party/ixsimpl/DESIGN.md +++ b/third_party/ixsimpl/DESIGN.md @@ -1003,6 +1003,9 @@ This enables rules like: - `floor(x/64)` where `0 <= x < 64` → `0` - `floor(x)` → constant when `floor(lo) == floor(hi)` (same for ceiling) - `Mod(x, m)` bounds tightened to dividend's bounds when `0 <= x < m` +- `Mod(x, m)` upper bound tightened to `m - gcd(d, m)` when `x` is + integer-valued and `d` is the gcd of its top-level integer coefficients + (e.g. `Mod(4*a, 16)` in `[0, 12]` instead of `[0, 15]`) - `Max(1, expr)` where `expr >= 1` → `expr` **Congruence-gated rewrites** (requires `Mod(sym, M) == R` assumption): diff --git a/third_party/ixsimpl/ixsimpl_amalg.c b/third_party/ixsimpl/ixsimpl_amalg.c index 471cfbb808..460baa91f7 100644 --- a/third_party/ixsimpl/ixsimpl_amalg.c +++ b/third_party/ixsimpl/ixsimpl_amalg.c @@ -518,6 +518,34 @@ IXS_STATIC void ixs_bounds_add_assumption(ixs_bounds *b, ixs_node *a) { } } +/* Conservative positive divisor of expr's value when integer-valued: + * MUL — absolute coefficient; ADD — gcd of constant and all term + * coefficients; everything else — 1. Deliberately ignores nested + * structure; the result is a lower bound on the true step. */ +static int64_t mod_dividend_step(ixs_node *expr) { + int64_t p, q, g; + uint32_t i; + switch (expr->tag) { + case IXS_MUL: + ixs_node_get_rat(expr->u.mul.coeff, &p, &q); + return (q == 1) ? ixs_gcd(p, 0) : 1; + case IXS_ADD: + ixs_node_get_rat(expr->u.add.coeff, &p, &q); + if (q != 1) + return 1; + g = ixs_gcd(p, 0); + for (i = 0; i < expr->u.add.nterms; i++) { + ixs_node_get_rat(expr->u.add.terms[i].coeff, &p, &q); + if (q != 1) + return 1; + g = ixs_gcd(g, p); + } + return (g > 0) ? g : 1; + default: + return 1; + } +} + static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) { uint32_t i; if (!expr) @@ -571,15 +599,21 @@ static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) { case IXS_MOD: { /* Mod(x, m) in [0, m-1] only when x is integer-valued and m is a * positive integer. For non-integer dividends the range is the - * half-open [0, m) which we cannot represent tightly. */ + * half-open [0, m) which we cannot represent tightly. + * + * Tighter: if x is always a multiple of d, Mod(x, m) is a multiple + * of gcd(d, m), so the upper bound drops to m - gcd(d, m). */ ixs_node *m = expr->u.binary.rhs; if (m->tag == IXS_INT && m->u.ival > 0) { ixs_interval pi = ixs_bounds_get(b, expr->u.binary.lhs); if (pi.valid && pi.lo_q == 1 && pi.hi_q == 1 && pi.lo_p >= 0 && pi.hi_p < m->u.ival) return pi; - if (ixs_node_is_integer_valued(expr->u.binary.lhs)) - return ixs_interval_range(0, 1, m->u.ival - 1, 1); + if (ixs_node_is_integer_valued(expr->u.binary.lhs)) { + int64_t step = mod_dividend_step(expr->u.binary.lhs); + int64_t g = ixs_gcd(step, m->u.ival); + return ixs_interval_range(0, 1, m->u.ival - g, 1); + } } return ixs_interval_unknown(); } @@ -4428,32 +4462,50 @@ static ixs_node *recognize_mod(ixs_ctx *ctx, ixs_addterm *terms, int64_t const_q) { uint32_t i; int rc1, rc2; + ixs_node *result; + ixs_addterm *snap = NULL; + + /* Snapshot terms so we can roll back if a pass hits OOM after + * partially rewriting entries (NULLing matched floor/ceil terms + * and replacing their partners with Mod nodes). */ + if (nterms > 0) { + snap = + ixs_arena_alloc(&ctx->scratch, nterms * sizeof(*snap), sizeof(void *)); + if (!snap) + return NULL; + memcpy(snap, terms, nterms * sizeof(*terms)); + } rc1 = recognize_mod_const_div(ctx, terms, nterms); if (rc1 < 0) - return NULL; + goto rollback; rc2 = recognize_mod_sym_div(ctx, terms, nterms); if (rc2 < 0) - return NULL; + goto rollback; if (!rc1 && !rc2) return NULL; IXS_STAT_HIT(ctx); - ixs_node *result = make_const(ctx, const_p, const_q); + result = make_const(ctx, const_p, const_q); if (!result) - return NULL; + goto rollback; for (i = 0; i < nterms; i++) { ixs_node *t; if (!terms[i].term) continue; t = simp_mul(ctx, terms[i].coeff, terms[i].term); if (!t) - return NULL; + goto rollback; result = simp_add(ctx, result, t); if (!result) - return NULL; + goto rollback; } return result; + +rollback: + if (snap) + memcpy(terms, snap, nterms * sizeof(*terms)); + return NULL; } /* diff --git a/third_party/ixsimpl/src/bounds.c b/third_party/ixsimpl/src/bounds.c index 8c50d796d6..71b0789be3 100644 --- a/third_party/ixsimpl/src/bounds.c +++ b/third_party/ixsimpl/src/bounds.c @@ -367,6 +367,34 @@ IXS_STATIC void ixs_bounds_add_assumption(ixs_bounds *b, ixs_node *a) { } } +/* Conservative positive divisor of expr's value when integer-valued: + * MUL — absolute coefficient; ADD — gcd of constant and all term + * coefficients; everything else — 1. Deliberately ignores nested + * structure; the result is a lower bound on the true step. */ +static int64_t mod_dividend_step(ixs_node *expr) { + int64_t p, q, g; + uint32_t i; + switch (expr->tag) { + case IXS_MUL: + ixs_node_get_rat(expr->u.mul.coeff, &p, &q); + return (q == 1) ? ixs_gcd(p, 0) : 1; + case IXS_ADD: + ixs_node_get_rat(expr->u.add.coeff, &p, &q); + if (q != 1) + return 1; + g = ixs_gcd(p, 0); + for (i = 0; i < expr->u.add.nterms; i++) { + ixs_node_get_rat(expr->u.add.terms[i].coeff, &p, &q); + if (q != 1) + return 1; + g = ixs_gcd(g, p); + } + return (g > 0) ? g : 1; + default: + return 1; + } +} + static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) { uint32_t i; if (!expr) @@ -420,15 +448,21 @@ static ixs_interval bounds_get_propagated(ixs_bounds *b, ixs_node *expr) { case IXS_MOD: { /* Mod(x, m) in [0, m-1] only when x is integer-valued and m is a * positive integer. For non-integer dividends the range is the - * half-open [0, m) which we cannot represent tightly. */ + * half-open [0, m) which we cannot represent tightly. + * + * Tighter: if x is always a multiple of d, Mod(x, m) is a multiple + * of gcd(d, m), so the upper bound drops to m - gcd(d, m). */ ixs_node *m = expr->u.binary.rhs; if (m->tag == IXS_INT && m->u.ival > 0) { ixs_interval pi = ixs_bounds_get(b, expr->u.binary.lhs); if (pi.valid && pi.lo_q == 1 && pi.hi_q == 1 && pi.lo_p >= 0 && pi.hi_p < m->u.ival) return pi; - if (ixs_node_is_integer_valued(expr->u.binary.lhs)) - return ixs_interval_range(0, 1, m->u.ival - 1, 1); + if (ixs_node_is_integer_valued(expr->u.binary.lhs)) { + int64_t step = mod_dividend_step(expr->u.binary.lhs); + int64_t g = ixs_gcd(step, m->u.ival); + return ixs_interval_range(0, 1, m->u.ival - g, 1); + } } return ixs_interval_unknown(); } diff --git a/third_party/ixsimpl/src/simplify.c b/third_party/ixsimpl/src/simplify.c index 1157cfc4ee..16c4262bfd 100644 --- a/third_party/ixsimpl/src/simplify.c +++ b/third_party/ixsimpl/src/simplify.c @@ -619,32 +619,50 @@ static ixs_node *recognize_mod(ixs_ctx *ctx, ixs_addterm *terms, int64_t const_q) { uint32_t i; int rc1, rc2; + ixs_node *result; + ixs_addterm *snap = NULL; + + /* Snapshot terms so we can roll back if a pass hits OOM after + * partially rewriting entries (NULLing matched floor/ceil terms + * and replacing their partners with Mod nodes). */ + if (nterms > 0) { + snap = + ixs_arena_alloc(&ctx->scratch, nterms * sizeof(*snap), sizeof(void *)); + if (!snap) + return NULL; + memcpy(snap, terms, nterms * sizeof(*terms)); + } rc1 = recognize_mod_const_div(ctx, terms, nterms); if (rc1 < 0) - return NULL; + goto rollback; rc2 = recognize_mod_sym_div(ctx, terms, nterms); if (rc2 < 0) - return NULL; + goto rollback; if (!rc1 && !rc2) return NULL; IXS_STAT_HIT(ctx); - ixs_node *result = make_const(ctx, const_p, const_q); + result = make_const(ctx, const_p, const_q); if (!result) - return NULL; + goto rollback; for (i = 0; i < nterms; i++) { ixs_node *t; if (!terms[i].term) continue; t = simp_mul(ctx, terms[i].coeff, terms[i].term); if (!t) - return NULL; + goto rollback; result = simp_add(ctx, result, t); if (!result) - return NULL; + goto rollback; } return result; + +rollback: + if (snap) + memcpy(terms, snap, nterms * sizeof(*terms)); + return NULL; } /* diff --git a/third_party/ixsimpl/test/test_simplify.c b/third_party/ixsimpl/test/test_simplify.c index a00678259d..6d1bb8e86e 100644 --- a/third_party/ixsimpl/test/test_simplify.c +++ b/third_party/ixsimpl/test/test_simplify.c @@ -2544,6 +2544,76 @@ static void test_cmp_bounds_resolve(void) { CHECK(r != ixs_true(ctx) && r != ixs_false(ctx)); } +/* Mod bounds tightening via dividend step: upper bound is K - gcd(d, K) + * where d is the gcd of top-level integer coefficients (MUL coeff or + * ADD constant + term coefficients). */ +static void test_mod_scaled_bounds(void) { + ixs_ctx *ctx = get_ctx(); + ixs_node *a = ixs_sym(ctx, "a"); + + /* floor(Mod(4*a + 3, 16) / 16) -> 0. + * mod_extract_small_const rewrites to 3 + Mod(4*a, 16). With the + * tighter bound Mod(4*a, 16) <= 12, the argument of floor is in + * [3/16, 15/16], which floors to 0. */ + ixs_node *expr = ixs_floor( + ctx, ixs_div(ctx, + ixs_mod(ctx, + ixs_add(ctx, ixs_mul(ctx, ixs_int(ctx, 4), a), + ixs_int(ctx, 3)), + ixs_int(ctx, 16)), + ixs_int(ctx, 16))); + CHECK(ixs_simplify(ctx, expr, NULL, 0) == ixs_int(ctx, 0)); + + /* floor(Mod(8*a + 5, 16) / 16) -> 0. + * Mod(8*a, 16) <= 8, so (5 + Mod(8*a, 16))/16 <= 13/16 < 1. */ + expr = ixs_floor( + ctx, ixs_div(ctx, + ixs_mod(ctx, + ixs_add(ctx, ixs_mul(ctx, ixs_int(ctx, 8), a), + ixs_int(ctx, 5)), + ixs_int(ctx, 16)), + ixs_int(ctx, 16))); + CHECK(ixs_simplify(ctx, expr, NULL, 0) == ixs_int(ctx, 0)); + + /* Negative: floor((Mod(4*a, 16) + 13) / 16) stays as floor. + * Even with tight bound Mod(4*a, 16) <= 12, range [13,25]/16 spans + * two integers so the floor cannot collapse. */ + { + ixs_node *mod4a = + ixs_mod(ctx, ixs_mul(ctx, ixs_int(ctx, 4), a), ixs_int(ctx, 16)); + expr = ixs_floor(ctx, ixs_div(ctx, ixs_add(ctx, mod4a, ixs_int(ctx, 13)), + ixs_int(ctx, 16))); + ixs_node *r = ixs_simplify(ctx, expr, NULL, 0); + CHECK(ixs_node_tag(r) == IXS_FLOOR); + } + + /* Concrete modulus: Mod(16*a + 1, 128) -> 1 + 16*a with 0 <= a < 8. + * mod_extract_small_const splits to 1 + Mod(16*a, 128), then + * bounds [0, 112] < 128 collapse the Mod. */ + { + ixs_node *assumes[] = { + ixs_cmp(ctx, a, IXS_CMP_GE, ixs_int(ctx, 0)), + ixs_cmp(ctx, a, IXS_CMP_LT, ixs_int(ctx, 8)), + }; + ixs_node *e = ixs_mod( + ctx, ixs_add(ctx, ixs_mul(ctx, ixs_int(ctx, 16), a), ixs_int(ctx, 1)), + ixs_int(ctx, 128)); + ixs_node *r = ixs_simplify(ctx, e, assumes, 2); + CHECK(r == + ixs_add(ctx, ixs_int(ctx, 1), ixs_mul(ctx, ixs_int(ctx, 16), a))); + } + + /* ADD case: Mod(6*a + 4*b, 12) <= 10 (gcd(6,4)=2, gcd(2,12)=2). */ + { + ixs_node *b = ixs_sym(ctx, "b"); + ixs_node *inner = ixs_add(ctx, ixs_mul(ctx, ixs_int(ctx, 6), a), + ixs_mul(ctx, ixs_int(ctx, 4), b)); + expr = ixs_floor(ctx, ixs_div(ctx, ixs_mod(ctx, inner, ixs_int(ctx, 12)), + ixs_int(ctx, 12))); + CHECK(ixs_simplify(ctx, expr, NULL, 0) == ixs_int(ctx, 0)); + } +} + int main(void) { test_add_canonicalize(); test_mul_canonicalize(); @@ -2598,6 +2668,7 @@ int main(void) { test_cmp_const_fold(); test_cmp_identity(); test_cmp_bounds_resolve(); + test_mod_scaled_bounds(); printf("test_simplify: %d/%d passed\n", tests_passed, tests_run); return tests_passed == tests_run ? 0 : 1; diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 36cb2ce912..50ae12568e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1128,6 +1128,7 @@ def _get_const(val): operand = _get_ir_value(base) base = arith_d.muli(operand, operand) if power < 0: + base = _resolve_rational(base) stack.append(_Rational(_get_const(1), base)) else: stack.append(base) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 4b755d2123..a1e4d4604e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -58,7 +58,7 @@ ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs, simplify +from ...wave.utils.symbol_utils import ixs_simplify, safe_subs, simplify from ..base import ValidationError from ..builder import IRProxyValue from ..utils import ( @@ -1651,7 +1651,7 @@ def handle_tensor_load_to_lds(emitter: WaveEmitter, node: fx.Node): if local_multicast_mask := subs_idxc( safe_subs(multicast_mask, {INPUT_SELECTOR: i}) ): - local_multicast_mask = sympy.simplify(local_multicast_mask) + local_multicast_mask = ixs_simplify(local_multicast_mask) local_multicast_mask_val = gen_sympy_index(subs, local_multicast_mask) workgroup_mask = arith_d.index_cast(i16, local_multicast_mask_val) workgroup_mask = vector_d.from_elements(v1i16, [workgroup_mask]) diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index f1e03ddce4..ead0e86ef9 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -81,6 +81,7 @@ get_inputs, get_users, ) +from ..utils.symbol_utils import ixs_simplify from ..utils.mma_utils import ( get_mma_dimensional_mapping, ) @@ -292,9 +293,9 @@ def _check_index_difference_is_zero( """Check if two index sequences are equal, raise assertions if not.""" def f(seq1: IndexSequence, seq2: IndexSequence) -> bool: - start = sympy.simplify(seq1.start - seq2.start) - size = sympy.simplify(seq1.size - seq2.size) - stride = sympy.simplify(seq1.stride - seq2.stride) + start = ixs_simplify(seq1.start - seq2.start) + size = ixs_simplify(seq1.size - seq2.size) + stride = ixs_simplify(seq1.stride - seq2.stride) if start != 0: raise ValueError(f"Start difference: {start}") if size != 0: @@ -368,21 +369,19 @@ def ensure_symbols_positive( return { dim: IndexSequence( start=( - sympy.simplify( + ixs_simplify( seq.start.subs(symbol_remapping, simultaneous=True) ) if isinstance(seq.start, IndexExpr) else seq.start ), size=( - sympy.simplify( - seq.size.subs(symbol_remapping, simultaneous=True) - ) + ixs_simplify(seq.size.subs(symbol_remapping, simultaneous=True)) if isinstance(seq.size, IndexExpr) else seq.size ), stride=( - sympy.simplify( + ixs_simplify( seq.stride.subs(symbol_remapping, simultaneous=True) ) if isinstance(seq.stride, IndexExpr) diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index a9a3be267c..173f89206b 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -58,6 +58,7 @@ from wave_lang.kernel.wave.utils.general_utils import infer_dim from wave_lang.kernel.wave.utils.symbol_utils import ( collect_allowed_induction_symbols, + ixs_simplify, strip_out_of_scope_induction_symbols, ) @@ -348,7 +349,7 @@ def _convert_sympy_expr_to_affine_map( # Simplify the expression with the assumption that all symbols are positive. This allows for rewriting, for instance, # `Max(1, ceiling(x/2))` into `ceiling(x/2)`. expr = expr.subs(symbol_mapping) - expr = sympy.simplify(expr) + expr = ixs_simplify(expr) # Expressions with integer denominators (e.g. BLOCK_K/32, or compound # expressions like A + B*floor(T/64)/2) cannot be represented directly in diff --git a/wave_lang/kernel/wave/multicast.py b/wave_lang/kernel/wave/multicast.py index 73a6b9edee..8df50fc2f8 100644 --- a/wave_lang/kernel/wave/multicast.py +++ b/wave_lang/kernel/wave/multicast.py @@ -27,6 +27,7 @@ from .compile_options import WaveCompileOptions from .constraints import Constraint, WorkgroupConstraint from .utils.general_utils import get_hardware_constraint +from .utils.symbol_utils import ixs_simplify logger = logging.getLogger(__name__) @@ -87,7 +88,7 @@ def compute_multicast_mask( mask_expr += sympy.Piecewise((1 << bit_pos, condition), (0, True)) # Try to simplify - return sympy.simplify(mask_expr) + return ixs_simplify(mask_expr) @requires_region_format(RegionFormat.DIRECT_OUTER_REF) diff --git a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py index 372ff5fe40..44796ee0ac 100644 --- a/wave_lang/kernel/wave/preshuffle_scale_to_shared.py +++ b/wave_lang/kernel/wave/preshuffle_scale_to_shared.py @@ -59,7 +59,7 @@ remove_global_indexing, remove_thread_indexing, ) -from .utils.symbol_utils import subs_idxc +from .utils.symbol_utils import ixs_simplify, subs_idxc logger = get_logger("wave.preshuffle_scale_to_shared") @@ -95,8 +95,8 @@ def _is_floor_mod_pair(mod_expr, floor_expr): floor_num, floor_den = floor_parts mod_num, mod_den = mod_expr.args return ( - sympy.simplify(mod_num - floor_num) == 0 - and sympy.simplify(mod_den - floor_den) == 0 + ixs_simplify(mod_num - floor_num) == 0 + and ixs_simplify(mod_den - floor_den) == 0 ) exprs = list(mapping.input_mapping.values()) diff --git a/wave_lang/kernel/wave/utils/graph_utils.py b/wave_lang/kernel/wave/utils/graph_utils.py index a0f6cae79b..e5434817df 100644 --- a/wave_lang/kernel/wave/utils/graph_utils.py +++ b/wave_lang/kernel/wave/utils/graph_utils.py @@ -54,6 +54,7 @@ from .classes import Failure, Result, Success from .symbol_utils import ( collect_allowed_induction_symbols, + ixs_simplify, strip_out_of_scope_induction_symbols, subs_idxc, ) @@ -213,11 +214,11 @@ def _check_expr_equivalent( if isinstance(lhs, int) and isinstance(rhs, int): return Success() if lhs == rhs else Failure(f"int mismatch: {lhs} vs {rhs}") if isinstance(lhs, sympy.Basic) and isinstance(rhs, sympy.Basic): - if sympy.simplify(lhs - rhs) == 0: + if ixs_simplify(lhs - rhs) == 0: return Success() return Failure(f"symbolic expr mismatch: {lhs} vs {rhs}") if isinstance(lhs, (int, sympy.Basic)) and isinstance(rhs, (int, sympy.Basic)): - if sympy.simplify(sympy.sympify(lhs) - sympy.sympify(rhs)) == 0: + if ixs_simplify(sympy.sympify(lhs) - sympy.sympify(rhs)) == 0: return Success() return Failure(f"expr mismatch: {lhs} vs {rhs}") raise ValueError(f"Unsupported expression types: {type(lhs)} vs {type(rhs)}") @@ -276,9 +277,9 @@ def _sympy_equiv(a: sympy.Basic, b: sympy.Basic) -> bool: if a == b: return True if isinstance(a, sympy.Rel) and type(a) is type(b): - return sympy.simplify((a.lhs - a.rhs) - (b.lhs - b.rhs)) == 0 + return ixs_simplify((a.lhs - a.rhs) - (b.lhs - b.rhs)) == 0 try: - return sympy.simplify(a - b) == 0 + return ixs_simplify(a - b) == 0 except TypeError: return False diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 5d045baacf..f031fda618 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -7,9 +7,14 @@ from functools import lru_cache import math import operator as op +import threading from typing import Callable, Optional import sympy +import ixsimpl +from ixsimpl.sympy_conv import extract_assumptions as _extract_assumptions +from ixsimpl.sympy_conv import from_sympy as _conv_from_sympy +from ixsimpl.sympy_conv import to_sympy as _conv_to_sympy # Reexport symbols from indexing.py from ..._support.indexing import ( @@ -20,6 +25,7 @@ safe_subs, # noqa subs_idxc, # noqa is_literal, # noqa + xor as _wave_xor, ) @@ -58,6 +64,65 @@ } +#################################################################### +# ixsimpl roundtrip helper. +#################################################################### + +# Thread-local context reused across calls to benefit from hash-consing. +# The C context is not thread-safe and GIL is optional since Python 3.13, +# so each thread gets its own instance. +_ixs_local = threading.local() + + +def _get_ixs_ctx() -> ixsimpl.Context: + """Return the thread-local ixsimpl context, creating it on first use.""" + try: + return _ixs_local.ctx + except AttributeError: + _ixs_local.ctx = ixsimpl.Context() + return _ixs_local.ctx + + +def _ixs_simplify_core( + ctx: ixsimpl.Context, + expr: sympy.Expr, + extra_assumptions: list[ixsimpl.Expr] | None = None, +) -> sympy.Expr: + """Convert *expr* through ixsimpl and simplify. + + Raises ``ValueError``/``TypeError``/``OverflowError`` on conversion + failure so callers can distinguish "unsupported expression" from + "simplified but unchanged". + """ + ixs_expr = _conv_from_sympy(ctx, expr) + assumptions = _extract_assumptions(ctx, expr) + if extra_assumptions: + assumptions.extend(extra_assumptions) + ixs_expr = ixs_expr.simplify(assumptions=assumptions) + sym_map = {s.name: s for s in expr.free_symbols} + return _conv_to_sympy(ixs_expr, symbols=sym_map, xor_fn=_wave_xor) + + +def ixs_simplify( + expr: sympy.Expr, + extra_assumptions: list[ixsimpl.Expr] | None = None, +) -> sympy.Expr: + """Simplify a sympy expression via ixsimpl roundtrip. + + Converts *expr* to an ixsimpl node, simplifies with assumptions + derived from the sympy symbol properties (nonnegative, positive, + etc.) plus any *extra_assumptions*, then converts back to sympy. + + Falls back to the original *expr* on conversion errors. + """ + if not isinstance(expr, sympy.Basic) or expr.is_Atom: + return expr + try: + return _ixs_simplify_core(_get_ixs_ctx(), expr, extra_assumptions) + except (ValueError, TypeError, OverflowError): + return expr + + #################################################################### # Interval-arithmetic simplification for floor/Mod expressions. #################################################################### @@ -414,30 +479,22 @@ def transform_mod_div(expr): return expr -_simplify_cache: dict[sympy.Basic, sympy.Expr] = {} - - def simplify(expr: sympy.Expr) -> sympy.Expr: - """Simplify a sympy expression using interval arithmetic and cancel. + """Simplify a sympy expression via ixsimpl roundtrip. - Extends sympy.cancel with bounds-based reasoning that can resolve - floor/Mod sub-expressions (e.g. floor(Mod(x,16)/16) -> 0) that standard - sympy cannot handle, plus custom algebraic rewrites for Mod/floor - patterns. Iterates to a fixed point. - - Uses sympy.cancel instead of sympy.simplify because simplify tries 20+ - strategies (trigsimp, combsimp, hyperexpand, ...) that are irrelevant for - integer floor/Mod arithmetic and can take 0.5s+ per expression. - - The cache maps both ``src -> dst`` and ``dst -> dst`` so that calling - simplify on an already-simplified expression is a cache hit. + Delegates to ``_ixs_simplify_core`` which handles bounds reasoning, + floor/Mod rewrites, and rational cancellation natively. + Falls back to a sympy expand + cancel loop only when the expression + cannot be converted to ixsimpl (as opposed to being converted + successfully but not simplified further). """ if not isinstance(expr, sympy.Basic): return expr - if expr in _simplify_cache: - return _simplify_cache[expr] - orig = expr - # Cheap flatten before the heavier fixed-point loop. + try: + return _ixs_simplify_core(_get_ixs_ctx(), expr) + except (ValueError, TypeError, OverflowError): + pass + # Fallback: conversion to ixsimpl failed. expr = sympy.expand(expr) for _ in range(5): new_expr = _bounds_simplify_once(expr) @@ -446,8 +503,6 @@ def simplify(expr: sympy.Expr) -> sympy.Expr: if new_expr == expr: break expr = new_expr - _simplify_cache[orig] = expr - _simplify_cache[expr] = expr return expr