3939"""
4040
4141import logging
42- from typing import Any
42+ from typing import TYPE_CHECKING
4343
4444import numpy as np
4545
4949from sumpy .tools import add_mi , add_to_sac
5050
5151
52+ if TYPE_CHECKING :
53+ import sympy as sp
54+
5255logger = logging .getLogger (__name__ )
5356
5457
@@ -341,7 +344,7 @@ def diff(self, mi, q=0):
341344
342345# {{{ DifferentiatedExprDerivativeTaker
343346
344- DerivativeCoeffDict = dict [tuple [int , ...], Any ]
347+ DerivativeCoeffDict = dict [tuple [int , ...], int | float | complex | sym . Expr ]
345348
346349
347350@tag_dataclass
@@ -387,8 +390,10 @@ def diff(self, mi, save_intermediate=lambda x: x):
387390
388391# {{{ Helper functions
389392
390- def diff_derivative_coeff_dict (derivative_coeff_dict : DerivativeCoeffDict ,
391- variable_idx , variables ):
393+ def diff_derivative_coeff_dict (
394+ derivative_coeff_dict : DerivativeCoeffDict ,
395+ variable_idx : int ,
396+ variables : sp .Matrix ) -> DerivativeCoeffDict :
392397 """Differentiate a derivative transformation dictionary given by
393398 *derivative_coeff_dict* using the variable given by **variable_idx**
394399 and return a new derivative transformation dictionary.
@@ -397,15 +402,17 @@ def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict,
397402 new_derivative_coeff_dict : DerivativeCoeffDict = defaultdict (lambda : 0 )
398403
399404 for mi , coeff in derivative_coeff_dict .items ():
400- # In the case where we have x * u.diff(x), the result should
401- # be x.diff(x) + x * u.diff(x, x)
405+ # In the case where we have x * u.diff(x), the result should be
406+ # x.diff(x) + x * u.diff(x, x)
402407 # Calculate the first term by differentiating the coefficients
403408 new_coeff = sym .sympify (coeff ).diff (variables [variable_idx ])
404409 new_derivative_coeff_dict [mi ] += new_coeff
410+
405411 # Next calculate the second term by differentiating the derivatives
406412 new_mi = list (mi )
407413 new_mi [variable_idx ] += 1
408414 new_derivative_coeff_dict [tuple (new_mi )] += coeff
415+
409416 return {derivative : coeff for derivative , coeff in
410417 new_derivative_coeff_dict .items () if coeff != 0 }
411418
0 commit comments