Skip to content

Commit 8db8f6f

Browse files
alexfiklinducer
authored andcommitted
tools: reorder non-top level imports
1 parent d1c7302 commit 8db8f6f

1 file changed

Lines changed: 39 additions & 27 deletions

File tree

sumpy/tools.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,8 @@ def build_matrix(op, dtype=None, shape=None):
205205

206206

207207
def vector_to_device(queue, vec):
208-
from pytools.obj_array import obj_array_vectorize
209-
210208
from pyopencl.array import to_device
209+
from pytools.obj_array import obj_array_vectorize
211210

212211
def to_dev(ary):
213212
return to_device(queue, ary)
@@ -449,12 +448,12 @@ def get_cached_optimized_kernel(self, **kwargs):
449448

450449
@memoize_method
451450
def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
452-
from sumpy import (code_cache, CACHING_ENABLED, OPT_ENABLED,
453-
NO_CACHE_KERNELS)
451+
from sumpy import CACHING_ENABLED, NO_CACHE_KERNELS, OPT_ENABLED, code_cache
454452

455453
if CACHING_ENABLED and not (
456454
NO_CACHE_KERNELS and self.name in NO_CACHE_KERNELS):
457455
import loopy.version
456+
458457
from sumpy.version import KERNEL_VERSION
459458
cache_key = (
460459
self.get_cache_key()
@@ -465,8 +464,7 @@ def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
465464

466465
try:
467466
result = code_cache[cache_key]
468-
logger.debug("{}: kernel cache hit [key={}]".format(
469-
self.name, cache_key))
467+
logger.debug("%s: kernel cache hit [key=%s]", self.name, cache_key)
470468
return result.executor(self.context)
471469
except KeyError:
472470
pass
@@ -678,7 +676,8 @@ class ProfileGetter:
678676

679677

680678
def get_native_event(evt):
681-
return evt if isinstance(evt, cl.Event) else evt.native_event
679+
from pyopencl import Event
680+
return evt if isinstance(evt, Event) else evt.native_event
682681

683682

684683
class AggregateProfilingEvent:
@@ -719,9 +718,11 @@ def wait(self):
719718

720719
def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
721720
name=None):
722-
from pymbolic.algorithm import find_factors
723721
from math import pi
724722

723+
from pymbolic import var
724+
from pymbolic.algorithm import find_factors
725+
725726
sign = 1 if not inverse else -1
726727
n = shape[-1]
727728

@@ -733,19 +734,19 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
733734

734735
nfft = n
735736

736-
broadcast_dims = tuple(pymbolic.var(f"j{d}") for d in range(len(shape) - 1))
737+
broadcast_dims = tuple(var(f"j{d}") for d in range(len(shape) - 1))
737738

738739
domains = [
739740
"{[i]: 0<=i<n}",
740741
"{[i2]: 0<=i2<n}",
741742
]
742743
domains += [f"{{[j{d}]: 0<=j{d}<{shape[d]} }}" for d in range(len(shape) - 1)]
743744

744-
x = pymbolic.var("x")
745-
y = pymbolic.var("y")
746-
i = pymbolic.var("i")
747-
i2 = pymbolic.var("i2")
748-
i3 = pymbolic.var("i3")
745+
x = var("x")
746+
y = var("y")
747+
i = var("i")
748+
i2 = var("i2")
749+
i3 = var("i3")
749750

750751
fixed_parameters = {"const": complex_dtype(sign*(-2j)*pi/n), "n": n}
751752

@@ -767,16 +768,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
767768
else:
768769
init_depends_on = f"update_{ilev-1}"
769770

770-
temp = pymbolic.var("temp")
771-
exp_table = pymbolic.var("exp_table")
772-
i = pymbolic.var(f"i_{ilev}")
773-
i2 = pymbolic.var(f"i2_{ilev}")
774-
ifft = pymbolic.var(f"ifft_{ilev}")
775-
iN1 = pymbolic.var(f"iN1_{ilev}") # noqa: N806
776-
iN1_sum = pymbolic.var(f"iN1_sum_{ilev}") # noqa: N806
777-
iN2 = pymbolic.var(f"iN2_{ilev}") # noqa: N806
778-
table_idx = pymbolic.var(f"table_idx_{ilev}")
779-
exp = pymbolic.var(f"exp_{ilev}")
771+
temp = var("temp")
772+
exp_table = var("exp_table")
773+
i = var(f"i_{ilev}")
774+
i2 = var(f"i2_{ilev}")
775+
ifft = var(f"ifft_{ilev}")
776+
iN1 = var(f"iN1_{ilev}") # noqa: N806
777+
iN1_sum = var(f"iN1_sum_{ilev}") # noqa: N806
778+
iN2 = var(f"iN2_{ilev}") # noqa: N806
779+
table_idx = var(f"table_idx_{ilev}")
780+
exp = var(f"exp_{ilev}")
780781

781782
insns += [
782783
lp.Assignment(
@@ -879,12 +880,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
879880

880881

881882
class FFTBackend(enum.Enum):
883+
#: FFT backend based on the vkFFT library.
882884
pyvkfft = 1
885+
#: FFT backend based on :mod:`loopy` used as a fallback.
883886
loopy = 2
884887

885888

886-
def _get_fft_backend(queue) -> FFTBackend:
887-
env_val = os.environ.get("SUMPY_FFT_BACKEND", None)
889+
def _get_fft_backend(queue: "cl.CommandQueue") -> FFTBackend:
890+
import os
891+
892+
env_val = os.environ.get("SUMPY_FFT_BACKEND")
888893
if env_val:
889894
if env_val not in ["loopy", "pyvkfft"]:
890895
raise ValueError("Expected 'loopy' or 'pyvkfft' for SUMPY_FFT_BACKEND. "
@@ -897,13 +902,17 @@ def _get_fft_backend(queue) -> FFTBackend:
897902
warnings.warn("VkFFT not found. FFT runs will be slower.", stacklevel=3)
898903
return FFTBackend.loopy
899904

900-
if queue.properties & cl.command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
905+
from pyopencl import command_queue_properties
906+
907+
if queue.properties & command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
901908
warnings.warn(
902909
"VkFFT does not support out of order queues yet. "
903910
"Falling back to slower implementation.", stacklevel=3)
904911
return FFTBackend.loopy
905912

906913
import platform
914+
import sys
915+
907916
if (sys.platform == "darwin"
908917
and platform.machine() == "x86_64"
909918
and queue.context.devices[0].platform.name
@@ -960,6 +969,9 @@ def run_opencl_fft(
960969
if wait_for is None:
961970
wait_for = []
962971

972+
import pyopencl as cl
973+
import pyopencl.array as cla
974+
963975
start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
964976

965977
if app.inplace:

0 commit comments

Comments
 (0)