Skip to content

Commit 6039cec

Browse files
authored
Add decorator for sympy monkey patching (#3049)
Add decorator for sympy monkey patching and cover additional functions where this may be relevant.
1 parent 72a7a92 commit 6039cec

3 files changed

Lines changed: 27 additions & 9 deletions

File tree

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@
6464
toposort_symbols,
6565
)
6666
from amici.logging import get_logger, log_execution_time, set_log_level
67-
from amici.sympy_utils import smart_is_zero_matrix, smart_multiply
67+
from amici.sympy_utils import (
68+
_monkeypatch_sympy,
69+
smart_is_zero_matrix,
70+
smart_multiply,
71+
)
6872

6973
SymbolicFormula = dict[sp.Symbol, sp.Expr]
7074

@@ -537,6 +541,7 @@ def sbml2jax(
537541
)
538542
exporter.generate_model_code()
539543

544+
@_monkeypatch_sympy
540545
def _build_ode_model(
541546
self,
542547
fixed_parameters: Iterable[str] = None,

python/sdist/amici/jax/ode_export.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from amici.jax.nn import generate_equinox
3030
from amici.logging import get_logger, log_execution_time, set_log_level
3131
from amici.sympy_utils import (
32-
_custom_pow_eval_derivative,
33-
_monkeypatched,
32+
_monkeypatch_sympy,
3433
)
3534

3635
#: python log manager
@@ -168,17 +167,15 @@ def __init__(
168167

169168
self._code_printer = AmiciJaxCodePrinter()
170169

170+
@_monkeypatch_sympy
171171
@log_execution_time("generating jax code", logger)
172172
def generate_model_code(self) -> None:
173173
"""
174174
Generates the jax code for the loaded model
175175
"""
176-
with _monkeypatched(
177-
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
178-
):
179-
self._prepare_model_folder()
180-
self._generate_jax_code()
181-
self._generate_nn_code()
176+
self._prepare_model_folder()
177+
self._generate_jax_code()
178+
self._generate_nn_code()
182179

183180
def _prepare_model_folder(self) -> None:
184181
"""

python/sdist/amici/sympy_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
from collections.abc import Callable
7+
from functools import wraps
78
from itertools import starmap
89
from typing import Any
910

@@ -62,6 +63,21 @@ def _monkeypatched(obj: object, name: str, patch: Any):
6263
setattr(obj, name, pre_patched_value)
6364

6465

66+
def _monkeypatch_sympy(func):
67+
"""
68+
Decorator that temporarily monkeypatches sympy.Pow._eval_derivative.
69+
"""
70+
71+
@wraps(func)
72+
def wrapper(*args, **kwargs):
73+
with _monkeypatched(
74+
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
75+
):
76+
return func(*args, **kwargs)
77+
78+
return wrapper
79+
80+
6581
@log_execution_time("running smart_jacobian", logger)
6682
def smart_jacobian(
6783
eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix

0 commit comments

Comments
 (0)