Skip to content

Commit 0e760e8

Browse files
authored
Merge pull request #2912 from devitocodes/misc-sparse-fix
compiler: misc sparse bug fixes
2 parents 3d5bb25 + a54b3c6 commit 0e760e8

9 files changed

Lines changed: 127 additions & 62 deletions

File tree

devito/operations/interpolators.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from devito.logger import warning
1515
from devito.symbolics import INT, retrieve_function_carriers, retrieve_functions
1616
from devito.tools import Pickable, as_tuple, filter_ordered, flatten, memoized_meth
17-
from devito.types import (
18-
ConditionalDimension, CustomDimension, Eq, Evaluable, Inc, SubFunction, Symbol
19-
)
17+
from devito.types import Eq, Evaluable, Inc, SubFunction, Symbol
2018
from devito.types.utils import DimensionTuple
2119

2220
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator', 'SincInterpolator']
@@ -239,13 +237,10 @@ def _weights(self, subdomain=None):
239237
def _gdims(self):
240238
return self.grid.dimensions
241239

242-
@cached_property
240+
@property
243241
def _cdim(self):
244242
"""Base CustomDimensions used to construct _rdim"""
245-
parent = self.sfunction._sparse_dim
246-
dims = [CustomDimension(f"r{self.sfunction.name}{d.name}",
247-
-self.r+1, self.r, 2*self.r, parent)
248-
for d in self._gdims]
243+
dims = [self.sfunction._crdim(d) for d in self._gdims]
249244
return dims
250245

251246
@memoized_meth
@@ -274,8 +269,7 @@ def _rdim(self, subdomain=None):
274269
rank_populated = subdomain.distributor.rank_populated
275270
cond = sympy.And(rank_populated, cond)
276271

277-
rdims.append(ConditionalDimension(rd.name, rd, condition=cond,
278-
indirect=True))
272+
rdims.append(self.sfunction._cond_rdim(d.root, cond))
279273

280274
return DimensionTuple(*rdims, getters=gdims)
281275

@@ -425,7 +419,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
425419
subdomain=subdomain)
426420

427421
# Accumulate point-wise contributions into a temporary
428-
rhs = Symbol(name='sum', dtype=self.sfunction.dtype)
422+
rhs = Symbol(name=f'sum{self.sfunction.name}', dtype=self.sfunction.dtype)
429423
summands = [Eq(rhs, 0., implicit_dims=implicit_dims)]
430424
# Substitute coordinate base symbols into the interpolation coefficients
431425
weights = self._weights(subdomain=subdomain)

devito/passes/iet/languages/CXX.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def std_arith(prefix=''):
1919
prefix = prefix if prefix.endswith(' ') else f'{prefix} '
2020
return f"""
2121
#include <complex>
22+
#include <type_traits>
23+
24+
// ---- scalar <op> complex<T> (scalar promoted to T) --------------------
2225
2326
template<typename _Tp, typename _Ti>
2427
{prefix}std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){{
@@ -32,7 +35,7 @@ def std_arith(prefix=''):
3235
3336
template<typename _Tp, typename _Ti>
3437
{prefix}std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){{
35-
_Tp denom = b.real() * b.real () + b.imag() * b.imag();
38+
_Tp denom = b.real() * b.real() + b.imag() * b.imag();
3639
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
3740
}}
3841
@@ -53,14 +56,37 @@ def std_arith(prefix=''):
5356
5457
template<typename _Tp, typename _Ti>
5558
{prefix}std::complex<_Tp> operator - (const _Ti & a, const std::complex<_Tp> & b){{
56-
return std::complex<_Tp>(a = b.real(), b.imag());
59+
return std::complex<_Tp>(a - b.real(), -b.imag());
5760
}}
5861
5962
template<typename _Tp, typename _Ti>
6063
{prefix}std::complex<_Tp> operator - (const std::complex<_Tp> & b, const _Ti & a){{
6164
return std::complex<_Tp>(b.real() - a, b.imag());
6265
}}
6366
67+
// ---- mixed-precision complex<T1> <op> complex<T2> ----------------------
68+
// Promote both sides to std::complex<common_type_t<T1,T2>> and delegate to
69+
// the standard library's same-type operator. The enable_if disables the
70+
// overload when T1 == T2 so we don't collide with std::complex's own ops.
71+
72+
#define _MIXED_COMPLEX_OP(OP) \\
73+
template<typename _Tp1, typename _Tp2, \\
74+
typename _Tr = std::common_type_t<_Tp1, _Tp2>, \\
75+
typename = std::enable_if_t<!std::is_same<_Tp1, _Tp2>::value>> \\
76+
{prefix}std::complex<_Tr> \\
77+
operator OP (const std::complex<_Tp1> & a, \\
78+
const std::complex<_Tp2> & b) {{ \\
79+
return std::complex<_Tr>(a.real(), a.imag()) \\
80+
OP std::complex<_Tr>(b.real(), b.imag()); \\
81+
}}
82+
83+
_MIXED_COMPLEX_OP(*)
84+
_MIXED_COMPLEX_OP(/)
85+
_MIXED_COMPLEX_OP(+)
86+
_MIXED_COMPLEX_OP(-)
87+
88+
#undef _MIXED_COMPLEX_OP
89+
6490
"""
6591

