Skip to content

Commit 277b1d7

Browse files
Hardcode84claude
andcommitted
Make ixsimpl context thread-local and differentiate conversion errors
Address PR iree-org#1130 review feedback: - Use threading.local() for ixsimpl context instead of a module-level global, since the C context is not thread-safe and GIL is optional since Python 3.13. - Extract _ixs_simplify_core that propagates conversion exceptions, so simplify() only falls back to the sympy path on actual conversion failures, not when ixsimpl processed the expression successfully but could not simplify it further. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent 720c495 commit 277b1d7

1 file changed

Lines changed: 44 additions & 15 deletions

File tree

wave_lang/kernel/wave/utils/symbol_utils.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import lru_cache
88
import math
99
import operator as op
10+
import threading
1011
from typing import Callable, Optional
1112

1213
import sympy
@@ -67,8 +68,39 @@
6768
# ixsimpl roundtrip helper.
6869
####################################################################
6970

70-
# Module-level context reused across calls to benefit from hash-consing.
71-
_ixs_ctx = ixsimpl.Context()
71+
# Thread-local context reused across calls to benefit from hash-consing.
72+
# The C context is not thread-safe and GIL is optional since Python 3.13,
73+
# so each thread gets its own instance.
74+
_ixs_local = threading.local()
75+
76+
77+
def _get_ixs_ctx() -> ixsimpl.Context:
78+
"""Return the thread-local ixsimpl context, creating it on first use."""
79+
try:
80+
return _ixs_local.ctx
81+
except AttributeError:
82+
_ixs_local.ctx = ixsimpl.Context()
83+
return _ixs_local.ctx
84+
85+
86+
def _ixs_simplify_core(
87+
ctx: ixsimpl.Context,
88+
expr: sympy.Expr,
89+
extra_assumptions: list[ixsimpl.Expr] | None = None,
90+
) -> sympy.Expr:
91+
"""Convert *expr* through ixsimpl and simplify.
92+
93+
Raises ``ValueError``/``TypeError``/``OverflowError`` on conversion
94+
failure so callers can distinguish "unsupported expression" from
95+
"simplified but unchanged".
96+
"""
97+
ixs_expr = _conv_from_sympy(ctx, expr)
98+
assumptions = _extract_assumptions(ctx, expr)
99+
if extra_assumptions:
100+
assumptions.extend(extra_assumptions)
101+
ixs_expr = ixs_expr.simplify(assumptions=assumptions)
102+
sym_map = {s.name: s for s in expr.free_symbols}
103+
return _conv_to_sympy(ixs_expr, symbols=sym_map, xor_fn=_wave_xor)
72104

73105

74106
def ixs_simplify(
@@ -86,13 +118,7 @@ def ixs_simplify(
86118
if not isinstance(expr, sympy.Basic) or expr.is_Atom:
87119
return expr
88120
try:
89-
ixs_expr = _conv_from_sympy(_ixs_ctx, expr)
90-
assumptions = _extract_assumptions(_ixs_ctx, expr)
91-
if extra_assumptions:
92-
assumptions.extend(extra_assumptions)
93-
ixs_expr = ixs_expr.simplify(assumptions=assumptions)
94-
sym_map = {s.name: s for s in expr.free_symbols}
95-
return _conv_to_sympy(ixs_expr, symbols=sym_map, xor_fn=_wave_xor)
121+
return _ixs_simplify_core(_get_ixs_ctx(), expr, extra_assumptions)
96122
except (ValueError, TypeError, OverflowError):
97123
return expr
98124

@@ -456,16 +482,19 @@ def transform_mod_div(expr):
456482
def simplify(expr: sympy.Expr) -> sympy.Expr:
457483
"""Simplify a sympy expression via ixsimpl roundtrip.
458484
459-
Delegates to ``ixs_simplify`` which handles bounds reasoning,
485+
Delegates to ``_ixs_simplify_core`` which handles bounds reasoning,
460486
floor/Mod rewrites, and rational cancellation natively.
461-
Falls back to a sympy expand + cancel loop on conversion errors.
487+
Falls back to a sympy expand + cancel loop only when the expression
488+
cannot be converted to ixsimpl (as opposed to being converted
489+
successfully but not simplified further).
462490
"""
463491
if not isinstance(expr, sympy.Basic):
464492
return expr
465-
result = ixs_simplify(expr)
466-
if result is not expr:
467-
return result
468-
# Fallback: ixs_simplify returned expr unchanged (conversion error).
493+
try:
494+
return _ixs_simplify_core(_get_ixs_ctx(), expr)
495+
except (ValueError, TypeError, OverflowError):
496+
pass
497+
# Fallback: conversion to ixsimpl failed.
469498
expr = sympy.expand(expr)
470499
for _ in range(5):
471500
new_expr = _bounds_simplify_once(expr)

0 commit comments

Comments
 (0)