Skip to content

Commit c87c528

Browse files
committed
port e2p to arraycontext
1 parent 2285243 commit c87c528

1 file changed

Lines changed: 42 additions & 43 deletions

File tree

sumpy/e2p.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222

2323
import numpy as np
2424
import loopy as lp
25-
import sumpy.symbolic as sym
2625

27-
from sumpy.tools import KernelCacheWrapper
28-
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
26+
from sumpy.array_context import PyOpenCLArrayContext, make_loopy_program
27+
from sumpy.tools import KernelCacheMixin
2928

3029

3130
__doc__ = """
@@ -42,7 +41,7 @@
4241

4342
# {{{ E2P base class
4443

45-
class E2PBase(KernelCacheWrapper):
44+
class E2PBase(KernelCacheMixin):
4645
def __init__(self, expansion, kernels, name=None):
4746
"""
4847
:arg expansion: a subclass of :class:`sympy.expansion.ExpansionBase`
@@ -52,13 +51,12 @@ def __init__(self, expansion, kernels, name=None):
5251
Default: all kernels use the same strength.
5352
"""
5453

55-
from sumpy.kernel import (SourceTransformationRemover,
56-
TargetTransformationRemover)
54+
from sumpy.kernel import (
55+
SourceTransformationRemover, TargetTransformationRemover)
5756
sxr = SourceTransformationRemover()
5857
txr = TargetTransformationRemover()
59-
expansion = expansion.with_kernel(
60-
sxr(expansion.kernel))
6158

59+
expansion = expansion.with_kernel(sxr(expansion.kernel))
6260
kernels = [sxr(knl) for knl in kernels]
6361
for knl in kernels:
6462
assert txr(knl) == expansion.kernel
@@ -70,8 +68,8 @@ def __init__(self, expansion, kernels, name=None):
7068
self.dim = expansion.dim
7169

7270
def get_loopy_insns_and_result_names(self):
73-
from sumpy.symbolic import make_sym_vector
74-
bvec = make_sym_vector("b", self.dim)
71+
import sumpy.symbolic as sym
72+
bvec = sym.make_sym_vector("b", self.dim)
7573

7674
import sumpy.symbolic as sp
7775
rscale = sp.Symbol("rscale")
@@ -131,11 +129,10 @@ def get_kernel(self):
131129

132130
loopy_insns, result_names = self.get_loopy_insns_and_result_names()
133131

134-
loopy_knl = lp.make_kernel(
135-
[
136-
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
137-
"{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}",
138-
],
132+
loopy_knl = make_loopy_program([
133+
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
134+
"{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}",
135+
],
139136
self.get_kernel_scaling_assignment()
140137
+ ["""
141138
for itgt_box
@@ -163,7 +160,7 @@ def get_kernel(self):
163160
end
164161
end
165162
"""],
166-
[
163+
kernel_data=[
167164
lp.GlobalArg("targets", None, shape=(self.dim, "ntargets"),
168165
dim_tags="sep,C"),
169166
lp.GlobalArg("box_target_starts,box_target_counts_nonchild",
@@ -177,16 +174,17 @@ def get_kernel(self):
177174
lp.ValueArg("nsrc_level_boxes,naligned_boxes", np.int32),
178175
lp.ValueArg("src_base_ibox", np.int32),
179176
lp.ValueArg("ntargets", np.int32),
180-
"..."
177+
...
181178
] + [arg.loopy_arg for arg in self.expansion.get_args()],
182179
name=self.name,
183-
assumptions="ntgt_boxes>=1",
184180
silenced_warnings="write_race(write_result*)",
185-
default_offset=lp.auto,
186-
fixed_parameters=dict(dim=self.dim, nresults=len(result_names)),
187-
lang_version=MOST_RECENT_LANGUAGE_VERSION)
181+
)
188182

183+
loopy_knl = lp.assume(loopy_knl, "ntgt_boxes>=1")
184+
loopy_knl = lp.fix_parameters(loopy_knl,
185+
dim=self.dim, nresults=len(result_names))
189186
loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
187+
190188
for knl in self.kernels:
191189
loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
192190

@@ -202,7 +200,7 @@ def get_optimized_kernel(self):
202200

203201
return knl
204202

205-
def __call__(self, queue, **kwargs):
203+
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
206204
"""
207205
:arg expansions:
208206
:arg target_boxes:
@@ -211,14 +209,15 @@ def __call__(self, queue, **kwargs):
211209
:arg centers:
212210
:arg targets:
213211
"""
214-
knl = self.get_cached_optimized_kernel()
215212

216213
centers = kwargs.pop("centers")
217214
# "1" may be passed for rscale, which won't have its type
218215
# meaningfully inferred. Make the type of rscale explicit.
219216
rscale = centers.dtype.type(kwargs.pop("rscale"))
220217

221-
return knl(queue, centers=centers, rscale=rscale, **kwargs)
218+
return actx.call_loopy(
219+
self.get_cached_optimized_kernel(),
220+
centers=centers, rscale=rscale, **kwargs)
222221

223222
# }}}
224223

@@ -233,13 +232,12 @@ def get_kernel(self):
233232

234233
loopy_insns, result_names = self.get_loopy_insns_and_result_names()
235234

236-
loopy_knl = lp.make_kernel(
237-
[
238-
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
239-
"{[itgt]: itgt_start<=itgt<itgt_end}",
240-
"{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }",
241-
"{[idim]: 0<=idim<dim}",
242-
],
235+
loopy_knl = make_loopy_program([
236+
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
237+
"{[itgt]: itgt_start<=itgt<itgt_end}",
238+
"{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }",
239+
"{[idim]: 0<=idim<dim}",
240+
],
243241
self.get_kernel_scaling_assignment()
244242
+ ["""
245243
for itgt_box
@@ -274,7 +272,7 @@ def get_kernel(self):
274272
end
275273
end
276274
"""],
277-
[
275+
kernel_data=[
278276
lp.GlobalArg("targets", None, shape=(self.dim, "ntargets"),
279277
dim_tags="sep,C"),
280278
lp.GlobalArg("box_target_starts,box_target_counts_nonchild",
@@ -289,19 +287,20 @@ def get_kernel(self):
289287
dim_tags="sep,C"),
290288
lp.GlobalArg("source_box_starts, source_box_lists,",
291289
None, shape=None, offset=lp.auto),
292-
"..."
290+
...
293291
] + [arg.loopy_arg for arg in self.expansion.get_args()],
294292
name=self.name,
295-
assumptions="ntgt_boxes>=1",
296293
silenced_warnings="write_race(write_result*)",
297-
default_offset=lp.auto,
298-
fixed_parameters=dict(
299-
dim=self.dim,
300-
nresults=len(result_names)),
301-
lang_version=MOST_RECENT_LANGUAGE_VERSION)
294+
)
295+
296+
loopy_knl = lp.assume(loopy_knl, "ntgt_boxes>=1")
297+
loopy_knl = lp.fix_parameters(loopy_knl,
298+
dim=self.dim,
299+
nresults=len(result_names))
302300

303301
loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
304302
loopy_knl = lp.prioritize_loops(loopy_knl, "itgt_box,itgt,isrc_box")
303+
305304
for knl in self.kernels:
306305
loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
307306

@@ -316,15 +315,15 @@ def get_optimized_kernel(self):
316315
enforce_variable_access_ordered="no_check")
317316
return knl
318317

319-
def __call__(self, queue, **kwargs):
320-
knl = self.get_cached_optimized_kernel()
321-
318+
def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
322319
centers = kwargs.pop("centers")
323320
# "1" may be passed for rscale, which won't have its type
324321
# meaningfully inferred. Make the type of rscale explicit.
325322
rscale = centers.dtype.type(kwargs.pop("rscale"))
326323

327-
return knl(queue, centers=centers, rscale=rscale, **kwargs)
324+
return actx.call_loopy(
325+
self.get_cached_optimized_kernel(),
326+
centers=centers, rscale=rscale, **kwargs)
328327

329328
# }}}
330329

0 commit comments

Comments
 (0)