6692

devito/symbolics/inspection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,14 @@ def sympy_dtype(expr, base=None, default=None, smin=None):
316316
if expr is None:
317317
return default
318318

319-
dtypes = {base} - {None}
319+
dtypes = set()
320320
for i in expr.free_symbols:
321321
with suppress(AttributeError):
322322
dtypes.add(i.dtype)
323323

324+
if not dtypes or not np.issubdtype(base, np.complexfloating):
325+
dtypes.update({base} - {None})
326+
324327
dtype = infer_dtype(dtypes)
325328

326329
# Promote if we missed complex number, i.e f + I

devito/tools/dtypes_lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,5 @@ def extract_dtype(expr):
371371
"""Extract the "winning" dtype from an expression"""
372372
dtypes = {getattr(e, 'dtype', None)
373373
for e in expr.free_symbols}
374+
374375
return infer_dtype(dtypes - {None})

devito/types/sparse.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
)
1414
from devito.symbolics import indexify, retrieve_function_carriers
1515
from devito.tools import (
16-
ReducerMap, as_tuple, dtype_to_mpidtype, filter_ordered, flatten, is_integer, prod
16+
ReducerMap, as_tuple, dtype_to_mpidtype, filter_ordered, flatten, is_integer,
17+
memoized_meth, prod
1718
)
1819
from devito.types.basic import Symbol
1920
from devito.types.dense import DiscreteFunction, SubFunction
2021
from devito.types.dimension import (
21-
ConditionalDimension, DefaultDimension, Dimension, DynamicDimension
22+
ConditionalDimension, CustomDimension, DefaultDimension, Dimension, DynamicDimension
2223
)
2324
from devito.types.dimension import dimensions as mkdims
2425
from devito.types.equation import Eq, Inc
@@ -386,6 +387,24 @@ def _position_map(self):
386387
def dist_origin(self):
387388
return self._dist_origin
388389

