Skip to content

Commit 7ea336f

Browse files
committed
misc: Weakly cache Scope instances
1 parent 6d068d3 commit 7ea336f

8 files changed

Lines changed: 21 additions & 12 deletions

File tree

devito/ir/clusters/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def dist_dimensions(self):
183183

184184
@cached_property
185185
def scope(self):
186-
return Scope(self.exprs)
186+
return Scope.maybe_cached(as_tuple(self.exprs))
187187

188188
@cached_property
189189
def functions(self):
@@ -473,7 +473,7 @@ def exprs(self):
473473

474474
@cached_property
475475
def scope(self):
476-
return Scope(exprs=self.exprs)
476+
return Scope.maybe_cached(exprs=as_tuple(self.exprs))
477477

478478
@cached_property
479479
def ispace(self):

devito/ir/clusters/visitors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _fetch_scope(self, clusters):
134134
exprs = flatten(c.exprs for c in as_tuple(clusters))
135135
key = tuple(exprs)
136136
if key not in self.state.scopes:
137-
self.state.scopes[key] = Scope(exprs)
137+
self.state.scopes[key] = Scope.maybe_cached(as_tuple(exprs))
138138
return self.state.scopes[key]
139139

140140
def _fetch_properties(self, clusters, prefix):

devito/ir/support/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,15 @@ def __init__(self, exprs, rules=None):
852852
self.rules = as_tuple(rules)
853853
assert all(callable(i) for i in self.rules)
854854

855+
@classmethod
856+
@weak_instance_cache
857+
def maybe_cached(cls, exprs, rules=None) -> 'Scope':
858+
"""
859+
Constructs a new Scope or retrieves one from the cache if it already/still
860+
exists.
861+
"""
862+
return cls(exprs, rules)
863+
855864
@memoized_generator
856865
def writes_gen(self):
857866
"""

devito/mpi/halo_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def classify(exprs, ispace):
519519
# performed before (reads) or after (writes) the OWNED region is computed
520520
loc_indices_from_reads = configuration['mpi'] not in ('dual',)
521521

522-
scope = Scope(exprs)
522+
scope = Scope.maybe_cached(as_tuple(exprs))
523523

524524
mapper = {}
525525
for f, r in scope.reads.items():

devito/passes/clusters/blocking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def callback(self, clusters, prefix):
279279
if len(clusters) > 1:
280280
# Heuristic: same as above if it induces dynamic bounds
281281
exprs = flatten(c.exprs for c in as_tuple(clusters))
282-
scope = Scope(exprs)
282+
scope = Scope.maybe_cached(as_tuple(exprs))
283283
if any(i.is_lex_non_stmt for i in scope.d_all_gen()):
284284
return clusters
285285
else:

devito/passes/clusters/cse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
124124
maybe_exprs = as_list(maybe_exprs)
125125
if all(e.is_Equality for e in maybe_exprs):
126126
exprs = maybe_exprs
127-
scope = Scope(maybe_exprs)
127+
scope = Scope.maybe_cached(tuple(maybe_exprs))
128128
else:
129129
exprs = [Eq(make(e), e) for e in maybe_exprs]
130130
scope = Scope([])

devito/passes/clusters/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def is_cross(source, sink):
355355
for n1, cg1 in enumerate(cgroups[n+1:], start=n+1):
356356

357357
# A Scope to compute all cross-ClusterGroup anti-dependences
358-
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)
358+
scope = Scope(exprs=tuple(cg0.exprs + cg1.exprs), rules=is_cross)
359359

360360
# Anti-dependences along `prefix` break the execution flow
361361
# (intuitively, "the loop nests are to be kept separated")
@@ -444,7 +444,7 @@ def callback(self, clusters, prefix):
444444
return clusters
445445

446446
# Analyze and abort if fissioning would break a dependence
447-
scope = Scope(flatten(c.exprs for c in clusters))
447+
scope = Scope.maybe_cached(tuple(flatten(c.exprs for c in clusters)))
448448
if any(d._defines & dep.cause or dep.is_reduce(d) for dep in scope.d_all_gen()):
449449
return clusters
450450

devito/passes/iet/mpi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _drop_reduction_halospots(iet):
4040

4141
# If all HaloSpot reads pertain to reductions, then the HaloSpot is useless
4242
for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items():
43-
scope = Scope(i.expr for i in expressions)
43+
scope = Scope.maybe_cached(tuple(i.expr for i in expressions))
4444
for k, v in hs.fmapper.items():
4545
f = v.bundle or k
4646
if f not in scope.reads:
@@ -82,7 +82,7 @@ def _hoist_redundant_from_conditionals(iet):
8282

8383
mapper = HaloSpotMapper()
8484
for it, halo_spots in iter_mapper.items():
85-
scope = Scope(e.expr for e in FindNodes(Expression).visit(it))
85+
scope = Scope.maybe_cached(tuple(e.expr for e in FindNodes(Expression).visit(it)))
8686

8787
for hs0 in halo_spots:
8888
conditions = cond_mapper[hs0]
@@ -282,7 +282,7 @@ def _mark_overlappable(iet):
282282
if not expressions:
283283
continue
284284

285-
scope = Scope(i.expr for i in expressions)
285+
scope = Scope.maybe_cached(tuple(i.expr for i in expressions))
286286

287287
# Comp/comm overlaps is legal only if the OWNED regions can grow
288288
# arbitrarly, which means all of the dependences must be carried
@@ -462,7 +462,7 @@ def _derive_scope(it, hs0, hs1):
462462
and ends at the HaloSpot `hs1`.
463463
"""
464464
expressions = FindWithin(Expression, hs0, stop=hs1).visit(it)
465-
return Scope(e.expr for e in expressions)
465+
return Scope.maybe_cached(tuple(e.expr for e in expressions))
466466

467467

468468
def _check_control_flow(hs0, hs1, cond_mapper):

0 commit comments

Comments
 (0)