diff --git a/.gitignore b/.gitignore index 4378c7122..f03dc4d4c 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ virtualenv-[0-9]*[0-9] # Files used by run-pylint.sh .pylintrc.yml .run-pylint.py + +.codex diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py new file mode 100644 index 000000000..43d09309e --- /dev/null +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -0,0 +1,294 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def centered_second_derivative_coefficients(radius: int, dtype) -> np.ndarray: + offsets = np.arange(-radius, radius + 1, dtype=dtype) + powers = np.arange(2 * radius + 1) + vandermonde = offsets[np.newaxis, :] ** powers[:, np.newaxis] + rhs = np.zeros(2 * radius + 1, dtype=dtype) + rhs[2] = 2 + + return np.linalg.solve(vandermonde, rhs).astype(dtype) + + +# FIXME: choose a better test case +def f(x, y, z): + return x**2 + y**2 + z**2 + + +def laplacian_f(x, y, z): + return 6 * np.ones_like(x) + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def laplacian_flop_count(npts: int, stencil_width: int) -> int: + radius = stencil_width // 2 + output_points = (npts - 2 * radius) ** 3 + return 4 * stencil_width * output_points + + +def main( + npts: int = 64, + stencil_width: int = 5, + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> float | None: + if stencil_width <= 0 or stencil_width % 2 == 0: + raise ValueError("stencil_width must be a positive odd integer") + + pts = np.linspace(-1, 1, num=npts, endpoint=True) + h = pts[1] - pts[0] + + x, y, z = np.meshgrid(*(pts,)*3) + + dtype = np.float64 + x = x.reshape(*(npts,)*3).astype(dtype) + y = y.reshape(*(npts,)*3).astype(dtype) + z = z.reshape(*(npts,)*3).astype(dtype) + + m = stencil_width + r = m // 2 + c = (centered_second_derivative_coefficients(r, dtype) / h**2).astype(dtype) + + bm = bn = 16 + bk = 32 + + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(is, js, ks) := u[is, js, ks] + + lap_u[i,j,k] = sum( + [l], + c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u_(i,j,k-l)) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(npts, npts, npts)), + lp.GlobalArg("lap_u", dtype=dtype, shape=(npts, npts, npts), + is_output=True), + lp.GlobalArg("c", dtype=dtype, shape=(m,)) + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2 + ) + + knl = lp.fix_parameters(knl, npts=npts, r=r) + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + if use_compute: + plane_map = nisl.make_map(f"""{{ + [is, js, ks] -> [io, ii_s, jo, ji_s, ko, ki] : + is = io * {bm} + ii_s - {r} and + js = jo * {bn} + ji_s - {r} and + ks = ko * {bk} + ki + }}""") + + knl = compute( + knl, + "u_", + compute_map=plane_map, + storage_indices=["ii_s", "ji_s"], + temporal_inames=["io", "jo", "ko", "ki"], + + temporary_name="u_ij_plane", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=dtype, + + compute_insn_id="u_plane_compute" + ) + + ring_buffer_map = nisl.make_map(f"""{{ + [is, js, ks] -> [io, ii, jo, ji, ko, ki, kb] : + is = io * {bm} + ii and + js = jo * {bn} + ji and + kb = ks - (ko * {bk} + ki) + {r} + }}""") + + knl = compute( + knl, + "u_", + compute_map=ring_buffer_map, + storage_indices=["kb"], + temporal_inames=["io", "ii", "jo", "ji", "ko", "ki"], + + temporary_name="u_k_buf", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + + compute_insn_id="u_ring_buf_compute", + inames_to_advance=["ki"] + ) + + nt = 16 + knl = lp.split_iname( + knl, "ii_s", nt, outer_iname="ii_s_tile", inner_iname="ii_s_local" + ) + + knl = lp.split_iname( + knl, "ji_s", nt, outer_iname="ji_s_tile", inner_iname="ji_s_local" + ) + + knl = lp.tag_inames(knl, { + # 2D plane compute storage loops + "ii_s_local": "l.1", + "ji_s_local": "l.0", + + # force the use of registers by unrolling + "kb": "unr" + }) + + knl = lp.tag_inames(knl, { + # outer block loops + "io": "g.2", + "jo": "g.1", + "ko": "g.0", + + # inner tile loops + "ii": "l.1", + "ji": "l.0", + }) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + if not run_kernel: + return None + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + + f_vals = f(x, y, z) + + import pyopencl.array as cl_array + f_vals_cl = cl_array.to_device(queue, f_vals) + c_cl = cl_array.to_device(queue, c) + lap_u_cl = cl_array.zeros(queue, (npts,)*3, dtype=f_vals_cl.dtype) + avg_time_per_iter = benchmark_executor( + ex, queue, {"u": f_vals_cl, "c": c_cl, "lap_u": lap_u_cl}, + warmup=warmup, iterations=iterations) + avg_gflops = laplacian_flop_count(npts, stencil_width) / avg_time_per_iter / 1e9 + + _, lap_fd = ex(queue, u=f_vals_cl, c=c_cl, lap_u=lap_u_cl) + lap_true = laplacian_f(x, y, z) + sl = (slice(r, npts - r),)*3 + + rel_err = la.norm(lap_true[sl] - lap_fd[0].get()[sl]) / la.norm(lap_true[sl]) + + print(20 * "=", "Finite difference report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Grid points : {npts}^3") + print(f"Stencil width: {stencil_width}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Finite difference report ")) * "=") + + return avg_time_per_iter + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--npoints", action="store", type=int, default=64) + _ = parser.add_argument("--stencil-width", action="store", type=int, default=5) + + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", + dest="run_kernel") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + if args.compare: + print("Running example without compute...") + no_compute_time = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/finite-difference-diamond.py b/examples/python/compute-examples/finite-difference-diamond.py new file mode 100644 index 000000000..94baf9183 --- /dev/null +++ b/examples/python/compute-examples/finite-difference-diamond.py @@ -0,0 +1,342 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def centered_second_derivative_coefficients(radius: int, dtype) -> np.ndarray: + offsets = np.arange(-radius, radius + 1, dtype=dtype) + powers = np.arange(2 * radius + 1) + vandermonde = offsets[np.newaxis, :] ** powers[:, np.newaxis] + rhs = np.zeros(2 * radius + 1, dtype=dtype) + rhs[2] = 2 + + return np.linalg.solve(vandermonde, rhs).astype(dtype) + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def fd_flop_count(ntime: int, nspace: int, stencil_width: int) -> int: + radius = stencil_width // 2 + output_points = (ntime - 1) * (nspace - 2 * radius) + return 2 * stencil_width * output_points + + +def make_initial_condition(nspace: int, dtype) -> np.ndarray: + x = np.linspace(-1, 1, num=nspace, endpoint=True, dtype=dtype) + wave_number = dtype(2 * np.pi) + return np.sin(wave_number * x).astype(dtype) + + +def reference_time_stepper( + u0: np.ndarray, + coeffs: np.ndarray, + ntime: int, + radius: int, +) -> np.ndarray: + result = np.zeros((ntime, u0.size), dtype=u0.dtype) + result[0] = u0 + result[1:, :radius] = u0[:radius] + result[1:, u0.size - radius :] = u0[u0.size - radius :] + + for t in range(ntime - 1): + for i in range(radius, u0.size - radius): + result[t + 1, i] = sum( + coeffs[ell + radius] * result[t, i - ell] + for ell in range(-radius, radius + 1) + ) + + return result + + +def offset_name(ell: int) -> str: + return f"u_p{ell}" if ell >= 0 else f"u_m{-ell}" + + +def main( + ntime: int = 128, + nspace: int = 4096, + stencil_width: int = 9, + time_block_size: int = 8, + space_block_size: int = 128, + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10, +) -> float | None: + if stencil_width <= 0 or stencil_width % 2 == 0: + raise ValueError("stencil_width must be a positive odd integer") + if ntime <= stencil_width: + raise ValueError("ntime must be larger than stencil_width") + if nspace <= 2 * stencil_width: + raise ValueError("nspace must be larger than twice stencil_width") + + dtype = np.float64 + r = stencil_width // 2 + + u0 = make_initial_condition(nspace, dtype) + u_hist = np.zeros((ntime, nspace), dtype=dtype) + u_hist[0] = u0 + u_hist[1:, :r] = u0[:r] + u_hist[1:, nspace - r :] = u0[nspace - r :] + h = dtype(2 / (nspace - 1)) + dt = dtype(0.05 * h**2) + lap_coeffs = centered_second_derivative_coefficients(r, dtype) / h**2 + coeffs = (dt * lap_coeffs).astype(dtype) + coeffs[r] += 1 + + bt = time_block_size + bx = space_block_size + subst_rules = "\n".join( + f"{offset_name(ell)}(ts, is) := u_hist[ts, is " + f"{'+' if -ell >= 0 else '-'} {abs(ell)}]" + for ell in range(-r, r + 1) + ) + stencil_expr = " + ".join( + f"c[{ell + r}] * {offset_name(ell)}(t, i)" for ell in range(-r, r + 1) + ) + + knl = lp.make_kernel( + "{ [t, i] : 0 <= t < ntime - 1 and r <= i < nspace - r }", + f""" + {subst_rules} + + u_hist[t + 1, i] = {stencil_expr} {{id=step}} + """, + [ + lp.GlobalArg( + "u_hist", + dtype=dtype, + shape=(ntime, nspace), + is_input=True, + is_output=True, + ), + lp.GlobalArg("c", dtype=dtype, shape=(stencil_width,)), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntime=ntime, nspace=nspace, r=r) + knl = lp.split_iname(knl, "t", bt, inner_iname="ti", outer_iname="to") + knl = lp.split_iname(knl, "i", bx, inner_iname="xi", outer_iname="xo") + + if use_compute: + raise NotImplementedError( + "The recurrent diamond time-stepper cannot currently be lowered " + "through compute(): Loopy represents the instance-wise dependence " + "between compute loads from u_hist[t] and writes to u_hist[t+1] as " + "an instruction-level dependency cycle." + ) + compute_insn_ids = [] + for ell in range(-r, r + 1): + suffix = offset_name(ell) + compute_insn_id = f"{suffix}_diamond_compute" + compute_insn_ids.append(compute_insn_id) + storage_axis = f"xi_s_{suffix}" + diamond_map = nisl.make_map(f"""{{ + [ts, is] -> [to, xo, ti, {storage_axis}] : + ts = to * {bt} + ti and + is = xo * {bx} + {storage_axis} + ti - {bt - 1} + }}""") + + knl = compute( + knl, + suffix, + compute_map=diamond_map, + storage_indices=[storage_axis], + temporal_inames=["to", "xo", "ti"], + temporary_name=f"{suffix}_diamond", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=dtype, + compute_insn_id=compute_insn_id, + ) + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + f"id:{suffix}_diamond_compute", + lambda insn: insn.copy(depends_on=frozenset()), + ) + ) + + knl = lp.split_iname( + knl, + storage_axis, + 128, + outer_iname=f"{storage_axis}_tile", + inner_iname=f"{storage_axis}_local", + ) + knl = lp.tag_inames(knl, {f"{storage_axis}_local": "l.0"}) + + no_sync_with_computes = frozenset( + (compute_insn_id, "global") for compute_insn_id in compute_insn_ids + ) + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + "id:step", + lambda insn: insn.copy( + no_sync_with=insn.no_sync_with | no_sync_with_computes + ), + ) + ) + for compute_insn_id in compute_insn_ids: + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + f"id:{compute_insn_id}", + lambda insn: insn.copy( + no_sync_with=insn.no_sync_with | frozenset([("step", "global")]) + ), + ) + ) + + knl = lp.tag_inames(knl, {"xi": "l.0"}) + knl = lp.prioritize_loops(knl, "to,ti,xo,xi") + knl = lp.set_options(knl, insert_gbarriers=True) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + if not run_kernel: + return None + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + import pyopencl.array as cl_array + + u_hist_cl = cl_array.to_device(queue, u_hist) + coeffs_cl = cl_array.to_device(queue, coeffs) + + args = {"c": coeffs_cl, "u_hist": u_hist_cl} + avg_time_per_iter = benchmark_executor( + ex, queue, args, warmup=warmup, iterations=iterations + ) + avg_gflops = fd_flop_count(ntime, nspace, stencil_width) / avg_time_per_iter / 1e9 + + _, out = ex(queue, **args) + reference = reference_time_stepper(u0, coeffs, ntime, r) + sl = (slice(None), slice(r, nspace - r)) + rel_err = la.norm(reference[sl] - out[0].get()[sl]) / la.norm(reference[sl]) + + print(20 * "=", "Diamond finite difference report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Time steps : {ntime}") + print(f"Space points : {nspace}") + print(f"Stencil width: {stencil_width}") + print(f"Tile shape : bt = {bt}, bx = {bx}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Diamond finite difference report ")) * "=") + + return avg_time_per_iter + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--ntime", action="store", type=int, default=128) + _ = parser.add_argument("--nspace", action="store", type=int, default=4096) + _ = parser.add_argument("--stencil-width", action="store", type=int, default=9) + _ = parser.add_argument("--time-block-size", action="store", type=int, default=8) + _ = parser.add_argument("--space-block-size", action="store", type=int, default=128) + + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", dest="run_kernel") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + if args.compare: + print("Running example without compute...") + no_compute_time = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/l2p-3d-tensor-product-compute.py b/examples/python/compute-examples/l2p-3d-tensor-product-compute.py new file mode 100644 index 000000000..820947ef4 --- /dev/null +++ b/examples/python/compute-examples/l2p-3d-tensor-product-compute.py @@ -0,0 +1,489 @@ +"""Benchmark a 3D Cartesian Taylor L2P microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic local +expansion evaluation in three spatial dimensions. This script does not +directly evaluate the 3D Laplace, Helmholtz, or biharmonic Green's function. +The dense local coefficient tensor ``gamma`` is assumed to have already been +formed by the relevant FMM translation machinery; this benchmark isolates the +target-side monomial contraction. + +The kernel evaluates a dense 3D tensor-product local expansion at many target +points: + + phi[itgt] = sum_{q0,q1,q2} gamma[q0, q1, q2] + * x[itgt]**q0 / q0! + * y[itgt]**q1 / q1! + * z[itgt]**q2 / q2! + +The baseline variants are GPU-parallel kernels over target blocks that expand +the basis substitutions inline. The compute variants use +:func:`loopy.transform.compute.compute` to materialize the x, y, and z basis +values into private temporaries. The script includes both a direct tiled +compute schedule and an optimized register-tiled schedule, following the style +of Loopy's compute matmul example. + +Use ``--compare`` to run the naive parallel baseline and the optimized compute +kernel, validate both against the NumPy reference, and report timing, modeled +GFLOP/s, speedup, and relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_l2p_3d( + gamma: np.ndarray, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = gamma.shape[0] - 1 + result = np.empty_like(x) + + for itgt in range(x.size): + acc = 0 + for q0 in range(order + 1): + x_basis = x[itgt]**q0 * inv_fact[q0] + for q1 in range(order + 1): + y_basis = y[itgt]**q1 * inv_fact[q1] + for q2 in range(order + 1): + z_basis = z[itgt]**q2 * inv_fact[q2] + acc += gamma[q0, q1, q2] * x_basis * y_basis * z_basis + result[itgt] = acc + + return result + + +def make_kernel( + ntargets: int, + order: int, + dtype + ) -> lp.TranslationUnit: + knl = lp.make_kernel( + "{ [itgt, q0, q1, q2] : " + "0 <= itgt < ntargets and 0 <= q0, q1, q2 <= p }", + """ + x_basis_(itgt_arg, q0_arg) := ( + x[itgt_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_basis_(itgt_arg, q1_arg) := ( + y[itgt_arg] ** q1_arg * inv_fact[q1_arg] + ) + + z_basis_(itgt_arg, q2_arg) := ( + z[itgt_arg] ** q2_arg * inv_fact[q2_arg] + ) + + phi[itgt] = sum( + [q0, q1, q2], + gamma[q0, q1, q2] + * x_basis_(itgt, q0) + * y_basis_(itgt, q1) + * z_basis_(itgt, q2) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("y", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("z", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg( + "gamma", + dtype=dtype, + shape=(order + 1, order + 1, order + 1), + ), + lp.GlobalArg("phi", dtype=dtype, shape=(ntargets,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + return lp.fix_parameters(knl, ntargets=ntargets, p=order) + + +def split_targets( + knl: lp.TranslationUnit, + target_block_size: int + ) -> lp.TranslationUnit: + knl = lp.split_iname( + knl, + "itgt", + target_block_size, + inner_iname="itgt_inner", + outer_iname="itgt_block", + ) + return lp.tag_inames(knl, {"itgt_block": "g.0"}) + + +def block_private_l2p_3d( + knl: lp.TranslationUnit, + target_block_size: int, + dtype + ) -> lp.TranslationUnit: + knl = split_targets(knl, target_block_size) + + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_s, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["itgt_s", "q0_s"], + temporal_inames=["itgt_block"], + temporary_name="x_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_s, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["itgt_s", "q1_s"], + temporal_inames=["itgt_block"], + temporary_name="y_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + z_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q2_arg] -> [itgt_block, itgt_s, q2_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q2_arg = q2_s + }}""") + + return compute( + knl, + "z_basis_", + compute_map=z_basis_map, + storage_indices=["itgt_s", "q2_s"], + temporal_inames=["itgt_block"], + temporary_name="z_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_basis_compute", + ) + + +def register_tiled_l2p_3d( + knl: lp.TranslationUnit, + target_block_size: int, + dtype + ) -> lp.TranslationUnit: + knl = split_targets(knl, target_block_size) + + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_inner, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["q0_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="x_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_inner, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["q1_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="y_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + z_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q2_arg] -> [itgt_block, itgt_inner, q2_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q2_arg = q2_s + }}""") + + knl = compute( + knl, + "z_basis_", + compute_map=z_basis_map, + storage_indices=["q2_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="z_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_basis_compute", + ) + + return lp.tag_inames(knl, { + "itgt_inner": "l.0", + "q0_s": "unr", + "q1_s": "unr", + "q2_s": "unr", + "q0": "unr", + "q1": "unr", + "q2": "unr", + }) + + +def operation_model( + ntargets: int, + order: int, + target_block_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_basis_evals = 3 * ntargets * ncoeff**3 + tiled_compute_basis_evals = 3 * ntargets * ncoeff + return inline_basis_evals, tiled_compute_basis_evals + + +def l2p_3d_flop_count(ntargets: int, order: int, use_compute: bool) -> int: + ncoeff = order + 1 + + contraction_flops = 4 * ntargets * ncoeff**3 + if use_compute: + basis_scale_flops = 3 * ntargets * ncoeff + else: + basis_scale_flops = 3 * ntargets * ncoeff**3 + + return contraction_flops + basis_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + inv_fact: np.ndarray, + gamma: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + z_cl = cl_array.to_device(queue, z) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + gamma_cl = cl_array.to_device(queue, gamma) + phi_cl = cl_array.zeros(queue, x.shape, dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "z": z_cl, + "inv_fact": inv_fact_cl, + "gamma": gamma_cl, + "phi": phi_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, z=z_cl, inv_fact=inv_fact_cl, + gamma=gamma_cl, phi=phi_cl) + return out[0].get(), elapsed + + +def main( + ntargets: int = 256, + order: int = 8, + target_block_size: int = 32, + use_compute: bool = False, + use_block_private_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + if ntargets % target_block_size: + raise ValueError("ntargets must be divisible by target_block_size") + + dtype = np.float64 + rng = np.random.default_rng(22) + x = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + z = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + inv_fact = inv_factorials(order, dtype) + gamma = rng.normal(size=(order + 1, order + 1, order + 1)).astype(dtype) + reference = reference_l2p_3d(gamma, x, y, z, inv_fact) + + inline_evals, compute_evals = operation_model( + ntargets, order, target_block_size) + + if compare: + variants = ["inline", "register-tiled compute"] + elif use_block_private_compute: + variants = ["block-private compute"] + elif use_compute: + variants = ["register-tiled compute"] + else: + variants = ["inline"] + + timings: dict[str, float] = {} + for variant in variants: + knl = make_kernel(ntargets, order, dtype) + + if variant == "inline": + knl = split_targets(knl, target_block_size) + knl = lp.tag_inames(knl, { + "itgt_inner": "l.0", + "q0": "unr", + "q1": "unr", + "q2": "unr", + }) + elif variant == "block-private compute": + knl = block_private_l2p_3d(knl, target_block_size, dtype) + elif variant == "register-tiled compute": + knl = register_tiled_l2p_3d(knl, target_block_size, dtype) + else: + raise ValueError(f"unknown variant '{variant}'") + + variant_uses_compute = variant != "inline" + modeled_flops = l2p_3d_flop_count( + ntargets, order, use_compute=variant_uses_compute) + + print(20 * "=", "3D L2P basis report", 20 * "=") + print(f"Variant : {variant}") + print(f"Targets : {ntargets}") + print(f"Order : {order}") + print(f"Target block: {target_block_size}") + print(f"Inline basis evaluations: {inline_evals}") + print(f"Tiled compute evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, z, inv_fact, gamma, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" 3D L2P basis report ")) * "=") + + if ( + compare + and "inline" in timings + and "register-tiled compute" in timings): + speedup = timings["inline"] / timings["register-tiled compute"] + time_reduction = ( + 1 - timings["register-tiled compute"] / timings["inline"]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--ntargets", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=8) + _ = parser.add_argument("--target-block-size", action="store", + type=int, default=32) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--block-private-compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + ntargets=args.ntargets, + order=args.order, + target_block_size=args.target_block_size, + use_compute=args.compute, + use_block_private_compute=args.block_private_compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/l2p-tiled-basis-compute.py b/examples/python/compute-examples/l2p-tiled-basis-compute.py new file mode 100644 index 000000000..f11a1054a --- /dev/null +++ b/examples/python/compute-examples/l2p-tiled-basis-compute.py @@ -0,0 +1,351 @@ +"""Benchmark a 2D Cartesian Taylor L2P microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic local +expansion evaluation. This script does not evaluate a particular Green's +function such as the 2D Laplace or Helmholtz kernel. The local coefficients +``gamma`` are treated as already available; if they came from a Laplace FMM, +this kernel is the L2P monomial-contraction stage after the Laplace-specific +coefficient/derivative work has already happened. + +The kernel evaluates a tensor-product Taylor-like local expansion at many +target points: + + phi[itgt] = sum_{q0,q1} gamma[q0, q1] + * x[itgt]**q0 / q0! + * y[itgt]**q1 / q1! + +The inline variant is a parallel GPU kernel over target blocks that expands the +two basis substitutions at every use inside the coefficient sum. The compute +variant uses :func:`loopy.transform.compute.compute` to materialize the x and y +basis values into private temporaries for each target, so the powers/factorial +scalings are reused across the inner coefficient loops instead of recomputed. + +Use ``--compare`` to run both GPU-parallel variants, check against the NumPy +reference implementation, and report timing, modeled GFLOP/s, speedup, and +relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_l2p( + gamma: np.ndarray, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = gamma.shape[0] - 1 + result = np.empty_like(x) + + for itgt in range(x.size): + acc = 0 + for q0 in range(order + 1): + x_basis = x[itgt]**q0 * inv_fact[q0] + for q1 in range(order + 1): + y_basis = y[itgt]**q1 * inv_fact[q1] + acc += gamma[q0, q1] * x_basis * y_basis + result[itgt] = acc + + return result + + +def make_kernel( + ntargets: int, + order: int, + target_block_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if ntargets % target_block_size: + raise ValueError("ntargets must be divisible by target_block_size") + + knl = lp.make_kernel( + "{ [itgt, q0, q1] : 0 <= itgt < ntargets and 0 <= q0, q1 <= p }", + """ + x_basis_(itgt_arg, q0_arg) := ( + x[itgt_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_basis_(itgt_arg, q1_arg) := ( + y[itgt_arg] ** q1_arg * inv_fact[q1_arg] + ) + + phi[itgt] = sum( + [q0, q1], + gamma[q0, q1] * x_basis_(itgt, q0) * y_basis_(itgt, q1) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("y", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg("gamma", dtype=dtype, shape=(order + 1, order + 1)), + lp.GlobalArg("phi", dtype=dtype, shape=(ntargets,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntargets=ntargets, p=order) + knl = lp.split_iname( + knl, + "itgt", + target_block_size, + inner_iname="itgt_inner", + outer_iname="itgt_block", + ) + + if use_compute: + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_inner, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["q0_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="x_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_inner, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["q1_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="y_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + iname_tags = { + "itgt_block": "g.0", + "itgt_inner": "l.0", + "q0": "unr", + "q1": "unr", + } + if use_compute: + iname_tags.update({ + "q0_s": "unr", + "q1_s": "unr", + }) + + knl = lp.tag_inames(knl, iname_tags) + return knl + + +def operation_model( + ntargets: int, + order: int, + target_block_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_basis_evals = 2 * ntargets * ncoeff**2 + tiled_compute_basis_evals = 2 * ntargets * ncoeff + return inline_basis_evals, tiled_compute_basis_evals + + +def l2p_flop_count(ntargets: int, order: int, use_compute: bool) -> int: + ncoeff = order + 1 + + contraction_flops = 3 * ntargets * ncoeff**2 + if use_compute: + basis_scale_flops = 2 * ntargets * ncoeff + else: + basis_scale_flops = 2 * ntargets * ncoeff**2 + + return contraction_flops + basis_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray, + gamma: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + gamma_cl = cl_array.to_device(queue, gamma) + phi_cl = cl_array.zeros(queue, x.shape, dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "inv_fact": inv_fact_cl, + "gamma": gamma_cl, + "phi": phi_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, inv_fact=inv_fact_cl, + gamma=gamma_cl, phi=phi_cl) + return out[0].get(), elapsed + + +def main( + ntargets: int = 256, + order: int = 12, + target_block_size: int = 32, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + dtype = np.float64 + rng = np.random.default_rng(14) + x = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + inv_fact = inv_factorials(order, dtype) + gamma = rng.normal(size=(order + 1, order + 1)).astype(dtype) + reference = reference_l2p(gamma, x, y, inv_fact) + + inline_evals, compute_evals = operation_model( + ntargets, order, target_block_size) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + ntargets, order, target_block_size, dtype, + use_compute=variant_uses_compute) + modeled_flops = l2p_flop_count( + ntargets, order, use_compute=variant_uses_compute) + + print(20 * "=", "L2P basis report", 20 * "=") + print(f"Variant : {'tiled compute' if variant_uses_compute else 'inline'}") + print(f"Targets : {ntargets}") + print(f"Order : {order}") + print(f"Target block: {target_block_size}") + print(f"Inline basis evaluations: {inline_evals}") + print(f"Tiled compute evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, inv_fact, gamma, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" L2P basis report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--ntargets", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=12) + _ = parser.add_argument("--target-block-size", action="store", + type=int, default=32) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + ntargets=args.ntargets, + order=args.order, + target_block_size=args.target_block_size, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/m2m-sum-factorization.py b/examples/python/compute-examples/m2m-sum-factorization.py new file mode 100644 index 000000000..d3618d27d --- /dev/null +++ b/examples/python/compute-examples/m2m-sum-factorization.py @@ -0,0 +1,636 @@ +"""Benchmark compressed Cartesian Taylor M2M sum factorization with compute. + +FMM kernel class: compressed Cartesian Taylor/asymptotic multipole-to-multipole +translation. This is the binomial center-shift part of an FMM translation, not +a direct evaluation of a Green's function such as Laplace, Helmholtz, or +biharmonic. The translation weights are powers of the center displacement +divided by factorials; PDE-specific derivative generation, compression +matrices, and recompression are outside this microbenchmark. + +The stored-index pattern here is intentionally simple. The 2D mode stores the +two coordinate axes, and the 3D mode stores the three coordinate axes. That +captures the sum-factorized structure from Section 4.2.3, but it is not the +full compressed 3D Laplace stored set, which would retain PDE-derived +hyperplane layers with O(p**2) stored coefficients rather than only O(p) axis +coefficients. + +This script models a multipole-to-multipole-like translation in 2D or 3D where +the input expansion is stored only on the coordinate axes. For 2D, the stored +coefficients are ``beta[zeta0, 0]`` and ``beta[0, zeta1]``. For 3D, they are +``beta[zeta0, 0, 0]``, ``beta[0, zeta1, 0]``, and ``beta[0, 0, zeta2]``. The +output still fills the full tensor-product coefficient grid. + +The inline variant is a GPU-parallel kernel over output coefficient indices +that expands the one-dimensional translation sums at each output. The compute +variant uses :func:`loopy.transform.compute.compute` to materialize those axis +sums into private temporaries and reuses them across an ILP tile of the last +output axis. In 2D it tiles ``eta1``; in 3D it tiles ``eta2`` while reusing the +``eta0`` and ``eta1`` axis sums across that register tile. + +Use ``--dimension 2`` or ``--dimension 3`` to choose the kernel. Use +``--compare`` to run both GPU-parallel variants, check against the NumPy +reference implementation, and report timing, modeled GFLOP/s, speedup, and +relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la +import pymbolic.primitives as p + +import loopy as lp +import loopy.transform.compute as compute_mod +from loopy.symbolic import DependencyMapper +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def enable_compute_for_reduction_substitutions() -> None: + """Let compute inspect substitution rules whose expressions are reductions.""" + + def gather_vars(expr): + deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif ( + isinstance(dep, p.Subscript) + and isinstance(dep.aggregate, p.Variable)): + var_names.add(dep.aggregate.name) + + return var_names + + compute_mod._gather_vars = gather_vars + + +def translation_weights(h: np.ndarray, order: int) -> np.ndarray: + weights = np.empty((len(h), order + 1), dtype=h.dtype) + weights[:, 0] = 1 + + for axis in range(len(h)): + for n in range(1, order + 1): + weights[axis, n] = weights[axis, n - 1] * h[axis] / n + + return weights + + +def make_axis_compressed_coefficients( + order: int, + dimension: int, + dtype + ) -> np.ndarray: + rng = np.random.default_rng(12) + beta = np.zeros(dimension * (order + 1,), dtype=dtype) + + for axis in range(dimension): + axis_slice = [0] * dimension + axis_slice[axis] = slice(None) + beta[tuple(axis_slice)] = rng.normal(size=order + 1) + + beta[(0,) * dimension] = rng.normal() + + return beta + + +def reference_axis_m2m_2d(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + order = beta.shape[0] - 1 + sigma = np.empty_like(beta) + + for eta0 in range(order + 1): + for eta1 in range(order + 1): + acc = 0 + + for zeta1 in range(eta1 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1 - zeta1] + * beta[0, zeta1] + ) + + for zeta0 in range(eta0 + 1): + acc += ( + weights[0, eta0 - zeta0] + * weights[1, eta1] + * beta[zeta0, 0] + ) + + acc -= weights[0, eta0] * weights[1, eta1] * beta[0, 0] + sigma[eta0, eta1] = acc + + return sigma + + +def reference_axis_m2m_3d(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + order = beta.shape[0] - 1 + sigma = np.empty_like(beta) + + for eta0 in range(order + 1): + for eta1 in range(order + 1): + for eta2 in range(order + 1): + acc = 0 + + for zeta0 in range(eta0 + 1): + acc += ( + weights[0, eta0 - zeta0] + * weights[1, eta1] + * weights[2, eta2] + * beta[zeta0, 0, 0] + ) + + for zeta1 in range(eta1 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1 - zeta1] + * weights[2, eta2] + * beta[0, zeta1, 0] + ) + + for zeta2 in range(eta2 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1] + * weights[2, eta2 - zeta2] + * beta[0, 0, zeta2] + ) + + acc -= ( + 2 + * weights[0, eta0] + * weights[1, eta1] + * weights[2, eta2] + * beta[0, 0, 0] + ) + sigma[eta0, eta1, eta2] = acc + + return sigma + + +def make_kernel_2d( + order: int, + eta_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % eta_tile_size: + raise ValueError("order + 1 must be divisible by eta_tile_size") + + knl = lp.make_kernel( + "{ [eta0, eta1, zeta0, zeta1] : 0 <= eta0, eta1, zeta0, zeta1 <= p }", + """ + x_axis_sum_(eta0_arg) := sum( + [zeta0], + if( + zeta0 <= eta0_arg, + w[0, eta0_arg - zeta0] * beta[zeta0, 0], + 0 + ) + ) + + y_axis_sum_(eta1_arg) := sum( + [zeta1], + if( + zeta1 <= eta1_arg, + w[1, eta1_arg - zeta1] * beta[0, zeta1], + 0 + ) + ) + + sigma[eta0, eta1] = ( + w[0, eta0] * y_axis_sum_(eta1) + + w[1, eta1] * x_axis_sum_(eta0) + - w[0, eta0] * w[1, eta1] * beta[0, 0] + ) + """, + [ + lp.GlobalArg("beta", dtype=dtype, shape=(order + 1, order + 1)), + lp.GlobalArg("w", dtype=dtype, shape=(2, order + 1)), + lp.GlobalArg("sigma", dtype=dtype, shape=(order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, p=order) + knl = lp.split_iname( + knl, + "eta1", + eta_tile_size, + inner_iname="eta1_inner", + outer_iname="eta1_block", + ) + + if use_compute: + x_axis_sum_map = nisl.make_map(f"""{{ + [eta0_arg] -> [eta0, eta1_block, x_slot] : + eta0_arg = eta0 and x_slot = 0 + }}""") + knl = compute( + knl, + "x_axis_sum_", + compute_map=x_axis_sum_map, + storage_indices=["x_slot"], + temporal_inames=["eta0", "eta1_block"], + temporary_name="x_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_axis_sum_compute", + ) + + y_axis_sum_map = nisl.make_map(f"""{{ + [eta1_arg] -> [eta0, eta1_block, y_slot] : + eta1_arg = eta1_block * {eta_tile_size} + y_slot + }}""") + knl = compute( + knl, + "y_axis_sum_", + compute_map=y_axis_sum_map, + storage_indices=["y_slot"], + temporal_inames=["eta0", "eta1_block"], + temporary_name="y_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_axis_sum_compute", + ) + knl = lp.tag_inames(knl, { + "x_slot": "unr", + "y_slot": "unr", + }) + + knl = lp.tag_inames(knl, { + "eta0": "g.1", + "eta1_block": "g.0", + "eta1_inner": "ilp", + }) + return knl + + +def make_kernel_3d( + order: int, + eta_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % eta_tile_size: + raise ValueError("order + 1 must be divisible by eta_tile_size") + + knl = lp.make_kernel( + """ + { [eta0, eta1, eta2, zeta0, zeta1, zeta2] : + 0 <= eta0, eta1, eta2, zeta0, zeta1, zeta2 <= p } + """, + """ + x_axis_sum_(eta0_arg) := sum( + [zeta0], + if( + zeta0 <= eta0_arg, + w[0, eta0_arg - zeta0] * beta[zeta0, 0, 0], + 0 + ) + ) + + y_axis_sum_(eta1_arg) := sum( + [zeta1], + if( + zeta1 <= eta1_arg, + w[1, eta1_arg - zeta1] * beta[0, zeta1, 0], + 0 + ) + ) + + z_axis_sum_(eta2_arg) := sum( + [zeta2], + if( + zeta2 <= eta2_arg, + w[2, eta2_arg - zeta2] * beta[0, 0, zeta2], + 0 + ) + ) + + sigma[eta0, eta1, eta2] = ( + w[1, eta1] * w[2, eta2] * x_axis_sum_(eta0) + + w[0, eta0] * w[2, eta2] * y_axis_sum_(eta1) + + w[0, eta0] * w[1, eta1] * z_axis_sum_(eta2) + - 2 * w[0, eta0] * w[1, eta1] * w[2, eta2] * beta[0, 0, 0] + ) + """, + [ + lp.GlobalArg( + "beta", dtype=dtype, shape=(order + 1, order + 1, order + 1)), + lp.GlobalArg("w", dtype=dtype, shape=(3, order + 1)), + lp.GlobalArg( + "sigma", dtype=dtype, + shape=(order + 1, order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, p=order) + knl = lp.split_iname( + knl, + "eta2", + eta_tile_size, + inner_iname="eta2_inner", + outer_iname="eta2_block", + ) + + if use_compute: + x_axis_sum_map = nisl.make_map(""" + { + [eta0_arg] -> [eta0, eta1, eta2_block, x_slot] : + eta0_arg = eta0 and x_slot = 0 + } + """) + knl = compute( + knl, + "x_axis_sum_", + compute_map=x_axis_sum_map, + storage_indices=["x_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="x_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_axis_sum_compute", + ) + + y_axis_sum_map = nisl.make_map(""" + { + [eta1_arg] -> [eta0, eta1, eta2_block, y_slot] : + eta1_arg = eta1 and y_slot = 0 + } + """) + knl = compute( + knl, + "y_axis_sum_", + compute_map=y_axis_sum_map, + storage_indices=["y_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="y_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_axis_sum_compute", + ) + + z_axis_sum_map = nisl.make_map(f"""{{ + [eta2_arg] -> [eta0, eta1, eta2_block, z_slot] : + eta2_arg = eta2_block * {eta_tile_size} + z_slot + }}""") + knl = compute( + knl, + "z_axis_sum_", + compute_map=z_axis_sum_map, + storage_indices=["z_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="z_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_axis_sum_compute", + ) + knl = lp.tag_inames(knl, { + "x_slot": "unr", + "y_slot": "unr", + "z_slot": "unr", + }) + + knl = lp.tag_inames(knl, { + "eta0": "g.2", + "eta1": "g.1", + "eta2_block": "g.0", + "eta2_inner": "ilp", + }) + return knl + + +def make_kernel( + order: int, + eta_tile_size: int, + dimension: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if dimension == 2: + return make_kernel_2d(order, eta_tile_size, dtype, use_compute) + if dimension == 3: + return make_kernel_3d(order, eta_tile_size, dtype, use_compute) + raise ValueError("dimension must be 2 or 3") + + +def reference_axis_m2m(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + dimension = beta.ndim + if dimension == 2: + return reference_axis_m2m_2d(beta, weights) + if dimension == 3: + return reference_axis_m2m_3d(beta, weights) + raise ValueError("dimension must be 2 or 3") + + +def operation_model( + order: int, + eta_tile_size: int, + dimension: int + ) -> tuple[int, int]: + ncoeff = order + 1 + if dimension == 2: + inline_sum_terms = 2 * ncoeff**3 + tiled_compute_sum_terms = ncoeff**3 + ncoeff**3 // eta_tile_size + elif dimension == 3: + inline_sum_terms = 3 * ncoeff**4 + tiled_compute_sum_terms = ncoeff**4 + 2 * ncoeff**4 // eta_tile_size + else: + raise ValueError("dimension must be 2 or 3") + return inline_sum_terms, tiled_compute_sum_terms + + +def m2m_flop_count( + order: int, + eta_tile_size: int, + dimension: int, + use_compute: bool + ) -> int: + ncoeff = order + 1 + + if dimension == 2: + if use_compute: + sum_flops = 2 * ncoeff**3 + 2 * ncoeff**3 // eta_tile_size + else: + sum_flops = 4 * ncoeff**3 + correction_flops = 3 * ncoeff**2 + elif dimension == 3: + if use_compute: + sum_flops = 2 * ncoeff**4 + 4 * ncoeff**4 // eta_tile_size + else: + sum_flops = 6 * ncoeff**4 + correction_flops = 8 * ncoeff**3 + else: + raise ValueError("dimension must be 2 or 3") + + return sum_flops + correction_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + beta: np.ndarray, + weights: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + beta_cl = cl_array.to_device(queue, beta) + weights_cl = cl_array.to_device(queue, weights) + sigma_cl = cl_array.zeros(queue, beta.shape, dtype=beta.dtype) + + elapsed = benchmark_executor( + ex, + queue, + {"beta": beta_cl, "w": weights_cl, "sigma": sigma_cl}, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex(queue, beta=beta_cl, w=weights_cl, sigma=sigma_cl) + return out[0].get(), elapsed + + +def main( + order: int = 16, + eta_tile_size: int = 8, + dimension: int = 2, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + if order < 0: + raise ValueError("order must be nonnegative") + if dimension not in (2, 3): + raise ValueError("dimension must be 2 or 3") + + dtype = np.float64 + h = np.array([0.25, -0.2, 0.15][:dimension], dtype=dtype) + weights = translation_weights(h, order) + beta = make_axis_compressed_coefficients(order, dimension, dtype) + reference = reference_axis_m2m(beta, weights) + + inline_terms, compute_terms = operation_model( + order, eta_tile_size, dimension) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + order, eta_tile_size, dimension, dtype, + use_compute=variant_uses_compute) + modeled_flops = m2m_flop_count( + order, eta_tile_size, dimension, + use_compute=variant_uses_compute) + + print(20 * "=", "Compressed M2M report", 20 * "=") + print(f"Variant: {'compute sum-factorized' if variant_uses_compute else 'inline'}") + print(f"Dimension: {dimension}D") + print(f"Order : {order}") + print(f"Eta tile: {eta_tile_size}") + if dimension == 2: + print("Stored compressed set: zeta0 = 0 or zeta1 = 0") + else: + print("Stored compressed set: exactly one zeta axis may be nonzero") + print(f"Inline sum terms : {inline_terms}") + print(f"Tiled compute sum terms: {compute_terms}") + print(f"Modeled flop count : {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, beta, weights, warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" Compressed M2M report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--order", action="store", type=int, default=16) + _ = parser.add_argument("--eta-tile-size", action="store", type=int, default=8) + _ = parser.add_argument("--dimension", action="store", type=int, choices=(2, 3), + default=2) + _ = parser.add_argument("--dim", action="store", type=int, choices=(2, 3), + dest="dimension") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + order=args.order, + eta_tile_size=args.eta_tile_size, + dimension=args.dimension, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/matmul.py b/examples/python/compute-examples/matmul.py new file mode 100644 index 000000000..45af6f431 --- /dev/null +++ b/examples/python/compute-examples/matmul.py @@ -0,0 +1,400 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl +import pyopencl.array as cl_array + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def benchmark_kernel( + knl: lp.TranslationUnit, + queue: cl.CommandQueue, + a: np.ndarray, + b: np.ndarray, + nwarmup: int = 5, + niterations: int = 20 +): + ex = knl.executor(queue) + + a_cl = cl_array.to_device(queue, a) + b_cl = cl_array.to_device(queue, b) + c_cl = cl_array.zeros(queue, (a.shape[0], b.shape[1]), dtype=a_cl.dtype) + + start = cl.enqueue_marker(queue) + for _ in range(nwarmup): + ex(queue, a=a_cl, b=b_cl, c=c_cl) + end = cl.enqueue_marker(queue) + end.wait() + start.wait() + + start = cl.enqueue_marker(queue) + for _ in range(niterations): + ex(queue, a=a_cl, b=b_cl, c=c_cl) + end = cl.enqueue_marker(queue) + end.wait() + start.wait() + + total_ns = end.profile.end - start.profile.end + total_elapsed_s = total_ns * 1e-9 + s_per_iter = total_elapsed_s / niterations + + total_flops = 2 * a.shape[0] * a.shape[1] * b.shape[1] + gflops = (total_flops / s_per_iter) * 1e-9 + + c_ref = a @ b + _, c_res = ex(queue, a=a_cl, b=b_cl, c=c_cl) + + error = la.norm(c_res[0].get() - c_ref) / la.norm(c_ref) + + m, k = a.shape + _, n = b.shape + print(f"================= Results =================") + print(f"M = {m}, N = {n}, K = {k}") + print(f" Error = {error:.4}") + print(f" Total time (s): {total_elapsed_s:.4}") + print(f"Time per iter (s): {s_per_iter:.4}") + print(f" GFLOP/s: {gflops}") + print(f"===========================================") + + +def naive_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int + ) -> lp.TranslationUnit: + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + iname_tags = { + "io": "g.1", + "jo": "g.0", + + "ii": "l.1", + "ji": "l.0" + } + + return lp.tag_inames(knl, iname_tags) + + +def shared_memory_tiled_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int + ) -> lp.TranslationUnit: + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [a_ii, io, a_ki, ko, jo] : + is = io * {bm} + a_ii and + ks = ko * {bk} + a_ki + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [b_ki, ko, b_ji, jo, io] : + js = jo * {bn} + b_ji and + ks = ko * {bk} + b_ki + }}""") + + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["a_ii", "a_ki"], + temporal_inames=["io", "ko"], + temporary_name="a_tile", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="a_load" + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["b_ki", "b_ji"], + temporal_inames=["ko", "jo", "io"], + temporary_name="b_tile", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="b_load" + ) + + iname_tags = { + "io": "g.1", + "ii": "l.1", + + "jo": "g.0", + "ji": "l.0", + + "a_ii": "l.1", + "a_ki": "l.0", + + "b_ki": "l.1", + "b_ji": "l.0" + } + + return lp.tag_inames(knl, iname_tags) + + +def register_tiled_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int, + tm: int, + tn: int + ) -> lp.TranslationUnit: + + # {{{ shared-memory-level split / compute + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [a_ii, io, a_ki, ko, jo] : + is = io * {bm} + a_ii and + ks = ko * {bk} + a_ki + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [b_ki, ko, b_ji, jo, io] : + js = jo * {bn} + b_ji and + ks = ko * {bk} + b_ki + }}""") + + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["a_ii", "a_ki"], + temporal_inames=["io", "ko"], + temporary_name="a_smem", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="a_load" + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["b_ki", "b_ji"], + temporal_inames=["ko", "jo"], + temporary_name="b_smem", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="b_load" + ) + + wg_size_i = bm // tm + wg_size_j = bn // tn + knl = lp.split_iname( + knl, + "a_ii", + wg_size_i, + inner_iname="a_local", + outer_iname="a_tile" + ) + + knl = lp.split_iname( + knl, + "b_ji", + wg_size_j, + inner_iname="b_local", + outer_iname="b_tile" + ) + + # }}} + + # {{{ register-level split / compute + + knl = lp.extract_subst( + knl, + "a_smem_", + "a_smem[is, ks]", + parameters="is, ks" + ) + + knl = lp.extract_subst( + knl, + "b_smem_", + "b_smem[ks, js]", + parameters="ks, js" + ) + + knl = lp.split_iname(knl, "ii", tm, + inner_iname="ii_reg", + outer_iname="ii_thr") + + knl = lp.split_iname(knl, "ji", tn, + inner_iname="ji_reg", + outer_iname="ji_thr") + + knl = lp.split_iname(knl, "ki", 8, + inner_iname="dot", + outer_iname="ki_outer") + + a_reg_tile = nisl.make_map(f"""{{ + [is, ks] -> [a_reg_i, ii_thr, ji_thr, ki_outer, dot, io, jo, ko] : + is = ii_thr * {tm} + a_reg_i and + ks = ki_outer * 8 + dot + }}""") + + b_reg_tile = nisl.make_map(f"""{{ + [ks, js] -> [b_reg_j, ki_outer, dot, ii_thr, ji_thr, io, jo, ko] : + ks = ki_outer * 8 + dot and + js = ji_thr * {tn} + b_reg_j + }}""") + + knl = compute( + knl, + "a_smem_", + compute_map=a_reg_tile, + storage_indices=["a_reg_i"], + temporal_inames=["ii_thr", "ji_thr", "ki_outer", "dot", "io", "jo", "ko"], + temporary_name="a_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id="a_reg_load" + ) + + knl = compute( + knl, + "b_smem_", + compute_map=b_reg_tile, + storage_indices=["b_reg_j"], + temporal_inames=["ii_thr", "ji_thr", "ki_outer", "dot", "io", "jo", "ko"], + temporary_name="b_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id="b_reg_load" + ) + + # }}} + + iname_tags = { + # global tiles + "io" : "g.1", + "jo" : "g.0", + + # a local storage axes + "a_local": "l.1", + "a_ki" : "l.0", + + # b local storage axes + "b_local": "l.0", + "b_ki" : "l.1", + + # register tiles + "ii_thr": "l.1", + "ji_thr": "l.0", + + # register storage axes + "a_reg_i": "ilp", + "b_reg_j": "ilp", + + # compute axes + "ii_reg": "ilp", + "ji_reg": "ilp" + } + + return lp.tag_inames(knl, iname_tags) + + +def main( + m: int = 1024, + n: int = 1024, + k: int = 1024, + bm: int = 64, + bn: int = 64, + bk: int = 32, + tm: int = 4, + tn: int = 4, + shared_memory_tiled: bool = False, + register_tiled: bool = False, + dtype: lp.ToLoopyTypeConvertible = np.float32, + print_kernel: bool = False, + print_device_code: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(m, k), dtype=dtype), + lp.GlobalArg("b", shape=(k, n), dtype=dtype), + lp.GlobalArg("c", shape=(m, n), is_output=True) + ] + ) + + knl = lp.fix_parameters(knl, M=m, N=n, K=k) + + if shared_memory_tiled: + knl = shared_memory_tiled_matmul(knl, bm, bn, bk) + elif register_tiled: + knl = register_tiled_matmul(knl, bm, bn, bk, tm, tn) + else: + knl = naive_matmul(knl, bm, bn, bk) + + ctx = cl.create_some_context() + queue = cl.CommandQueue( + ctx, + properties=cl.command_queue_properties.PROFILING_ENABLE + ) + + a = np.random.randn(m, k).astype(dtype) + b = np.random.randn(k, n).astype(dtype) + + benchmark_kernel(knl, queue, a, b) + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--m", action="store", type=int, default=1024) + _ = parser.add_argument("--n", action="store", type=int, default=1024) + _ = parser.add_argument("--k", action="store", type=int, default=1024) + + _ = parser.add_argument("--bm", action="store", type=int, default=64) + _ = parser.add_argument("--bn", action="store", type=int, default=64) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + + _ = parser.add_argument("--tm", action="store", type=int, default=4) + _ = parser.add_argument("--tn", action="store", type=int, default=4) + + _ = parser.add_argument("--shared-memory-tiled", action="store_true") + _ = parser.add_argument("--register-tiled", action="store_true") + + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + + args = parser.parse_args() + + main( + m=args.m, n=args.n, k=args.k, + bm=args.bm, bn=args.bn, bk=args.bk, + tm=args.tm, tn=args.tn, + shared_memory_tiled=args.shared_memory_tiled, + register_tiled=args.register_tiled, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code + ) diff --git a/examples/python/compute-examples/p2m-basis-compute.py b/examples/python/compute-examples/p2m-basis-compute.py new file mode 100644 index 000000000..2b4b1ab82 --- /dev/null +++ b/examples/python/compute-examples/p2m-basis-compute.py @@ -0,0 +1,369 @@ +"""Benchmark a 2D Cartesian Taylor P2M microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic multipole +moment formation. This script does not evaluate a particular Green's function +such as the 2D Laplace or Helmholtz kernel. It builds the source monomial +moments that a Taylor FMM would later pair with kernel derivatives or translated +coefficients. In a Laplace FMM, the Laplace-specific derivative recurrence or +compressed representation lives outside this benchmark. + +The kernel forms tensor-product source moments from particle strengths: + + beta[q0, q1] = sum_{isrc} strength[isrc] + * x[isrc]**q0 / q0! + * y[isrc]**q1 / q1! + +The inline variant is a GPU-parallel reduction over sources for every output +coefficient. The compute variant splits the source and q1 loops and uses +:func:`loopy.transform.compute.compute` to precompute reusable x and y monomial +basis values in private temporaries. This tests whether compute can expose +source-tile and coefficient-tile reuse in a reduction-heavy P2M-like kernel. + +Use ``--compare`` to run both GPU-parallel variants, compare with the NumPy +reference result, and print timing, modeled GFLOP/s, speedup, and relative +error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_p2m( + strength: np.ndarray, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = inv_fact.size - 1 + beta = np.empty((order + 1, order + 1), dtype=x.dtype) + + for q0 in range(order + 1): + for q1 in range(order + 1): + acc = 0 + for isrc in range(x.size): + x_monom = x[isrc]**q0 * inv_fact[q0] + y_monom = y[isrc]**q1 * inv_fact[q1] + acc += strength[isrc] * x_monom * y_monom + beta[q0, q1] = acc + + return beta + + +def make_kernel( + nsources: int, + order: int, + q1_tile_size: int, + source_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % q1_tile_size: + raise ValueError("order + 1 must be divisible by q1_tile_size") + if nsources % source_tile_size: + raise ValueError("nsources must be divisible by source_tile_size") + + knl = lp.make_kernel( + "{ [isrc, q0, q1] : 0 <= isrc < nsources and 0 <= q0, q1 <= p }", + """ + x_monom_(isrc_arg, q0_arg) := ( + x[isrc_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_monom_(isrc_arg, q1_arg) := ( + y[isrc_arg] ** q1_arg * inv_fact[q1_arg] + ) + + beta[q0, q1] = sum( + [isrc], + strength[isrc] * x_monom_(isrc, q0) * y_monom_(isrc, q1) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("y", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("strength", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg("beta", dtype=dtype, shape=(order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, nsources=nsources, p=order) + knl = lp.split_iname( + knl, + "q1", + q1_tile_size, + inner_iname="q1_inner", + outer_iname="q1_outer", + ) + knl = lp.split_iname( + knl, + "isrc", + source_tile_size, + inner_iname="isrc_inner", + outer_iname="isrc_outer", + ) + + if use_compute: + x_monom_map = nisl.make_map(f"""{{ + [isrc_arg, q0_arg] -> [q0, q1_outer, isrc_outer, isrc_s] : + isrc_arg = isrc_outer * {source_tile_size} + isrc_s and + q0_arg = q0 + }}""") + + knl = compute( + knl, + "x_monom_", + compute_map=x_monom_map, + storage_indices=["isrc_s"], + temporal_inames=["q0", "q1_outer", "isrc_outer"], + temporary_name="x_monom_for_q1_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_monom_compute", + ) + + y_monom_map = nisl.make_map(f"""{{ + [isrc_arg, q1_arg] -> [q0, q1_outer, q1_inner, isrc_outer, isrc_s] : + isrc_arg = isrc_outer * {source_tile_size} + isrc_s and + q1_arg = q1_outer * {q1_tile_size} + q1_inner + }}""") + + knl = compute( + knl, + "y_monom_", + compute_map=y_monom_map, + storage_indices=["isrc_s"], + temporal_inames=["q0", "q1_outer", "q1_inner", "isrc_outer"], + temporary_name="y_monom_for_coeff", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_monom_compute", + ) + + return lp.tag_inames(knl, { + "q0": "g.1", + "q1_outer": "g.0", + "q1_inner": "ilp", + }) + + +def operation_model( + nsources: int, + order: int, + q1_tile_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_monomial_evals = 2 * nsources * ncoeff**2 + compute_monomial_evals = ( + nsources * ncoeff**2 + + nsources * ncoeff**2 // q1_tile_size + ) + return inline_monomial_evals, compute_monomial_evals + + +def p2m_flop_count( + nsources: int, + order: int, + q1_tile_size: int, + use_compute: bool + ) -> int: + ncoeff = order + 1 + + contraction_flops = 3 * nsources * ncoeff**2 + if use_compute: + monomial_scale_flops = ( + nsources * ncoeff**2 + + nsources * ncoeff**2 // q1_tile_size + ) + else: + monomial_scale_flops = 2 * nsources * ncoeff**2 + + return contraction_flops + monomial_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + strength: np.ndarray, + inv_fact: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + strength_cl = cl_array.to_device(queue, strength) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + beta_cl = cl_array.zeros(queue, (inv_fact.size, inv_fact.size), + dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "strength": strength_cl, + "inv_fact": inv_fact_cl, + "beta": beta_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, strength=strength_cl, + inv_fact=inv_fact_cl, beta=beta_cl) + return out[0].get(), elapsed + + +def main( + nsources: int = 256, + order: int = 12, + q1_tile_size: int = 13, + source_tile_size: int = 128, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + dtype = np.float64 + rng = np.random.default_rng(18) + x = rng.uniform(-0.25, 0.25, size=nsources).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=nsources).astype(dtype) + strength = rng.normal(size=nsources).astype(dtype) + inv_fact = inv_factorials(order, dtype) + reference = reference_p2m(strength, x, y, inv_fact) + + inline_evals, compute_evals = operation_model( + nsources, order, q1_tile_size) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + nsources, order, q1_tile_size, source_tile_size, dtype, + use_compute=variant_uses_compute) + modeled_flops = p2m_flop_count( + nsources, order, q1_tile_size, + use_compute=variant_uses_compute) + + print(20 * "=", "P2M basis report", 20 * "=") + print(f"Variant: {'compute' if variant_uses_compute else 'inline'}") + print(f"Sources: {nsources}") + print(f"Order : {order}") + print(f"q1 tile: {q1_tile_size}") + print(f"Source tile: {source_tile_size}") + print(f"Inline monomial evaluations: {inline_evals}") + print(f"Compute monomial evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, strength, inv_fact, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" P2M basis report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--nsources", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=12) + _ = parser.add_argument("--q1-tile-size", action="store", type=int, default=13) + _ = parser.add_argument("--source-tile-size", action="store", + type=int, default=128) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + nsources=args.nsources, + order=args.order, + q1_tile_size=args.q1_tile_size, + source_tile_size=args.source_tile_size, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/run-compute-examples.sh b/examples/python/compute-examples/run-compute-examples.sh new file mode 100755 index 000000000..1121bc0c3 --- /dev/null +++ b/examples/python/compute-examples/run-compute-examples.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +set -euo pipefail + +PYTHON="$(which python)" +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" + +cd "$SCRIPT_DIR" + +run_example() { + echo + echo "===== $* =====" + "$PYTHON" "$@" +} + +run_example finite-difference-2-5D.py \ + --npoints 96 \ + --stencil-width 9 \ + --compute \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example finite-difference-diamond.py \ + --ntime 96 \ + --nspace 4096 \ + --stencil-width 9 \ + --time-block-size 8 \ + --space-block-size 128 \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example wave-equation-ring-buffer.py \ + --ntime 4096 \ + --compute \ + --run-kernel \ + --warmup 2 \ + --iterations 5 + +run_example matmul.py \ + --m 512 \ + --n 512 \ + --k 512 \ + --bm 32 \ + --bn 32 \ + --bk 16 \ + --shared-memory-tiled + +run_example matmul.py \ + --m 512 \ + --n 512 \ + --k 512 \ + --bm 64 \ + --bn 64 \ + --bk 16 \ + --tm 4 \ + --tn 4 \ + --register-tiled + +run_example l2p-tiled-basis-compute.py \ + --ntargets 512 \ + --order 12 \ + --target-block-size 64 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example p2m-basis-compute.py \ + --nsources 512 \ + --order 12 \ + --q1-tile-size 13 \ + --source-tile-size 128 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example m2m-sum-factorization.py \ + --order 23 \ + --eta-tile-size 8 \ + --dimension 3 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example l2p-3d-tensor-product-compute.py \ + --ntargets 512 \ + --order 8 \ + --target-block-size 64 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 diff --git a/examples/python/compute-examples/wave-equation-ring-buffer.py b/examples/python/compute-examples/wave-equation-ring-buffer.py new file mode 100644 index 000000000..501f89502 --- /dev/null +++ b/examples/python/compute-examples/wave-equation-ring-buffer.py @@ -0,0 +1,211 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def wave_flop_count(ntime: int) -> int: + return 5 * (ntime - 2) + + +def main( + ntime: int = 128, + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> float | None: + dtype = np.float64 + + dt = dtype(1 / 512) + omega = dtype(2 * np.pi) + omega2 = dtype(omega**2) + + t = dt * np.arange(ntime, dtype=dtype) + u = np.cos(omega * t).astype(dtype) + + bt = 32 + + knl = lp.make_kernel( + "{ [t] : 1 <= t < ntime - 1 }", + """ + u_hist(ts) := u[ts] + + u_next[t + 1] = ( + 2 * u_hist(t) + - u_hist(t - 1) + - dt2 * omega2 * u_hist(t) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(ntime,)), + lp.GlobalArg("u_next", dtype=dtype, shape=(ntime,), + is_output=True), + lp.ValueArg("dt2", dtype=dtype), + lp.ValueArg("omega2", dtype=dtype), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntime=ntime) + knl = lp.split_iname(knl, "t", bt, inner_iname="ti", outer_iname="to") + + if use_compute: + ring_buffer_map = nisl.make_map(f"""{{ + [ts] -> [to, ti, tb] : + tb = ts - (to * {bt} + ti) + 1 + }}""") + + knl = compute( + knl, + "u_hist", + compute_map=ring_buffer_map, + storage_indices=["tb"], + temporal_inames=["to", "ti"], + + temporary_name="u_time_buf", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + + compute_insn_id="u_time_buf_compute", + inames_to_advance=["ti"], + ) + + knl = lp.tag_inames(knl, {"tb": "unr"}) + + knl = lp.tag_inames(knl, {"to": "g.0"}) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + if not run_kernel: + return None + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + + dt2 = dtype(dt**2) + avg_time_per_iter = benchmark_executor( + ex, queue, {"u": u, "dt2": dt2, "omega2": omega2}, + warmup=warmup, iterations=iterations) + avg_gflops = wave_flop_count(ntime) / avg_time_per_iter / 1e9 + + _, out = ex(queue, u=u, dt2=dt2, omega2=omega2) + + ref = np.zeros_like(u) + for time_idx in range(1, ntime - 1): + ref[time_idx + 1] = ( + 2 * u[time_idx] + - u[time_idx - 1] + - dt2 * omega2 * u[time_idx] + ) + + sl = slice(2, ntime) + rel_err = la.norm(ref[sl] - out[0][sl]) / la.norm(ref[sl]) + + print(20 * "=", "Wave recurrence report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Time steps : {ntime}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Wave recurrence report ")) * "=") + + return avg_time_per_iter + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--ntime", action="store", type=int, default=128) + + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", + dest="run_kernel") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + if args.compare: + print("Running example without compute...") + no_compute_time = main( + ntime=args.ntime, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + ntime=args.ntime, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + ntime=args.ntime, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 4a67d69f5..5c587c4e7 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -49,6 +49,7 @@ from constantdict import constantdict from typing_extensions import Self, override +import namedisl as nisl import islpy as isl import pymbolic.primitives as p import pytools.lex @@ -2077,23 +2078,38 @@ def map_subscript(self, expr: p.Subscript, /) -> AbstractSet[p.Subscript]: # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: +def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var + # FIXME: remove this once namedisl is the standard in loopy denom = aff.get_denominator_val().to_python() - result = (aff.get_constant_val()*denom).to_python() - for dt in [isl.dim_type.in_, isl.dim_type.param]: - for i in range(aff.dim(dt)): - coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if isinstance(aff, isl.Aff): + for dt in [isl.dim_type.in_, isl.dim_type.param]: + for i in range(aff.dim(dt)): + coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if coeff: + dim_name = not_none(aff.get_dim_name(dt, i)) + result += coeff*var(dim_name) + + for i in range(aff.dim(isl.dim_type.div)): + coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() if coeff: - dim_name = not_none(aff.get_dim_name(dt, i)) - result += coeff*var(dim_name) + result += coeff*aff_to_expr(aff.get_div(i)) + + else: + in_names = set(aff.dim_type_names(isl.dim_type.in_)) + param_names = set(aff.dim_type_names(isl.dim_type.param)) - for i in range(aff.dim(isl.dim_type.div)): - coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() - if coeff: - result += coeff*aff_to_expr(aff.get_div(i)) + for name in in_names | param_names: + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result = coeff * var(name) + + for name in aff.dim_type_names(isl.dim_type.div): + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result += coeff * aff_to_expr(aff.get_div(name)) assert not isinstance(result, complex) return flatten(result // denom) @@ -2108,7 +2124,7 @@ def pw_aff_to_expr(pw_aff: isl.PwAff | isl.Aff, def pw_aff_to_expr( - pw_aff: int | isl.PwAff | isl.Aff, + pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff, int_ok: bool = False ) -> ArithmeticExpression: if isinstance(pw_aff, int): @@ -2548,6 +2564,26 @@ def constraint_to_cond_expr(cns: isl.Constraint) -> ArithmeticExpression: # }}} +# {{{ MultiPwAff from sequence of pymbolic exprs + +def multi_pw_aff_from_exprs( + exprs: Sequence[Expression], + space: isl.Space + ) -> isl.MultiPwAff: + + mpwaff = isl.MultiPwAff.zero(space) + for i in range(len(exprs)): + local_space = mpwaff.get_at(i).get_space().domain() + mpwaff = mpwaff.set_pw_aff( + i, + pwaff_from_expr(local_space, exprs[i]) + ) + + return mpwaff + +# }}} + + # {{{ isl_set_from_expr class ConditionExpressionToBooleanOpsExpression(IdentityMapper[[]]): diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 80ed45de9..2b168ca39 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 80ed45de98b5432341763b9fa52a00fdac870b89 +Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 diff --git a/loopy/transform/compute-old.py b/loopy/transform/compute-old.py new file mode 100644 index 000000000..7be2a9f26 --- /dev/null +++ b/loopy/transform/compute-old.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias, override + +import namedisl as nisl + +import islpy as isl +import pymbolic.primitives as p +from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func +from pytools.tag import Tag + +import loopy as lp +from loopy.kernel.tools import DomainChanger +from loopy.match import StackMatch, parse_stack_match +from loopy.symbolic import ( + DependencyMapper, + ExpansionState, + RuleAwareIdentityMapper, + RuleAwareSubstitutionMapper, + SubstitutionRuleExpander, + SubstitutionRuleMappingContext, + multi_pw_aff_from_exprs, + pw_aff_to_expr, +) +from loopy.transform.precompute import contains_a_subst_rule_invocation +from loopy.translation_unit import for_each_kernel +from loopy.types import ToLoopyTypeConvertible, to_loopy_type + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence, Set + + from pymbolic.typing import Expression + + from loopy.kernel import LoopKernel + from loopy.kernel.data import AddressSpace + + +AccessTuple: TypeAlias = tuple[str, ...] + + +def _access_key(args: Sequence[Expression]) -> AccessTuple: + return tuple(str(arg) for arg in args) + + +def _base_name(name: str) -> str: + return name.removesuffix("_") + + +def _cur_name(name: str) -> str: + return f"{_base_name(name)}_cur" + + +def _prev_name(name: str) -> str: + return f"{_base_name(name)}_prev" + + +def _basic_set_to_predicates(bset: nisl.BasicSet) -> frozenset[Expression]: + isl_bset = bset._reconstruct_isl_object() + + predicates = [] + for constraint in isl_bset.get_constraints(): + expr = pw_aff_to_expr(constraint.get_aff()) + if constraint.is_equality(): + predicates.append(p.Comparison(expr, "==", 0)) + else: + predicates.append(p.Comparison(expr, ">=", 0)) + + return frozenset(predicates) + + +def _set_to_predicate_options( + set_: nisl.Set | nisl.BasicSet + ) -> Sequence[frozenset[Expression]]: + if isinstance(set_, nisl.BasicSet): + if set_._reconstruct_isl_object().is_empty(): + return [] + return [_basic_set_to_predicates(set_)] + + predicate_options = [] + for bset in set_.get_basic_sets(): + if not bset._reconstruct_isl_object().is_empty(): + predicate_options.append(_basic_set_to_predicates(bset)) + + return predicate_options + + +# helper for gathering names of variables in pymbolic expressions +def _gather_vars(expr: Expression) -> set[str]: + deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif ( + isinstance(dep, p.Subscript) + and isinstance(dep.aggregate, p.Variable)): + var_names.add(dep.aggregate.name) + + return var_names + + +def _existing_name_mapping( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str] + ) -> Mapping[str, str]: + names = map_.names + return { + source: target + for source, target in name_mapping.items() + if source in names and target in names + } + + +def _normalize_renamed_dims( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], + ) -> nisl.Map | nisl.BasicMap: + map_ = map_.equate_dims(_existing_name_mapping(map_, name_mapping)) + + names = map_.names + project_names = [ + renamed_name + for original_name, renamed_name in name_mapping.items() + if original_name in names and renamed_name in names + ] + map_ = map_.project_out(project_names) + + names = map_.names + rename_mapping = { + renamed_name: original_name + for original_name, renamed_name in name_mapping.items() + if original_name not in names and renamed_name in names + } + return map_.rename_dims(rename_mapping) + + +# {{{ gathering usage expressions + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + """ + Gathers all expressions used as inputs to a particular substitution rule, + identified by name. + """ + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_expander: SubstitutionRuleExpander, + kernel: LoopKernel, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None + ) -> None: + + super().__init__(rule_mapping_ctx) + + self.subst_expander: SubstitutionRuleExpander = subst_expander + self.kernel: LoopKernel = kernel + self.subst_name: str = subst_name + self.subst_tag: Set[Tag] | None = ( + {subst_tag} if isinstance(subst_tag, Tag) else subst_tag + ) + + self.usage_expressions: list[Sequence[Expression]] = [] + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + + if name != self.subst_name: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + if self.subst_tag is not None and self.subst_tag != tags: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.usage_expressions.append([ + arg_ctx[arg_name] for arg_name in rule.arguments + ]) + + return 0 + +# }}} + + +# {{{ substitution rule use replacement + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_ids: str | Sequence[str], + footprint: nisl.Set + ) -> None: + + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: Sequence[Tag] | None = subst_tag + + self.usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = \ + usage_descriptors + self.storage_indices: Sequence[str] = storage_indices + self.footprint: nisl.Set = footprint + + self.temporary_name: str = temporary_name + self.compute_insn_ids: frozenset[str] = ( + frozenset([compute_insn_ids]) + if isinstance(compute_insn_ids, str) + else frozenset(compute_insn_ids) + ) + + self.replaced_something: bool = False + + # FIXME: may not always be the case (i.e. global barrier between + # compute insn and uses) + self.compute_dep_ids: frozenset[str] = self.compute_insn_ids + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState + ) -> Expression: + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + args = [arg_ctx[arg_name] for arg_name in rule.arguments] + + # {{{ validation checks + + if name != self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + access_key = _access_key(args) + if access_key not in self.usage_descriptors: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if len(arguments) != len(rule.arguments): + raise ValueError( + f"Number of arguments passed to rule {name} " + f"does not match the signature of {name}." + ) + + local_map = self.usage_descriptors[access_key] + temp_footprint = self.footprint.move_dims( + frozenset(self.footprint.names) - frozenset(self.storage_indices), + isl.dim_type.param + ) + + if not local_map.range() <= temp_footprint: + return super().map_subst_rule(name, tags, arguments, expn_state) + + # }}} + + # {{{ get index expression in terms of global inames + + local_pwmaff = self.usage_descriptors[access_key].as_pw_multi_aff() + + index_exprs: Sequence[Expression] = [] + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + + new_expression = var(self.temporary_name)[tuple(index_exprs)] + + # }}} + + self.replaced_something = True + + return new_expression + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: + + new_insns: Sequence[lp.InstructionBase] = [] + for insn in kernel.instructions: + self.replaced_something = False + + if (isinstance(insn, lp.MultiAssignmentBase) and not + contains_a_subst_rule_invocation(kernel, insn)): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr, insn=insn: self(expr, kernel, insn) + ) + + if self.replaced_something: + insn = insn.copy( + depends_on=( + insn.depends_on | self.compute_dep_ids + ) + ) + + # FIXME: determine compute insn dependencies + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + +# }}} + + +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: nisl.Map, + + storage_indices: Sequence[str], + + # FIXME: can these two be deduced? + temporal_inames: Sequence[str], + inames_to_advance: Sequence[str] | None = None, + + temporary_name: str | None = None, + temporary_address_space: AddressSpace | None = None, + + temporary_dtype: ToLoopyTypeConvertible = None, + + compute_insn_id: str | None = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + inserted compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". + + :arg storage_indices: An ordered sequence of names of storage indices. Used + to create inames for the loops that cover the required set of compute points. + """ + + name_mapping = { + name: name + "_" + for name in compute_map.output_names + if name not in storage_indices + } + compute_map = compute_map.rename_dims(name_mapping) + + # {{{ setup and useful variables + + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + expander = SubstitutionRuleExpander(kernel.substitutions) + expr_gatherer = UsageSiteExpressionGatherer( + ctx, expander, kernel, substitution, None + ) + + _ = expr_gatherer.map_kernel(kernel) + usage_exprs = expr_gatherer.usage_expressions + + all_exprs = [expr for usage in usage_exprs for expr in usage] + usage_inames: frozenset[str] = frozenset( + set.union(*(_gather_vars(expr) for expr in all_exprs)) + ) + + # }}} + + # {{{ construct necessary pieces; footprint, global usage map + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) + + # restrict domain to used inames + local_domain = named_domain.project_out_except(usage_inames) + + # FIXME: gross. find a cleaner way to generate a space for an empty map + global_usage_map = nisl.make_map_from_domain_and_range( + nisl.make_set(isl.Set.empty(local_domain.get_space())), + compute_map.domain() + ) + global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) + + usage_substs: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = {} + for usage in usage_exprs: + + # {{{ compute local usage map, update global usage map + + local_usage_mpwaff = multi_pw_aff_from_exprs( + usage, + global_usage_map.get_space() + ) + + local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) + + local_usage_map = local_usage_map.intersect_domain(local_domain) + global_usage_map = global_usage_map | local_usage_map + + # }}} + + # {{{ compute storage map + + local_storage_map = local_usage_map.apply_range(compute_map) + local_storage_map = _normalize_renamed_dims( + local_storage_map, name_mapping) + + # check that no restrictions happened during composition (i.e. tile + # valid for a single point in the domain) + if not local_usage_map.domain() <= local_storage_map.domain(): + continue + + # clean up names + non_param_names = (usage_inames - temporal_set) | storage_set + parameter_names = frozenset(local_storage_map.names) - non_param_names + local_storage_map = local_storage_map.move_dims(parameter_names, + isl.dim_type.param) + + # }}} + + usage_substs[_access_key(usage)] = local_storage_map + + storage_map = global_usage_map.apply_range(compute_map) + storage_map = _normalize_renamed_dims(storage_map, name_mapping) + + # }}} + + # {{{ compute bounds and update kernel domain + + storage_map = storage_map.move_dims(temporal_set, isl.dim_type.param) + footprint = storage_map.range() + + # clean up ticked duplicate names + footprint = footprint.project_out_except(temporal_set | storage_set) + footprint = footprint.move_dims(temporal_set, isl.dim_type.set) + + # {{{ FIXME: use Sets instead of BasicSets when loopy is ready + + # FIXME: convex hull is not permanent + footprint_isl = footprint._reconstruct_isl_object() + footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) + named_domain = named_domain & footprint + + if len(named_domain.get_basic_sets()) != 1: + raise ValueError("New domain should be composed of a single basic set") + + # FIXME: use named object once loopy is name-ified + domain = named_domain.get_basic_sets()[0]._reconstruct_isl_object() + new_domains = domain_changer.get_domains_with(domain) + + # }}} + + kernel = kernel.copy(domains=new_domains) + + # }}} + + if not temporary_name: + temporary_name = substitution + "_temp" + + if not compute_insn_id: + compute_insn_id = substitution + "_compute" + + # {{{ reuse analysis + + update_insns: list[lp.InstructionBase] = [] + update_insn_ids: list[str] = [] + refill_predicate_options: Sequence[frozenset[Expression] | None] = [None] + current_update_deps: frozenset[str] = frozenset() + + if inames_to_advance is not None: + advancing_set = frozenset(inames_to_advance) + + compute_map_cur = compute_map.rename_dims({ + name: _cur_name(name) for name in compute_map.output_names + }) + compute_map_prev = compute_map.rename_dims({ + name: _prev_name(name) for name in compute_map.output_names + }) + + cur_storage = global_usage_map.apply_range(compute_map_cur) + prev_storage = global_usage_map.apply_range(compute_map_prev) + + reuse_map = prev_storage.reverse().apply_range(cur_storage) + reuse_map = reuse_map.add_constraint([ + ( + f"{name}_cur = {name}_prev + 1" + if name in advancing_set + else + f"{name}_cur = {name}_prev" + ) + for name in temporal_inames + ]) + + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + previous_footprint = footprint.rename_dims({ + name: _prev_name(name) for name in footprint.names + }) + + reuse_map = reuse_map.intersect_domain(previous_footprint) + reuse_map = reuse_map.intersect_range(current_footprint) + reuse_map = reuse_map - nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in footprint.names) + + "] -> [" + + ", ".join(_cur_name(name) for name in footprint.names) + + "] : " + + " and ".join( + f"{_cur_name(name)} = {_prev_name(name)}" + for name in storage_indices + ) + + " }" + ) + + reused_current = reuse_map.range() + refill = current_footprint - reused_current + + cur_to_normal = { + _cur_name(name): name + for name in footprint.names + } + reused_current = reused_current.rename_dims(cur_to_normal) + refill = refill.rename_dims(cur_to_normal) + + reused_context = named_domain.project_out_except(reused_current.names) + refill_context = named_domain.project_out_except(refill.names) + + reused_current = reused_current.gist(reused_context) + refill = refill.gist(refill_context) + + refill_predicate_options = _set_to_predicate_options(refill) + + storage_reuse_map = reuse_map.project_out_except( + frozenset(_prev_name(name) for name in storage_indices) + | frozenset(_cur_name(name) for name in storage_indices) + ) + storage_reuse_map = storage_reuse_map.rename_dims({ + _cur_name(name): name + for name in storage_indices + }) + cur_to_prev = storage_reuse_map.reverse() + cur_to_prev_pwma = cur_to_prev.as_pw_multi_aff() + prev_expr_by_name = { + cur_to_prev_pwma.get_dim_name(isl.dim_type.out, dim): + pw_aff_to_expr(cur_to_prev_pwma.get_at(dim)) + for dim in range(cur_to_prev_pwma.dim(isl.dim_type.out)) + } + prev_storage_exprs = [ + prev_expr_by_name[_prev_name(name)] + for name in storage_indices + ] + + shift_assignee = var(temporary_name)[ + tuple(var(idx) for idx in storage_indices) + ] + shift_expression = var(temporary_name)[tuple(prev_storage_exprs)] + + shift_predicate_options = _set_to_predicate_options(reused_current) + for i, predicates in enumerate(shift_predicate_options): + shift_insn_id = ( + f"{compute_insn_id}_shift" + if len(shift_predicate_options) == 1 + else f"{compute_insn_id}_shift_{i}" + ) + update_insns.append(lp.Assignment( + id=shift_insn_id, + assignee=shift_assignee, + expression=shift_expression, + within_inames=frozenset(temporal_inames) | storage_set, + predicates=predicates, + depends_on=current_update_deps, + )) + update_insn_ids.append(shift_insn_id) + current_update_deps = frozenset([shift_insn_id]) + + # }}} + + # {{{ create compute instruction in kernel + + # FIXME: maybe just keep original around? + compute_map = compute_map.rename_dims({ + value: key for key, value in name_mapping.items() + }) + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + storage_ax_to_global_expr = { + compute_pw_aff.get_dim_name(isl.dim_type.out, dim): + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + } + + expr_subst_map = RuleAwareSubstitutionMapper( + ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map( + subst_expr, + kernel, + None, + ) + compute_dep_ids = frozenset().union(*( + kernel.writer_map().get(var_name, frozenset()) + for var_name in _gather_vars(compute_expression) + )) + + assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + + within_inames = compute_map.output_names + + new_insns = list(kernel.instructions) + new_insns.extend(update_insns) + + for i, predicates in enumerate(refill_predicate_options): + refill_insn_id = ( + compute_insn_id + if len(refill_predicate_options) == 1 + else f"{compute_insn_id}_refill_{i}" + ) + compute_insn = lp.Assignment( + id=refill_insn_id, + assignee=assignee, + expression=compute_expression, + within_inames=within_inames, + predicates=predicates, + depends_on=current_update_deps | compute_dep_ids, + ) + new_insns.append(compute_insn) + update_insn_ids.append(refill_insn_id) + current_update_deps = frozenset([refill_insn_id]) + + kernel = kernel.copy(instructions=new_insns) + + # }}} + + # {{{ replace invocations with new compute instruction + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + + replacer = RuleInvocationReplacer( + ctx, + substitution, + None, + usage_substs, + storage_indices, + temporary_name, + update_insn_ids, + footprint + ) + + kernel = replacer.map_kernel(kernel) + + # }}} + + # {{{ create temporary variable for result of compute + + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) + + temp_shape = tuple( + pw_aff_to_expr(footprint.dim_max(dim)) + 1 + for dim in storage_indices + ) + + new_temp_vars = dict(kernel.temporary_variables) + + # FIXME: temp_var might already exist, handle the case where it does + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=(0,)*len(temp_shape), + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices) + ) + + new_temp_vars[temporary_name] = temp_var + + kernel = kernel.copy( + temporary_variables=new_temp_vars + ) + + # }}} + + return kernel diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py new file mode 100644 index 000000000..fd839d549 --- /dev/null +++ b/loopy/transform/compute.py @@ -0,0 +1,986 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, TypeAlias, cast + +import namedisl as nisl +from typing_extensions import override + +import islpy as isl +import pymbolic.primitives as p +from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func +from pymbolic.typing import Expression +from pytools.tag import Tag + +from ..kernel.data import TemporaryVariable +from ..kernel.instruction import Assignment, InstructionBase, MultiAssignmentBase +from ..kernel.tools import DomainChanger +from ..match import StackMatch, parse_stack_match +from ..symbolic import ( + DependencyMapper, + ExpansionState, + RuleAwareIdentityMapper, + RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext, + multi_pw_aff_from_exprs, + pw_aff_to_expr, +) +from ..translation_unit import for_each_kernel +from ..types import ToLoopyTypeConvertible, to_loopy_type +from .precompute import contains_a_subst_rule_invocation + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence, Set + + from ..kernel import LoopKernel + from ..kernel.data import AddressSpace + + +UsageKey: TypeAlias = tuple[str, int] +PredicateSet: TypeAlias = frozenset[Expression] +PredicateOptions: TypeAlias = tuple[PredicateSet | None, ...] + + +@dataclass(frozen=True) +class UsageSite: + insn_id: str + ordinal: int + args: tuple[Expression, ...] + predicates: PredicateSet + + @property + def key(self) -> UsageKey: + return self.insn_id, self.ordinal + + @property + def domain_names(self) -> frozenset[str]: + exprs = (*self.args, *self.predicates) + return frozenset(set().union(*(_gather_vars(expr) for expr in exprs))) + + +@dataclass(frozen=True) +class NameState: + internal_compute_map: nisl.Map + renamed_to_original: Mapping[str, str] + original_to_renamed: Mapping[str, str] + + +@dataclass(frozen=True) +class UsageInfo: + global_usage_map: nisl.Map + local_storage_maps: Mapping[UsageKey, nisl.Map | nisl.BasicMap] + + +@dataclass(frozen=True) +class FootprintInfo: + loopy_footprint: nisl.Set + named_domain: nisl.Set + + +@dataclass(frozen=True) +class ReuseRelations: + shift_relation: nisl.Map + reusable_footprint: nisl.Set + refill_footprint: nisl.Set + + +@dataclass(frozen=True) +class ComputePlan: + name_state: NameState + usage_info: UsageInfo + footprint_info: FootprintInfo + storage_indices: tuple[str, ...] + temporal_inames: tuple[str, ...] + reuse_relations: ReuseRelations | None + + +@dataclass(frozen=True) +class ComputeInstructionInfo: + expression: Expression + dependencies: frozenset[str] + within_inames: frozenset[str] + + +def _base_name(name: str) -> str: + return name.removesuffix("_") + + +def _cur_name(name: str) -> str: + return f"{_base_name(name)}_cur" + + +def _prev_name(name: str) -> str: + return f"{_base_name(name)}_prev" + + +def _make_name_state( + compute_map: nisl.Map, + storage_indices: Sequence[str], +) -> NameState: + original_to_renamed = { + name: f"{name}_" + for name in compute_map.output_names + if name not in storage_indices + } + renamed_to_original = { + renamed: original for original, renamed in original_to_renamed.items() + } + return NameState( + internal_compute_map=compute_map.rename_dims(original_to_renamed), + renamed_to_original=renamed_to_original, + original_to_renamed=original_to_renamed, + ) + + +def _infer_temporal_inames( + compute_map: nisl.Map, + storage_indices: Sequence[str], +) -> tuple[str, ...]: + storage_set = frozenset(storage_indices) + return tuple(name for name in compute_map.output_names if name not in storage_set) + + +def _basic_set_to_predicates(bset: nisl.BasicSet) -> PredicateSet: + return frozenset( + p.Comparison( + pw_aff_to_expr(constraint.get_aff()), + "==" if constraint.is_equality() else ">=", + 0, + ) + for constraint in bset._reconstruct_isl_object().get_constraints() + ) + + +def _set_to_predicate_options( + set_: nisl.Set | nisl.BasicSet, +) -> tuple[PredicateSet, ...]: + if isinstance(set_, nisl.BasicSet): + if set_._reconstruct_isl_object().is_empty(): + return () + return (_basic_set_to_predicates(set_),) + + return tuple( + _basic_set_to_predicates(bset) + for bset in set_.get_basic_sets() + if not bset._reconstruct_isl_object().is_empty() + ) + + +def _gather_vars(expr: Expression) -> set[str]: + deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif isinstance(dep, p.Subscript) and isinstance(dep.aggregate, p.Variable): + var_names.add(dep.aggregate.name) + + return var_names + + +def _gather_usage_inames(sites: Sequence[UsageSite]) -> frozenset[str]: + return frozenset(set().union(*(site.domain_names for site in sites))) + + +def _next_ordinal(counters: dict[str, int], insn_id: str) -> int: + ordinal = counters.get(insn_id, 0) + counters[insn_id] = ordinal + 1 + return ordinal + + +def _normalize_subst_tag( + tag: Set[Tag] | Sequence[Tag] | Tag | None, +) -> frozenset[Tag] | None: + if tag is None: + return None + if isinstance(tag, Tag): + return frozenset([tag]) + return frozenset(tag) + + +def _add_predicates_to_domain( + domain: nisl.BasicSet, + predicates: PredicateSet, +) -> nisl.BasicSet: + predicate_constraints = [str(predicate) for predicate in predicates] + if not predicate_constraints: + return domain + return domain.add_constraint(predicate_constraints) + + +def _existing_name_mapping( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], +) -> Mapping[str, str]: + names = map_.names + return { + source: target + for source, target in name_mapping.items() + if source in names and target in names + } + + +def _normalize_renamed_dims( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], +) -> nisl.Map | nisl.BasicMap: + map_ = map_.equate_dims(_existing_name_mapping(map_, name_mapping)) + + names = map_.names + map_ = map_.project_out([ + renamed_name + for original_name, renamed_name in name_mapping.items() + if original_name in names and renamed_name in names + ]) + + names = map_.names + return map_.rename_dims({ + renamed_name: original_name + for original_name, renamed_name in name_mapping.items() + if original_name not in names and renamed_name in names + }) + + +def _empty_usage_map(local_domain: nisl.BasicSet, range_: nisl.Set) -> nisl.Map: + map_space = nisl.make_map_from_domain_and_range( + nisl.make_set(isl.Set.empty(local_domain.get_space())), + range_, + ).get_space() + return nisl.make_map(isl.Map.empty(map_space)) + + +def _map_to_output_exprs(map_: nisl.Map | nisl.BasicMap) -> tuple[Expression, ...]: + pwmaff = map_.as_pw_multi_aff() + return tuple( + pw_aff_to_expr(pwmaff.get_at(dim)) + for dim in range(pwmaff.dim(isl.dim_type.out)) + ) + + +def _map_to_named_output_exprs( + map_: nisl.Map | nisl.BasicMap, +) -> Mapping[str, Expression]: + pwmaff = map_.as_pw_multi_aff() + return { + pwmaff.get_dim_name(isl.dim_type.out, dim): pw_aff_to_expr(pwmaff.get_at(dim)) + for dim in range(pwmaff.dim(isl.dim_type.out)) + } + + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None, + ) -> None: + super().__init__(rule_mapping_ctx) + + self.subst_name: str = subst_name + self.subst_tag: frozenset[Tag] | None = _normalize_subst_tag(subst_tag) + self.sites: list[UsageSite] = [] + self._next_ordinal_by_insn: dict[str, int] = {} + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + if name != self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if self.subst_tag is not None and self.subst_tag != frozenset(tags or ()): + return super().map_subst_rule(name, tags, arguments, expn_state) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.sites.append( + UsageSite( + insn_id=expn_state.insn_id, + ordinal=_next_ordinal(self._next_ordinal_by_insn, expn_state.insn_id), + args=tuple(arg_ctx[arg_name] for arg_name in rule.arguments), + predicates=frozenset(expn_state.instruction.predicates), + ) + ) + + return 0 + + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Mapping[UsageKey, nisl.Map | nisl.BasicMap], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_ids: str | Sequence[str], + footprint: nisl.Set, + ) -> None: + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: frozenset[Tag] | None = _normalize_subst_tag(subst_tag) + self.usage_descriptors: Mapping[UsageKey, nisl.Map | nisl.BasicMap] = ( + usage_descriptors + ) + self.storage_indices: tuple[str, ...] = tuple(storage_indices) + self.temp_footprint: nisl.Set = footprint.move_dims( + frozenset(footprint.names) - frozenset(self.storage_indices), + isl.dim_type.param, + ) + self.temporary_name: str = temporary_name + self.compute_dep_ids: frozenset[str] = ( + frozenset([compute_insn_ids]) + if isinstance(compute_insn_ids, str) + else frozenset(compute_insn_ids) + ) + self.replaced_something: bool = False + self._next_ordinal_by_insn: dict[str, int] = {} + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + if name != self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if self.subst_tag is not None and self.subst_tag != frozenset(tags or ()): + return super().map_subst_rule(name, tags, arguments, expn_state) + + rule = self.rule_mapping_context.old_subst_rules[name] + if len(arguments) != len(rule.arguments): + raise ValueError( + f"Number of arguments passed to rule {name} " + f"does not match the signature of {name}." + ) + + access_key = ( + expn_state.insn_id, + _next_ordinal(self._next_ordinal_by_insn, expn_state.insn_id), + ) + + local_map = self.usage_descriptors.get(access_key) + if local_map is None: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if not local_map.range() <= self.temp_footprint: + return super().map_subst_rule(name, tags, arguments, expn_state) + + self.replaced_something = True + return var(self.temporary_name)[_map_to_output_exprs(local_map)] + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True, + ) -> LoopKernel: + new_insns = [] + for insn in kernel.instructions: + self.replaced_something = False + + if isinstance( + insn, MultiAssignmentBase + ) and not contains_a_subst_rule_invocation(kernel, insn): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr, insn=insn: self(expr, kernel, insn) + ) + + if self.replaced_something: + insn = insn.copy(depends_on=insn.depends_on | self.compute_dep_ids) + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + + +def _gather_usage_sites( + kernel: LoopKernel, + substitution: str, +) -> tuple[UsageSite, ...]: + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + gatherer = UsageSiteExpressionGatherer(ctx, substitution) + + _ = gatherer.map_kernel(kernel) + return tuple(gatherer.sites) + + +def _build_usage_info( + named_domain: nisl.BasicSet, + name_state: NameState, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + sites: Sequence[UsageSite], +) -> UsageInfo: + if not sites: + raise ValueError( + "compute() did not find any invocation of the requested substitution rule." + ) + + usage_inames = _gather_usage_inames(sites) + local_domain = named_domain.project_out_except(usage_inames) + global_usage_map = _empty_usage_map( + local_domain, name_state.internal_compute_map.domain() + ) + + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) + usage_descriptors: dict[UsageKey, nisl.Map | nisl.BasicMap] = {} + + for site in sites: + local_domain = _add_predicates_to_domain( + named_domain.project_out_except(site.domain_names), + site.predicates, + ) + usage_mpwaff = multi_pw_aff_from_exprs(site.args, global_usage_map.get_space()) + local_usage_map = nisl.make_map(usage_mpwaff.as_map()).intersect_domain( + local_domain + ) + global_usage_map = global_usage_map | local_usage_map + + local_storage_map = local_usage_map.apply_range(name_state.internal_compute_map) + local_storage_map = _normalize_renamed_dims( + local_storage_map, name_state.original_to_renamed + ) + if not local_usage_map.domain() <= local_storage_map.domain(): + continue + + non_param_names = (usage_inames - temporal_set) | storage_set + usage_descriptors[site.key] = local_storage_map.move_dims( + frozenset(local_storage_map.names) - non_param_names, + isl.dim_type.param, + ) + + return UsageInfo( + global_usage_map=global_usage_map, + local_storage_maps=usage_descriptors, + ) + + +def _build_footprint_info( + named_domain: nisl.BasicSet, + name_state: NameState, + usage_info: UsageInfo, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], +) -> FootprintInfo: + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) + + storage_map = usage_info.global_usage_map.apply_range( + name_state.internal_compute_map + ) + storage_map = _normalize_renamed_dims(storage_map, name_state.original_to_renamed) + + storage_map = storage_map.move_dims(temporal_set, isl.dim_type.param) + exact_footprint = storage_map.range() + exact_footprint = exact_footprint.project_out_except(temporal_set | storage_set) + exact_footprint = exact_footprint.move_dims(temporal_set, isl.dim_type.set) + + # Loopy domains are still restricted to a single BasicSet in this path. + footprint_isl = exact_footprint._reconstruct_isl_object() + loopy_footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) + + loopy_domain = named_domain & loopy_footprint + if len(loopy_domain.get_basic_sets()) != 1: + raise ValueError("New domain should be composed of a single basic set") + + return FootprintInfo(loopy_footprint=loopy_footprint, named_domain=loopy_domain) + + +def _build_compute_plan( + compute_map: nisl.Map, + named_domain: nisl.BasicSet, + sites: Sequence[UsageSite], + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + inames_to_advance: Sequence[str] | Literal["auto"] | None, +) -> ComputePlan: + name_state = _make_name_state(compute_map, storage_indices) + usage_info = _build_usage_info( + named_domain, + name_state, + storage_indices, + temporal_inames, + sites, + ) + footprint_info = _build_footprint_info( + named_domain, + name_state, + usage_info, + storage_indices, + temporal_inames, + ) + + if inames_to_advance == "auto": + inames_to_advance = _detect_inames_to_advance( + name_state.internal_compute_map, + usage_info.global_usage_map, + footprint_info.loopy_footprint, + storage_indices, + temporal_inames, + ) + + reuse_relations = ( + None + if inames_to_advance is None + else _build_reuse_relations( + name_state.internal_compute_map, + usage_info.global_usage_map, + footprint_info.loopy_footprint, + footprint_info.named_domain, + storage_indices, + temporal_inames, + frozenset(inames_to_advance), + ) + ) + + return ComputePlan( + name_state=name_state, + usage_info=usage_info, + footprint_info=footprint_info, + storage_indices=tuple(storage_indices), + temporal_inames=tuple(temporal_inames), + reuse_relations=reuse_relations, + ) + + +def _build_reuse_relations( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + named_domain: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + advancing_set: frozenset[str], +) -> ReuseRelations: + predecessor_context = _build_predecessor_context(temporal_inames, advancing_set) + shift_relation = _build_shift_relation( + compute_map, + global_usage_map, + footprint, + storage_indices, + predecessor_context, + ) + reusable_footprint = shift_relation.range() + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + refill_footprint = current_footprint - reusable_footprint + + normal_names = {_cur_name(name): name for name in footprint.names} + reusable_footprint = reusable_footprint.rename_dims(normal_names) + refill_footprint = refill_footprint.rename_dims(normal_names) + + reusable_footprint = reusable_footprint.gist( + named_domain.project_out_except(reusable_footprint.names) + ) + refill_footprint = refill_footprint.gist( + named_domain.project_out_except(refill_footprint.names) + ) + + return ReuseRelations( + shift_relation=shift_relation, + reusable_footprint=reusable_footprint, + refill_footprint=refill_footprint, + ) + + +def _build_predecessor_context( + temporal_inames: Sequence[str], + advancing_set: frozenset[str], +) -> nisl.Map: + constraints = [ + ( + f"{_cur_name(name)} = {_prev_name(name)} + 1" + if name in advancing_set + else f"{_cur_name(name)} = {_prev_name(name)}" + ) + for name in temporal_inames + ] + + return nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in temporal_inames) + + "] -> [" + + ", ".join(_cur_name(name) for name in temporal_inames) + + "]" + + (f" : {' and '.join(constraints)}" if constraints else "") + + " }" + ) + + +def _build_shift_relation( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + storage_indices: Sequence[str], + predecessor_context: nisl.Map, +) -> nisl.Map: + compute_map_cur = compute_map.rename_dims({ + name: _cur_name(name) for name in compute_map.output_names + }) + compute_map_prev = compute_map.rename_dims({ + name: _prev_name(name) for name in compute_map.output_names + }) + + reuse_map = ( + global_usage_map + .apply_range(compute_map_prev) + .reverse() + .apply_range(global_usage_map.apply_range(compute_map_cur)) + ) + reuse_map = reuse_map & predecessor_context + + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + previous_footprint = footprint.rename_dims({ + name: _prev_name(name) for name in footprint.names + }) + + reuse_map = reuse_map.intersect_domain(previous_footprint) + reuse_map = reuse_map.intersect_range(current_footprint) + + return reuse_map - _identity_storage_map(footprint, storage_indices) + + +def _detect_inames_to_advance( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], +) -> tuple[str, ...]: + candidates = [] + for name in temporal_inames: + shift_relation = _build_shift_relation( + compute_map, + global_usage_map, + footprint, + storage_indices, + _build_predecessor_context(temporal_inames, frozenset([name])), + ) + if not shift_relation._reconstruct_isl_object().is_empty(): + candidates.append(name) + + if len(candidates) > 1: + raise ValueError( + "Could not infer a unique advancing iname. " + f"Candidates are {candidates}; pass inames_to_advance explicitly." + ) + + return tuple(candidates) + + +def _identity_storage_map( + footprint: nisl.Set, + storage_indices: Sequence[str], +) -> nisl.Map: + return nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in footprint.names) + + "] -> [" + + ", ".join(_cur_name(name) for name in footprint.names) + + "] : " + + " and ".join( + f"{_cur_name(name)} = {_prev_name(name)}" for name in storage_indices + ) + + " }" + ) + + +def _make_shift_instructions( + reuse_map: nisl.Map, + reused_current: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + temporary_name: str, + compute_insn_id: str, +) -> tuple[tuple[InstructionBase, ...], tuple[str, ...], frozenset[str]]: + storage_reuse_map = reuse_map.project_out_except( + frozenset(_prev_name(name) for name in storage_indices) + | frozenset(_cur_name(name) for name in storage_indices) + ) + storage_reuse_map = storage_reuse_map.rename_dims({ + _cur_name(name): name for name in storage_indices + }) + + cur_to_prev_exprs = _map_to_named_output_exprs(storage_reuse_map.reverse()) + prev_storage_exprs = tuple( + cur_to_prev_exprs[_prev_name(name)] for name in storage_indices + ) + + shift_assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + shift_expression = var(temporary_name)[prev_storage_exprs] + + update_insns = [] + update_ids = [] + current_deps: frozenset[str] = frozenset() + shift_predicate_options = _set_to_predicate_options(reused_current) + + for i, predicates in enumerate(shift_predicate_options): + shift_insn_id = ( + f"{compute_insn_id}_shift" + if len(shift_predicate_options) == 1 + else f"{compute_insn_id}_shift_{i}" + ) + update_insns.append( + Assignment( + id=shift_insn_id, + assignee=shift_assignee, + expression=shift_expression, + within_inames=frozenset(temporal_inames) | frozenset(storage_indices), + predicates=predicates, + depends_on=current_deps, + ) + ) + update_ids.append(shift_insn_id) + current_deps = frozenset([shift_insn_id]) + + return tuple(update_insns), tuple(update_ids), current_deps + + +def _build_compute_instruction_info( + kernel: LoopKernel, + substitution: str, + name_state: NameState, +) -> ComputeInstructionInfo: + compute_map = name_state.internal_compute_map.rename_dims( + name_state.renamed_to_original + ) + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + storage_axis_to_global_expr = { + compute_pw_aff.get_dim_name(isl.dim_type.out, dim): pw_aff_to_expr( + compute_pw_aff.get_at(dim) + ) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + } + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + expr_subst_map = RuleAwareSubstitutionMapper( + ctx, + make_subst_func(storage_axis_to_global_expr), + within=parse_stack_match(None), + ) + + compute_expression = expr_subst_map( + kernel.substitutions[substitution].expression, + kernel, + cast("InstructionBase", cast("object", None)), + ) + + dependencies = frozenset().union( + *( + kernel.writer_map().get(var_name, frozenset()) + for var_name in _gather_vars(compute_expression) + ) + ) + + return ComputeInstructionInfo( + expression=compute_expression, + dependencies=dependencies, + within_inames=frozenset(compute_map.output_names), + ) + + +def _add_update_and_compute_instructions( + kernel: LoopKernel, + update_insns: Sequence[InstructionBase], + update_ids: Sequence[str], + refill_options: PredicateOptions, + final_deps: frozenset[str], + compute_info: ComputeInstructionInfo, + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_id: str, +) -> tuple[LoopKernel, tuple[str, ...]]: + new_insns = [*kernel.instructions, *update_insns] + update_ids = list(update_ids) + current_deps = final_deps + assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + + for i, predicates in enumerate(refill_options): + refill_insn_id = ( + compute_insn_id + if len(refill_options) == 1 + else f"{compute_insn_id}_refill_{i}" + ) + new_insns.append( + Assignment( + id=refill_insn_id, + assignee=assignee, + expression=compute_info.expression, + within_inames=compute_info.within_inames, + predicates=predicates, + depends_on=current_deps | compute_info.dependencies, + ) + ) + update_ids.append(refill_insn_id) + current_deps = frozenset([refill_insn_id]) + + return kernel.copy(instructions=new_insns), tuple(update_ids) + + +def _add_temporary( + kernel: LoopKernel, + footprint: nisl.Set, + storage_indices: Sequence[str], + temporary_name: str, + temporary_address_space: AddressSpace | None, + temporary_dtype: ToLoopyTypeConvertible, +) -> LoopKernel: + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) + bounds = tuple( + (pw_aff_to_expr(footprint.dim_min(dim)), pw_aff_to_expr(footprint.dim_max(dim))) + for dim in storage_indices + ) + base_indices = tuple(lower for lower, _upper in bounds) + temp_shape = tuple(upper - lower + 1 for lower, upper in bounds) + + new_temp_vars = dict(kernel.temporary_variables) + new_temp_vars[temporary_name] = TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=base_indices, + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices), + ) + return kernel.copy(temporary_variables=new_temp_vars) + + +def _lower_compute_plan( + kernel: LoopKernel, + substitution: str, + plan: ComputePlan, + domain_changer: DomainChanger, + temporary_name: str, + temporary_address_space: AddressSpace | None, + temporary_dtype: ToLoopyTypeConvertible, + compute_insn_id: str, +) -> LoopKernel: + domain = plan.footprint_info.named_domain.get_basic_sets()[ + 0 + ]._reconstruct_isl_object() + kernel = kernel.copy(domains=domain_changer.get_domains_with(domain)) + + update_insns: tuple[InstructionBase, ...] = () + update_insn_ids: tuple[str, ...] = () + refill_options: PredicateOptions = (None,) + final_deps: frozenset[str] = frozenset() + if plan.reuse_relations is not None: + update_insns, update_insn_ids, final_deps = _make_shift_instructions( + plan.reuse_relations.shift_relation, + plan.reuse_relations.reusable_footprint, + plan.storage_indices, + plan.temporal_inames, + temporary_name, + compute_insn_id, + ) + refill_options = _set_to_predicate_options( + plan.reuse_relations.refill_footprint + ) + + compute_info = _build_compute_instruction_info( + kernel, substitution, plan.name_state + ) + kernel, update_insn_ids = _add_update_and_compute_instructions( + kernel, + update_insns, + update_insn_ids, + refill_options, + final_deps, + compute_info, + plan.storage_indices, + temporary_name, + compute_insn_id, + ) + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + kernel = RuleInvocationReplacer( + ctx, + substitution, + None, + plan.usage_info.local_storage_maps, + plan.storage_indices, + temporary_name, + update_insn_ids, + plan.footprint_info.loopy_footprint, + ).map_kernel(kernel) + + return _add_temporary( + kernel, + plan.footprint_info.loopy_footprint, + plan.storage_indices, + temporary_name, + temporary_address_space, + temporary_dtype, + ) + + +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: nisl.Map, + storage_indices: Sequence[str], + temporal_inames: Sequence[str] | None = None, + inames_to_advance: Sequence[str] | Literal["auto"] | None = None, + temporary_name: str | None = None, + temporary_address_space: AddressSpace | None = None, + temporary_dtype: ToLoopyTypeConvertible = None, + compute_insn_id: str | None = None, +) -> LoopKernel: + """Compute a substitution into a temporary and replace covered uses.""" + temporary_name = temporary_name or f"{substitution}_temp" + compute_insn_id = compute_insn_id or f"{substitution}_compute" + if temporal_inames is None: + temporal_inames = _infer_temporal_inames(compute_map, storage_indices) + + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) + + plan = _build_compute_plan( + compute_map, + named_domain, + _gather_usage_sites(kernel, substitution), + tuple(storage_indices), + temporal_inames, + inames_to_advance, + ) + return _lower_compute_plan( + kernel, + substitution, + plan, + domain_changer, + temporary_name, + temporary_address_space, + temporary_dtype, + compute_insn_id, + ) diff --git a/test/test_compute.py b/test/test_compute.py new file mode 100644 index 000000000..b5442133d --- /dev/null +++ b/test/test_compute.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import namedisl as nisl +import numpy as np + +import loopy as lp +from loopy.transform.compute_stub import _gather_usage_sites, compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def test_compute_stub_simple_substitution_codegen() -> None: + knl = lp.make_kernel( + "{ [i] : 0 <= i < n }", + """ + u_(is) := u[is] + out[i] = u_(i) + """, + [ + lp.GlobalArg("u", shape=(16,), dtype=np.float32), + lp.GlobalArg("out", shape=(16,), dtype=np.float32, is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, n=16) + + knl = compute( + knl, + "u_", + compute_map=nisl.make_map("{ [is] -> [i_s] : is = i_s }"), + storage_indices=["i_s"], + temporal_inames=[], + temporary_name="u_tmp", + temporary_dtype=np.float32, + ) + + code = lp.generate_code_v2(knl).device_code() + assert "float u_tmp[16]" in code + assert "u_tmp[i_s] = u[i_s]" in code + assert "out[i] = u_tmp[i]" in code + + +def test_compute_stub_repeated_substitution_uses_are_unique() -> None: + knl = lp.make_kernel( + "{ [i] : 0 <= i < n }", + """ + u_(is) := u[is] + out[i] = u_(i) + u_(i + 1) {id=write_out} + """, + [ + lp.GlobalArg("u", shape=(16,), dtype=np.float32), + lp.GlobalArg("out", shape=(16,), dtype=np.float32, is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, n=16) + + sites = _gather_usage_sites(knl["loopy_kernel"], "u_") + + assert [site.key for site in sites] == [("write_out", 0), ("write_out", 1)] + assert sites[0].args != sites[1].args + + +def test_compute_stub_ring_buffer_codegen() -> None: + ntime = 128 + block_size = 32 + knl = lp.make_kernel( + "{ [t] : 1 <= t < ntime - 1 }", + """ + u_hist(ts) := u[ts] + u_next[t + 1] = 2*u_hist(t) - u_hist(t - 1) + """, + [ + lp.GlobalArg("u", dtype=np.float64, shape=(ntime,)), + lp.GlobalArg("u_next", dtype=np.float64, shape=(ntime,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, ntime=ntime) + knl = lp.split_iname( + knl, + "t", + block_size, + inner_iname="ti", + outer_iname="to", + ) + + knl = compute( + knl, + "u_hist", + compute_map=nisl.make_map("{ [ts] -> [to, ti, tb] : tb = 32*to + ti - ts }"), + storage_indices=["tb"], + inames_to_advance="auto", + temporary_name="u_time_buf", + temporary_dtype=np.float64, + ) + + code = lp.generate_code_v2(knl).device_code() + assert "double u_time_buf[2]" in code + assert "u_time_buf[tb] = u_time_buf[0]" in code + assert "u_time_buf[tb] = u[-1 * tb + ti + 32 * to]" in code + assert "u_next[1 + ti + 32 * to]" in code diff --git a/test/test_transform.py b/test/test_transform.py index b6f87c108..37edb61a6 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -20,6 +20,7 @@ THE SOFTWARE. """ +from collections.abc import Mapping import logging import numpy as np @@ -1745,6 +1746,108 @@ def test_duplicate_iname_not_read_only_nested(ctx_factory: cl.CtxFactory): lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) +@pytest.mark.parametrize("case", ( + {"M": 128, "N": 128, "K": 128, "BM": 32, "BN": 32, "BK": 16}, + {"M": 200, "N": 200, "K": 200, "BM": 32, "BN": 32, "BK": 16}, +)) +def test_compute_simple_tiled_matmul( + ctx_factory: cl.CtxFactory, + case: Mapping[str, int] + ): + + import namedisl as nisl + + M = case["M"] + N = case["N"] + K = case["K"] + bm = case["BM"] + bn = case["BN"] + bk = case["BK"] + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, + is_output=True) + ] + ) + + knl = lp.fix_parameters(knl, M=M, N=N, K=K) + + # shared memory tile-level splitting + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [ii_s, io, ki_s, ko] : + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ki_s, ko, ji_s, jo] : + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + from loopy.transform.compute import compute + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols + + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols + + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + a = np.random.randn(M, K) + b = np.random.randn(K, N) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + import numpy.linalg as la + assert (la.norm((a @ b) - out) / la.norm(a @ b)) < 1e-15 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: