Skip to content

Commit a3c9880

Browse files
committed
compiler: misc fixes from lowering multiple sparse operations with mixd prec
1 parent 931e137 commit a3c9880

4 files changed

Lines changed: 8 additions & 3 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ def callback(self, clusters, prefix, seen=None):
484484
# `c` is scheduled
485485
index = 0
486486
for i in reversed(range(n)):
487-
if not processed[i].ispace.is_subset(c.ispace):
487+
if not processed[i].ispace.is_subset(c.ispace) and \
488+
not processed[i].is_sparse:
488489
index = i + 1
489490
break
490491
processed.insert(index, halo_touch)

devito/operations/interpolators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
425425
subdomain=subdomain)
426426

427427
# Accumulate point-wise contributions into a temporary
428-
rhs = Symbol(name='sum', dtype=self.sfunction.dtype)
428+
rhs = Symbol(name=f'sum{self.sfunction.name}', dtype=self.sfunction.dtype)
429429
summands = [Eq(rhs, 0., implicit_dims=implicit_dims)]
430430
# Substitute coordinate base symbols into the interpolation coefficients
431431
weights = self._weights(subdomain=subdomain)

devito/symbolics/inspection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,14 @@ def sympy_dtype(expr, base=None, default=None, smin=None):
316316
if expr is None:
317317
return default
318318

319-
dtypes = {base} - {None}
319+
dtypes = set()
320320
for i in expr.free_symbols:
321321
with suppress(AttributeError):
322322
dtypes.add(i.dtype)
323323

324+
if not dtypes:
325+
dtypes = {base} - {None}
326+
324327
dtype = infer_dtype(dtypes)
325328

326329
# Promote if we missed complex number, i.e f + I

devito/tools/dtypes_lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,5 @@ def extract_dtype(expr):
371371
"""Extract the "winning" dtype from an expression"""
372372
dtypes = {getattr(e, 'dtype', None)
373373
for e in expr.free_symbols}
374+
374375
return infer_dtype(dtypes - {None})

0 commit comments

Comments
 (0)