Skip to content

Commit cea48f3

Browse files
committed
compiler: prevent hosted per-thread arrays are dereferenced within partree at read
1 parent 2d32230 commit cea48f3

4 files changed

Lines changed: 14 additions & 8 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def __rfloordiv__(self, other):
259259
from .elementary import floor
260260
return floor(other / self)
261261

262+
def safe_inv(self, ref):
263+
return SafeInv(self, ref or self)
264+
262265
def __mod__(self, other):
263266
return Mod(self, other)
264267

devito/passes/iet/parpragma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def _make_parregion(self, partree, parrays):
317317
i = n.write
318318
if not (i.is_Array or i.is_TempFunction):
319319
continue
320+
elif partree.dim in i.dimensions:
321+
# Non-local Array (full iteration space): no need to vector-expand
322+
continue
320323
elif i in parrays:
321324
pi = parrays[i]
322325
else:

devito/symbolics/inspection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sympy import (Function, Indexed, Integer, Mul, Number,
55
Pow, S, Symbol, Tuple)
66
from sympy.core.numbers import ImaginaryUnit
7+
from sympy.core.function import Application
78

89
from devito.finite_differences import Derivative
910
from devito.finite_differences.differentiable import IndexDerivative
@@ -116,7 +117,7 @@ def estimate_cost(exprs, estimate=False):
116117
estimate_values = {
117118
'elementary': 100,
118119
'pow': 50,
119-
'SafeInv': 10,
120+
'SafeInv': 50,
120121
'div': 5,
121122
'Abs': 5,
122123
'floor': 1,
@@ -211,6 +212,7 @@ def _(expr, estimate, seen):
211212

212213

213214
@_estimate_cost.register(Function)
215+
@_estimate_cost.register(Application)
214216
def _(expr, estimate, seen):
215217
if q_routine(expr):
216218
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])

devito/types/basic.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def c0(self):
989989
def _eval_deriv(self):
990990
return self
991991

992-
@cached_property
992+
@property
993993
def _grid_map(self):
994994
"""
995995
Mapper of off-grid interpolation points indices for each dimension.
@@ -1049,14 +1049,13 @@ def _evaluate(self, **kwargs):
10491049
return self
10501050

10511051
io = self.interp_order
1052+
retval = self.subs({i.subs(subs): self.indices_ref[d]
1053+
for d, i in mapper.items()})
10521054
if self._avg_mode == 'harmonic':
1053-
retval = 1 / self
1054-
else:
1055-
retval = self
1055+
retval = retval.safe_inv(retval)
10561056

10571057
# Apply interpolation from inner most dim
10581058
for d, i in mapper.items():
1059-
retval = retval._subs(i.subs(subs), self.indices_ref[d])
10601059
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10611060

10621061
# Evaluate. Since we used `self.function` it will be on the grid when
@@ -1066,8 +1065,7 @@ def _evaluate(self, **kwargs):
10661065

10671066
# If harmonic averaging, invert at the end
10681067
if self._avg_mode == 'harmonic':
1069-
from devito.finite_differences.differentiable import SafeInv
1070-
retval = SafeInv(retval, self.function.subs(subs))
1068+
retval = retval.safe_inv(self.function.subs(subs))
10711069

10721070
return retval
10731071

0 commit comments

Comments
 (0)