Skip to content

Commit 2ef79c9

Browse files
committed
compiler: Tweak EqBlock/Cluster after rebasing
1 parent 8d985c2 commit 2ef79c9

2 files changed

Lines changed: 44 additions & 69 deletions

File tree

devito/ir/clusters/cluster.py

Lines changed: 41 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -65,33 +65,6 @@ def __hash__(self):
6565
return hash((self.exprs, self.ispace, self.guards, self.properties,
6666
self.syncs, self.halo_scheme))
6767

68-
def subs(self, mapper, compact=()):
69-
"""
70-
Build a new Cluster applying substitutions rules to `self`.
71-
"""
72-
if not mapper:
73-
return self
74-
75-
if self.halo_scheme:
76-
raise NotImplementedError
77-
78-
key0 = lambda i: i.is_Block
79-
subs0 = {d: self.ispace[d].promote(key0).dim for d in compact}
80-
81-
subs = {**mapper, **subs0}
82-
exprs = [uxreplace(e, subs) for e in self.exprs]
83-
84-
ispace = self.ispace.switch(mapper)
85-
key = lambda i: key0(i) and i in flatten(d._defines for d in subs0)
86-
ispace = ispace.promote(key, mode='total')
87-
88-
guards = self.guards.subs(mapper).promote(subs0)
89-
properties = self.properties.subs(mapper).promote(subs0)
90-
syncs = self.syncs.subs(mapper)
91-
92-
return self.__class__(exprs=exprs, ispace=ispace, guards=guards,
93-
properties=properties, syncs=syncs)
94-
9568
@property
9669
def exprs(self):
9770
return self._exprs
@@ -139,7 +112,7 @@ def dimensions(self):
139112
@cached_property
140113
def exprs_dimensions(self):
141114
"""
142-
The Dimensions that appear explicitly in the Cluster expressions.
115+
The Dimensions that appear explicitly in the expressions.
143116
"""
144117
dims_explicit = {i for i in self.free_symbols if i.is_Dimension}
145118
dims_implicit = {d for e in self.exprs for d in e.implicit_dims}
@@ -148,7 +121,7 @@ def exprs_dimensions(self):
148121
@cached_property
149122
def guards_dimensions(self):
150123
"""
151-
The Dimensions that appear explicitly in the Cluster guards.
124+
The Dimensions that appear explicitly in the guards.
152125
"""
153126
syms_guards = {d for e in self.guards.values() for d in e.free_symbols}
154127
dims_guards = {i for i in syms_guards if i.is_Dimension}
@@ -169,7 +142,7 @@ def used_dimensions(self):
169142
@cached_property
170143
def dist_dimensions(self):
171144
"""
172-
The Cluster's distributed Dimensions.
145+
The distributed Dimensions.
173146
"""
174147
ret = set()
175148
for f in self.functions:
@@ -195,7 +168,7 @@ def grid(self):
195168
elif len(grids) == 1:
196169
return grids.pop()
197170
else:
198-
raise ValueError("Cluster has no unique Grid")
171+
raise ValueError("Multiple Grids detected")
199172

200173
@cached_property
201174
def is_scalar(self):
@@ -323,31 +296,27 @@ def is_glb_load_to_mem_shared(self):
323296
@cached_property
324297
def is_async(self):
325298
"""
326-
True if an asynchronous Cluster, False otherwise.
299+
True if asynchronous, False otherwise.
327300
"""
328301
return any(isinstance(s, (WithLock, PrefetchUpdate))
329302
for s in flatten(self.syncs.values()))
330303

331304
@cached_property
332305
def is_wait(self):
333306
"""
334-
True if a Cluster waiting on a lock (that is a special synchronization
335-
operation), False otherwise.
307+
True if waiting on a lock (that is a special synchronization operation),
308+
False otherwise.
336309
"""
337310
return any(isinstance(s, WaitLock)
338311
for s in flatten(self.syncs.values()))
339312

