Skip to content

Commit 0b65baf

Browse files
committed
port tools and toys to arraycontext
1 parent 2aa656b commit 0b65baf

2 files changed

Lines changed: 201 additions & 191 deletions

File tree

sumpy/tools.py

Lines changed: 67 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,24 @@
3636
.. autoclass:: DifferentiatedExprDerivativeTaker
3737
"""
3838

39-
from pytools import memoize_method
40-
from pytools.tag import Tag, tag_dataclass
41-
import numbers
42-
import warnings
4339
import os
4440
import sys
4541
import enum
46-
import platform
47-
from collections import defaultdict, namedtuple
48-
from pymbolic.mapper import WalkMapper
49-
import pymbolic
42+
import numbers
43+
import warnings
44+
from dataclasses import dataclass
45+
from typing import Any, Dict, List, Tuple
5046

5147
import numpy as np
52-
import sumpy.symbolic as sym
53-
import pyopencl as cl
54-
import pyopencl.array as cla
5548

5649
import loopy as lp
57-
from typing import Dict, Tuple, Any
50+
from pytools import memoize_method
51+
from pytools.tag import Tag, tag_dataclass
52+
from pymbolic.mapper import WalkMapper
53+
from arraycontext import Array
54+
55+
import sumpy.symbolic as sym
56+
from sumpy.array_context import PyOpenCLArrayContext
5857

5958
import logging
6059
logger = logging.getLogger(__name__)
@@ -452,7 +451,9 @@ def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict,
452451
*derivative_coeff_dict* using the variable given by **variable_idx**
453452
and return a new derivative transformation dictionary.
454453
"""
454+
from collections import defaultdict
455455
new_derivative_coeff_dict = defaultdict(lambda: 0)
456+
456457
for mi, coeff in derivative_coeff_dict.items():
457458
# In the case where we have x * u.diff(x), the result should
458459
# be x.diff(x) + x * u.diff(x, x)
@@ -511,31 +512,6 @@ def build_matrix(op, dtype=None, shape=None):
511512
return mat
512513

513514

514-
def vector_to_device(queue, vec):
515-
from pytools.obj_array import obj_array_vectorize
516-
517-
from pyopencl.array import to_device
518-
519-
def to_dev(ary):
520-
return to_device(queue, ary)
521-
522-
return obj_array_vectorize(to_dev, vec)
523-
524-
525-
def vector_from_device(queue, vec):
526-
from pytools.obj_array import obj_array_vectorize
527-
528-
def from_dev(ary):
529-
from numbers import Number
530-
if isinstance(ary, (np.number, Number)):
531-
# zero, most likely
532-
return ary
533-
534-
return ary.get(queue=queue)
535-
536-
return obj_array_vectorize(from_dev, vec)
537-
538-
539515
def _merge_kernel_arguments(dictionary, arg):
540516
# Check for strict equality until there's a usecase
541517
if dictionary.setdefault(arg.name, arg) != arg:
@@ -716,7 +692,7 @@ def __eq__(self, other):
716692
# }}}
717693

718694

719-
class KernelCacheWrapper:
695+
class KernelCacheMixin:
720696
@memoize_method
721697
def get_cached_optimized_kernel(self, **kwargs):
722698
from sumpy import code_cache, CACHING_ENABLED, OPT_ENABLED
@@ -763,6 +739,9 @@ def _allow_redundant_execution_of_knl_scaling(knl):
763739
knl, within=ObjTagged(ScalingAssignmentTag()))
764740

765741

