Skip to content

Commit 74e7bdc

Browse files
committed
feat(typing): add annotations to diff_derivative_coeff_dict
1 parent 4692762 commit 74e7bdc

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

sumpy/derivative_taker.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"""
4040

4141
import logging
42-
from typing import Any
42+
from typing import TYPE_CHECKING
4343

4444
import numpy as np
4545

@@ -49,6 +49,9 @@
4949
from sumpy.tools import add_mi, add_to_sac
5050

5151

52+
if TYPE_CHECKING:
53+
import sympy as sp
54+
5255
logger = logging.getLogger(__name__)
5356

5457

@@ -341,7 +344,7 @@ def diff(self, mi, q=0):
341344

342345
# {{{ DifferentiatedExprDerivativeTaker
343346

344-
DerivativeCoeffDict = dict[tuple[int, ...], Any]
347+
DerivativeCoeffDict = dict[tuple[int, ...], int | float | complex | sym.Expr]
345348

346349

347350
@tag_dataclass
@@ -387,8 +390,10 @@ def diff(self, mi, save_intermediate=lambda x: x):
387390

388391
# {{{ Helper functions
389392

390-
def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict,
391-
variable_idx, variables):
393+
def diff_derivative_coeff_dict(
394+
derivative_coeff_dict: DerivativeCoeffDict,
395+
variable_idx: int,
396+
variables: sp.Matrix) -> DerivativeCoeffDict:
392397
"""Differentiate a derivative transformation dictionary given by
393398
*derivative_coeff_dict* using the variable given by **variable_idx**
394399
and return a new derivative transformation dictionary.
@@ -397,15 +402,17 @@ def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict,
397402
new_derivative_coeff_dict: DerivativeCoeffDict = defaultdict(lambda: 0)
398403

399404
for mi, coeff in derivative_coeff_dict.items():
400-
# In the case where we have x * u.diff(x), the result should
401-
# be x.diff(x) + x * u.diff(x, x)
405+
# In the case where we have x * u.diff(x), the result should be
406+
# x.diff(x) + x * u.diff(x, x)
402407
# Calculate the first term by differentiating the coefficients
403408
new_coeff = sym.sympify(coeff).diff(variables[variable_idx])
404409
new_derivative_coeff_dict[mi] += new_coeff
410+
405411
# Next calculate the second term by differentiating the derivatives
406412
new_mi = list(mi)
407413
new_mi[variable_idx] += 1
408414
new_derivative_coeff_dict[tuple(new_mi)] += coeff
415+
409416
return {derivative: coeff for derivative, coeff in
410417
new_derivative_coeff_dict.items() if coeff != 0}
411418

0 commit comments

Comments
 (0)