390+
@memoized_meth
391+
def _crdim(self, dim):
392+
"""
393+
The CustomDimension associated with the Dimension `dim` for
394+
the radius of the interpolation/injection stencil
395+
"""
396+
sname = self._sparse_dim.name
397+
return CustomDimension(f"r{sname}{dim.name}", -self.r+1,
398+
self.r, 2*self.r, self._sparse_dim)
399+
400+
@memoized_meth
401+
def _cond_rdim(self, dim, cond):
402+
"""
403+
The interpolation/injection radius dimension with guard bounds
404+
"""
405+
parent = self._crdim(dim)
406+
return ConditionalDimension(parent.name, parent, condition=cond, indirect=True)
407+
389408
def interpolate(self, *args, **kwargs):
390409
"""
391410
Implement an interpolation operation from the grid onto the given sparse points

examples/userapi/06_sparse_operations.ipynb

Lines changed: 26 additions & 26 deletions
Large diffs are not rendered by default.

tests/test_dle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def test_prec_inject(self):
314314
'openmp': True,
315315
'par-collapse-ncores': 1}))
316316

317-
assert_structure(op, ['t', 't,p_s0_blk0,p_s,rsx,rsy'],
318-
't,p_s0_blk0,p_s,rsx,rsy')
317+
assert_structure(op, ['t', 't,p_s0_blk0,p_s,rp_sx,rp_sy'],
318+
't,p_s0_blk0,p_s,rp_sx,rp_sy')
319319

320320

321321
class TestBlockingParTile:

tests/test_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None:
311311
op()
312312

313313
if op._options['linearize']:
314-
ustr = 'uL0(t1, rsx + posx + 2, rsy + posy + 2)'
314+
ustr = 'uL0(t1, rp_sx + posx + 2, rp_sy + posy + 2)'
315315
else:
316-
ustr = 'u[t1][rsx + posx + 2][rsy + posy + 2]'
316+
ustr = 'u[t1][rp_sx + posx + 2][rp_sy + posy + 2]'
317317

318318
compiler = configuration['compiler']
319319
gnu = isinstance(compiler, GNUCompiler) or \

tests/test_interpolation.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from conftest import assert_structure
88
from devito import (
99
DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction, Operator,
10-
PrecomputedSparseFunction, PrecomputedSparseTimeFunction, SparseFunction,
10+
PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real, SparseFunction,
1111
SparseTimeFunction, SubDomain, TimeFunction, switchconfig
1212
)
1313
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
@@ -17,6 +17,11 @@
1717
from examples.seismic.acoustic import AcousticWaveSolver, acoustic_setup
1818

1919

20+
class SparseFirst(SparseFunction):
21+
""" Custom sparse class with the sparse dimension as the first one"""
22+
_sparse_position = 0
23+
24+
2025
def unit_box(name='a', shape=(11, 11), grid=None, space_order=1):
2126
"""Create a field with value 0. to 1. in each dimension"""
2227
grid = grid or Grid(shape=shape)
@@ -698,11 +703,6 @@ def test_sparse_first():
698703
"""
699704
Tests custom sprase function with sparse dimension as first index.
700705
"""
701-
702-
class SparseFirst(SparseFunction):
703-
""" Custom sparse class with the sparse dimension as the first one"""
704-
_sparse_position = 0
705-
706706
dr = Dimension("cd")
707707
ds = DefaultDimension("ps", default_value=3)
708708
grid = Grid((11, 11))
@@ -870,6 +870,28 @@ def test_interp_complex(dtype):
870870
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
871871

872872

873+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
874+
def test_interp_complex_and_real(dtype):
875+
grid = Grid((11, 11, 11))
876+
877+
sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype)
878+
sc.coordinates.data[:] = [.5, .5, .5]
879+
scre = SparseFunction(name="sce", grid=grid, npoint=1, coordinates=sc.coordinates)
880+
881+
fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype)
882+
fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape)
883+
exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc))
884+
opC = Operator(exprs, name="OpC")
885+
opC()
886+
887+
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
888+
assert np.isclose(scre.data[0], fc.data[5, 5, 5].real)
889+
890+
assert_structure(opC, ['p_sc', 'p_sc,rp_scx,rp_scy,rp_scz',
891+
'p_sc,rp_scx,rp_scy,rp_scz'],
892+
'p_sc,rp_scx,rp_scy,rp_scz,rp_scx,rp_scy,rp_scz')
893+
894+
873895
class SD0(SubDomain):
874896
name = 'sd0'
875897

@@ -989,9 +1011,9 @@ def test_interpolate_subdomain(self):
9891011
assert np.all(np.isclose(sr1.data, check1))
9901012
assert np.all(np.isclose(sr2.data, check2))
9911013
assert_structure(op,
992-
['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1',
993-
'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'],
994-
'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y')
1014+
['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1',
1015+
'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'],
1016+
'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y')
9951017

9961018
def test_interpolate_subdomain_sinc(self):
9971019
"""
@@ -1032,9 +1054,9 @@ def test_interpolate_subdomain_sinc(self):
10321054
assert np.all(np.isclose(sr0.data, sr2.data))
10331055
assert np.all(np.isclose(sr1.data, sr2.data))
10341056
assert_structure(op,
1035-
['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1',
1036-
'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'],
1037-
'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y')
1057+
['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1',
1058+
'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'],
1059+
'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y')
10381060

10391061
def test_inject_subdomain(self):
10401062
"""
@@ -1080,8 +1102,8 @@ def test_inject_subdomain(self):
10801102
assert np.all(np.isclose(f0.data, check0))
10811103
assert np.all(np.isclose(f1.data, check1))
10821104
assert_structure(op,
1083-
['p_sr0rsr0xrsr0y'],
1084-
'p_sr0rsr0xrsr0y')
1105+
['p_sr0rp_sr0xrp_sr0y'],
1106+
'p_sr0rp_sr0xrp_sr0y')
10851107

10861108
def test_inject_subdomain_sinc(self):
10871109
"""
@@ -1112,8 +1134,8 @@ def test_inject_subdomain_sinc(self):
11121134
assert np.all(np.isclose(f0.data, f2.data[:9, -9:]))
11131135
assert np.all(np.isclose(f1.data, f2.data[1:-1, 1:-1]))
11141136
assert_structure(op,
1115-
['p_sr0rsr0xrsr0y'],
1116-
'p_sr0rsr0xrsr0y')
1137+
['p_sr0rp_sr0xrp_sr0y'],
1138+
'p_sr0rp_sr0xrp_sr0y')
11171139

11181140
@pytest.mark.xfail(reason="OOB issue")
11191141
@pytest.mark.parallel(mode=4)

0 commit comments

Comments
 (0)