742+
KernelCacheWrapper = KernelCacheMixin
743+
744+
766745
def is_obj_array_like(ary):
767746
return (
768747
isinstance(ary, (tuple, list))
@@ -934,10 +913,14 @@ def to_complex_dtype(dtype):
934913
raise RuntimeError(f"Unknown dtype: {dtype}")
935914

936915

937-
ProfileGetter = namedtuple("ProfileGetter", "start, end")
916+
@dataclass(frozen=True)
917+
class ProfileGetter:
918+
start: int
919+
end: int
938920

939921

940922
def get_native_event(evt):
923+
import pyopencl as cl
941924
return evt if isinstance(evt, cl.Event) else evt.native_event
942925

943926

@@ -991,21 +974,22 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
991974
N1, m = find_factors(m) # noqa: N806
992975
factors.append(N1)
993976

977+
import pymbolic as prim
994978
nfft = n
995979

996-
broadcast_dims = tuple(pymbolic.var(f"j{d}") for d in range(len(shape) - 1))
980+
broadcast_dims = tuple(prim.var(f"j{d}") for d in range(len(shape) - 1))
997981

998982
domains = [
999983
"{[i]: 0<=i<n}",
1000984
"{[i2]: 0<=i2<n}",
1001985
]
1002986
domains += [f"{{[j{d}]: 0<=j{d}<{shape[d]} }}" for d in range(len(shape) - 1)]
1003987

1004-
x = pymbolic.var("x")
1005-
y = pymbolic.var("y")
1006-
i = pymbolic.var("i")
1007-
i2 = pymbolic.var("i2")
1008-
i3 = pymbolic.var("i3")
988+
x = prim.var("x")
989+
y = prim.var("y")
990+
i = prim.var("i")
991+
i2 = prim.var("i2")
992+
i3 = prim.var("i3")
1009993

1010994
fixed_parameters = {"const": complex_dtype(sign*(-2j)*pi/n), "n": n}
1011995

@@ -1027,16 +1011,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
10271011
else:
10281012
init_depends_on = f"update_{ilev-1}"
10291013

1030-
temp = pymbolic.var("temp")
1031-
exp_table = pymbolic.var("exp_table")
1032-
i = pymbolic.var(f"i_{ilev}")
1033-
i2 = pymbolic.var(f"i2_{ilev}")
1034-
ifft = pymbolic.var(f"ifft_{ilev}")
1035-
iN1 = pymbolic.var(f"iN1_{ilev}") # noqa: N806
1036-
iN1_sum = pymbolic.var(f"iN1_sum_{ilev}") # noqa: N806
1037-
iN2 = pymbolic.var(f"iN2_{ilev}") # noqa: N806
1038-
table_idx = pymbolic.var(f"table_idx_{ilev}")
1039-
exp = pymbolic.var(f"exp_{ilev}")
1014+
temp = prim.var("temp")
1015+
exp_table = prim.var("exp_table")
1016+
i = prim.var(f"i_{ilev}")
1017+
i2 = prim.var(f"i2_{ilev}")
1018+
ifft = prim.var(f"ifft_{ilev}")
1019+
iN1 = prim.var(f"iN1_{ilev}") # noqa: N806
1020+
iN1_sum = prim.var(f"iN1_sum_{ilev}") # noqa: N806
1021+
iN2 = prim.var(f"iN2_{ilev}") # noqa: N806
1022+
table_idx = prim.var(f"table_idx_{ilev}")
1023+
exp = prim.var(f"exp_{ilev}")
10401024

10411025
insns += [
10421026
lp.Assignment(
@@ -1122,15 +1106,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
11221106
else:
11231107
name = f"fft_{n}"
11241108

1125-
knl = lp.make_kernel(
1109+
from arraycontext import make_loopy_program
1110+
knl = make_loopy_program(
11261111
domains, insns,
11271112
kernel_data=kernel_data,
11281113
name=name,
1129-
fixed_parameters=fixed_parameters,
1130-
lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
1131-
index_dtype=index_dtype,
11321114
)
11331115

1116+
# FIXME: set index_dtype?
1117+
knl = lp.fix_parameters(knl, **fixed_parameters)
1118+
11341119
if broadcast_dims:
11351120
knl = lp.split_iname(knl, "j0", 32, inner_tag="l.0", outer_tag="g.0")
11361121
knl = lp.add_inames_for_unused_hw_axes(knl)
@@ -1143,7 +1128,7 @@ class FFTBackend(enum.Enum):
11431128
loopy = 2
11441129

11451130

1146-
def _get_fft_backend(queue) -> FFTBackend:
1131+
def _get_fft_backend(actx: PyOpenCLArrayContext) -> FFTBackend:
11471132
env_val = os.environ.get("SUMPY_FFT_BACKEND", None)
11481133
if env_val:
11491134
if env_val not in ["loopy", "pyvkfft"]:
@@ -1157,11 +1142,15 @@ def _get_fft_backend(queue) -> FFTBackend:
11571142
warnings.warn("VkFFT not found. FFT runs will be slower.")
11581143
return FFTBackend.loopy
11591144

1145+
import pyopencl as cl
1146+
queue = actx.queue
1147+
11601148
if queue.properties & cl.command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
11611149
warnings.warn("VkFFT does not support out of order queues yet. "
11621150
"Falling back to slower implementation.")
11631151
return FFTBackend.loopy
11641152

1153+
import platform
11651154
if (sys.platform == "darwin"
11661155
and platform.machine() == "x86_64"
11671156
and queue.context.devices[0].platform.name
@@ -1174,26 +1163,34 @@ def _get_fft_backend(queue) -> FFTBackend:
11741163
return FFTBackend.pyvkfft
11751164

11761165

1177-
def get_opencl_fft_app(queue, shape, dtype, inverse):
1166+
def get_opencl_fft_app(
1167+
actx: PyOpenCLArrayContext,
1168+
shape: Tuple[int, ...], dtype: "np.dtype", *,
1169+
inverse: bool) -> Any:
11781170
"""Setup an object for out-of-place FFT on with given shape and dtype
11791171
on given queue.
11801172
"""
11811173
assert dtype.type in (np.float32, np.float64, np.complex64,
11821174
np.complex128)
11831175

1184-
backend = _get_fft_backend(queue)
1176+
backend = _get_fft_backend(actx)
11851177

11861178
if backend == FFTBackend.loopy:
11871179
return loopy_fft(shape, inverse=inverse, complex_dtype=dtype.type), backend
11881180
elif backend == FFTBackend.pyvkfft:
11891181
from pyvkfft.opencl import VkFFTApp
1190-
app = VkFFTApp(shape=shape, dtype=dtype, queue=queue, ndim=1, inplace=False)
1182+
app = VkFFTApp(
1183+
shape=shape, dtype=dtype,
1184+
queue=actx.queue, ndim=1, inplace=False)
11911185
return app, backend
11921186
else:
11931187
raise RuntimeError(f"Unsupported FFT backend {backend}")
11941188

11951189

1196-
def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
1190+
def run_opencl_fft(actx: PyOpenCLArrayContext,
1191+
fft_app: Tuple[Any, FFTBackend], input_vec: Array, *,
1192+
inverse: bool = False,
1193+
wait_for: List[Any] = None):
11971194
"""Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
11981195
that indicate the end and start of the operations carried out and the output
11991196
vector.
@@ -1202,18 +1199,19 @@ def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
12021199
app, backend = fft_app
12031200

12041201
if backend == FFTBackend.loopy:
1205-
evt, (output_vec,) = app(queue, y=input_vec, wait_for=wait_for)
1202+
evt, (output_vec,) = app(actx.queue, y=input_vec, wait_for=wait_for)
12061203
return (evt, output_vec)
12071204
elif backend == FFTBackend.pyvkfft:
12081205
if wait_for is None:
12091206
wait_for = []
12101207

1211-
start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
1208+
import pyopencl as cl
1209+
start_evt = cl.enqueue_marker(actx.queue, wait_for=wait_for[:])
12121210

12131211
if app.inplace:
12141212
raise RuntimeError("inplace fft is not supported")
12151213
else:
1216-
output_vec = cla.empty_like(input_vec, queue)
1214+
output_vec = actx.np.empty_like(input_vec)
12171215

12181216
# FIXME: use the public API once
12191217
# https://github.com/vincefn/pyvkfft/pull/17 is in
@@ -1224,9 +1222,9 @@ def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
12241222
meth = _vkfft_opencl.fft
12251223

12261224
meth(app.app, int(input_vec.data.int_ptr),
1227-
int(output_vec.data.int_ptr), int(queue.int_ptr))
1225+
int(output_vec.data.int_ptr), int(actx.queue.int_ptr))
12281226

1229-
end_evt = cl.enqueue_marker(queue, wait_for=[start_evt])
1227+
end_evt = cl.enqueue_marker(actx.queue, wait_for=[start_evt])
12301228
output_vec.add_event(end_evt)
12311229

12321230
return (MarkerBasedProfilingEvent(end_event=end_evt, start_event=start_evt),

0 commit comments

Comments
 (0)