Skip to content

Commit bcb020d

Browse files
committed
Make use of loopy.TranslationUnit.executor
This avoids long-lived references to CL kernels held by loopy caches
1 parent 2fbc5ee commit bcb020d

6 files changed

Lines changed: 31 additions & 22 deletions

File tree

sumpy/e2e.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, ctx, src_expansion, tgt_expansion,
8282
SourceTransformationRemover()(
8383
TargetTransformationRemover()(tgt_expansion.kernel)))
8484

85-
self.ctx = ctx
85+
self.context = ctx
8686
self.src_expansion = src_expansion
8787
self.tgt_expansion = tgt_expansion
8888
self.name = name or self.default_name
@@ -297,7 +297,7 @@ def __call__(self, queue, **kwargs):
297297
src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
298298
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
299299

300-
knl = self.get_cached_optimized_kernel()
300+
knl = self.get_cached_kernel_executor()
301301

302302
return knl(queue,
303303
centers=centers,
@@ -537,7 +537,7 @@ def __call__(self, queue, **kwargs):
537537
tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
538538
src_expansions = kwargs.pop("src_expansions")
539539

540-
knl = self.get_cached_optimized_kernel(result_dtype=src_expansions.dtype)
540+
knl = self.get_cached_kernel_executor(result_dtype=src_expansions.dtype)
541541

542542
return knl(queue,
543543
src_expansions=src_expansions,
@@ -647,7 +647,7 @@ def __call__(self, queue, **kwargs):
647647
"m2l_translation_classes_dependent_data")
648648
result_dtype = m2l_translation_classes_dependent_data.dtype
649649

650-
knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
650+
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
651651

652652
return knl(queue,
653653
src_rscale=src_rscale,
@@ -741,7 +741,7 @@ def __call__(self, queue, **kwargs):
741741
"""
742742
preprocessed_src_expansions = kwargs.pop("preprocessed_src_expansions")
743743
result_dtype = preprocessed_src_expansions.dtype
744-
knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
744+
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
745745

746746
return knl(queue,
747747
preprocessed_src_expansions=preprocessed_src_expansions, **kwargs)
@@ -840,7 +840,7 @@ def __call__(self, queue, **kwargs):
840840
"""
841841
tgt_expansions = kwargs.pop("tgt_expansions")
842842
result_dtype = tgt_expansions.dtype
843-
knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
843+
knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
844844

845845
return knl(queue, tgt_expansions=tgt_expansions, **kwargs)
846846

@@ -950,7 +950,7 @@ def __call__(self, queue, **kwargs):
950950
:arg tgt_rscale:
951951
:arg centers:
952952
"""
953-
knl = self.get_cached_optimized_kernel()
953+
knl = self.get_cached_kernel_executor()
954954

955955
centers = kwargs.pop("centers")
956956
# "1" may be passed for rscale, which won't have its type
@@ -1054,7 +1054,7 @@ def __call__(self, queue, **kwargs):
10541054
:arg tgt_rscale:
10551055
:arg centers:
10561056
"""
1057-
knl = self.get_cached_optimized_kernel()
1057+
knl = self.get_cached_kernel_executor()
10581058

10591059
centers = kwargs.pop("centers")
10601060
# "1" may be passed for rscale, which won't have its type

sumpy/e2p.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, ctx, expansion, kernels,
6868
for knl in kernels:
6969
assert txr(knl) == expansion.kernel
7070

71-
self.ctx = ctx
71+
self.context = ctx
7272
self.expansion = expansion
7373
self.kernels = kernels
7474
self.name = name or self.default_name
@@ -210,7 +210,7 @@ def __call__(self, queue, **kwargs):
210210
:arg centers:
211211
:arg targets:
212212
"""
213-
knl = self.get_cached_optimized_kernel()
213+
knl = self.get_cached_kernel_executor()
214214

215215
centers = kwargs.pop("centers")
216216
# "1" may be passed for rscale, which won't have its type
@@ -327,7 +327,7 @@ def get_optimized_kernel(self):
327327
return knl
328328

329329
def __call__(self, queue, **kwargs):
330-
knl = self.get_cached_optimized_kernel()
330+
knl = self.get_cached_kernel_executor()
331331

332332
centers = kwargs.pop("centers")
333333
# "1" may be passed for rscale, which won't have its type

sumpy/p2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __call__(self, queue, **kwargs):
124124
from sumpy.tools import is_obj_array_like
125125
sources = kwargs.pop("sources")
126126
centers = kwargs.pop("centers")
127-
knl = self.get_cached_optimized_kernel(
127+
knl = self.get_cached_kernel_executor(
128128
sources_is_obj_array=is_obj_array_like(sources),
129129
centers_is_obj_array=is_obj_array_like(centers))
130130

sumpy/p2p.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def get_kernel(self):
256256
return loopy_knl
257257

258258
def __call__(self, queue, targets, sources, strength, **kwargs):
259-
knl = self.get_cached_optimized_kernel(
259+
knl = self.get_cached_kernel_executor(
260260
targets_is_obj_array=is_obj_array_like(targets),
261261
sources_is_obj_array=is_obj_array_like(sources))
262262

@@ -318,7 +318,7 @@ def get_kernel(self):
318318
return loopy_knl
319319

320320
def __call__(self, queue, targets, sources, **kwargs):
321-
knl = self.get_cached_optimized_kernel(
321+
knl = self.get_cached_kernel_executor(
322322
targets_is_obj_array=is_obj_array_like(targets),
323323
sources_is_obj_array=is_obj_array_like(sources))
324324

@@ -429,7 +429,7 @@ def __call__(self, queue, targets, sources, tgtindices, srcindices, **kwargs):
429429
:returns: a one-dimensional array of interactions, for each index pair
430430
in (*srcindices*, *tgtindices*)
431431
"""
432-
knl = self.get_cached_optimized_kernel(
432+
knl = self.get_cached_kernel_executor(
433433
targets_is_obj_array=is_obj_array_like(targets),
434434
sources_is_obj_array=is_obj_array_like(sources))
435435

@@ -731,7 +731,7 @@ def __call__(self, queue, **kwargs):
731731
else:
732732
dtype_size = None
733733

734-
knl = self.get_cached_optimized_kernel(
734+
knl = self.get_cached_kernel_executor(
735735
max_nsources_in_one_box=max_nsources_in_one_box,
736736
max_ntargets_in_one_box=max_ntargets_in_one_box,
737737
dtype_size=dtype_size,

sumpy/qbx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def __call__(self, queue, targets, sources, centers, strengths, expansion_radii,
288288
already multiplied in.
289289
"""
290290

291-
knl = self.get_cached_optimized_kernel(
291+
knl = self.get_cached_kernel_executor(
292292
targets_is_obj_array=is_obj_array_like(targets),
293293
sources_is_obj_array=is_obj_array_like(sources),
294294
centers_is_obj_array=is_obj_array_like(centers))
@@ -359,7 +359,7 @@ def get_kernel(self):
359359
return loopy_knl
360360

361361
def __call__(self, queue, targets, sources, centers, expansion_radii, **kwargs):
362-
knl = self.get_cached_optimized_kernel(
362+
knl = self.get_cached_kernel_executor(
363363
targets_is_obj_array=is_obj_array_like(targets),
364364
sources_is_obj_array=is_obj_array_like(sources),
365365
centers_is_obj_array=is_obj_array_like(centers))
@@ -479,7 +479,7 @@ def __call__(self, queue, targets, sources, centers, expansion_radii,
479479
in (*srcindices*, *tgtindices*)
480480
"""
481481

482-
knl = self.get_cached_optimized_kernel(
482+
knl = self.get_cached_kernel_executor(
483483
targets_is_obj_array=is_obj_array_like(targets),
484484
sources_is_obj_array=is_obj_array_like(sources),
485485
centers_is_obj_array=is_obj_array_like(centers))

sumpy/tools.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,17 @@ def __eq__(self, other):
381381

382382

383383
class KernelCacheMixin:
384-
@memoize_method
385384
def get_cached_optimized_kernel(self, **kwargs):
385+
from warnings import warn
386+
warn("get_cached_optimized_kernel is deprecated. "
387+
"Use get_cached_kernel_executor instead. "
388+
"This will stop working in October 2023.",
389+
DeprecationWarning, stacklevel=2)
390+
391+
return self.get_cached_kernel_executor(**kwargs)
392+
393+
@memoize_method
394+
def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
386395
from sumpy import (code_cache, CACHING_ENABLED, OPT_ENABLED,
387396
NO_CACHE_KERNELS)
388397

@@ -401,7 +410,7 @@ def get_cached_optimized_kernel(self, **kwargs):
401410
result = code_cache[cache_key]
402411
logger.debug("{}: kernel cache hit [key={}]".format(
403412
self.name, cache_key))
404-
return result
413+
return result.executor(self.context)
405414
except KeyError:
406415
pass
407416

@@ -422,7 +431,7 @@ def get_cached_optimized_kernel(self, **kwargs):
422431
NO_CACHE_KERNELS and self.name in NO_CACHE_KERNELS):
423432
code_cache.store_if_not_present(cache_key, knl)
424433

425-
return knl
434+
return knl.executor(self.context)
426435

427436
@staticmethod
428437
def _allow_redundant_execution_of_knl_scaling(knl):

0 commit comments

Comments
 (0)