Skip to content

Commit 8e2ea64

Browse files
committed
tests: add test for sparse corner case
1 parent 17ee8f9 commit 8e2ea64

2 files changed

Lines changed: 29 additions & 8 deletions

File tree

devito/symbolics/inspection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ def sympy_dtype(expr, base=None, default=None, smin=None):
321321
with suppress(AttributeError):
322322
dtypes.add(i.dtype)
323323

324-
if not dtypes:
325-
dtypes = {base} - {None}
324+
if not dtypes or not np.issubdtype(base, np.complexfloating):
325+
dtypes.update({base} - {None})
326326

327327
dtype = infer_dtype(dtypes)
328328

tests/test_interpolation.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from conftest import assert_structure
88
from devito import (
99
DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction, Operator,
10-
PrecomputedSparseFunction, PrecomputedSparseTimeFunction, SparseFunction,
10+
PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real, SparseFunction,
1111
SparseTimeFunction, SubDomain, TimeFunction, switchconfig
1212
)
1313
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
@@ -17,6 +17,11 @@
1717
from examples.seismic.acoustic import AcousticWaveSolver, acoustic_setup
1818

1919

20+
class SparseFirst(SparseFunction):
21+
""" Custom sparse class with the sparse dimension as the first one"""
22+
_sparse_position = 0
23+
24+
2025
def unit_box(name='a', shape=(11, 11), grid=None, space_order=1):
2126
"""Create a field with value 0. to 1. in each dimension"""
2227
grid = grid or Grid(shape=shape)
@@ -698,11 +703,6 @@ def test_sparse_first():
698703
"""
699704
Tests custom sprase function with sparse dimension as first index.
700705
"""
701-
702-
class SparseFirst(SparseFunction):
703-
""" Custom sparse class with the sparse dimension as the first one"""
704-
_sparse_position = 0
705-
706706
dr = Dimension("cd")
707707
ds = DefaultDimension("ps", default_value=3)
708708
grid = Grid((11, 11))
@@ -870,6 +870,27 @@ def test_interp_complex(dtype):
870870
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
871871

872872

873+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
874+
def test_interp_complex_and_real(dtype):
875+
grid = Grid((11, 11, 11))
876+
877+
sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype)
878+
sc.coordinates.data[:] = [.5, .5, .5]
879+
scre = SparseFunction(name="sce", grid=grid, npoint=1, coordinates=sc.coordinates)
880+
881+
fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype)
882+
fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape)
883+
exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc))
884+
opC = Operator(exprs, name="OpC")
885+
opC()
886+
887+
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
888+
assert np.isclose(scre.data[0], fc.data[5, 5, 5].real)
889+
890+
assert_structure(opC, ['p_sc', 'p_sc,rscx,rscy,rscz', 'p_sc,rscex,rscey,rscez'],
891+
'p_sc,rscx,rscy,rscz,rscex,rscey,rscez')
892+
893+
873894
class SD0(SubDomain):
874895
name = 'sd0'
875896

0 commit comments

Comments
 (0)