Skip to content

Commit 7a4687b

Browse files
authored
Merge pull request #332 from DedalusProject/cardinal
Cardinal basis
2 parents 3c1c846 + d0b816f commit 7a4687b

5 files changed

Lines changed: 461 additions & 5 deletions

File tree

dedalus/core/basis.py

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727

2828
# Public interface
29-
__all__ = ['Jacobi',
29+
__all__ = ['CardinalBasis',
30+
'Jacobi',
3031
'Legendre',
3132
'Ultraspherical',
3233
'Chebyshev',
@@ -334,10 +335,166 @@ def enum_indices(tensorsig):
334335
# self._ncc_matrices = [self._ncc_matrix_recursion(ncc.data[ind], ncc.domain.full_bases, operand.domain.full_bases, separability, **kw) for ind in np.ndindex(*tshape)]
335336

336337

337-
class IntervalBasis(Basis):
338+
class CardinalBasis(Basis):
339+
"""Cardinal basis."""
338340

339341
dim = 1
342+
group_shape = (1,)
343+
subaxis_dependence = [False]
344+
345+
def __init__(self, coord, size):
346+
self.coord = coord
347+
self.coordsys = coord
348+
self.size = size
349+
self.shape = (size,)
350+
self.dealias = (1,)
351+
super().__init__(coord)
352+
353+
def __add__(self, other):
354+
if other is None or other is self:
355+
return self
356+
return NotImplemented
357+
358+
def __mul__(self, other):
359+
if other is None or other is self:
360+
return self
361+
return NotImplemented
362+
363+
def __rmatmul__(self, other):
364+
# NCC (other) * operand (self)
365+
if other is None or other is self:
366+
return self
367+
return NotImplemented
368+
369+
def elements_to_groups(self, grid_space, elements):
370+
# No permutations
371+
return elements
372+
373+
def valid_elements(self, tensorsig, grid_space, elements):
374+
# No invalid modes
375+
vshape = tuple(cs.dim for cs in tensorsig) + elements[0].shape
376+
return np.ones(shape=vshape, dtype=bool)
377+
378+
def matrix_dependence(self, matrix_coupling):
379+
return matrix_coupling
380+
381+
def global_grids(self, dist, scales):
382+
"""Global grids."""
383+
return (self.global_grid(dist, scales[0]),)
384+
385+
def global_grid(self, dist, scale):
386+
"""Global grid."""
387+
if scale != 1:
388+
raise NotImplementedError("Cardinal basis only supports scale=1.")
389+
return np.arange(self.size)
390+
391+
def local_grids(self, dist, scales):
392+
"""Local grids."""
393+
return (self.local_grid(dist, scales[0]),)
394+
395+
def local_grid(self, dist, scale):
396+
"""Local grid."""
397+
if scale != 1:
398+
raise NotImplementedError("Cardinal basis only supports scale=1.")
399+
local_elements = dist.grid_layout.local_elements(self.domain(dist), scales=scale)
400+
return np.arange(self.size)[local_elements[dist.get_basis_axis(self)]]
401+
402+
def local_modes(self, dist):
403+
"""Local grid."""
404+
local_elements = dist.coeff_layout.local_elements(self.domain(dist), scales=1)
405+
return reshape_vector(local_elements[dist.get_basis_axis(self)], dim=dist.dim, axis=dist.get_basis_axis(self))
406+
407+
def global_shape(self, grid_space, scales):
408+
return self.shape
409+
410+
def chunk_shape(self, grid_space):
411+
return (1,)
412+
413+
def forward_transform(self, field, axis, gdata, cdata):
414+
"""Forward transform field data."""
415+
np.copyto(cdata, gdata)
416+
417+
def backward_transform(self, field, axis, cdata, gdata):
418+
"""Backward transform field data."""
419+
np.copyto(gdata, cdata)
420+
421+
422+
class ConvertConstantCardinal(operators.ConvertConstant, operators.SpectralOperator1D):
423+
"""Convert constant to Cardinal basis."""
424+
425+
output_basis_type = CardinalBasis
426+
subaxis_dependence = [True]
427+
subaxis_coupling = [True]
428+
429+
@staticmethod
430+
def _full_matrix(input_basis, output_basis):
431+
return np.ones((output_basis.size, 1))
432+
433+
434+
class InterpolateCardinal(operators.Interpolate, operators.SpectralOperator1D):
435+
"""Interpolate Cardinal basis."""
436+
437+
input_basis_type = CardinalBasis
438+
basis_subaxis = 0
439+
subaxis_dependence = [True]
440+
subaxis_coupling = [True]
441+
442+
def __init__(self, operand, coord, position, out=None):
443+
if not isinstance(position, (int, np.integer)):
444+
raise TypeError("Cardinal interpolation position must be an integer")
445+
super().__init__(operand, coord, position, out=out)
446+
447+
@staticmethod
448+
def _output_basis(input_basis, position):
449+
return None
450+
451+
@staticmethod
452+
def _full_matrix(input_basis, output_basis, position):
453+
interp_vector = np.zeros(input_basis.size)
454+
interp_vector[position] = 1
455+
return interp_vector[None, :]
456+
457+
458+
class IntegrateCardinal(operators.Integrate, operators.SpectralOperator1D):
459+
"""Cardinal basis integration."""
460+
461+
input_coord_type = Coordinate
462+
input_basis_type = CardinalBasis
463+
subaxis_dependence = [True]
464+
subaxis_coupling = [True]
465+
466+
@staticmethod
467+
def _output_basis(input_basis):
468+
return None
469+
470+
@staticmethod
471+
def _full_matrix(input_basis, output_basis):
472+
integ_vector = np.ones(input_basis.size)
473+
return integ_vector[None, :]
474+
475+
476+
class AverageCardinal(operators.Average, operators.SpectralOperator1D):
477+
"""Cardinal basis averaging."""
478+
479+
input_coord_type = Coordinate
480+
input_basis_type = CardinalBasis
340481
subaxis_dependence = [True]
482+
subaxis_coupling = [True]
483+
484+
@staticmethod
485+
def _output_basis(input_basis):
486+
return None
487+
488+
@staticmethod
489+
def _full_matrix(input_basis, output_basis):
490+
ave_vector = np.ones(input_basis.size) / input_basis.size
491+
return ave_vector[None, :]
492+
493+
494+
class IntervalBasis(Basis):
495+
496+
dim = 1
497+
subaxis_dependence = [False]
341498

342499
def __init__(self, coord, size, bounds, dealias):
343500
self.coord = coord
@@ -6084,15 +6241,16 @@ def cfl_spacing(self):
60846241
velocity = self.operand
60856242
coordsys = velocity.tensorsig[0]
60866243
spacing = []
6087-
for i, c in enumerate(coordsys.coords):
6244+
for c in coordsys.coords:
60886245
basis = velocity.domain.get_basis(c)
60896246
if basis:
60906247
dealias = basis.dealias[0]
60916248
axis_spacing = basis.local_grid_spacing(self.dist, dealias) * dealias
60926249
N = basis.grid_shape((dealias,))[0]
60936250
if isinstance(basis, Jacobi) and basis.a == -1/2 and basis.b == -1/2:
60946251
#Special case for ChebyshevT (a=b=-1/2)
6095-
local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[i]
6252+
axis = self.dist.get_basis_axis(basis)
6253+
local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[axis]
60966254
i = np.arange(N)[local_elements].reshape(axis_spacing.shape)
60976255
theta = np.pi * (i + 1/2) / N
60986256
axis_spacing[:] = dealias * basis.COV.stretch * np.sin(theta) * np.pi / N
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Test Cardinal basis operators: ConvertConstant, Interpolate, Integrate, Average."""
2+
3+
import pytest
4+
import numpy as np
5+
import dedalus.public as d3
6+
from dedalus.tools.cache import CachedMethod
7+
8+
9+
N_range = [5, 10]
10+
dtype_range = [np.float64, np.complex128]
11+
12+
13+
@CachedMethod
14+
def build_cardinal(N, dtype):
15+
c = d3.Coordinate('n')
16+
dist = d3.Distributor(c, dtype=dtype)
17+
b = d3.CardinalBasis(c, size=N)
18+
n = dist.local_grid(b, scale=1)
19+
return c, dist, b, n
20+
21+
22+
@pytest.mark.parametrize('N', N_range)
23+
@pytest.mark.parametrize('dtype', dtype_range)
24+
@pytest.mark.parametrize('layout', ['g', 'c'])
25+
def test_cardinal_convert_constant(N, dtype, layout):
26+
"""Test conversion from constant to Cardinal basis (broadcasts scalar to all N entries)."""
27+
c, dist, b, n = build_cardinal(N, dtype)
28+
f = dist.Field()
29+
f['g'] = 3
30+
f.change_layout(layout)
31+
g = d3.Convert(f, b).evaluate()
32+
assert np.allclose(g['g'], 3 * np.ones(N))
33+
34+
35+
@pytest.mark.parametrize('N', N_range)
36+
@pytest.mark.parametrize('dtype', dtype_range)
37+
@pytest.mark.parametrize('index', [0, 2, -1])
38+
def test_cardinal_interpolate(N, dtype, index):
39+
"""Test Interpolate extracts the value at the given integer index."""
40+
c, dist, b, n = build_cardinal(N, dtype)
41+
f = dist.Field(bases=b)
42+
f.fill_random('g')
43+
g = d3.Interpolate(f, c, index).evaluate()
44+
assert np.allclose(g['g'], f['g'][index])
45+
46+
47+
@pytest.mark.parametrize('N', N_range)
48+
@pytest.mark.parametrize('dtype', dtype_range)
49+
def test_cardinal_integrate(N, dtype):
50+
"""Test Integrate computes the discrete sum over all entries."""
51+
c, dist, b, n = build_cardinal(N, dtype)
52+
f = dist.Field(bases=b)
53+
f.fill_random('g')
54+
g = d3.Integrate(f, c).evaluate()
55+
assert np.allclose(g['g'], f['g'].sum())
56+
57+
58+
@pytest.mark.parametrize('N', N_range)
59+
@pytest.mark.parametrize('dtype', dtype_range)
60+
def test_cardinal_integrate_constant(N, dtype):
61+
"""Test Integrate of a uniform field equals N * value."""
62+
c, dist, b, n = build_cardinal(N, dtype)
63+
f = dist.Field(bases=b)
64+
f['g'] = 3
65+
g = d3.Integrate(f, c).evaluate()
66+
assert np.allclose(g['g'], 3 * N)
67+
68+
69+
@pytest.mark.parametrize('N', N_range)
70+
@pytest.mark.parametrize('dtype', dtype_range)
71+
def test_cardinal_average(N, dtype):
72+
"""Test Average computes the mean over all entries."""
73+
c, dist, b, n = build_cardinal(N, dtype)
74+
f = dist.Field(bases=b)
75+
f.fill_random('g')
76+
g = d3.Average(f, c).evaluate()
77+
assert np.allclose(g['g'], f['g'].mean())
78+
79+
80+
@pytest.mark.parametrize('N', N_range)
81+
@pytest.mark.parametrize('dtype', dtype_range)
82+
def test_cardinal_average_constant(N, dtype):
83+
"""Test Average of a uniform field returns that value."""
84+
c, dist, b, n = build_cardinal(N, dtype)
85+
f = dist.Field(bases=b)
86+
f['g'] = 7
87+
g = d3.Average(f, c).evaluate()
88+
assert np.allclose(g['g'], 7)
89+

docs/notebooks/dedalus_tutorial_1.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,10 @@
246246
"\n",
247247
"* `RealFourier` for real periodic functions on an interval using cosine & sine modes.\n",
248248
"* `ComplexFourier` for complex periodic functions on an interval using complex exponentials.\n",
249-
"* `Chebyshev` for functions on an interval.\n",
249+
"* `Chebyshev` for functions on an interval with fast tranforms (DCTs).\n",
250+
"* `Legendre` for functions on an interval with L2 orthogonality.\n",
250251
"* `Jacobi` for functions on an interval under a more general inner product (usually `Chebyshev` is best for performance).\n",
252+
"* `CardinalBasis` for a finite set of discrete elements (e.g., ensemble members or parameter sweeps).\n",
251253
"* `DiskBasis` for functions on a full disk in polar coordinates.\n",
252254
"* `AnnulusBasis` for functions on an annulus in polar coordinates.\n",
253255
"* `SphereBasis` for functions on the 2-sphere in S2 or spherical coordinates.\n",

0 commit comments

Comments
 (0)