3636 .. autoclass:: DifferentiatedExprDerivativeTaker
3737"""
3838
39- from pytools import memoize_method
40- from pytools .tag import Tag , tag_dataclass
41- import numbers
42- import warnings
4339import os
4440import sys
4541import enum
46- import platform
47- from collections import defaultdict , namedtuple
48- from pymbolic . mapper import WalkMapper
49- import pymbolic
42+ import numbers
43+ import warnings
44+ from dataclasses import dataclass
45+ from typing import Any , Dict , List , Tuple
5046
5147import numpy as np
52- import sumpy .symbolic as sym
53- import pyopencl as cl
54- import pyopencl .array as cla
5548
5649import loopy as lp
57- from typing import Dict , Tuple , Any
50+ from pytools import memoize_method
51+ from pytools .tag import Tag , tag_dataclass
52+ from pymbolic .mapper import WalkMapper
53+ from arraycontext import Array
54+
55+ import sumpy .symbolic as sym
56+ from sumpy .array_context import PyOpenCLArrayContext
5857
5958import logging
6059logger = logging .getLogger (__name__ )
@@ -452,7 +451,9 @@ def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict,
452451 *derivative_coeff_dict* using the variable given by **variable_idx**
453452 and return a new derivative transformation dictionary.
454453 """
454+ from collections import defaultdict
455455 new_derivative_coeff_dict = defaultdict (lambda : 0 )
456+
456457 for mi , coeff in derivative_coeff_dict .items ():
457458 # In the case where we have x * u.diff(x), the result should
458459 # be x.diff(x) + x * u.diff(x, x)
@@ -511,31 +512,6 @@ def build_matrix(op, dtype=None, shape=None):
511512 return mat
512513
513514
514- def vector_to_device (queue , vec ):
515- from pytools .obj_array import obj_array_vectorize
516-
517- from pyopencl .array import to_device
518-
519- def to_dev (ary ):
520- return to_device (queue , ary )
521-
522- return obj_array_vectorize (to_dev , vec )
523-
524-
525- def vector_from_device (queue , vec ):
526- from pytools .obj_array import obj_array_vectorize
527-
528- def from_dev (ary ):
529- from numbers import Number
530- if isinstance (ary , (np .number , Number )):
531- # zero, most likely
532- return ary
533-
534- return ary .get (queue = queue )
535-
536- return obj_array_vectorize (from_dev , vec )
537-
538-
539515def _merge_kernel_arguments (dictionary , arg ):
540516 # Check for strict equality until there's a usecase
541517 if dictionary .setdefault (arg .name , arg ) != arg :
@@ -716,7 +692,7 @@ def __eq__(self, other):
716692# }}}
717693
718694
719- class KernelCacheWrapper :
695+ class KernelCacheMixin :
720696 @memoize_method
721697 def get_cached_optimized_kernel (self , ** kwargs ):
722698 from sumpy import code_cache , CACHING_ENABLED , OPT_ENABLED
@@ -763,6 +739,9 @@ def _allow_redundant_execution_of_knl_scaling(knl):
763739 knl , within = ObjTagged (ScalingAssignmentTag ()))
764740
765741
742+ KernelCacheWrapper = KernelCacheMixin
743+
744+
766745def is_obj_array_like (ary ):
767746 return (
768747 isinstance (ary , (tuple , list ))
@@ -934,10 +913,14 @@ def to_complex_dtype(dtype):
934913 raise RuntimeError (f"Unknown dtype: { dtype } " )
935914
936915
937- ProfileGetter = namedtuple ("ProfileGetter" , "start, end" )
916+ @dataclass (frozen = True )
917+ class ProfileGetter :
918+ start : int
919+ end : int
938920
939921
940922def get_native_event (evt ):
923+ import pyopencl as cl
941924 return evt if isinstance (evt , cl .Event ) else evt .native_event
942925
943926
@@ -991,21 +974,22 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
991974 N1 , m = find_factors (m ) # noqa: N806
992975 factors .append (N1 )
993976
977+ import pymbolic as prim
994978 nfft = n
995979
996- broadcast_dims = tuple (pymbolic .var (f"j{ d } " ) for d in range (len (shape ) - 1 ))
980+ broadcast_dims = tuple (prim .var (f"j{ d } " ) for d in range (len (shape ) - 1 ))
997981
998982 domains = [
999983 "{[i]: 0<=i<n}" ,
1000984 "{[i2]: 0<=i2<n}" ,
1001985 ]
1002986 domains += [f"{{[j{ d } ]: 0<=j{ d } <{ shape [d ]} }}" for d in range (len (shape ) - 1 )]
1003987
1004- x = pymbolic .var ("x" )
1005- y = pymbolic .var ("y" )
1006- i = pymbolic .var ("i" )
1007- i2 = pymbolic .var ("i2" )
1008- i3 = pymbolic .var ("i3" )
988+ x = prim .var ("x" )
989+ y = prim .var ("y" )
990+ i = prim .var ("i" )
991+ i2 = prim .var ("i2" )
992+ i3 = prim .var ("i3" )
1009993
1010994 fixed_parameters = {"const" : complex_dtype (sign * (- 2j )* pi / n ), "n" : n }
1011995
@@ -1027,16 +1011,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
10271011 else :
10281012 init_depends_on = f"update_{ ilev - 1 } "
10291013
1030- temp = pymbolic .var ("temp" )
1031- exp_table = pymbolic .var ("exp_table" )
1032- i = pymbolic .var (f"i_{ ilev } " )
1033- i2 = pymbolic .var (f"i2_{ ilev } " )
1034- ifft = pymbolic .var (f"ifft_{ ilev } " )
1035- iN1 = pymbolic .var (f"iN1_{ ilev } " ) # noqa: N806
1036- iN1_sum = pymbolic .var (f"iN1_sum_{ ilev } " ) # noqa: N806
1037- iN2 = pymbolic .var (f"iN2_{ ilev } " ) # noqa: N806
1038- table_idx = pymbolic .var (f"table_idx_{ ilev } " )
1039- exp = pymbolic .var (f"exp_{ ilev } " )
1014+ temp = prim .var ("temp" )
1015+ exp_table = prim .var ("exp_table" )
1016+ i = prim .var (f"i_{ ilev } " )
1017+ i2 = prim .var (f"i2_{ ilev } " )
1018+ ifft = prim .var (f"ifft_{ ilev } " )
1019+ iN1 = prim .var (f"iN1_{ ilev } " ) # noqa: N806
1020+ iN1_sum = prim .var (f"iN1_sum_{ ilev } " ) # noqa: N806
1021+ iN2 = prim .var (f"iN2_{ ilev } " ) # noqa: N806
1022+ table_idx = prim .var (f"table_idx_{ ilev } " )
1023+ exp = prim .var (f"exp_{ ilev } " )
10401024
10411025 insns += [
10421026 lp .Assignment (
@@ -1122,15 +1106,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
11221106 else :
11231107 name = f"fft_{ n } "
11241108
1125- knl = lp .make_kernel (
1109+ from arraycontext import make_loopy_program
1110+ knl = make_loopy_program (
11261111 domains , insns ,
11271112 kernel_data = kernel_data ,
11281113 name = name ,
1129- fixed_parameters = fixed_parameters ,
1130- lang_version = lp .MOST_RECENT_LANGUAGE_VERSION ,
1131- index_dtype = index_dtype ,
11321114 )
11331115
1116+ # FIXME: set index_dtype?
1117+ knl = lp .fix_parameters (knl , ** fixed_parameters )
1118+
11341119 if broadcast_dims :
11351120 knl = lp .split_iname (knl , "j0" , 32 , inner_tag = "l.0" , outer_tag = "g.0" )
11361121 knl = lp .add_inames_for_unused_hw_axes (knl )
@@ -1143,7 +1128,7 @@ class FFTBackend(enum.Enum):
11431128 loopy = 2
11441129
11451130
1146- def _get_fft_backend (queue ) -> FFTBackend :
1131+ def _get_fft_backend (actx : PyOpenCLArrayContext ) -> FFTBackend :
11471132 env_val = os .environ .get ("SUMPY_FFT_BACKEND" , None )
11481133 if env_val :
11491134 if env_val not in ["loopy" , "pyvkfft" ]:
@@ -1157,11 +1142,15 @@ def _get_fft_backend(queue) -> FFTBackend:
11571142 warnings .warn ("VkFFT not found. FFT runs will be slower." )
11581143 return FFTBackend .loopy
11591144
1145+ import pyopencl as cl
1146+ queue = actx .queue
1147+
11601148 if queue .properties & cl .command_queue_properties .OUT_OF_ORDER_EXEC_MODE_ENABLE :
11611149 warnings .warn ("VkFFT does not support out of order queues yet. "
11621150 "Falling back to slower implementation." )
11631151 return FFTBackend .loopy
11641152
1153+ import platform
11651154 if (sys .platform == "darwin"
11661155 and platform .machine () == "x86_64"
11671156 and queue .context .devices [0 ].platform .name
@@ -1174,26 +1163,34 @@ def _get_fft_backend(queue) -> FFTBackend:
11741163 return FFTBackend .pyvkfft
11751164
11761165
1177- def get_opencl_fft_app (queue , shape , dtype , inverse ):
1166+ def get_opencl_fft_app (
1167+ actx : PyOpenCLArrayContext ,
1168+ shape : Tuple [int , ...], dtype : "np.dtype" , * ,
1169+ inverse : bool ) -> Any :
11781170 """Setup an object for out-of-place FFT on with given shape and dtype
11791171 on given queue.
11801172 """
11811173 assert dtype .type in (np .float32 , np .float64 , np .complex64 ,
11821174 np .complex128 )
11831175
1184- backend = _get_fft_backend (queue )
1176+ backend = _get_fft_backend (actx )
11851177
11861178 if backend == FFTBackend .loopy :
11871179 return loopy_fft (shape , inverse = inverse , complex_dtype = dtype .type ), backend
11881180 elif backend == FFTBackend .pyvkfft :
11891181 from pyvkfft .opencl import VkFFTApp
1190- app = VkFFTApp (shape = shape , dtype = dtype , queue = queue , ndim = 1 , inplace = False )
1182+ app = VkFFTApp (
1183+ shape = shape , dtype = dtype ,
1184+ queue = actx .queue , ndim = 1 , inplace = False )
11911185 return app , backend
11921186 else :
11931187 raise RuntimeError (f"Unsupported FFT backend { backend } " )
11941188
11951189
1196- def run_opencl_fft (fft_app , queue , input_vec , inverse = False , wait_for = None ):
1190+ def run_opencl_fft (actx : PyOpenCLArrayContext ,
1191+ fft_app : Tuple [Any , FFTBackend ], input_vec : Array , * ,
1192+ inverse : bool = False ,
1193+ wait_for : List [Any ] = None ):
11971194 """Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
11981195 that indicate the end and start of the operations carried out and the output
11991196 vector.
@@ -1202,18 +1199,19 @@ def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
12021199 app , backend = fft_app
12031200
12041201 if backend == FFTBackend .loopy :
1205- evt , (output_vec ,) = app (queue , y = input_vec , wait_for = wait_for )
1202+ evt , (output_vec ,) = app (actx . queue , y = input_vec , wait_for = wait_for )
12061203 return (evt , output_vec )
12071204 elif backend == FFTBackend .pyvkfft :
12081205 if wait_for is None :
12091206 wait_for = []
12101207
1211- start_evt = cl .enqueue_marker (queue , wait_for = wait_for [:])
1208+ import pyopencl as cl
1209+ start_evt = cl .enqueue_marker (actx .queue , wait_for = wait_for [:])
12121210
12131211 if app .inplace :
12141212 raise RuntimeError ("inplace fft is not supported" )
12151213 else :
1216- output_vec = cla . empty_like (input_vec , queue )
1214+ output_vec = actx . np . empty_like (input_vec )
12171215
12181216 # FIXME: use the public API once
12191217 # https://github.com/vincefn/pyvkfft/pull/17 is in
@@ -1224,9 +1222,9 @@ def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
12241222 meth = _vkfft_opencl .fft
12251223
12261224 meth (app .app , int (input_vec .data .int_ptr ),
1227- int (output_vec .data .int_ptr ), int (queue .int_ptr ))
1225+ int (output_vec .data .int_ptr ), int (actx . queue .int_ptr ))
12281226
1229- end_evt = cl .enqueue_marker (queue , wait_for = [start_evt ])
1227+ end_evt = cl .enqueue_marker (actx . queue , wait_for = [start_evt ])
12301228 output_vec .add_event (end_evt )
12311229
12321230 return (MarkerBasedProfilingEvent (end_event = end_evt , start_event = start_evt ),
0 commit comments