340313
@cached_property
341314
def dtype(self):
342315
"""
343-
The arithmetic data type of the Cluster.
344-
345-
If the Cluster performs floating point arithmetic, then the expressions
346-
performing integer arithmetic are ignored, assuming that they are only
347-
carrying out array index calculations.
316+
The arithmetic data type of the enclosed expressions.
348317
349-
If two expressions perform calculations with different precision,
350-
the data type with highest precision is returned.
318+
If two expressions perform calculations with different precision, the data
319+
type with highest precision is returned.
351320
"""
352321
dtypes = set()
353322
for i in self.exprs:
@@ -363,8 +332,8 @@ def dtype(self):
363332
@cached_property
364333
def dspace(self):
365334
"""
366-
Derive the DataSpace of the Cluster from its expressions,
367-
IterationSpace, and Guards.
335+
The DataSpace deriving from the enclosed expressions, IterationSpace,
336+
and Guards.
368337
"""
369338
accesses = detect_accesses(self.exprs)
370339

@@ -448,8 +417,8 @@ def ops(self):
448417
@cached_property
449418
def traffic(self):
450419
"""
451-
The Cluster compulsory traffic (number of reads/writes), as a mapper
452-
from Functions to IntervalGroups.
420+
The compulsory traffic (number of reads/writes), as a mapper from
421+
Functions to IntervalGroups.
453422
454423
Notes
455424
-----
@@ -536,30 +505,6 @@ def __getattr__(self, name):
536505
raise AttributeError(name) from None
537506
return getattr(block, name)
538507

539-
@property
540-
def exprs(self):
541-
return self._block.exprs
542-
543-
@property
544-
def ispace(self):
545-
return self._block.ispace
546-
547-
@property
548-
def guards(self):
549-
return self._block.guards
550-
551-
@property
552-
def properties(self):
553-
return self._block.properties
554-
555-
@property
556-
def syncs(self):
557-
return self._block.syncs
558-
559-
@property
560-
def halo_scheme(self):
561-
return self._block.halo_scheme
562-
563508
@classmethod
564509
def from_clusters(cls, *clusters):
565510
"""
@@ -639,6 +584,33 @@ def rebuild(self, *args, **kwargs):
639584
syncs=syncs,
640585
halo_scheme=halo_scheme)
641586

587+
def subs(self, mapper, compact=()):
588+
"""
589+
Build a new Cluster applying substitutions rules to `self`.
590+
"""
591+
if not mapper:
592+
return self
593+
594+
if self.halo_scheme:
595+
raise NotImplementedError
596+
597+
key0 = lambda i: i.is_Block
598+
subs0 = {d: self.ispace[d].promote(key0).dim for d in compact}
599+
600+
subs = {**mapper, **subs0}
601+
exprs = [uxreplace(e, subs) for e in self.exprs]
602+
603+
ispace = self.ispace.switch(mapper)
604+
key = lambda i: key0(i) and i in flatten(d._defines for d in subs0)
605+
ispace = ispace.promote(key, mode='total')
606+
607+
guards = self.guards.subs(mapper).promote(subs0)
608+
properties = self.properties.subs(mapper).promote(subs0)
609+
syncs = self.syncs.subs(mapper)
610+
611+
return self.__class__(exprs=exprs, ispace=ispace, guards=guards,
612+
properties=properties, syncs=syncs)
613+
642614

643615
class ClusterGroup(tuple):
644616

devito/symbolics/search.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def retrieve_accesses(exprs, **kwargs):
195195
* ComponentAccess's are retained, but the wrapped Indexed are discarded;
196196
* TensorMove's are upcasted to the logical Indexed they represent.
197197
"""
198+
from .manipulation import uxreplace # noqa
199+
from devito.types import ComponentAccess, Symbol, TensorMove # noqa
200+
198201
kwargs['mode'] = 'unique'
199202

200203
compaccs = search(exprs, ComponentAccess)

0 commit comments

Comments
 (0)