diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 4502e3d0b0..26c05feee0 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -531,6 +531,11 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr: # y >= x => not(x < y) => not(x - y < 0) => not r < 0 return sp.Heaviside(root) + # rewrite n-ary XOR to OR to be handled below: + trigger = trigger.replace(sp.Xor, _xor_to_or) + + # TODO: x == y + # or(x,y) = not(and(not(x),not(y)) if isinstance(trigger, sp.Or): return sp.Integer(1) - sp.Mul( @@ -540,8 +545,6 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr: ] ) - # TODO: x XOR y = (A & ~B) | (~A & B) - # TODO: x == y if isinstance(trigger, sp.And): return sp.Mul(*[_parse_heaviside_trigger(arg) for arg in trigger.args]) @@ -551,6 +554,25 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr: ) +def _xor_to_or(*args): + """ + Replace XOR by OR expression. + + ``xor(x, y, ...) = (x & ~y & ...) | (~x & y & ...) | ...``. + + to be used in ``trigger = trigger.replace(sp.Xor, _xor_to_or)``. + """ + res = sp.false + for i in range(len(args)): + res = sp.Or( + res, + sp.And( + *(arg if i == j else sp.Not(arg) for j, arg in enumerate(args)) + ), + ) + return res.simplify() + + def grouper( iterable: Iterable, n: int, fillvalue: Any = None ) -> Iterable[tuple[Any]]: diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 4ad578098e..011c0ecb14 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -52,6 +52,7 @@ _default_simplify, generate_flux_symbol, _parse_piecewise_to_heaviside, + _xor_to_or, ) from .logging import get_logger, log_execution_time, set_log_level from .sbml_utils import SBMLException @@ -2976,11 +2977,6 @@ def subs_locals(expr: sp.Basic) -> sp.Basic: f"Unsupported input: {var_or_math}, type: {type(var_or_math)}" ) - if expr.has(sp.Xor): - raise SBMLException( - "Xor is currently not supported as logical operation." - ) - try: _check_unsupported_functions_sbml(expr, expression_type=ele_name) except SBMLException: @@ -3218,6 +3214,9 @@ def _parse_event_trigger(trigger: sp.Expr) -> sp.Expr: # y >= x or y > x return root + # rewrite n-ary XOR to OR to be handled below: + trigger = trigger.replace(sp.Xor, _xor_to_or) + # or(x,y): any of {x,y} is > 0: sp.Max(x, y) if isinstance(trigger, sp.Or): return sp.Max(*[_parse_event_trigger(arg) for arg in trigger.args])