Skip to content

Commit 1520fa2

Browse files
committed
api: prevent inconsistency in interp radius dim by tighing it to the sparse dim
1 parent 8e2ea64 commit 1520fa2

6 files changed

Lines changed: 43 additions & 30 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,7 @@ def callback(self, clusters, prefix, seen=None):
484484
# `c` is scheduled
485485
index = 0
486486
for i in reversed(range(n)):
487-
if not processed[i].ispace.is_subset(c.ispace) and \
488-
not processed[i].is_sparse:
487+
if not processed[i].ispace.is_subset(c.ispace):
489488
index = i + 1
490489
break
491490
processed.insert(index, halo_touch)

devito/operations/interpolators.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from devito.logger import warning
1515
from devito.symbolics import INT, retrieve_function_carriers, retrieve_functions
1616
from devito.tools import Pickable, as_tuple, filter_ordered, flatten, memoized_meth
17-
from devito.types import (
18-
ConditionalDimension, CustomDimension, Eq, Evaluable, Inc, SubFunction, Symbol
19-
)
17+
from devito.types import Eq, Evaluable, Inc, SubFunction, Symbol
2018
from devito.types.utils import DimensionTuple
2119

2220
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator', 'SincInterpolator']
@@ -239,13 +237,10 @@ def _weights(self, subdomain=None):
239237
def _gdims(self):
240238
return self.grid.dimensions
241239

242-
@cached_property
240+
@property
243241
def _cdim(self):
244242
"""Base CustomDimensions used to construct _rdim"""
245-
parent = self.sfunction._sparse_dim
246-
dims = [CustomDimension(f"r{self.sfunction.name}{d.name}",
247-
-self.r+1, self.r, 2*self.r, parent)
248-
for d in self._gdims]
243+
dims = [self.sfunction._crdim(d) for d in self._gdims]
249244
return dims
250245

251246
@memoized_meth
@@ -274,8 +269,7 @@ def _rdim(self, subdomain=None):
274269
rank_populated = subdomain.distributor.rank_populated
275270
cond = sympy.And(rank_populated, cond)
276271

277-
rdims.append(ConditionalDimension(rd.name, rd, condition=cond,
278-
indirect=True))
272+
rdims.append(self.sfunction._cond_rdim(d.root, cond))
279273

280274
return DimensionTuple(*rdims, getters=gdims)
281275

devito/types/sparse.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
)
1414
from devito.symbolics import indexify, retrieve_function_carriers
1515
from devito.tools import (
16-
ReducerMap, as_tuple, dtype_to_mpidtype, filter_ordered, flatten, is_integer, prod
16+
ReducerMap, as_tuple, dtype_to_mpidtype, filter_ordered, flatten, is_integer,
17+
memoized_meth, prod
1718
)
1819
from devito.types.basic import Symbol
1920
from devito.types.dense import DiscreteFunction, SubFunction
2021
from devito.types.dimension import (
21-
ConditionalDimension, DefaultDimension, Dimension, DynamicDimension
22+
ConditionalDimension, CustomDimension, DefaultDimension, Dimension, DynamicDimension
2223
)
2324
from devito.types.dimension import dimensions as mkdims
2425
from devito.types.equation import Eq, Inc
@@ -386,6 +387,24 @@ def _position_map(self):
386387
def dist_origin(self):
387388
return self._dist_origin
388389

