2525
2626import logging
2727from dataclasses import dataclass
28+ from functools import cached_property
2829from itertools import accumulate
2930from typing import TYPE_CHECKING , TypeAlias
3031
6263.. autoclass:: DerivativeIdentifier
6364.. autofunction:: make_identity_diff_op
6465.. autofunction:: as_scalar_pde
66+ .. autofunction:: to_fourier_matrix
6567"""
6668
6769
@@ -92,37 +94,43 @@ class DerivativeIdentifier:
9294class LinearPDESystemOperator :
9395 r"""
9496 Represents a constant-coefficient linear differential operator of a
95- vector-valued function with `dim` spatial variables.
97+ vector-valued function with :attr:`spatial_dim` spatial variables and
98+ additional temporal variables.
9699
97- It is represented by a tuple of immutable dictionaries. The dictionary maps a
98- :class:`DerivativeIdentifier` to the coefficient. Optionally supports a
99- time variable as the last variable in the multi-index of the
100- :class:`DerivativeIdentifier`.
100+ The operator is given by a tuple of immutable dictionaries. The dictionary
101+ maps a :class:`DerivativeIdentifier` to a (time- and space-independent)
102+ coefficient. In the :class:`DerivativeIdentifier`, each multi-index has
103+ :attr:`spatial_dim` indices for the spatial variables and the remaining
104+ ones represent temporal variables.
105+
106+ The class also supports basic arithmetic, i.e. multiplication and addition
107+ with other operators and constants.
101108
102- .. autoattribute:: dim
103109 .. autoattribute:: eqs
110+ .. autoattribute:: spatial_dim
111+ .. autoproperty:: total_dims
104112
105113 .. autoproperty:: order
106- .. autoproperty:: total_dims
114+ .. autoproperty:: nvariables
115+ .. autoproperty:: is_time_dependent
116+
107117 .. automethod:: to_sym
108118 """
109119
110- dim : int
120+ spatial_dim : int
121+ """The number of spatial dimensions of the PDE (use :attr:`total_dims`
122+ to include time).
123+ """
111124 eqs : tuple [Mapping [DerivativeIdentifier , sym .Expr ], ...]
125+ """A tuple of all the equations in the system."""
112126
113127 if __debug__ :
114128
115129 def __post_init__ (self ) -> None :
116130 # NOTE: this will raise a TypeError if it's not hashable
117131 _ = hash (self )
118132
119- @property
120- def order (self ) -> int :
121- deg = 0
122- for eq in self .eqs :
123- deg = max (deg , max (sum (ident .mi ) for ident in eq ))
124-
125- return deg
133+ # {{{ arithmetic
126134
127135 def __mul__ (self , other : Number | sym .Expr ) -> LinearPDESystemOperator :
128136 import numbers
@@ -137,7 +145,7 @@ def __mul__(self, other: Number | sym.Expr) -> LinearPDESystemOperator:
137145
138146 eqs .append (constantdict (deriv_ident_to_coeff ))
139147
140- return LinearPDESystemOperator (self .dim , tuple (eqs ))
148+ return LinearPDESystemOperator (self .spatial_dim , tuple (eqs ))
141149
142150 def __rmul__ (self , param : Number | sym .Expr ) -> LinearPDESystemOperator :
143151 return self .__mul__ (param )
@@ -146,7 +154,7 @@ def __add__(self, other: LinearPDESystemOperator) -> LinearPDESystemOperator:
146154 if not isinstance (other , LinearPDESystemOperator ):
147155 return NotImplemented
148156
149- assert self .dim == other .dim
157+ assert self .spatial_dim == other .spatial_dim
150158 assert len (self .eqs ) == len (other .eqs )
151159
152160 eqs : list [Mapping [DerivativeIdentifier , sym .Expr ]] = []
@@ -159,7 +167,7 @@ def __add__(self, other: LinearPDESystemOperator) -> LinearPDESystemOperator:
159167 res [k ] = v
160168 eqs .append (constantdict (res ))
161169
162- return LinearPDESystemOperator (self .dim , tuple (eqs ))
170+ return LinearPDESystemOperator (self .spatial_dim , tuple (eqs ))
163171
164172 def __radd__ (self , other : LinearPDESystemOperator ) -> LinearPDESystemOperator :
165173 return self .__add__ (other )
@@ -170,39 +178,79 @@ def __sub__(self, other: LinearPDESystemOperator) -> LinearPDESystemOperator:
170178 def __neg__ (self ) -> LinearPDESystemOperator :
171179 return (- 1 ) * self
172180
181+ # }}}
182+
173183 @override
174184 def __repr__ (self ) -> str :
175- return f"LinearPDESystemOperator({ self .dim } , { self .eqs !r} )"
185+ return f"LinearPDESystemOperator({ self .spatial_dim } , { self .eqs !r} )"
176186
177187 def __getitem__ (self , idx : int | slice ) -> LinearPDESystemOperator :
178188 item = self .eqs .__getitem__ (idx )
179189 eqs = item if isinstance (item , tuple ) else (item ,)
180- return LinearPDESystemOperator (self .dim , eqs )
190+ return LinearPDESystemOperator (self .spatial_dim , eqs )
181191
182192 @property
193+ def is_time_dependent (self ) -> bool :
194+ """Is *True* if the PDE operator has a time component."""
195+ return self .spatial_dim != self .total_dims
196+
197+ @cached_property
198+ def order (self ) -> int :
199+ """The order of the PDE operator (maximum order of all derivatives)."""
200+ deg = 0
201+ for eq in self .eqs :
202+ deg = max (deg , max (sum (ident .mi ) for ident in eq ))
203+
204+ return deg
205+
206+ @cached_property
183207 def total_dims (self ) -> int :
184- """
185- Returns the total number of dimensions including time
186- """
187- did = next (iter (self .eqs [0 ].keys ()))
208+ """The total number of dimensions (including time)."""
209+ did = next (iter (self .eqs [0 ]))
188210 return len (did .mi )
189211
190- def to_sym (self , fnames : Sequence [str ] | None = None ) -> list [sym .Expr ]:
191- x : list [sym .Expr ] = list (sym .make_sym_vector ("x" , self .dim ))
192- x .extend (sym .make_sym_vector ("t" , self .total_dims - self .dim ))
212+ @cached_property
213+ def nvariables (self ) -> int :
214+ """Number of variables in the system."""
215+ max_vec_idx = max ((did .vec_idx for eq in self .eqs for did in eq ), default = - 1 )
216+ return max_vec_idx + 1
217+
218+ def to_sym (
219+ self ,
220+ fnames : Sequence [str ] | None = None ,
221+ * ,
222+ x_var_name : str = "x" ,
223+ t_var_name : str = "t" ,
224+ ) -> list [sym .Expr ]:
225+ """Transform the system to a list of :mod:`sympy` expressions.
226+
227+ :arg fnames: the names of the variables in the system.
228+ (defaults to `["f0", "f1", ....]`)
229+ :arg x_var_name: the name of the spatial variables.
230+ (defaults to `["x0", "x1", ....]`)
231+ :arg t_var_name: the name of the temporal variables.
232+ """
233+ x : list [sym .Expr ] = list (sym .make_sym_vector (x_var_name , self .spatial_dim ))
234+ x .extend (sym .make_sym_vector (t_var_name , self .total_dims - self .spatial_dim ))
193235
194236 if fnames is None :
195237 noutputs = 0
196238 for eq in self .eqs :
197239 for deriv_ident in eq :
198240 noutputs = max (noutputs , deriv_ident .vec_idx )
241+
199242 fnames = [f"f{ i } " for i in range (noutputs + 1 )]
200243
201244 funcs = [sym .Function (fname )(* x ) for fname in fnames ]
245+ if len (funcs ) < self .nvariables :
246+ raise ValueError (
247+ f"'fnames' does not match system: { len (fnames )} names "
248+ f"(for a system of { self .nvariables } variables)"
249+ )
202250
203251 res : list [sym .Expr ] = []
204252 for eq in self .eqs :
205- sym_eq : sym .Expr = sym .sympify (0 )
253+ sym_eq : sym .Expr = sym .Integer (0 )
206254 for deriv_ident , coeff in eq .items ():
207255 expr = funcs [deriv_ident .vec_idx ]
208256 for i , val in enumerate (deriv_ident .mi ):
@@ -237,8 +285,8 @@ def _get_all_scalar_pdes(pde: LinearPDESystemOperator) -> list[LinearPDESystemOp
237285 import sympy as sp
238286 from sympy .polys .orderings import grevlex
239287
240- gens = [sp .Symbol (f"_x{ i } " ) for i in range (pde .dim )]
241- gens += [sp .Symbol (f"_t{ i } " ) for i in range (pde .total_dims - pde .dim )]
288+ gens = [sp .Symbol (f"_x{ i } " ) for i in range (pde .spatial_dim )]
289+ gens += [sp .Symbol (f"_t{ i } " ) for i in range (pde .total_dims - pde .spatial_dim )]
242290
243291 max_vec_idx = max (deriv_ident .vec_idx for eq in pde .eqs
244292 for deriv_ident in eq )
@@ -303,7 +351,8 @@ def intersect(a: SubModulePolyRing, b: SubModulePolyRing) -> SubModulePolyRing:
303351 for (mi , coeff ) in zip (scalar_pde .monoms (),
304352 scalar_pde .coeffs (), strict = True )
305353 }
306- results .append (LinearPDESystemOperator (pde .dim , (constantdict (pde_dict ),)))
354+ results .append (LinearPDESystemOperator (pde .spatial_dim ,
355+ (constantdict (pde_dict ),)))
307356
308357 return results
309358
@@ -382,8 +431,44 @@ def as_scalar_pde(
382431 return _get_all_scalar_pdes (pde )[comp_idx ]
383432
384433
434+ def to_fourier_matrix (
435+ pde : LinearPDESystemOperator ,
436+ ks : sym .Matrix ,
437+ ) -> sym .Matrix :
438+ r"""Return the Fourier (symbol) matrix of a constant-coefficient PDE system.
439+
440+ Each spatial derivative :math:`\partial / \partial x_j` is replaced by
441+ multiplication by :math:`i\,k_j`. The result is a
442+ :obj:`sympy.matrices.dense.Matrix` whose ``(row, col)`` entry is the
443+ polynomial in the frequency variables contributed by equation *row* acting
444+ on component *col*.
445+
446+ :returns: a matrix of size ``(len(pde.eqs), nvariables)``.
447+ """
448+ if pde .is_time_dependent :
449+ raise ValueError ("cannot compute Fourier symbol for time-dependent PDEs" )
450+
451+ ncols = pde .nvariables
452+
453+ mat = []
454+ for eq in pde .eqs :
455+ row = [sym .Integer (0 )] * ncols
456+
457+ for deriv_ident , coeff in eq .items ():
458+ factor : sym .Expr = sym .Integer (1 )
459+ for j , power in enumerate (deriv_ident .mi ):
460+ factor *= (sym .I * ks [j ]) ** power
461+
462+ row [deriv_ident .vec_idx ] += coeff * factor
463+
464+ assert len (row ) == ncols
465+ mat .append (row )
466+
467+ return sym .Matrix (mat )
468+
469+
385470def laplacian (diff_op : LinearPDESystemOperator ) -> LinearPDESystemOperator :
386- dim = diff_op .dim
471+ dim = diff_op .spatial_dim
387472 empty : tuple [Mapping [DerivativeIdentifier , sym .Expr ], ...] = (
388473 (constantdict (),) * len (diff_op .eqs ))
389474
@@ -409,17 +494,17 @@ def diff(
409494
410495 eqs .append (constantdict (res ))
411496
412- return LinearPDESystemOperator (diff_op .dim , tuple (eqs ))
497+ return LinearPDESystemOperator (diff_op .spatial_dim , tuple (eqs ))
413498
414499
415500def divergence (diff_op : LinearPDESystemOperator ) -> LinearPDESystemOperator :
416- if len (diff_op .eqs ) != diff_op .dim :
501+ if len (diff_op .eqs ) != diff_op .spatial_dim :
417502 raise ValueError (
418503 "number of equations does not match system dimension: "
419- f"got { len (diff_op .eqs )} equations for { diff_op .dim } d system" )
504+ f"got { len (diff_op .eqs )} equations for { diff_op .spatial_dim } d system" )
420505
421- res = LinearPDESystemOperator (diff_op .dim , (constantdict (),))
422- for i in range (diff_op .dim ):
506+ res = LinearPDESystemOperator (diff_op .spatial_dim , (constantdict (),))
507+ for i in range (diff_op .spatial_dim ):
423508 mi = [0 ]* diff_op .total_dims
424509 mi [i ] = 1
425510 res += diff (diff_op [i ], tuple (mi ))
@@ -433,7 +518,7 @@ def gradient(diff_op: LinearPDESystemOperator) -> LinearPDESystemOperator:
433518 f"can only take gradient of scalar system: got { len (diff_op .eqs )} d" )
434519
435520 eqs : list [Mapping [DerivativeIdentifier , sym .Expr ]] = []
436- dim = diff_op .dim
521+ dim = diff_op .spatial_dim
437522 for i in range (dim ):
438523 mi = [0 ]* diff_op .total_dims
439524 mi [i ] = 1
@@ -443,13 +528,13 @@ def gradient(diff_op: LinearPDESystemOperator) -> LinearPDESystemOperator:
443528
444529
445530def curl (diff_op : LinearPDESystemOperator ) -> LinearPDESystemOperator :
446- if len (diff_op .eqs ) != diff_op .dim :
531+ if len (diff_op .eqs ) != diff_op .spatial_dim :
447532 raise ValueError (
448533 "number of equations does not match system dimension: "
449- f"got { len (diff_op .eqs )} equations for { diff_op .dim } d system" )
534+ f"got { len (diff_op .eqs )} equations for { diff_op .spatial_dim } d system" )
450535
451- if diff_op .dim != 3 :
452- raise ValueError (f"can only take curl of 3d system: got { diff_op .dim } d" )
536+ if diff_op .spatial_dim != 3 :
537+ raise ValueError (f"can only take curl of 3d system: got { diff_op .spatial_dim } d" )
453538
454539 eqs : list [Mapping [DerivativeIdentifier , sym .Expr ]] = []
455540 mis : list [MultiIndex ] = []
@@ -464,7 +549,7 @@ def curl(diff_op: LinearPDESystemOperator) -> LinearPDESystemOperator:
464549 - diff (diff_op [(i + 1 ) % 3 ], mis [(i + 2 ) % 3 ]))
465550 eqs .append (new_pde .eqs [0 ])
466551
467- return LinearPDESystemOperator (diff_op .dim , tuple (eqs ))
552+ return LinearPDESystemOperator (diff_op .spatial_dim , tuple (eqs ))
468553
469554
470555def concat (* ops : LinearPDESystemOperator ) -> LinearPDESystemOperator :
@@ -474,8 +559,8 @@ def concat(*ops: LinearPDESystemOperator) -> LinearPDESystemOperator:
474559 if len (ops ) == 1 :
475560 return ops [0 ]
476561
477- dim = ops [0 ].dim
478- if not all (op .dim == dim for op in ops ):
562+ dim = ops [0 ].spatial_dim
563+ if not all (op .spatial_dim == dim for op in ops ):
479564 raise ValueError (f"operators must have the same dimension (expected { dim } d)" )
480565
481566 eqs = list (ops [0 ].eqs )
0 commit comments