390+
@memoized_meth
391+
def _crdim(self, dim):
392+
"""
393+
The CustomDimension associated with the Dimension `dim` for
394+
the radius of the interpolation/injection stencil
395+
"""
396+
sname = self._sparse_dim.name
397+
return CustomDimension(f"r{sname}{dim.name}", -self.r+1,
398+
self.r, 2*self.r, self._sparse_dim)
399+
400+
@memoized_meth
401+
def _cond_rdim(self, dim, cond):
402+
"""
403+
The interpolation/injection radius dimension with guard bounds
404+
"""
405+
parent = self._crdim(dim)
406+
return ConditionalDimension(parent.name, parent, condition=cond, indirect=True)
407+
389408
def interpolate(self, *args, **kwargs):
390409
"""
391410
Implement an interpolation operation from the grid onto the given sparse points

tests/test_dle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def test_prec_inject(self):
314314
'openmp': True,
315315
'par-collapse-ncores': 1}))
316316

317-
assert_structure(op, ['t', 't,p_s0_blk0,p_s,rsx,rsy'],
318-
't,p_s0_blk0,p_s,rsx,rsy')
317+
assert_structure(op, ['t', 't,p_s0_blk0,p_s,rp_sx,rp_sy'],
318+
't,p_s0_blk0,p_s,rp_sx,rp_sy')
319319

320320

321321
class TestBlockingParTile:

tests/test_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None:
311311
op()
312312

313313
if op._options['linearize']:
314-
ustr = 'uL0(t1, rsx + posx + 2, rsy + posy + 2)'
314+
ustr = 'uL0(t1, rp_sx + posx + 2, rp_sy + posy + 2)'
315315
else:
316-
ustr = 'u[t1][rsx + posx + 2][rsy + posy + 2]'
316+
ustr = 'u[t1][rp_sx + posx + 2][rp_sy + posy + 2]'
317317

318318
compiler = configuration['compiler']
319319
gnu = isinstance(compiler, GNUCompiler) or \

tests/test_interpolation.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,9 @@ def test_interp_complex_and_real(dtype):
887887
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
888888
assert np.isclose(scre.data[0], fc.data[5, 5, 5].real)
889889

890-
assert_structure(opC, ['p_sc', 'p_sc,rscx,rscy,rscz', 'p_sc,rscex,rscey,rscez'],
891-
'p_sc,rscx,rscy,rscz,rscex,rscey,rscez')
890+
assert_structure(opC, ['p_sc', 'p_sc,rp_scx,rp_scy,rp_scz',
891+
'p_sc,rp_scx,rp_scy,rp_scz'],
892+
'p_sc,rp_scx,rp_scy,rp_scz,rp_scx,rp_scy,rp_scz')
892893

893894

894895
class SD0(SubDomain):
@@ -1010,9 +1011,9 @@ def test_interpolate_subdomain(self):
10101011
assert np.all(np.isclose(sr1.data, check1))
10111012
assert np.all(np.isclose(sr2.data, check2))
10121013
assert_structure(op,
1013-
['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1',
1014-
'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'],
1015-
'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y')
1014+
['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1',
1015+
'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'],
1016+
'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y')
10161017

10171018
def test_interpolate_subdomain_sinc(self):
10181019
"""
@@ -1053,9 +1054,9 @@ def test_interpolate_subdomain_sinc(self):
10531054
assert np.all(np.isclose(sr0.data, sr2.data))
10541055
assert np.all(np.isclose(sr1.data, sr2.data))
10551056
assert_structure(op,
1056-
['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1',
1057-
'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'],
1058-
'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y')
1057+
['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1',
1058+
'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'],
1059+
'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y')
10591060

10601061
def test_inject_subdomain(self):
10611062
"""
@@ -1101,8 +1102,8 @@ def test_inject_subdomain(self):
11011102
assert np.all(np.isclose(f0.data, check0))
11021103
assert np.all(np.isclose(f1.data, check1))
11031104
assert_structure(op,
1104-
['p_sr0rsr0xrsr0y'],
1105-
'p_sr0rsr0xrsr0y')
1105+
['p_sr0rp_sr0xrp_sr0y'],
1106+
'p_sr0rp_sr0xrp_sr0y')
11061107

11071108
def test_inject_subdomain_sinc(self):
11081109
"""
@@ -1133,8 +1134,8 @@ def test_inject_subdomain_sinc(self):
11331134
assert np.all(np.isclose(f0.data, f2.data[:9, -9:]))
11341135
assert np.all(np.isclose(f1.data, f2.data[1:-1, 1:-1]))
11351136
assert_structure(op,
1136-
['p_sr0rsr0xrsr0y'],
1137-
'p_sr0rsr0xrsr0y')
1137+
['p_sr0rp_sr0xrp_sr0y'],
1138+
'p_sr0rp_sr0xrp_sr0y')
11381139

11391140
@pytest.mark.xfail(reason="OOB issue")
11401141
@pytest.mark.parallel(mode=4)

0 commit comments

Comments
 (0)