From 375066c4068c09ccf3c3c88147ab3fe5fc09d30e Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:28:05 -0600 Subject: [PATCH 01/27] add compute in its current form --- loopy/target/c/compyte | 2 +- loopy/transform/compute.py | 204 +++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 loopy/transform/compute.py diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 2b168ca39..955160ac2 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 +Subproject commit 955160ac2f504dabcd8641471a56146fa1afe35d diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py new file mode 100644 index 000000000..4b325a7af --- /dev/null +++ b/loopy/transform/compute.py @@ -0,0 +1,204 @@ +# DomainChanger +# iname nesting order <=> tree +# loop transformations +# - traverse syntax tree +# - affine map inames +# +# index views for warp tiling + +from pymbolic.mapper.substitutor import make_subst_func +from loopy.kernel import LoopKernel +import islpy as isl + +import loopy as lp +from loopy.kernel.data import AddressSpace +from loopy.kernel.function_interface import CallableKernel, ScalarCallable +from loopy.kernel.instruction import MultiAssignmentBase +from loopy.kernel.tools import DomainChanger +from loopy.match import parse_stack_match +from loopy.symbolic import RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, aff_from_expr, aff_to_expr, pw_aff_to_expr +from loopy.transform.precompute import RuleInvocationGatherer, RuleInvocationReplacer, contains_a_subst_rule_invocation +from loopy.translation_unit import TranslationUnit + +import pymbolic.primitives as prim +from pymbolic import var + +from pytools.tag import Tag + + +def compute( + t_unit: TranslationUnit, +substitution: str, + *args, + **kwargs + ) -> TranslationUnit: + """ + Entrypoint for performing a compute transformation on all kernels in a + translation unit. See :func:`_compute_inner` for more details. + """ + + assert isinstance(t_unit, TranslationUnit) + new_callables = {} + + for id, callable in t_unit.callables_table.items(): + if isinstance(callable, CallableKernel): + kernel = _compute_inner( + callable.subkernel, + substitution, + *args, **kwargs + ) + + callable = callable.copy(subkernel=kernel) + elif isinstance(callable, ScalarCallable): + pass + else: + raise NotImplementedError() + + new_callables[id] = callable + + return t_unit + +def _compute_inner( + kernel: LoopKernel, + substitution: str, + transform_map: isl.Map, + compute_map: isl.Map, + storage_inames: list[str], + default_tag: Tag | str | None = None, + temporary_address_space: AddressSpace | None = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg transform_map: An :class:`isl.Map` representing the affine + transformation from the original iname domain to the transformed iname + domain. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". This map describes + """ + + if not temporary_address_space: + temporary_address_space = AddressSpace.GLOBAL + + # {{{ normalize names + + iname_to_storage_map = { + iname : (iname + "_store" if iname in kernel.all_inames() else iname) + for iname in storage_inames + } + + new_storage_axes = list(iname_to_storage_map.values()) + + for dim in range(compute_map.dim(isl.dim_type.out)): + for iname, storage_ax in iname_to_storage_map.items(): + if compute_map.get_dim_name(isl.dim_type.out, dim) == iname: + compute_map = compute_map.set_dim_name( + isl.dim_type.out, dim, storage_ax + ) + + # }}} + + # {{{ update kernel domain to contain storage inames + + storage_domain = compute_map.range().project_out_except( + new_storage_axes, [isl.dim_type.set] + ) + + # FIXME: likely need to do some more digging to find proper domain to update + new_domain = kernel.domains[0] + for ax in new_storage_axes: + new_domain = new_domain.add_dims(isl.dim_type.set, 1) + + new_domain = new_domain.set_dim_name( + isl.dim_type.set, + new_domain.dim(isl.dim_type.set) - 1, + ax + ) + + new_domain, storage_domain = isl.align_two(new_domain, storage_domain) + new_domain = new_domain & storage_domain + kernel = kernel.copy(domains=[new_domain]) + + # }}} + + # {{{ express substitution inputs as pw affs of (storage, time) names + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + subst_inp_names = [ + compute_map.get_dim_name(isl.dim_type.in_, i) + for i in range(compute_map.dim(isl.dim_type.in_)) + ] + storage_ax_to_global_expr = dict.fromkeys(subst_inp_names) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)): + subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim) + storage_ax_to_global_expr[subst_inp] = \ + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + + # }}} + + # {{{ generate instruction from compute map + + rule_mapping_ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + + expr_subst_map = RuleAwareSubstitutionMapper( + rule_mapping_ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map(subst_expr, kernel, None) + + temporary_name = substitution + "_temp" + assignee = var(temporary_name)[tuple( + var(iname) for iname in new_storage_axes + )] + + compute_insn_id = substitution + "_compute" + compute_insn = lp.Assignment( + id=compute_insn_id, + assignee=assignee, + expression=compute_expression, + ) + + compute_dep_id = compute_insn_id + new_insns = [compute_insn] + + # add global sync if we are storing in global memory + if temporary_address_space == lp.AddressSpace.GLOBAL: + gbarrier_id = kernel.make_unique_instruction_id( + based_on=substitution + "_barrier" + ) + + from loopy.kernel.instruction import BarrierInstruction + barrier_insn = BarrierInstruction( + id=gbarrier_id, + depends_on=frozenset([compute_insn_id]), + synchronization_kind="global", + mem_kind="global" + ) + + compute_dep_id = gbarrier_id + + # }}} + + # {{{ replace substitution rule with newly created instruction + + # FIXME: get these properly (see `precompute`) + subst_name = substitution + subst_tag = None + within = None # do we need this? + + + + # }}} + + return kernel From 745f841ca174d2cf97f1bcf0335330436e991b8f Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:41:05 -0600 Subject: [PATCH 02/27] align compyte with inducer/main --- loopy/target/c/compyte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 955160ac2..2b168ca39 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 955160ac2f504dabcd8641471a56146fa1afe35d +Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 From 88966446dc586b0a0d2ef44b9f410bc330774f21 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:49:42 -0600 Subject: [PATCH 03/27] clean up comments and typos --- loopy/transform/compute.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 4b325a7af..e1e54fa2d 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,27 +1,19 @@ -# DomainChanger -# iname nesting order <=> tree -# loop transformations -# - traverse syntax tree -# - affine map inames -# -# index views for warp tiling - -from pymbolic.mapper.substitutor import make_subst_func -from loopy.kernel import LoopKernel import islpy as isl import loopy as lp +from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace from loopy.kernel.function_interface import CallableKernel, ScalarCallable -from loopy.kernel.instruction import MultiAssignmentBase -from loopy.kernel.tools import DomainChanger from loopy.match import parse_stack_match -from loopy.symbolic import RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, aff_from_expr, aff_to_expr, pw_aff_to_expr -from loopy.transform.precompute import RuleInvocationGatherer, RuleInvocationReplacer, contains_a_subst_rule_invocation +from loopy.symbolic import ( + RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext, + pw_aff_to_expr +) from loopy.translation_unit import TranslationUnit -import pymbolic.primitives as prim from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func from pytools.tag import Tag @@ -81,7 +73,7 @@ def _compute_inner( :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of - storage indices and `l` is a vector of "timestamps". This map describes + storage indices and `l` is a vector of "timestamps". """ if not temporary_address_space: From 6230e11ff78e97bd074a19910c9094f38cbf0a69 Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 00:46:09 -0600 Subject: [PATCH 04/27] switch to alpha namedisl usage --- loopy/transform/compute.py | 119 +++++++------------------------------ 1 file changed, 22 insertions(+), 97 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index e1e54fa2d..cb01401db 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,16 +1,18 @@ import islpy as isl +import namedisl as nisl import loopy as lp from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace -from loopy.kernel.function_interface import CallableKernel, ScalarCallable +from loopy.kernel.instruction import MultiAssignmentBase from loopy.match import parse_stack_match from loopy.symbolic import ( RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, pw_aff_to_expr ) -from loopy.translation_unit import TranslationUnit +from loopy.transform.precompute import contains_a_subst_rule_invocation +from loopy.translation_unit import for_each_kernel from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func @@ -18,43 +20,11 @@ from pytools.tag import Tag +@for_each_kernel def compute( - t_unit: TranslationUnit, -substitution: str, - *args, - **kwargs - ) -> TranslationUnit: - """ - Entrypoint for performing a compute transformation on all kernels in a - translation unit. See :func:`_compute_inner` for more details. - """ - - assert isinstance(t_unit, TranslationUnit) - new_callables = {} - - for id, callable in t_unit.callables_table.items(): - if isinstance(callable, CallableKernel): - kernel = _compute_inner( - callable.subkernel, - substitution, - *args, **kwargs - ) - - callable = callable.copy(subkernel=kernel) - elif isinstance(callable, ScalarCallable): - pass - else: - raise NotImplementedError() - - new_callables[id] = callable - - return t_unit - -def _compute_inner( kernel: LoopKernel, substitution: str, - transform_map: isl.Map, - compute_map: isl.Map, + compute_map: isl.Map | nisl.Map, storage_inames: list[str], default_tag: Tag | str | None = None, temporary_address_space: AddressSpace | None = None @@ -67,14 +37,12 @@ def _compute_inner( :arg substitution: The substitution rule for which the compute transform should be applied. - :arg transform_map: An :class:`isl.Map` representing the affine - transformation from the original iname domain to the transformed iname - domain. - :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". """ + if isinstance(compute_map, isl.Map): + compute_map = nisl.make_map(compute_map) if not temporary_address_space: temporary_address_space = AddressSpace.GLOBAL @@ -86,52 +54,29 @@ def _compute_inner( for iname in storage_inames } - new_storage_axes = list(iname_to_storage_map.values()) - - for dim in range(compute_map.dim(isl.dim_type.out)): - for iname, storage_ax in iname_to_storage_map.items(): - if compute_map.get_dim_name(isl.dim_type.out, dim) == iname: - compute_map = compute_map.set_dim_name( - isl.dim_type.out, dim, storage_ax - ) + compute_map = compute_map.rename_dims(iname_to_storage_map) # }}} # {{{ update kernel domain to contain storage inames - storage_domain = compute_map.range().project_out_except( - new_storage_axes, [isl.dim_type.set] - ) + new_storage_axes = list(iname_to_storage_map.values()) - # FIXME: likely need to do some more digging to find proper domain to update + # FIXME: use DomainChanger to add domain to kernel + storage_domain = compute_map.range().project_out_except(new_storage_axes) new_domain = kernel.domains[0] - for ax in new_storage_axes: - new_domain = new_domain.add_dims(isl.dim_type.set, 1) - - new_domain = new_domain.set_dim_name( - isl.dim_type.set, - new_domain.dim(isl.dim_type.set) - 1, - ax - ) - - new_domain, storage_domain = isl.align_two(new_domain, storage_domain) - new_domain = new_domain & storage_domain - kernel = kernel.copy(domains=[new_domain]) # }}} # {{{ express substitution inputs as pw affs of (storage, time) names compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - subst_inp_names = [ - compute_map.get_dim_name(isl.dim_type.in_, i) - for i in range(compute_map.dim(isl.dim_type.in_)) - ] - storage_ax_to_global_expr = dict.fromkeys(subst_inp_names) - for dim in range(compute_pw_aff.dim(isl.dim_type.out)): - subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim) - storage_ax_to_global_expr[subst_inp] = \ - pw_aff_to_expr(compute_pw_aff.get_at(dim)) + + # FIXME: remove PwAff._obj usage when ready + storage_ax_to_global_expr = { + dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)._obj) + for dim_name in compute_map.dim_type_names(isl.dim_type.in_) + } # }}} @@ -161,34 +106,14 @@ def _compute_inner( expression=compute_expression, ) - compute_dep_id = compute_insn_id - new_insns = [compute_insn] - - # add global sync if we are storing in global memory - if temporary_address_space == lp.AddressSpace.GLOBAL: - gbarrier_id = kernel.make_unique_instruction_id( - based_on=substitution + "_barrier" - ) - - from loopy.kernel.instruction import BarrierInstruction - barrier_insn = BarrierInstruction( - id=gbarrier_id, - depends_on=frozenset([compute_insn_id]), - synchronization_kind="global", - mem_kind="global" - ) - - compute_dep_id = gbarrier_id - # }}} # {{{ replace substitution rule with newly created instruction - # FIXME: get these properly (see `precompute`) - subst_name = substitution - subst_tag = None - within = None # do we need this? - + for insn in kernel.instructions: + if contains_a_subst_rule_invocation(kernel, insn) \ + and isinstance(insn, MultiAssignmentBase): + print(insn) # }}} From 80839bad239b9833b7804ab2f750a7bd03fb38a8 Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 12:40:49 -0600 Subject: [PATCH 05/27] start using namedisl in places other than compute --- loopy/symbolic.py | 37 ++++++++++++++++++++++++++----------- loopy/transform/compute.py | 2 +- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ba6d71a80..442eb8572 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -48,6 +48,7 @@ from constantdict import constantdict from typing_extensions import Self, override +import namedisl as nisl import islpy as isl import pymbolic.primitives as p import pytools.lex @@ -2044,23 +2045,37 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: +def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var denom = aff.get_denominator_val().to_python() - result = (aff.get_constant_val()*denom).to_python() - for dt in [isl.dim_type.in_, isl.dim_type.param]: - for i in range(aff.dim(dt)): - coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if isinstance(aff, isl.Aff): + for dt in [isl.dim_type.in_, isl.dim_type.param]: + for i in range(aff.dim(dt)): + coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if coeff: + dim_name = not_none(aff.get_dim_name(dt, i)) + result += coeff*var(dim_name) + + for i in range(aff.dim(isl.dim_type.div)): + coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() + if coeff: + result += coeff*aff_to_expr(aff.get_div(i)) + + else: + in_names = set(aff.dim_type_names(isl.dim_type.in_)) + param_names = set(aff.dim_type_names(isl.dim_type.param)) + + for name in in_names | param_names: + coeff = (aff.get_coefficient_val(name) * denom).to_python() if coeff: - dim_name = not_none(aff.get_dim_name(dt, i)) - result += coeff*var(dim_name) + result = coeff * var(name) - for i in range(aff.dim(isl.dim_type.div)): - coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() - if coeff: - result += coeff*aff_to_expr(aff.get_div(i)) + for name in aff.dim_type_names(isl.dim_type.div): + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result += coeff * aff_to_expr(aff.get_div(name)) assert not isinstance(result, complex) return flatten(result // denom) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index cb01401db..b3e06f2e9 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -74,7 +74,7 @@ def compute( # FIXME: remove PwAff._obj usage when ready storage_ax_to_global_expr = { - dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)._obj) + dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) for dim_name in compute_map.dim_type_names(isl.dim_type.in_) } From c1ba35bb2d20f093f86df785378b63b1092ba0dd Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 12:43:44 -0600 Subject: [PATCH 06/27] add namedisl objects to a type signature --- loopy/symbolic.py | 3 ++- loopy/transform/compute.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 442eb8572..f47e32f9d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -2048,6 +2048,7 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var + # FIXME: remove this once namedisl is the standard in loopy denom = aff.get_denominator_val().to_python() result = (aff.get_constant_val()*denom).to_python() if isinstance(aff, isl.Aff): @@ -2082,7 +2083,7 @@ def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: def pw_aff_to_expr( - pw_aff: int | isl.PwAff | isl.Aff, + pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff, int_ok: bool = False ) -> ArithmeticExpression: if isinstance(pw_aff, int): diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index b3e06f2e9..59ddf8a2e 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -72,7 +72,6 @@ def compute( compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - # FIXME: remove PwAff._obj usage when ready storage_ax_to_global_expr = { dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) for dim_name in compute_map.dim_type_names(isl.dim_type.in_) From 7265536a3a1fc4cd24883946876b4f16a12bcac2 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Mar 2026 09:22:53 -0500 Subject: [PATCH 07/27] compute transform up to and including instruction creation + insertion; missing invocation replacement --- loopy/transform/compute.py | 176 ++++++++++++++++++++++++------------- 1 file changed, 116 insertions(+), 60 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 59ddf8a2e..5c3c02130 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,32 +1,55 @@ -import islpy as isl +import loopy as lp +from loopy.kernel.tools import DomainChanger import namedisl as nisl -import loopy as lp from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace -from loopy.kernel.instruction import MultiAssignmentBase from loopy.match import parse_stack_match from loopy.symbolic import ( + RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, - pw_aff_to_expr + pw_aff_to_expr, + pwaff_from_expr +) +from loopy.transform.precompute import ( + RuleInvocationGatherer, + contains_a_subst_rule_invocation ) -from loopy.transform.precompute import contains_a_subst_rule_invocation from loopy.translation_unit import for_each_kernel - from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func -from pytools.tag import Tag +import islpy as isl +import pymbolic.primitives as p +from pymbolic.mapper.dependency import DependencyMapper + +from pymbolic.mapper import IdentityMapper + + +def gather_vars(expr) -> set[str]: + deps = DependencyMapper()(expr) + return { + dep.name + for dep in deps + if isinstance(dep, p.Variable) + } +def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): + names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) + set_names = [name for name in names] + + return isl.Space.create_from_names( + ctx, + set=set_names + ) @for_each_kernel def compute( kernel: LoopKernel, substitution: str, - compute_map: isl.Map | nisl.Map, - storage_inames: list[str], - default_tag: Tag | str | None = None, + compute_map: nisl.Map, + storage_indices: frozenset[str], temporary_address_space: AddressSpace | None = None ) -> LoopKernel: """ @@ -40,52 +63,86 @@ def compute( :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". - """ - if isinstance(compute_map, isl.Map): - compute_map = nisl.make_map(compute_map) - - if not temporary_address_space: - temporary_address_space = AddressSpace.GLOBAL - - # {{{ normalize names - - iname_to_storage_map = { - iname : (iname + "_store" if iname in kernel.all_inames() else iname) - for iname in storage_inames - } - - compute_map = compute_map.rename_dims(iname_to_storage_map) - - # }}} - # {{{ update kernel domain to contain storage inames - - new_storage_axes = list(iname_to_storage_map.values()) - - # FIXME: use DomainChanger to add domain to kernel - storage_domain = compute_map.range().project_out_except(new_storage_axes) - new_domain = kernel.domains[0] - - # }}} + :arg storage_indices: A :class:`frozenset` of names of storage indices. Used + to create inames for the loops that cover the required footprint. + """ + compute_map = compute_map._reconstruct_isl_object() - # {{{ express substitution inputs as pw affs of (storage, time) names + # construct union of usage footprints to determine bounds on compute inames + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + inv_gatherer = RuleInvocationGatherer( + ctx, kernel, substitution, None, parse_stack_match(None) + ) + for insn in kernel.instructions: + if (isinstance(insn, lp.MultiAssignmentBase) and + contains_a_subst_rule_invocation(kernel, insn)): + for assignee in insn.assignees: + _ = inv_gatherer(assignee, kernel, insn) + _ = inv_gatherer(insn.expression, kernel, insn) + + access_descriptors = inv_gatherer.access_descriptors + + acc_desc_exprs = [ + arg + for ad in access_descriptors + if ad.args is not None + for arg in ad.args + ] + + space = space_from_exprs(acc_desc_exprs) + + footprint = isl.Set.empty(isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + )) + for ad in access_descriptors: + if not ad.args: + continue + + nout = len(ad.args) + + range_space = isl.Space.alloc(space.get_ctx(), 0, nout, 0).domain() + map_space = space.map_from_domain_and_range(range_space) + pw_multi_aff = isl.MultiPwAff.zero(map_space) + + for i, arg in enumerate(ad.args): + if arg is not None: + pw_multi_aff = pw_multi_aff.set_pw_aff( + i, + pwaff_from_expr(space, arg) + ) + + usage_map = pw_multi_aff.as_map() + iname_to_timespace = usage_map.apply_range(compute_map).coalesce() + iname_to_storage = iname_to_timespace.project_out_except( + storage_indices, [isl.dim_type.out] + ) + + footprint = footprint | iname_to_storage.range() + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + domain = domain_changer.domain + + footprint, domain = isl.align_two(footprint, domain) + domain = domain & footprint + + new_domains = domain_changer.get_domains_with(domain) + kernel = kernel.copy(domains=new_domains) + + # create compute instruction in kernel compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - storage_ax_to_global_expr = { - dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) - for dim_name in compute_map.dim_type_names(isl.dim_type.in_) + compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } - # }}} - - # {{{ generate instruction from compute map - - rule_mapping_ctx = SubstitutionRuleMappingContext( - kernel.substitutions, kernel.get_var_name_generator()) - expr_subst_map = RuleAwareSubstitutionMapper( - rule_mapping_ctx, + ctx, make_subst_func(storage_ax_to_global_expr), within=parse_stack_match(None) ) @@ -95,26 +152,25 @@ def compute( temporary_name = substitution + "_temp" assignee = var(temporary_name)[tuple( - var(iname) for iname in new_storage_axes + var(iname) for iname in storage_indices )] + within_inames = frozenset( + compute_map.get_dim_name(isl.dim_type.out, dim) + for dim in range(compute_map.dim(isl.dim_type.out)) + ) + compute_insn_id = substitution + "_compute" compute_insn = lp.Assignment( id=compute_insn_id, assignee=assignee, expression=compute_expression, + within_inames=within_inames ) - # }}} - - # {{{ replace substitution rule with newly created instruction - - for insn in kernel.instructions: - if contains_a_subst_rule_invocation(kernel, insn) \ - and isinstance(insn, MultiAssignmentBase): - print(insn) - - - # }}} + new_insns = list(kernel.instructions) + new_insns.append(compute_insn) + kernel = kernel.copy(instructions=new_insns) + print(kernel) return kernel From 56af4fe810ca2c7a9813cb33a51f7e45f08498dc Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 17:43:06 -0500 Subject: [PATCH 08/27] invocation replacement; dependencies still need handling --- loopy/transform/compute.py | 291 +++++++++++++++++++++++++++++++------ 1 file changed, 248 insertions(+), 43 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 5c3c02130..02df1eaa7 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,30 +1,35 @@ +from collections.abc import Sequence, Set +from dataclasses import dataclass +from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger import namedisl as nisl from loopy.kernel import LoopKernel -from loopy.kernel.data import AddressSpace -from loopy.match import parse_stack_match +from loopy.kernel.data import AddressSpace, SubstitutionRule +from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( + ExpansionState, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, + SubstitutionRuleExpander, SubstitutionRuleMappingContext, + get_dependencies, pw_aff_to_expr, pwaff_from_expr ) from loopy.transform.precompute import ( - RuleInvocationGatherer, contains_a_subst_rule_invocation ) from loopy.translation_unit import for_each_kernel -from pymbolic import var +from pymbolic import ArithmeticExpression, var from pymbolic.mapper.substitutor import make_subst_func import islpy as isl import pymbolic.primitives as p from pymbolic.mapper.dependency import DependencyMapper - -from pymbolic.mapper import IdentityMapper +from pymbolic.typing import Expression +from pytools.tag import Tag def gather_vars(expr) -> set[str]: @@ -35,6 +40,7 @@ def gather_vars(expr) -> set[str]: if isinstance(dep, p.Variable) } + def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) set_names = [name for name in names] @@ -44,12 +50,181 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) + +@dataclass(frozen=True) +class UsageDescriptor: + usage: Sequence[Expression] + global_map: isl.Map + local_map: isl.Map + + @override + def __str__(self): + return ( + f"USAGE = {self.usage}\n" + + f"GLOBAL MAP = {self.global_map}\n" + + f"LOCAL MAP = {self.local_map}" + ) + + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + """ + Gathers all expressions used as inputs to a particular substitution rule, + identified by name. + """ + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_expander: SubstitutionRuleExpander, + kernel: LoopKernel, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None + ) -> None: + + super().__init__(rule_mapping_ctx) + + self.subst_expander: SubstitutionRuleExpander = subst_expander + self.kernel: LoopKernel = kernel + self.subst_name: str = subst_name + self.subst_tag: Set[Tag] | None = ( + {subst_tag} if isinstance(subst_tag, Tag) else subst_tag + ) + + self.usage_expressions: list[Sequence[Expression]] = [] + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + + if name != self.subst_name: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + if self.subst_tag is not None and self.subst_tag != tags: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.usage_expressions.append([ + arg_ctx[arg_name] for arg_name in rule.arguments + ]) + + return 0 + + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Sequence[UsageDescriptor], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_id: str, + compute_map: isl.Map + ) -> None: + + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: Sequence[Tag] | None = subst_tag + + self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors + self.storage_indices: Sequence[str] = storage_indices + + self.temporary_name: str = temporary_name + self.compute_insn_id: str = compute_insn_id + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState + ) -> Expression: + + if not name == self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + args = [arg_ctx[arg_name] for arg_name in rule.arguments] + + # FIXME: footprint check? likely required if user supplies bounds on + # storage indices because we are not guaranteed to capture the footprint + # of all usage sites + + if not len(arguments) == len(rule.arguments): + raise ValueError("Number of arguments passed to rule {name} ", + "does not match the signature of {name}.") + + index_exprs: Sequence[Expression] = [] + for usage_descr in self.usage_descriptors: + if args == usage_descr.usage: + local_pwmaff = usage_descr.local_map.as_pw_multi_aff() + + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + + break + + new_expression = var(self.temporary_name)[tuple(index_exprs)] + + return new_expression + + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: + + new_insns = [] + for insn in kernel.instructions: + if (isinstance(insn, lp.MultiAssignmentBase) and not + contains_a_subst_rule_invocation(kernel, insn)): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr: self(expr, kernel, insn) + ) + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + + @for_each_kernel def compute( kernel: LoopKernel, substitution: str, compute_map: nisl.Map, - storage_indices: frozenset[str], + storage_indices: Sequence[str], + + # NOTE: how can we deduce this? + temporal_inames: Sequence[str], + + temporary_name: str | None = None, temporary_address_space: AddressSpace | None = None ) -> LoopKernel: """ @@ -64,65 +239,76 @@ def compute( substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". - :arg storage_indices: A :class:`frozenset` of names of storage indices. Used - to create inames for the loops that cover the required footprint. + :arg storage_indices: An ordered sequence of names of storage indices. Used + to create inames for the loops that cover the required set of compute points. """ compute_map = compute_map._reconstruct_isl_object() # construct union of usage footprints to determine bounds on compute inames ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) - inv_gatherer = RuleInvocationGatherer( - ctx, kernel, substitution, None, parse_stack_match(None) + expander = SubstitutionRuleExpander(kernel.substitutions) + expr_gatherer = UsageSiteExpressionGatherer( + ctx, expander, kernel, substitution, None ) - for insn in kernel.instructions: - if (isinstance(insn, lp.MultiAssignmentBase) and - contains_a_subst_rule_invocation(kernel, insn)): - for assignee in insn.assignees: - _ = inv_gatherer(assignee, kernel, insn) - _ = inv_gatherer(insn.expression, kernel, insn) + _ = expr_gatherer.map_kernel(kernel) + usage_exprs = expr_gatherer.usage_expressions - access_descriptors = inv_gatherer.access_descriptors - - acc_desc_exprs = [ - arg - for ad in access_descriptors - if ad.args is not None - for arg in ad.args + all_exprs = [ + expr + for usage in usage_exprs + for expr in usage ] - space = space_from_exprs(acc_desc_exprs) + space = space_from_exprs(all_exprs) - footprint = isl.Set.empty(isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - )) - for ad in access_descriptors: - if not ad.args: - continue + footprint = isl.Set.empty( + isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + ) - nout = len(ad.args) + usage_descrs: Sequence[UsageDescriptor] = [] + for usage in usage_exprs: - range_space = isl.Space.alloc(space.get_ctx(), 0, nout, 0).domain() + range_space = isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) map_space = space.map_from_domain_and_range(range_space) + pw_multi_aff = isl.MultiPwAff.zero(map_space) - for i, arg in enumerate(ad.args): - if arg is not None: - pw_multi_aff = pw_multi_aff.set_pw_aff( - i, - pwaff_from_expr(space, arg) - ) + for i, arg in enumerate(usage): + pw_multi_aff = pw_multi_aff.set_pw_aff( + i, + pwaff_from_expr(space, arg) + ) usage_map = pw_multi_aff.as_map() - iname_to_timespace = usage_map.apply_range(compute_map).coalesce() + + iname_to_timespace = usage_map.apply_range(compute_map) iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) + local_map = iname_to_storage.project_out_except( + kernel.all_inames() - frozenset(temporal_inames), + [isl.dim_type.in_] + ) + footprint = footprint | iname_to_storage.range() + usage_descrs.append( + UsageDescriptor( + usage, + iname_to_storage, + local_map + ) + ) + # add compute inames to domain / kernel domain_changer = DomainChanger(kernel, kernel.all_inames()) domain = domain_changer.domain @@ -138,7 +324,7 @@ def compute( storage_ax_to_global_expr = { compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : pw_aff_to_expr(compute_pw_aff.get_at(dim)) - for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } expr_subst_map = RuleAwareSubstitutionMapper( @@ -150,7 +336,9 @@ def compute( subst_expr = kernel.substitutions[substitution].expression compute_expression = expr_subst_map(subst_expr, kernel, None) - temporary_name = substitution + "_temp" + if not temporary_name: + temporary_name = substitution + "_temp" + assignee = var(temporary_name)[tuple( var(iname) for iname in storage_indices )] @@ -172,5 +360,22 @@ def compute( new_insns.append(compute_insn) kernel = kernel.copy(instructions=new_insns) + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + + replacer = RuleInvocationReplacer( + ctx, + substitution, + None, + usage_descrs, + storage_indices, + temporary_name, + compute_insn_id, + compute_map + ) + + kernel = replacer.map_kernel(kernel) + print(kernel) return kernel From 9d121834435cd1a7ff7bf95805bcef13be6d356e Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 22:59:58 -0500 Subject: [PATCH 09/27] rough sketch of compute transform; inames not schedulable because of duplicates --- loopy/transform/compute.py | 100 +++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 02df1eaa7..342a3911c 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,8 +1,9 @@ -from collections.abc import Sequence, Set +from collections.abc import Mapping, Sequence, Set from dataclasses import dataclass from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger +from loopy.types import to_loopy_type import namedisl as nisl from loopy.kernel import LoopKernel @@ -50,22 +51,6 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) - -@dataclass(frozen=True) -class UsageDescriptor: - usage: Sequence[Expression] - global_map: isl.Map - local_map: isl.Map - - @override - def __str__(self): - return ( - f"USAGE = {self.usage}\n" + - f"GLOBAL MAP = {self.global_map}\n" + - f"LOCAL MAP = {self.local_map}" - ) - - class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): """ Gathers all expressions used as inputs to a particular substitution rule, @@ -129,7 +114,7 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Sequence[UsageDescriptor], + usage_descriptors: Mapping[tuple[Expression, ...], isl.Map], storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, @@ -141,12 +126,19 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors + self.usage_descriptors: Mapping[tuple[Expression, ...], isl.Map] = \ + usage_descriptors self.storage_indices: Sequence[str] = storage_indices self.temporary_name: str = temporary_name self.compute_insn_id: str = compute_insn_id + # FIXME: may not always be the case (i.e. global barrier between + # compute insn and uses) + self.compute_dep_id: str = compute_insn_id + + self.replaced_something: bool = False + @override def map_subst_rule( @@ -175,17 +167,17 @@ def map_subst_rule( "does not match the signature of {name}.") index_exprs: Sequence[Expression] = [] - for usage_descr in self.usage_descriptors: - if args == usage_descr.usage: - local_pwmaff = usage_descr.local_map.as_pw_multi_aff() - for dim in range(local_pwmaff.dim(isl.dim_type.out)): - index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + # FIXME: make self.usage_descriptors a constantdict + local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() - break + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) new_expression = var(self.temporary_name)[tuple(index_exprs)] + self.replaced_something = True + return new_expression @@ -198,8 +190,10 @@ def map_kernel( map_tvs: bool = True ) -> LoopKernel: - new_insns = [] + new_insns: Sequence[lp.InstructionBase] = [] for insn in kernel.instructions: + self.replaced_something = False + if (isinstance(insn, lp.MultiAssignmentBase) and not contains_a_subst_rule_invocation(kernel, insn)): new_insns.append(insn) @@ -209,6 +203,15 @@ def map_kernel( lambda expr: self(expr, kernel, insn) ) + if self.replaced_something: + insn = insn.copy( + depends_on=( + insn.depends_on | frozenset([self.compute_insn_id]) + ) + ) + + # FIXME: determine compute insn dependencies + new_insns.append(insn) return kernel.copy(instructions=new_insns) @@ -270,7 +273,7 @@ def compute( ) ) - usage_descrs: Sequence[UsageDescriptor] = [] + usage_descrs: Mapping[tuple[Expression, ...], isl.Map] = {} for usage in usage_exprs: range_space = isl.Space.create_from_names( @@ -301,20 +304,14 @@ def compute( footprint = footprint | iname_to_storage.range() - usage_descrs.append( - UsageDescriptor( - usage, - iname_to_storage, - local_map - ) - ) + usage_descrs[tuple(usage)] = local_map # add compute inames to domain / kernel domain_changer = DomainChanger(kernel, kernel.all_inames()) domain = domain_changer.domain - footprint, domain = isl.align_two(footprint, domain) - domain = domain & footprint + footprint_tmp, domain = isl.align_two(footprint, domain) + domain = domain & footprint_tmp new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) @@ -377,5 +374,34 @@ def compute( kernel = replacer.map_kernel(kernel) - print(kernel) + # FIXME: accept dtype as an argument + import numpy as np + loopy_type = to_loopy_type(np.float64, allow_none=True) + + # WARNING: this can result in symbolic shapes, is that allowed? + temp_shape = tuple( + pw_aff_to_expr(footprint.dim_max(dim)) + 1 + for dim in range(footprint.dim(isl.dim_type.out)) + ) + + new_temp_vars = dict(kernel.temporary_variables) + + # FIXME: temp_var might already exist, handle the case where it does + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=(0,)*len(temp_shape), + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices) + ) + + new_temp_vars[temporary_name] = temp_var + + kernel = kernel.copy( + temporary_variables=new_temp_vars + ) + + # FIXME: handle iname tagging + return kernel From 8667a01f196bf6f580b5014e74d522a64534e5e1 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:14:44 -0500 Subject: [PATCH 10/27] compute working for tiled matmul; write race condition warning --- loopy/transform/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 342a3911c..afdad1bf4 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -311,7 +311,7 @@ def compute( domain = domain_changer.domain footprint_tmp, domain = isl.align_two(footprint, domain) - domain = domain & footprint_tmp + domain = (domain & footprint_tmp).get_basic_sets()[0] new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) From 959e68b78ddc044b68d0da8bad5e7885a480e3d9 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:15:57 -0500 Subject: [PATCH 11/27] add compute matmul example --- examples/compute-tiled-matmul.py | 120 +++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 examples/compute-tiled-matmul.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py new file mode 100644 index 000000000..979f93421 --- /dev/null +++ b/examples/compute-tiled-matmul.py @@ -0,0 +1,120 @@ +import namedisl as nisl + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +def main( + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i, j, k < 128 }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + out[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(128, 128), dtype=np.float64), + lp.GlobalArg("b", shape=(128, 128), dtype=np.float64), + lp.GlobalArg("out", shape=(128, 128), dtype=np.float64, + is_output=True) + ] + ) + + bm = bn = 32 + bk = 16 + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [io, ii_s, ko, ki_s] : + 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ko, ki_s, jo, ji_s] : + 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + if use_compute: + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", + "jo" : "g.1", + "ii" : "l.0", + "ji" : "l.1", + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + if use_precompute: + knl = lp.precompute( + knl, + "a_", + sweep_inames=["ii", "ki"], + ) + + if run_kernel: + a = np.random.randn(128, 128) + b = np.random.randn(128, 128) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + print(la.norm((a @ b) - out) / la.norm(out)) + + knl = lp.generate_code_v2(knl).device_code() + + print(knl) + + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--precompute", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + + args = parser.parse_args() + + main(args.precompute, args.compute, args.run_kernel) From a566b6ee3e51acdaddb4b96516250c33dbbae2f3 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:20:05 -0500 Subject: [PATCH 12/27] clean up compute matmul example --- examples/compute-tiled-matmul.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py index 979f93421..17620b41d 100644 --- a/examples/compute-tiled-matmul.py +++ b/examples/compute-tiled-matmul.py @@ -12,7 +12,9 @@ def main( use_precompute: bool = False, use_compute: bool = False, - run_kernel: bool = False + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False ) -> None: knl = lp.make_kernel( @@ -98,11 +100,12 @@ def main( ex = knl.executor(ctx) _, out = ex(queue, a=a, b=b) - print(la.norm((a @ b) - out) / la.norm(out)) + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") - knl = lp.generate_code_v2(knl).device_code() - - print(knl) + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + elif print_kernel: + print(knl) @@ -114,7 +117,15 @@ def main( _ = parser.add_argument("--precompute", action="store_true") _ = parser.add_argument("--compute", action="store_true") _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") args = parser.parse_args() - main(args.precompute, args.compute, args.run_kernel) + main( + args.precompute, + args.compute, + args.run_kernel, + args.print_kernel, + args.print_device_code + ) From cf920b597ec8fac3ea1d2f69856e506a9a9d863e Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 19 Mar 2026 08:23:43 -0500 Subject: [PATCH 13/27] improve matmul example with more parameters, better post-compute transformations --- examples/compute-tiled-matmul.py | 71 +++++--- .../compute-examples/compute-tiled-matmul.py | 156 ++++++++++++++++++ .../finite-difference-2-5D.py | 0 loopy/transform/compute.py | 6 +- 4 files changed, 203 insertions(+), 30 deletions(-) create mode 100644 examples/python/compute-examples/compute-tiled-matmul.py create mode 100644 examples/python/compute-examples/finite-difference-2-5D.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py index 17620b41d..07dad6e1a 100644 --- a/examples/compute-tiled-matmul.py +++ b/examples/compute-tiled-matmul.py @@ -10,31 +10,34 @@ def main( - use_precompute: bool = False, - use_compute: bool = False, - run_kernel: bool = False, - print_kernel: bool = False, - print_device_code: bool = False - ) -> None: + M: int = 128, + N: int = 128, + K: int = 128, + bm: int = 32, + bn: int = 32, + bk: int = 16, + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False +) -> None: knl = lp.make_kernel( - "{ [i, j, k] : 0 <= i, j, k < 128 }", + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", """ a_(is, ks) := a[is, ks] b_(ks, js) := b[ks, js] - out[i, j] = sum([k], a_(i, k) * b_(k, j)) + c[i, j] = sum([k], a_(i, k) * b_(k, j)) """, [ - lp.GlobalArg("a", shape=(128, 128), dtype=np.float64), - lp.GlobalArg("b", shape=(128, 128), dtype=np.float64), - lp.GlobalArg("out", shape=(128, 128), dtype=np.float64, + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, is_output=True) ] ) - bm = bn = 32 - bk = 16 - knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") @@ -72,16 +75,7 @@ def main( temporary_address_space=lp.AddressSpace.LOCAL ) - knl = lp.tag_inames( - knl, { - "io" : "g.0", - "jo" : "g.1", - "ii" : "l.0", - "ji" : "l.1", - } - ) - - knl = lp.add_inames_for_unused_hw_axes(knl) + # knl = lp.add_inames_for_unused_hw_axes(knl) if use_precompute: knl = lp.precompute( @@ -90,9 +84,20 @@ def main( sweep_inames=["ii", "ki"], ) + # knl = lp.tag_inames( + # knl, { + # "io" : "g.0", + # "jo" : "g.1", + # "ii" : "l.0", + # "ji" : "l.1", + # "ii_s": "l.0", + # "ji_s": "l.1" + # } + # ) + if run_kernel: - a = np.random.randn(128, 128) - b = np.random.randn(128, 128) + a = np.random.randn(M, K) + b = np.random.randn(K, N) ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -120,9 +125,23 @@ def main( _ = parser.add_argument("--print-kernel", action="store_true") _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--m", action="store", type=int, default=128) + _ = parser.add_argument("--n", action="store", type=int, default=128) + _ = parser.add_argument("--k", action="store", type=int, default=128) + + _ = parser.add_argument("--bm", action="store", type=int, default=32) + _ = parser.add_argument("--bn", action="store", type=int, default=32) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + args = parser.parse_args() main( + args.m, + args.n, + args.k, + args.bm, + args.bn, + args.bk, args.precompute, args.compute, args.run_kernel, diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py new file mode 100644 index 000000000..67536f701 --- /dev/null +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -0,0 +1,156 @@ +import namedisl as nisl + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +def main( + M: int = 128, + N: int = 128, + K: int = 128, + bm: int = 32, + bn: int = 32, + bk: int = 16, + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, + is_output=True) + ] + ) + + # FIXME: without this, there are complaints about in-bounds access guarantees + knl = lp.fix_parameters(knl, M=M, N=N, K=K) + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [io, ii_s, ko, ki_s] : + 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ko, ki_s, jo, ji_s] : + 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + if use_compute: + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + if use_precompute: + knl = lp.precompute( + knl, + "a_", + sweep_inames=["ii", "ki"], + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols + + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols + + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + if run_kernel: + a = np.random.randn(M, K) + b = np.random.randn(K, N) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--precompute", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + + _ = parser.add_argument("--m", action="store", type=int, default=128) + _ = parser.add_argument("--n", action="store", type=int, default=128) + _ = parser.add_argument("--k", action="store", type=int, default=128) + + _ = parser.add_argument("--bm", action="store", type=int, default=32) + _ = parser.add_argument("--bn", action="store", type=int, default=32) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + + args = parser.parse_args() + + main( + args.m, + args.n, + args.k, + args.bm, + args.bn, + args.bk, + args.precompute, + args.compute, + args.run_kernel, + args.print_kernel, + args.print_device_code + ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py new file mode 100644 index 000000000..e69de29bb diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index afdad1bf4..6312f596a 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,5 +1,4 @@ from collections.abc import Mapping, Sequence, Set -from dataclasses import dataclass from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger @@ -7,7 +6,7 @@ import namedisl as nisl from loopy.kernel import LoopKernel -from loopy.kernel.data import AddressSpace, SubstitutionRule +from loopy.kernel.data import AddressSpace from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( ExpansionState, @@ -15,7 +14,6 @@ RuleAwareSubstitutionMapper, SubstitutionRuleExpander, SubstitutionRuleMappingContext, - get_dependencies, pw_aff_to_expr, pwaff_from_expr ) @@ -23,7 +21,7 @@ contains_a_subst_rule_invocation ) from loopy.translation_unit import for_each_kernel -from pymbolic import ArithmeticExpression, var +from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func import islpy as isl From 4c7d1ac8dd764ec7a12f51eb7f485dcb04e9a513 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 19 Mar 2026 19:07:53 -0500 Subject: [PATCH 14/27] add 2.5D FD example base; minor stylistic changes --- examples/compute-tiled-matmul.py | 150 ------------------ .../compute-examples/compute-tiled-matmul.py | 52 +++--- .../finite-difference-2-5D.py | 71 +++++++++ loopy/transform/compute.py | 11 +- 4 files changed, 109 insertions(+), 175 deletions(-) delete mode 100644 examples/compute-tiled-matmul.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py deleted file mode 100644 index 07dad6e1a..000000000 --- a/examples/compute-tiled-matmul.py +++ /dev/null @@ -1,150 +0,0 @@ -import namedisl as nisl - -import loopy as lp -from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 -from loopy.transform.compute import compute - -import numpy as np -import numpy.linalg as la -import pyopencl as cl - - -def main( - M: int = 128, - N: int = 128, - K: int = 128, - bm: int = 32, - bn: int = 32, - bk: int = 16, - use_precompute: bool = False, - use_compute: bool = False, - run_kernel: bool = False, - print_kernel: bool = False, - print_device_code: bool = False -) -> None: - - knl = lp.make_kernel( - "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", - """ - a_(is, ks) := a[is, ks] - b_(ks, js) := b[ks, js] - c[i, j] = sum([k], a_(i, k) * b_(k, j)) - """, - [ - lp.GlobalArg("a", shape=(M, K), dtype=np.float64), - lp.GlobalArg("b", shape=(K, N), dtype=np.float64), - lp.GlobalArg("c", shape=(M, N), dtype=np.float64, - is_output=True) - ] - ) - - knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") - knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") - knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") - - compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [io, ii_s, ko, ki_s] : - 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and - is = io * {bm} + ii_s and - ks = ko * {bk} + ki_s - }}""") - - compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [ko, ki_s, jo, ji_s] : - 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and - js = jo * {bn} + ji_s and - ks = ko * {bk} + ki_s - }}""") - - if use_compute: - knl = compute( - knl, - "a_", - compute_map=compute_map_a, - storage_indices=["ii_s", "ki_s"], - temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL - ) - - knl = compute( - knl, - "b_", - compute_map=compute_map_b, - storage_indices=["ki_s", "ji_s"], - temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL - ) - - # knl = lp.add_inames_for_unused_hw_axes(knl) - - if use_precompute: - knl = lp.precompute( - knl, - "a_", - sweep_inames=["ii", "ki"], - ) - - # knl = lp.tag_inames( - # knl, { - # "io" : "g.0", - # "jo" : "g.1", - # "ii" : "l.0", - # "ji" : "l.1", - # "ii_s": "l.0", - # "ji_s": "l.1" - # } - # ) - - if run_kernel: - a = np.random.randn(M, K) - b = np.random.randn(K, N) - - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - - ex = knl.executor(ctx) - _, out = ex(queue, a=a, b=b) - - print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") - - if print_device_code: - print(lp.generate_code_v2(knl).device_code()) - elif print_kernel: - print(knl) - - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - - _ = parser.add_argument("--precompute", action="store_true") - _ = parser.add_argument("--compute", action="store_true") - _ = parser.add_argument("--run-kernel", action="store_true") - _ = parser.add_argument("--print-kernel", action="store_true") - _ = parser.add_argument("--print-device-code", action="store_true") - - _ = parser.add_argument("--m", action="store", type=int, default=128) - _ = parser.add_argument("--n", action="store", type=int, default=128) - _ = parser.add_argument("--k", action="store", type=int, default=128) - - _ = parser.add_argument("--bm", action="store", type=int, default=32) - _ = parser.add_argument("--bn", action="store", type=int, default=32) - _ = parser.add_argument("--bk", action="store", type=int, default=16) - - args = parser.parse_args() - - main( - args.m, - args.n, - args.k, - args.bm, - args.bn, - args.bk, - args.precompute, - args.compute, - args.run_kernel, - args.print_kernel, - args.print_device_code - ) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 67536f701..5b5f47c2f 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -16,6 +16,7 @@ def main( bm: int = 32, bn: int = 32, bk: int = 16, + run_sequentially: bool = False, use_precompute: bool = False, use_compute: bool = False, run_kernel: bool = False, @@ -85,19 +86,20 @@ def main( sweep_inames=["ii", "ki"], ) - knl = lp.tag_inames( - knl, { - "io" : "g.0", # outer block loop over block rows - "jo" : "g.1", # outer block loop over block cols + if not run_sequentially: + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols - "ii" : "l.0", # inner block loop over rows - "ji" : "l.1", # inner block loop over cols + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols - "ii_s" : "l.0", # inner storage loop over a rows - "ji_s" : "l.0", # inner storage loop over b cols - "ki_s" : "l.1" # inner storage loop over a cols / b rows - } - ) + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) knl = lp.add_inames_for_unused_hw_axes(knl) @@ -111,7 +113,11 @@ def main( ex = knl.executor(ctx) _, out = ex(queue, a=a, b=b) + print(20*"=", "Tiled matmul report", 20*"=") + print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}") + print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}") print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + print((40 + len(" Tiled matmul report "))*"=") if print_device_code: print(lp.generate_code_v2(knl).device_code()) @@ -130,6 +136,7 @@ def main( _ = parser.add_argument("--run-kernel", action="store_true") _ = parser.add_argument("--print-kernel", action="store_true") _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--run-sequentially", action="store_true") _ = parser.add_argument("--m", action="store", type=int, default=128) _ = parser.add_argument("--n", action="store", type=int, default=128) @@ -142,15 +149,16 @@ def main( args = parser.parse_args() main( - args.m, - args.n, - args.k, - args.bm, - args.bn, - args.bk, - args.precompute, - args.compute, - args.run_kernel, - args.print_kernel, - args.print_device_code + M=args.m, + N=args.n, + K=args.k, + bm=args.bm, + bn=args.bn, + bk=args.bk, + use_precompute=args.precompute, + use_compute=args.compute, + run_kernel=args.run_kernel, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run_sequentially=args.run_sequentially ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index e69de29bb..b1cb5c7c0 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -0,0 +1,71 @@ +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +# FIXME: more complicated function, or better yet define a set of functions +# with sympy and get the exact laplacian symbolically +def f(x, y, z): + return x**2 + y**2 + z**2 + + +def laplacian_f(x, y, z): + return 6 * np.ones_like(x) + + +def main(use_compute: bool = False) -> None: + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(i, j, k) := u[i, j, k] + + lap_u[i,j,k] = sum([l], c[l+2] * (u[i-l,j,k] + u[i,j-l,k] + u[i,j,k-l])) + """ + ) + + if use_compute: + raise NotImplementedError("WIP") + + npts = 50 + pts = np.linspace(-1, 1, num=npts, endpoint=True) + h = pts[1] - pts[0] + + x, y, z = np.meshgrid(*(pts,)*3) + + x = x.reshape(*(npts,)*3) + y = y.reshape(*(npts,)*3) + z = z.reshape(*(npts,)*3) + + f_ = f(x, y, z) + lap_fd = np.zeros_like(f_) + c = np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2 + + m = 5 + r = m // 2 + + knl = lp.fix_parameters(knl, npts=npts, r=r) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + _, lap_fd = ex(queue, u=f(x, y, z), c=c) + + lap_true = laplacian_f(x, y, z) + sl = (slice(r, npts - r),)*3 + + print(la.norm(lap_true[sl] - lap_fd[0][sl]) / la.norm(lap_true[sl])) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--compute", action="store_true") + + args = parser.parse_args() + + main(use_compute=args.compute) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 6312f596a..8ff7a6637 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,5 +1,6 @@ from collections.abc import Mapping, Sequence, Set from typing import override +from typing_extensions import TypeAlias import loopy as lp from loopy.kernel.tools import DomainChanger from loopy.types import to_loopy_type @@ -31,6 +32,9 @@ from pytools.tag import Tag +AccessTuple: TypeAlias = AccessTuple + + def gather_vars(expr) -> set[str]: deps = DependencyMapper()(expr) return { @@ -112,7 +116,7 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Mapping[tuple[Expression, ...], isl.Map], + usage_descriptors: Mapping[AccessTuple, isl.Map], storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, @@ -124,7 +128,7 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Mapping[tuple[Expression, ...], isl.Map] = \ + self.usage_descriptors: Mapping[AccessTuple, isl.Map] = \ usage_descriptors self.storage_indices: Sequence[str] = storage_indices @@ -271,7 +275,7 @@ def compute( ) ) - usage_descrs: Mapping[tuple[Expression, ...], isl.Map] = {} + usage_descrs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: range_space = isl.Space.create_from_names( @@ -291,6 +295,7 @@ def compute( usage_map = pw_multi_aff.as_map() iname_to_timespace = usage_map.apply_range(compute_map) + iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) From 781df54974b549082491c703d757d317c118a2eb Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 23 Mar 2026 09:28:24 -0500 Subject: [PATCH 15/27] improvements to 2.5D example; bug fixes in compute transform --- .../finite-difference-2-5D.py | 99 +++++++++++++++---- loopy/transform/compute.py | 21 ++-- 2 files changed, 91 insertions(+), 29 deletions(-) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index b1cb5c7c0..f1525da12 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -1,7 +1,12 @@ import loopy as lp from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import namedisl as nisl + import numpy as np import numpy.linalg as la + import pyopencl as cl @@ -15,38 +20,88 @@ def laplacian_f(x, y, z): return 6 * np.ones_like(x) -def main(use_compute: bool = False) -> None: - knl = lp.make_kernel( - "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", - """ - u_(i, j, k) := u[i, j, k] - - lap_u[i,j,k] = sum([l], c[l+2] * (u[i-l,j,k] + u[i,j-l,k] + u[i,j,k-l])) - """ - ) - - if use_compute: - raise NotImplementedError("WIP") - - npts = 50 +def main( + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False + ) -> None: + npts = 64 pts = np.linspace(-1, 1, num=npts, endpoint=True) h = pts[1] - pts[0] x, y, z = np.meshgrid(*(pts,)*3) - x = x.reshape(*(npts,)*3) - y = y.reshape(*(npts,)*3) - z = z.reshape(*(npts,)*3) + dtype = np.float32 + x = x.reshape(*(npts,)*3).astype(np.float32) + y = y.reshape(*(npts,)*3).astype(np.float32) + z = z.reshape(*(npts,)*3).astype(np.float32) f_ = f(x, y, z) lap_fd = np.zeros_like(f_) - c = np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2 + c = (np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2).astype(dtype) m = 5 r = m // 2 + bm = bn = m + + # FIXME: the usage on the k dimension is incorrect since we are only testing + # tiling (i, j) planes + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(is, js, ks) := u[is, js, ks] + + lap_u[i,j,k] = sum( + [l], + c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u[i,j,k-l]) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(npts,npts,npts)), + lp.GlobalArg("lap_u", dtype=dtype, shape=(npts,npts,npts), + is_output=True), + lp.GlobalArg("c", dtype=dtype, shape=(m)) + ] + ) + knl = lp.fix_parameters(knl, npts=npts, r=r) + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + + # FIXME: need to split k dimension + + if use_compute: + compute_map = nisl.make_map( + f""" + {{ + [is, js, ks] -> [io, ii_s, jo, ji_s, k_s] : + 0 <= ii_s < {bm} and 0 <= ji_s < {bn} and 0 <= k_s < {npts} and + is = io * {bm} + ii_s and + js = jo * {bn} + ji_s and + ks = k_s + }} + """ + ) + + knl = compute( + knl, + "u_", + compute_map=compute_map, + storage_indices=["ii_s", "ji_s", "k_s"], + temporal_inames=["io", "jo"], + temporary_name="u_compute", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float32 + ) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -65,7 +120,13 @@ def main(use_compute: bool = False) -> None: parser = argparse.ArgumentParser() _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") args = parser.parse_args() - main(use_compute=args.compute) + main( + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel + ) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 8ff7a6637..78c02aed7 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -32,7 +32,7 @@ from pytools.tag import Tag -AccessTuple: TypeAlias = AccessTuple +AccessTuple: TypeAlias = tuple[Expression, ...] def gather_vars(expr) -> set[str]: @@ -230,7 +230,10 @@ def compute( temporal_inames: Sequence[str], temporary_name: str | None = None, - temporary_address_space: AddressSpace | None = None + temporary_address_space: AddressSpace | None = None, + + # FIXME: typing + temporary_dtype = None ) -> LoopKernel: """ Inserts an instruction to compute an expression given by :arg:`substitution` @@ -277,7 +280,6 @@ def compute( usage_descrs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: - range_space = isl.Space.create_from_names( ctx=space.get_ctx(), set=list(storage_indices) @@ -286,27 +288,26 @@ def compute( pw_multi_aff = isl.MultiPwAff.zero(map_space) - for i, arg in enumerate(usage): + # FIXME: this will not work if usages are not ordered properly + for i in range(len(storage_indices)): pw_multi_aff = pw_multi_aff.set_pw_aff( i, - pwaff_from_expr(space, arg) + pwaff_from_expr(space, usage[i]) ) usage_map = pw_multi_aff.as_map() iname_to_timespace = usage_map.apply_range(compute_map) - iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) + footprint = footprint | iname_to_storage.range() + local_map = iname_to_storage.project_out_except( kernel.all_inames() - frozenset(temporal_inames), [isl.dim_type.in_] ) - - footprint = footprint | iname_to_storage.range() - usage_descrs[tuple(usage)] = local_map # add compute inames to domain / kernel @@ -379,7 +380,7 @@ def compute( # FIXME: accept dtype as an argument import numpy as np - loopy_type = to_loopy_type(np.float64, allow_none=True) + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) # WARNING: this can result in symbolic shapes, is that allowed? temp_shape = tuple( From cccc6a100654b6ecf1085068addf375836c01245 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 23 Mar 2026 11:02:12 -0500 Subject: [PATCH 16/27] Feedback from meeting --- examples/python/compute-examples/compute-tiled-matmul.py | 5 ++--- loopy/transform/compute.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 5b5f47c2f..d204a1922 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -40,22 +40,21 @@ def main( ) # FIXME: without this, there are complaints about in-bounds access guarantees - knl = lp.fix_parameters(knl, M=M, N=N, K=K) + # knl = lp.fix_parameters(knl, M=M, N=N, K=K) knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. compute_map_a = nisl.make_map(f"""{{ [is, ks] -> [io, ii_s, ko, ki_s] : - 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and is = io * {bm} + ii_s and ks = ko * {bk} + ki_s }}""") compute_map_b = nisl.make_map(f"""{{ [ks, js] -> [ko, ki_s, jo, ji_s] : - 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and js = jo * {bn} + ji_s and ks = ko * {bk} + ki_s }}""") diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 78c02aed7..eb966003f 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -286,6 +286,7 @@ def compute( ) map_space = space.map_from_domain_and_range(range_space) + # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic pw_multi_aff = isl.MultiPwAff.zero(map_space) # FIXME: this will not work if usages are not ordered properly @@ -295,8 +296,10 @@ def compute( pwaff_from_expr(space, usage[i]) ) + # FIXME intersect the (kernel) domain with the domain (of the map) here. usage_map = pw_multi_aff.as_map() + # FIXME defer as much of this project-y work to be done once, later iname_to_timespace = usage_map.apply_range(compute_map) iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] From 6a7f7192a164e2c6b2049f3406d2192a8b70f067 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 24 Mar 2026 00:39:42 -0500 Subject: [PATCH 17/27] update footprint finding to minimize projections, find bounds --- .../compute-examples/compute-tiled-matmul.py | 9 +- loopy/transform/compute.py | 155 ++++++++++++++---- 2 files changed, 132 insertions(+), 32 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index d204a1922..25c33d46e 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -39,8 +39,9 @@ def main( ] ) - # FIXME: without this, there are complaints about in-bounds access guarantees - # knl = lp.fix_parameters(knl, M=M, N=N, K=K) + # FIXME: without this, there are complaints about in-bounds access + # guarantees for the instruction that stores into c + knl = lp.fix_parameters(knl, M=M, N=N, K=K) knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") @@ -48,13 +49,13 @@ def main( # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [io, ii_s, ko, ki_s] : + [is, ks] -> [ii_s, io, ki_s, ko] : is = io * {bm} + ii_s and ks = ko * {bk} + ki_s }}""") compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [ko, ki_s, jo, ji_s] : + [ks, js] -> [ki_s, ko, ji_s, jo] : js = jo * {bn} + ji_s and ks = ko * {bk} + ki_s }}""") diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index eb966003f..d1823f31a 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -53,6 +53,40 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) + +def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: + """ + Permute the domain dimensions of `m` to match the ordering of `s`, + by routing through the parameter space to preserve constraints. + + Example: + m = { [a, b, c, d] -> [e, f, g, h] } + s = { [d, c, b, a] } + result = { [d, c, b, a] -> [e, f, g, h] } + """ + dom_space = m.get_space().domain() + set_space = s.get_space() + + n = dom_space.dim(isl.dim_type.set) + assert set_space.dim(isl.dim_type.set) == n, "dimension count mismatch" + + dom_names = [dom_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + assert set(dom_names) == set(set_names), "dimension names must be the same set" + + # Step 1: move all domain dims into the parameter space + n_params = m.dim(isl.dim_type.param) + m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) + + # Step 2: move each param back to in_ in the order dictated by set_names. + # find_dim_by_name accounts for shifting indices as dims are moved out. + for i, name in enumerate(set_names): + param_idx = m.find_dim_by_name(isl.dim_type.param, name) + m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) + + return m + + class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): """ Gathers all expressions used as inputs to a particular substitution rule, @@ -120,7 +154,7 @@ def __init__( storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, - compute_map: isl.Map + global_usage_map: isl.Map ) -> None: super().__init__(ctx) @@ -278,51 +312,113 @@ def compute( ) ) - usage_descrs: Mapping[AccessTuple, isl.Map] = {} + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + domain = domain_changer.domain + + range_space = isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + map_space = space.map_from_domain_and_range(range_space) + global_usage_map = isl.Map.empty(map_space) + for usage in usage_exprs: - range_space = isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - ) - map_space = space.map_from_domain_and_range(range_space) # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic - pw_multi_aff = isl.MultiPwAff.zero(map_space) + local_usage_mpwaff = isl.MultiPwAff.zero(map_space) # FIXME: this will not work if usages are not ordered properly for i in range(len(storage_indices)): - pw_multi_aff = pw_multi_aff.set_pw_aff( + local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( i, pwaff_from_expr(space, usage[i]) ) # FIXME intersect the (kernel) domain with the domain (of the map) here. - usage_map = pw_multi_aff.as_map() + local_usage_map = local_usage_mpwaff.as_map() - # FIXME defer as much of this project-y work to be done once, later - iname_to_timespace = usage_map.apply_range(compute_map) - iname_to_storage = iname_to_timespace.project_out_except( - storage_indices, [isl.dim_type.out] + # FIXME: fix with namedisl + # remove unnecessary names from domain and intersect with usage map + usage_names = frozenset( + local_usage_map.get_dim_name(isl.dim_type.in_, dim) + for dim in range(local_usage_map.dim(isl.dim_type.in_)) ) - footprint = footprint | iname_to_storage.range() + domain_names = frozenset( + domain.get_dim_name(isl.dim_type.set, dim) + for dim in range(domain.dim(isl.dim_type.set)) + ) - local_map = iname_to_storage.project_out_except( - kernel.all_inames() - frozenset(temporal_inames), - [isl.dim_type.in_] + domain_tmp = domain.project_out_except( + usage_names & domain_names, [isl.dim_type.set] ) - usage_descrs[tuple(usage)] = local_map - # add compute inames to domain / kernel - domain_changer = DomainChanger(kernel, kernel.all_inames()) - domain = domain_changer.domain + local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) + + local_usage_map = local_usage_map.intersect_domain(domain_tmp) + + # U : G -> S + # C : S -> I + # C o U : G -> I + global_usage_map = global_usage_map | local_usage_map + + + # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl + + global_usage_map = global_usage_map.apply_range(compute_map) + common_dims = { + dim1 : dim2 + for dim1 in range(global_usage_map.dim(isl.dim_type.in_)) + for dim2 in range(global_usage_map.dim(isl.dim_type.out)) + if ( + global_usage_map.get_dim_name(isl.dim_type.in_, dim1) + == + global_usage_map.get_dim_name(isl.dim_type.out, dim2) + ) + } + + for pos1, pos2 in common_dims.items(): + global_usage_map = global_usage_map.equate( + isl.dim_type.in_, pos1, + isl.dim_type.out, pos2 + ) + + # }}} + print(domain) + footprint = global_usage_map.range() footprint_tmp, domain = isl.align_two(footprint, domain) domain = (domain & footprint_tmp).get_basic_sets()[0] new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) + # {{{ find all index expressions + + usage_substs: Mapping[AccessTuple, isl.Map] = {} + for usage in usage_exprs: + # find the relevant names + relevant_names = gather_vars(usage) + + # project out irrelevant names + relevant_names = frozenset(relevant_names) - frozenset(temporal_inames) + + local_iname_to_storage = global_usage_map.project_out_except( + relevant_names, + [isl.dim_type.in_] + ) + + local_iname_to_storage = local_iname_to_storage.project_out_except( + storage_indices, + [isl.dim_type.out] + ) + + # map usage -> resulting map + usage_substs[tuple(usage)] = local_iname_to_storage + + # }}} + # create compute instruction in kernel compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { @@ -372,23 +468,26 @@ def compute( ctx, substitution, None, - usage_descrs, + usage_substs, storage_indices, temporary_name, compute_insn_id, - compute_map + global_usage_map ) kernel = replacer.map_kernel(kernel) # FIXME: accept dtype as an argument - import numpy as np loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - # WARNING: this can result in symbolic shapes, is that allowed? + # FIXME: need a better way to determine the shape + shape_domain = footprint.project_out_except(storage_indices, + [isl.dim_type.set]) + shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) + temp_shape = tuple( - pw_aff_to_expr(footprint.dim_max(dim)) + 1 - for dim in range(footprint.dim(isl.dim_type.out)) + pw_aff_to_expr(shape_domain.dim_max(dim)) + 1 + for dim in range(shape_domain.dim(isl.dim_type.out)) ) new_temp_vars = dict(kernel.temporary_variables) From aa8b612997680836b29f442847a1ebe4914d4601 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 24 Mar 2026 00:48:58 -0500 Subject: [PATCH 18/27] add/remove FIXMEs --- loopy/transform/compute.py | 52 ++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index d1823f31a..e67d583c2 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -35,6 +35,7 @@ AccessTuple: TypeAlias = tuple[Expression, ...] +# FIXME: move to loopy/symbolic.py def gather_vars(expr) -> set[str]: deps = DependencyMapper()(expr) return { @@ -44,6 +45,7 @@ def gather_vars(expr) -> set[str]: } +# FIXME: move to loopy/symbolic.py def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) set_names = [name for name in names] @@ -54,6 +56,7 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): ) +# FIXME: remove this and rely on namedisl def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: """ Permute the domain dimensions of `m` to match the ordering of `s`, @@ -74,12 +77,9 @@ def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] assert set(dom_names) == set(set_names), "dimension names must be the same set" - # Step 1: move all domain dims into the parameter space n_params = m.dim(isl.dim_type.param) m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) - # Step 2: move each param back to in_ in the order dictated by set_names. - # find_dim_by_name accounts for shifting indices as dims are moved out. for i, name in enumerate(set_names): param_idx = m.find_dim_by_name(isl.dim_type.param, name) m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) @@ -194,9 +194,9 @@ def map_subst_rule( ) args = [arg_ctx[arg_name] for arg_name in rule.arguments] - # FIXME: footprint check? likely required if user supplies bounds on - # storage indices because we are not guaranteed to capture the footprint - # of all usage sites + # FIXME: usage within footprint check? likely required if user supplies + # bounds on storage indices because we are not guaranteed to capture the + # footprint of all usage sites if not len(arguments) == len(rule.arguments): raise ValueError("Number of arguments passed to rule {name} ", @@ -204,7 +204,6 @@ def map_subst_rule( index_exprs: Sequence[Expression] = [] - # FIXME: make self.usage_descriptors a constantdict local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() for dim in range(local_pwmaff.dim(isl.dim_type.out)): @@ -284,9 +283,11 @@ def compute( :arg storage_indices: An ordered sequence of names of storage indices. Used to create inames for the loops that cover the required set of compute points. """ + # FIXME: use namedisl directly compute_map = compute_map._reconstruct_isl_object() - # construct union of usage footprints to determine bounds on compute inames + # {{{ construct necessary pieces; footprint, global usage map + ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) expander = SubstitutionRuleExpander(kernel.substitutions) @@ -328,14 +329,12 @@ def compute( # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic local_usage_mpwaff = isl.MultiPwAff.zero(map_space) - # FIXME: this will not work if usages are not ordered properly for i in range(len(storage_indices)): local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( i, pwaff_from_expr(space, usage[i]) ) - # FIXME intersect the (kernel) domain with the domain (of the map) here. local_usage_map = local_usage_mpwaff.as_map() # FIXME: fix with namedisl @@ -355,15 +354,9 @@ def compute( ) local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) - local_usage_map = local_usage_map.intersect_domain(domain_tmp) - - # U : G -> S - # C : S -> I - # C o U : G -> I global_usage_map = global_usage_map | local_usage_map - # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl global_usage_map = global_usage_map.apply_range(compute_map) @@ -386,7 +379,10 @@ def compute( # }}} - print(domain) + # }}} + + # {{{ compute bounds and update kernel domain + footprint = global_usage_map.range() footprint_tmp, domain = isl.align_two(footprint, domain) domain = (domain & footprint_tmp).get_basic_sets()[0] @@ -394,7 +390,9 @@ def compute( new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) - # {{{ find all index expressions + # }}} + + # {{{ compute index expressions usage_substs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: @@ -419,7 +417,8 @@ def compute( # }}} - # create compute instruction in kernel + # {{{ create compute instruction in kernel + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : @@ -460,6 +459,10 @@ def compute( new_insns.append(compute_insn) kernel = kernel.copy(instructions=new_insns) + # }}} + + # {{{ replace invocations with new compute instruction + ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator() ) @@ -477,10 +480,13 @@ def compute( kernel = replacer.map_kernel(kernel) - # FIXME: accept dtype as an argument + # }}} + + # {{{ create temporary variable for result of compute + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - # FIXME: need a better way to determine the shape + # FIXME: fix with namedisl? shape_domain = footprint.project_out_except(storage_indices, [isl.dim_type.set]) shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) @@ -508,6 +514,8 @@ def compute( temporary_variables=new_temp_vars ) - # FIXME: handle iname tagging + # }}} + + # FIXME: anything else? return kernel From 907873c6d8d8578a27e23fc23bd843226606df94 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 6 Apr 2026 09:18:05 -0500 Subject: [PATCH 19/27] add tiled matmul example as test; islpy -> namedisl --- .../compute-examples/compute-tiled-matmul.py | 10 +- .../finite-difference-2-5D.py | 13 +- loopy/transform/compute.py | 129 +++++------------- test/test_transform.py | 103 ++++++++++++++ 4 files changed, 152 insertions(+), 103 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 25c33d46e..4780f2637 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -43,11 +43,11 @@ def main( # guarantees for the instruction that stores into c knl = lp.fix_parameters(knl, M=M, N=N, K=K) + # shared memory tile-level splitting knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") - # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. compute_map_a = nisl.make_map(f"""{{ [is, ks] -> [ii_s, io, ki_s, ko] : is = io * {bm} + ii_s and @@ -67,7 +67,8 @@ def main( compute_map=compute_map_a, storage_indices=["ii_s", "ki_s"], temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 ) knl = compute( @@ -76,7 +77,8 @@ def main( compute_map=compute_map_b, storage_indices=["ki_s", "ji_s"], temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 ) if use_precompute: @@ -116,7 +118,7 @@ def main( print(20*"=", "Tiled matmul report", 20*"=") print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}") print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}") - print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(a @ b)}") print((40 + len(" Tiled matmul report "))*"=") if print_device_code: diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index f1525da12..233ccf21d 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -76,11 +76,10 @@ def main( compute_map = nisl.make_map( f""" {{ - [is, js, ks] -> [io, ii_s, jo, ji_s, k_s] : - 0 <= ii_s < {bm} and 0 <= ji_s < {bn} and 0 <= k_s < {npts} and - is = io * {bm} + ii_s and - js = jo * {bn} + ji_s and - ks = k_s + [is, js, ks] -> [io, ii_s, jo, ji_s, k] : + is = io * {bm} + ii_s - {r} and + js = jo * {bn} + ji_s - {r} and + ks = k }} """ ) @@ -89,8 +88,8 @@ def main( knl, "u_", compute_map=compute_map, - storage_indices=["ii_s", "ji_s", "k_s"], - temporal_inames=["io", "jo"], + storage_indices=["ii_s", "ji_s"], + temporal_inames=["io", "jo", "k"], temporary_name="u_compute", temporary_address_space=lp.AddressSpace.LOCAL, temporary_dtype=np.float32 diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index e67d583c2..3d6bb7e8f 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -271,7 +271,7 @@ def compute( """ Inserts an instruction to compute an expression given by :arg:`substitution` and replaces all invocations of :arg:`substitution` with the result of the - compute instruction. + inserted compute instruction. :arg substitution: The substitution rule for which the compute transform should be applied. @@ -283,8 +283,6 @@ def compute( :arg storage_indices: An ordered sequence of names of storage indices. Used to create inames for the loops that cover the required set of compute points. """ - # FIXME: use namedisl directly - compute_map = compute_map._reconstruct_isl_object() # {{{ construct necessary pieces; footprint, global usage map @@ -295,125 +293,74 @@ def compute( ctx, expander, kernel, substitution, None ) + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) + _ = expr_gatherer.map_kernel(kernel) usage_exprs = expr_gatherer.usage_expressions - all_exprs = [ - expr - for usage in usage_exprs - for expr in usage - ] + all_exprs = [expr for usage in usage_exprs for expr in usage] + usage_inames = set.union(*(gather_vars(expr) for expr in all_exprs)) - space = space_from_exprs(all_exprs) + usage_domain = nisl.make_set(f"{{ [{",".join(iname for iname in usage_inames)}] }}") + footprint = nisl.make_set(f"{{ [{",".join(idx for idx in storage_indices)}] }}") - footprint = isl.Set.empty( - isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - ) + global_usage_map = nisl.make_map_from_domain_and_range( + usage_domain, + compute_map.domain() ) - # add compute inames to domain / kernel - domain_changer = DomainChanger(kernel, kernel.all_inames()) - domain = domain_changer.domain - - range_space = isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - ) - map_space = space.map_from_domain_and_range(range_space) - global_usage_map = isl.Map.empty(map_space) + global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) + usage_substs: Mapping[AccessTuple, nisl.Map] = {} for usage in usage_exprs: # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic - local_usage_mpwaff = isl.MultiPwAff.zero(map_space) + local_usage_mpwaff = isl.MultiPwAff.zero(global_usage_map.get_space()) for i in range(len(storage_indices)): + local_space = local_usage_mpwaff.get_at(i).get_space().domain() local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( i, - pwaff_from_expr(space, usage[i]) + pwaff_from_expr(local_space, usage[i]) ) - local_usage_map = local_usage_mpwaff.as_map() + local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) - # FIXME: fix with namedisl - # remove unnecessary names from domain and intersect with usage map - usage_names = frozenset( - local_usage_map.get_dim_name(isl.dim_type.in_, dim) - for dim in range(local_usage_map.dim(isl.dim_type.in_)) - ) - - domain_names = frozenset( - domain.get_dim_name(isl.dim_type.set, dim) - for dim in range(domain.dim(isl.dim_type.set)) - ) - - domain_tmp = domain.project_out_except( - usage_names & domain_names, [isl.dim_type.set] - ) - - local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) - local_usage_map = local_usage_map.intersect_domain(domain_tmp) + local_usage_map = local_usage_map.intersect_domain(named_domain) global_usage_map = global_usage_map | local_usage_map - # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl + local_storage_map = local_usage_map.apply_range(compute_map) + relevant_names = gather_vars(usage) - global_usage_map = global_usage_map.apply_range(compute_map) - common_dims = { - dim1 : dim2 - for dim1 in range(global_usage_map.dim(isl.dim_type.in_)) - for dim2 in range(global_usage_map.dim(isl.dim_type.out)) - if ( - global_usage_map.get_dim_name(isl.dim_type.in_, dim1) - == - global_usage_map.get_dim_name(isl.dim_type.out, dim2) + local_storage_map = local_storage_map.project_out_except( + (relevant_names - frozenset(temporal_inames)) | frozenset(storage_indices) ) - } - for pos1, pos2 in common_dims.items(): - global_usage_map = global_usage_map.equate( - isl.dim_type.in_, pos1, - isl.dim_type.out, pos2 - ) + usage_substs[tuple(usage)] = local_storage_map - # }}} + global_usage_map = global_usage_map.apply_range(compute_map) # }}} # {{{ compute bounds and update kernel domain footprint = global_usage_map.range() - footprint_tmp, domain = isl.align_two(footprint, domain) - domain = (domain & footprint_tmp).get_basic_sets()[0] - - new_domains = domain_changer.get_domains_with(domain) - kernel = kernel.copy(domains=new_domains) - - # }}} - - # {{{ compute index expressions - - usage_substs: Mapping[AccessTuple, isl.Map] = {} - for usage in usage_exprs: - # find the relevant names - relevant_names = gather_vars(usage) - - # project out irrelevant names - relevant_names = frozenset(relevant_names) - frozenset(temporal_inames) + footprint = footprint.project_out_except( + frozenset(temporal_inames) | frozenset(storage_indices) + ) - local_iname_to_storage = global_usage_map.project_out_except( - relevant_names, - [isl.dim_type.in_] - ) + # FIXME: probably do not want this permanently here + footprint = nisl.make_set(footprint._reconstruct_isl_object().convex_hull()) + named_domain = named_domain & footprint - local_iname_to_storage = local_iname_to_storage.project_out_except( - storage_indices, - [isl.dim_type.out] - ) + # FIXME: + if len(named_domain.get_basic_sets()) != 1: + raise ValueError("New domain should be composed of a single basic set") - # map usage -> resulting map - usage_substs[tuple(usage)] = local_iname_to_storage + new_domains = domain_changer.get_domains_with(named_domain.get_basic_sets()[0]) + kernel = kernel.copy(domains=new_domains) # }}} @@ -487,9 +434,7 @@ def compute( loopy_type = to_loopy_type(temporary_dtype, allow_none=True) # FIXME: fix with namedisl? - shape_domain = footprint.project_out_except(storage_indices, - [isl.dim_type.set]) - shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) + shape_domain = footprint.project_out_except(storage_indices) temp_shape = tuple( pw_aff_to_expr(shape_domain.dim_max(dim)) + 1 diff --git a/test/test_transform.py b/test/test_transform.py index b6f87c108..37edb61a6 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -20,6 +20,7 @@ THE SOFTWARE. """ +from collections.abc import Mapping import logging import numpy as np @@ -1745,6 +1746,108 @@ def test_duplicate_iname_not_read_only_nested(ctx_factory: cl.CtxFactory): lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) +@pytest.mark.parametrize("case", ( + {"M": 128, "N": 128, "K": 128, "BM": 32, "BN": 32, "BK": 16}, + {"M": 200, "N": 200, "K": 200, "BM": 32, "BN": 32, "BK": 16}, +)) +def test_compute_simple_tiled_matmul( + ctx_factory: cl.CtxFactory, + case: Mapping[str, int] + ): + + import namedisl as nisl + + M = case["M"] + N = case["N"] + K = case["K"] + bm = case["BM"] + bn = case["BN"] + bk = case["BK"] + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, + is_output=True) + ] + ) + + knl = lp.fix_parameters(knl, M=M, N=N, K=K) + + # shared memory tile-level splitting + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [ii_s, io, ki_s, ko] : + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ki_s, ko, ji_s, jo] : + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + from loopy.transform.compute import compute + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float64 + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols + + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols + + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + a = np.random.randn(M, K) + b = np.random.randn(K, N) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + import numpy.linalg as la + assert (la.norm((a @ b) - out) / la.norm(a @ b)) < 1e-15 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: From 0d3011f751f28debbc18b631d75cfe8a490ec811 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 7 Apr 2026 10:45:14 -0500 Subject: [PATCH 20/27] compute working for 2D plane in FD example --- .../compute-examples/compute-tiled-matmul.py | 40 ++-- .../finite-difference-2-5D.py | 19 +- loopy/symbolic.py | 20 ++ loopy/target/c/compyte | 2 +- loopy/transform/compute.py | 212 +++++++++--------- 5 files changed, 171 insertions(+), 122 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 4780f2637..4cc932bcf 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -49,15 +49,15 @@ def main( knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [ii_s, io, ki_s, ko] : - is = io * {bm} + ii_s and - ks = ko * {bk} + ki_s + [is, ks] -> [a_ii, io, a_ki, ko] : + is = io * {bm} + a_ii and + ks = ko * {bk} + a_ki }}""") compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [ki_s, ko, ji_s, jo] : - js = jo * {bn} + ji_s and - ks = ko * {bk} + ki_s + [ks, js] -> [b_ki, ko, b_ji, jo] : + js = jo * {bn} + b_ji and + ks = ko * {bk} + b_ki }}""") if use_compute: @@ -65,20 +65,24 @@ def main( knl, "a_", compute_map=compute_map_a, - storage_indices=["ii_s", "ki_s"], - temporal_inames=["io", "ko", "jo"], + storage_indices=["a_ii", "a_ki"], + temporal_inames=["io", "ko"], + temporary_name="a_tile", temporary_address_space=lp.AddressSpace.LOCAL, - temporary_dtype=np.float64 + temporary_dtype=np.float64, + compute_insn_id="a_load" ) knl = compute( knl, "b_", compute_map=compute_map_b, - storage_indices=["ki_s", "ji_s"], - temporal_inames=["io", "ko", "jo"], + storage_indices=["b_ki", "b_ji"], + temporal_inames=["ko", "jo"], + temporary_name="b_tile", temporary_address_space=lp.AddressSpace.LOCAL, - temporary_dtype=np.float64 + temporary_dtype=np.float64, + compute_insn_id="b_load" ) if use_precompute: @@ -91,15 +95,21 @@ def main( if not run_sequentially: knl = lp.tag_inames( knl, { + # inames for tiles "io" : "g.0", # outer block loop over block rows "jo" : "g.1", # outer block loop over block cols + # inames for within tile "ii" : "l.0", # inner block loop over rows "ji" : "l.1", # inner block loop over cols - "ii_s" : "l.0", # inner storage loop over a rows - "ji_s" : "l.0", # inner storage loop over b cols - "ki_s" : "l.1" # inner storage loop over a cols / b rows + # inames for 'a' compute + "a_ii" : "l.0", # inner storage loop over a rows + "a_ki" : "l.1", # inner storage loop over a cols + + # inames for 'b' compute + "b_ki" : "l.0", # inner storage loop over b rows + "b_ji" : "l.1", # inner storage loop over b cols } ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index 233ccf21d..8715a174d 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -54,7 +54,7 @@ def main( lap_u[i,j,k] = sum( [l], - c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u[i,j,k-l]) + c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u_(i,j,k-l)) ) """, [ @@ -95,6 +95,23 @@ def main( temporary_dtype=np.float32 ) + knl = lp.tag_inames(knl, { + # outer block loops + "io" : "g.0", + "jo" : "g.1", + "k" : "g.2", + # "ko" : "g.2", + + # inner tile loops + "ii" : "l.0", + "ji" : "l.1", + # "ki" : "l.2", + + # 2D plane compute storage loops + "ii_s" : "l.0", + "ji_s" : "l.1" + }) + if print_device_code: print(lp.generate_code_v2(knl).device_code()) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 74e6dc136..5c587c4e7 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -2564,6 +2564,26 @@ def constraint_to_cond_expr(cns: isl.Constraint) -> ArithmeticExpression: # }}} +# {{{ MultiPwAff from sequence of pymbolic exprs + +def multi_pw_aff_from_exprs( + exprs: Sequence[Expression], + space: isl.Space + ) -> isl.MultiPwAff: + + mpwaff = isl.MultiPwAff.zero(space) + for i in range(len(exprs)): + local_space = mpwaff.get_at(i).get_space().domain() + mpwaff = mpwaff.set_pw_aff( + i, + pwaff_from_expr(local_space, exprs[i]) + ) + + return mpwaff + +# }}} + + # {{{ isl_set_from_expr class ConditionExpressionToBooleanOpsExpression(IdentityMapper[[]]): diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 80ed45de9..2b168ca39 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 80ed45de98b5432341763b9fa52a00fdac870b89 +Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 3d6bb7e8f..1346a9162 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -3,7 +3,7 @@ from typing_extensions import TypeAlias import loopy as lp from loopy.kernel.tools import DomainChanger -from loopy.types import to_loopy_type +from loopy.types import ToLoopyTypeConvertible, to_loopy_type import namedisl as nisl from loopy.kernel import LoopKernel @@ -15,8 +15,8 @@ RuleAwareSubstitutionMapper, SubstitutionRuleExpander, SubstitutionRuleMappingContext, + multi_pw_aff_from_exprs, pw_aff_to_expr, - pwaff_from_expr ) from loopy.transform.precompute import ( contains_a_subst_rule_invocation @@ -35,8 +35,8 @@ AccessTuple: TypeAlias = tuple[Expression, ...] -# FIXME: move to loopy/symbolic.py -def gather_vars(expr) -> set[str]: +# helper for gathering names of variables in pymbolic expressions +def _gather_vars(expr: Expression) -> set[str]: deps = DependencyMapper()(expr) return { dep.name @@ -45,47 +45,7 @@ def gather_vars(expr) -> set[str]: } -# FIXME: move to loopy/symbolic.py -def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): - names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) - set_names = [name for name in names] - - return isl.Space.create_from_names( - ctx, - set=set_names - ) - - -# FIXME: remove this and rely on namedisl -def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: - """ - Permute the domain dimensions of `m` to match the ordering of `s`, - by routing through the parameter space to preserve constraints. - - Example: - m = { [a, b, c, d] -> [e, f, g, h] } - s = { [d, c, b, a] } - result = { [d, c, b, a] -> [e, f, g, h] } - """ - dom_space = m.get_space().domain() - set_space = s.get_space() - - n = dom_space.dim(isl.dim_type.set) - assert set_space.dim(isl.dim_type.set) == n, "dimension count mismatch" - - dom_names = [dom_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] - set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] - assert set(dom_names) == set(set_names), "dimension names must be the same set" - - n_params = m.dim(isl.dim_type.param) - m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) - - for i, name in enumerate(set_names): - param_idx = m.find_dim_by_name(isl.dim_type.param, name) - m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) - - return m - +# {{{ gathering usage expressions class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): """ @@ -143,6 +103,10 @@ def map_subst_rule( return 0 +# }}} + + +# {{{ substitution rule use replacement class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): def __init__( @@ -150,11 +114,11 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Mapping[AccessTuple, isl.Map], + usage_descriptors: Mapping[AccessTuple, nisl.Map], storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, - global_usage_map: isl.Map + footprint: nisl.Set ) -> None: super().__init__(ctx) @@ -162,19 +126,20 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Mapping[AccessTuple, isl.Map] = \ + self.usage_descriptors: Mapping[AccessTuple, nisl.Map] = \ usage_descriptors self.storage_indices: Sequence[str] = storage_indices + self.footprint: nisl.Set = footprint self.temporary_name: str = temporary_name self.compute_insn_id: str = compute_insn_id + self.replaced_something: bool = False + # FIXME: may not always be the case (i.e. global barrier between # compute insn and uses) self.compute_dep_id: str = compute_insn_id - self.replaced_something: bool = False - @override def map_subst_rule( @@ -185,32 +150,47 @@ def map_subst_rule( expn_state: ExpansionState ) -> Expression: - if not name == self.subst_name: - return super().map_subst_rule(name, tags, arguments, expn_state) - rule = self.rule_mapping_context.old_subst_rules[name] arg_ctx = self.make_new_arg_context( name, rule.arguments, arguments, expn_state.arg_context ) args = [arg_ctx[arg_name] for arg_name in rule.arguments] - # FIXME: usage within footprint check? likely required if user supplies - # bounds on storage indices because we are not guaranteed to capture the - # footprint of all usage sites + # {{{ validation checks + + if not name == self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if not tuple(args) in self.usage_descriptors: + return super().map_subst_rule(name, tags, arguments, expn_state) if not len(arguments) == len(rule.arguments): raise ValueError("Number of arguments passed to rule {name} ", "does not match the signature of {name}.") - index_exprs: Sequence[Expression] = [] + local_map = self.usage_descriptors[tuple(args)] + temp_footprint = self.footprint.move_dims( + frozenset(self.footprint.names) - frozenset(self.storage_indices), + isl.dim_type.param + ) + + if not local_map.range() <= temp_footprint: + return super().map_subst_rule(name, tags, arguments, expn_state) + + # }}} + + # {{{ get index expression in terms of global inames local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() + index_exprs: Sequence[Expression] = [] for dim in range(local_pwmaff.dim(isl.dim_type.out)): index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) new_expression = var(self.temporary_name)[tuple(index_exprs)] + # }}} + self.replaced_something = True return new_expression @@ -251,6 +231,8 @@ def map_kernel( return kernel.copy(instructions=new_insns) +# }}} + @for_each_kernel def compute( @@ -259,14 +241,14 @@ def compute( compute_map: nisl.Map, storage_indices: Sequence[str], - # NOTE: how can we deduce this? temporal_inames: Sequence[str], temporary_name: str | None = None, temporary_address_space: AddressSpace | None = None, - # FIXME: typing - temporary_dtype = None + temporary_dtype: ToLoopyTypeConvertible = None, + + compute_insn_id: str | None = None ) -> LoopKernel: """ Inserts an instruction to compute an expression given by :arg:`substitution` @@ -284,7 +266,10 @@ def compute( to create inames for the loops that cover the required set of compute points. """ - # {{{ construct necessary pieces; footprint, global usage map + # {{{ setup and useful items + + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) @@ -293,50 +278,65 @@ def compute( ctx, expander, kernel, substitution, None ) - # add compute inames to domain / kernel - domain_changer = DomainChanger(kernel, kernel.all_inames()) - named_domain = nisl.make_basic_set(domain_changer.domain) - _ = expr_gatherer.map_kernel(kernel) usage_exprs = expr_gatherer.usage_expressions all_exprs = [expr for usage in usage_exprs for expr in usage] - usage_inames = set.union(*(gather_vars(expr) for expr in all_exprs)) + usage_inames: frozenset[str] = frozenset( + set.union(*(_gather_vars(expr) for expr in all_exprs)) + ) - usage_domain = nisl.make_set(f"{{ [{",".join(iname for iname in usage_inames)}] }}") - footprint = nisl.make_set(f"{{ [{",".join(idx for idx in storage_indices)}] }}") + # }}} + # {{{ construct necessary pieces; footprint, global usage map + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) + + # restrict domain to used inames + usage_domain = named_domain.project_out_except(usage_inames) + + # FIXME: gross. find a cleaner way to generate a space for an empty map global_usage_map = nisl.make_map_from_domain_and_range( - usage_domain, + nisl.make_set(isl.Set.empty(usage_domain.get_space())), compute_map.domain() ) - global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) usage_substs: Mapping[AccessTuple, nisl.Map] = {} for usage in usage_exprs: - # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic - local_usage_mpwaff = isl.MultiPwAff.zero(global_usage_map.get_space()) + # {{{ compute local usage map, update global usage map - for i in range(len(storage_indices)): - local_space = local_usage_mpwaff.get_at(i).get_space().domain() - local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( - i, - pwaff_from_expr(local_space, usage[i]) - ) + local_usage_mpwaff = multi_pw_aff_from_exprs( + usage, + global_usage_map.get_space() + ) local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) - local_usage_map = local_usage_map.intersect_domain(named_domain) + local_usage_map = local_usage_map.intersect_domain(usage_domain) global_usage_map = global_usage_map | local_usage_map + # }}} + + # {{{ compute storage map + local_storage_map = local_usage_map.apply_range(compute_map) - relevant_names = gather_vars(usage) - local_storage_map = local_storage_map.project_out_except( - (relevant_names - frozenset(temporal_inames)) | frozenset(storage_indices) - ) + # check that no restrictions happened during composition (i.e. tile + # valid for a single point in the domain) + if not local_usage_map.domain() <= local_storage_map.domain(): + continue + + non_param_names = (usage_inames - temporal_set) | storage_set + parameter_names = frozenset(local_storage_map.names) - non_param_names + + local_storage_map = local_storage_map.move_dims(parameter_names, + isl.dim_type.param) + + # }}} usage_substs[tuple(usage)] = local_storage_map @@ -346,26 +346,37 @@ def compute( # {{{ compute bounds and update kernel domain - footprint = global_usage_map.range() - footprint = footprint.project_out_except( - frozenset(temporal_inames) | frozenset(storage_indices) + global_usage_map = global_usage_map.move_dims( + temporal_set, + isl.dim_type.param ) + footprint = global_usage_map.range() + + # clean up ticked duplicate names + footprint = footprint.project_out_except(temporal_set | storage_set) + footprint = footprint.move_dims(temporal_set, isl.dim_type.set) + + # {{{ FIXME: use Sets instead of BasicSets when loopy is ready - # FIXME: probably do not want this permanently here footprint = nisl.make_set(footprint._reconstruct_isl_object().convex_hull()) named_domain = named_domain & footprint - # FIXME: if len(named_domain.get_basic_sets()) != 1: raise ValueError("New domain should be composed of a single basic set") - new_domains = domain_changer.get_domains_with(named_domain.get_basic_sets()[0]) + # FIXME: use named object once loopy is name-ified + domain = named_domain.get_basic_sets()[0]._reconstruct_isl_object() + new_domains = domain_changer.get_domains_with(domain) + + # }}} + kernel = kernel.copy(domains=new_domains) # }}} # {{{ create compute instruction in kernel + # FIXME: compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : @@ -385,16 +396,12 @@ def compute( if not temporary_name: temporary_name = substitution + "_temp" - assignee = var(temporary_name)[tuple( - var(iname) for iname in storage_indices - )] + assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] - within_inames = frozenset( - compute_map.get_dim_name(isl.dim_type.out, dim) - for dim in range(compute_map.dim(isl.dim_type.out)) - ) + within_inames = compute_map.output_names - compute_insn_id = substitution + "_compute" + if not compute_insn_id: + compute_insn_id = substitution + "_compute" compute_insn = lp.Assignment( id=compute_insn_id, assignee=assignee, @@ -422,7 +429,7 @@ def compute( storage_indices, temporary_name, compute_insn_id, - global_usage_map + footprint ) kernel = replacer.map_kernel(kernel) @@ -433,12 +440,9 @@ def compute( loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - # FIXME: fix with namedisl? - shape_domain = footprint.project_out_except(storage_indices) - temp_shape = tuple( - pw_aff_to_expr(shape_domain.dim_max(dim)) + 1 - for dim in range(shape_domain.dim(isl.dim_type.out)) + pw_aff_to_expr(footprint.dim_max(dim)) + 1 + for dim in storage_indices ) new_temp_vars = dict(kernel.temporary_variables) @@ -461,6 +465,4 @@ def compute( # }}} - # FIXME: anything else? - return kernel From 9d5c9aab9bb2c2418243128e123bb93b636722e9 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 16 Apr 2026 09:34:56 -0500 Subject: [PATCH 21/27] rough version of shifting for ring-buffer-like compute transforms --- .../compute-examples/compute-tiled-matmul.py | 176 -------- .../finite-difference-2-5D.py | 263 +++++++++--- examples/python/compute-examples/matmul.py | 400 ++++++++++++++++++ .../wave-equation-ring-buffer.py | 211 +++++++++ loopy/transform/compute.py | 390 +++++++++++++---- 5 files changed, 1136 insertions(+), 304 deletions(-) delete mode 100644 examples/python/compute-examples/compute-tiled-matmul.py create mode 100644 examples/python/compute-examples/matmul.py create mode 100644 examples/python/compute-examples/wave-equation-ring-buffer.py diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py deleted file mode 100644 index 4cc932bcf..000000000 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ /dev/null @@ -1,176 +0,0 @@ -import namedisl as nisl - -import loopy as lp -from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 -from loopy.transform.compute import compute - -import numpy as np -import numpy.linalg as la -import pyopencl as cl - - -def main( - M: int = 128, - N: int = 128, - K: int = 128, - bm: int = 32, - bn: int = 32, - bk: int = 16, - run_sequentially: bool = False, - use_precompute: bool = False, - use_compute: bool = False, - run_kernel: bool = False, - print_kernel: bool = False, - print_device_code: bool = False - ) -> None: - - knl = lp.make_kernel( - "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", - """ - a_(is, ks) := a[is, ks] - b_(ks, js) := b[ks, js] - c[i, j] = sum([k], a_(i, k) * b_(k, j)) - """, - [ - lp.GlobalArg("a", shape=(M, K), dtype=np.float64), - lp.GlobalArg("b", shape=(K, N), dtype=np.float64), - lp.GlobalArg("c", shape=(M, N), dtype=np.float64, - is_output=True) - ] - ) - - # FIXME: without this, there are complaints about in-bounds access - # guarantees for the instruction that stores into c - knl = lp.fix_parameters(knl, M=M, N=N, K=K) - - # shared memory tile-level splitting - knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") - knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") - knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") - - compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [a_ii, io, a_ki, ko] : - is = io * {bm} + a_ii and - ks = ko * {bk} + a_ki - }}""") - - compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [b_ki, ko, b_ji, jo] : - js = jo * {bn} + b_ji and - ks = ko * {bk} + b_ki - }}""") - - if use_compute: - knl = compute( - knl, - "a_", - compute_map=compute_map_a, - storage_indices=["a_ii", "a_ki"], - temporal_inames=["io", "ko"], - temporary_name="a_tile", - temporary_address_space=lp.AddressSpace.LOCAL, - temporary_dtype=np.float64, - compute_insn_id="a_load" - ) - - knl = compute( - knl, - "b_", - compute_map=compute_map_b, - storage_indices=["b_ki", "b_ji"], - temporal_inames=["ko", "jo"], - temporary_name="b_tile", - temporary_address_space=lp.AddressSpace.LOCAL, - temporary_dtype=np.float64, - compute_insn_id="b_load" - ) - - if use_precompute: - knl = lp.precompute( - knl, - "a_", - sweep_inames=["ii", "ki"], - ) - - if not run_sequentially: - knl = lp.tag_inames( - knl, { - # inames for tiles - "io" : "g.0", # outer block loop over block rows - "jo" : "g.1", # outer block loop over block cols - - # inames for within tile - "ii" : "l.0", # inner block loop over rows - "ji" : "l.1", # inner block loop over cols - - # inames for 'a' compute - "a_ii" : "l.0", # inner storage loop over a rows - "a_ki" : "l.1", # inner storage loop over a cols - - # inames for 'b' compute - "b_ki" : "l.0", # inner storage loop over b rows - "b_ji" : "l.1", # inner storage loop over b cols - } - ) - - knl = lp.add_inames_for_unused_hw_axes(knl) - - if run_kernel: - a = np.random.randn(M, K) - b = np.random.randn(K, N) - - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - - ex = knl.executor(ctx) - _, out = ex(queue, a=a, b=b) - - print(20*"=", "Tiled matmul report", 20*"=") - print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}") - print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}") - print(f"Relative error = {la.norm((a @ b) - out) / la.norm(a @ b)}") - print((40 + len(" Tiled matmul report "))*"=") - - if print_device_code: - print(lp.generate_code_v2(knl).device_code()) - - if print_kernel: - print(knl) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - - _ = parser.add_argument("--precompute", action="store_true") - _ = parser.add_argument("--compute", action="store_true") - _ = parser.add_argument("--run-kernel", action="store_true") - _ = parser.add_argument("--print-kernel", action="store_true") - _ = parser.add_argument("--print-device-code", action="store_true") - _ = parser.add_argument("--run-sequentially", action="store_true") - - _ = parser.add_argument("--m", action="store", type=int, default=128) - _ = parser.add_argument("--n", action="store", type=int, default=128) - _ = parser.add_argument("--k", action="store", type=int, default=128) - - _ = parser.add_argument("--bm", action="store", type=int, default=32) - _ = parser.add_argument("--bn", action="store", type=int, default=32) - _ = parser.add_argument("--bk", action="store", type=int, default=16) - - args = parser.parse_args() - - main( - M=args.m, - N=args.n, - K=args.k, - bm=args.bm, - bn=args.bn, - bk=args.bk, - use_precompute=args.precompute, - use_compute=args.compute, - run_kernel=args.run_kernel, - print_kernel=args.print_kernel, - print_device_code=args.print_device_code, - run_sequentially=args.run_sequentially - ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index 8715a174d..ff15afb49 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -1,17 +1,30 @@ -import loopy as lp -from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 -from loopy.transform.compute import compute +import time import namedisl as nisl - import numpy as np import numpy.linalg as la import pyopencl as cl +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 -# FIXME: more complicated function, or better yet define a set of functions -# with sympy and get the exact laplacian symbolically + +def centered_second_derivative_coefficients(radius: int, dtype) -> np.ndarray: + offsets = np.arange(-radius, radius + 1, dtype=dtype) + powers = np.arange(2 * radius + 1) + + # Enforce exactness for monomials through degree 2r. The second derivative + # of x**2 at zero is 2, all other monomial derivatives in this range vanish. + vandermonde = offsets[np.newaxis, :] ** powers[:, np.newaxis] + rhs = np.zeros(2 * radius + 1, dtype=dtype) + rhs[2] = 2 + + return np.linalg.solve(vandermonde, rhs).astype(dtype) + + +# FIXME: choose a better test case def f(x, y, z): return x**2 + y**2 + z**2 @@ -20,33 +33,62 @@ def laplacian_f(x, y, z): return 6 * np.ones_like(x) +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def laplacian_flop_count(npts: int, stencil_width: int) -> int: + radius = stencil_width // 2 + output_points = (npts - 2 * radius) ** 3 + return 4 * stencil_width * output_points + + def main( + npts: int = 64, + stencil_width: int = 5, use_compute: bool = False, print_device_code: bool = False, - print_kernel: bool = False - ) -> None: - npts = 64 + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> float | None: + if stencil_width <= 0 or stencil_width % 2 == 0: + raise ValueError("stencil_width must be a positive odd integer") + pts = np.linspace(-1, 1, num=npts, endpoint=True) h = pts[1] - pts[0] x, y, z = np.meshgrid(*(pts,)*3) - dtype = np.float32 - x = x.reshape(*(npts,)*3).astype(np.float32) - y = y.reshape(*(npts,)*3).astype(np.float32) - z = z.reshape(*(npts,)*3).astype(np.float32) - - f_ = f(x, y, z) - lap_fd = np.zeros_like(f_) - c = (np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2).astype(dtype) + dtype = np.float64 + x = x.reshape(*(npts,)*3).astype(dtype) + y = y.reshape(*(npts,)*3).astype(dtype) + z = z.reshape(*(npts,)*3).astype(dtype) - m = 5 + m = stencil_width r = m // 2 + c = (centered_second_derivative_coefficients(r, dtype) / h**2).astype(dtype) - bm = bn = m + bm = bn = 16 + bk = 32 - # FIXME: the usage on the k dimension is incorrect since we are only testing - # tiling (i, j) planes knl = lp.make_kernel( "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", """ @@ -58,58 +100,91 @@ def main( ) """, [ - lp.GlobalArg("u", dtype=dtype, shape=(npts,npts,npts)), - lp.GlobalArg("lap_u", dtype=dtype, shape=(npts,npts,npts), + lp.GlobalArg("u", dtype=dtype, shape=(npts, npts, npts)), + lp.GlobalArg("lap_u", dtype=dtype, shape=(npts, npts, npts), is_output=True), - lp.GlobalArg("c", dtype=dtype, shape=(m)) - ] + lp.GlobalArg("c", dtype=dtype, shape=(m,)) + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2 ) knl = lp.fix_parameters(knl, npts=npts, r=r) knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") - - # FIXME: need to split k dimension + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") if use_compute: - compute_map = nisl.make_map( - f""" - {{ - [is, js, ks] -> [io, ii_s, jo, ji_s, k] : + plane_map = nisl.make_map(f"""{{ + [is, js, ks] -> [io, ii_s, jo, ji_s, ko, ki] : is = io * {bm} + ii_s - {r} and js = jo * {bn} + ji_s - {r} and - ks = k - }} - """ - ) + ks = ko * {bk} + ki + }}""") knl = compute( knl, "u_", - compute_map=compute_map, + compute_map=plane_map, storage_indices=["ii_s", "ji_s"], - temporal_inames=["io", "jo", "k"], - temporary_name="u_compute", + temporal_inames=["io", "jo", "ko", "ki"], + + temporary_name="u_ij_plane", temporary_address_space=lp.AddressSpace.LOCAL, - temporary_dtype=np.float32 + temporary_dtype=dtype, + + compute_insn_id="u_plane_compute" ) + ring_buffer_map = nisl.make_map(f"""{{ + [is, js, ks] -> [io, ii, jo, ji, ko, ki, kb] : + is = io * {bm} + ii and + js = jo * {bn} + ji and + kb = ks - (ko * {bk} + ki) + {r} + }}""") + + knl = compute( + knl, + "u_", + compute_map=ring_buffer_map, + storage_indices=["kb"], + temporal_inames=["io", "ii", "jo", "ji", "ko", "ki"], + + temporary_name="u_k_buf", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + + compute_insn_id="u_ring_buf_compute", + inames_to_advance=["ki"] + ) + + nt = 16 + knl = lp.split_iname( + knl, "ii_s", nt, outer_iname="ii_s_tile", inner_iname="ii_s_local" + ) + + knl = lp.split_iname( + knl, "ji_s", nt, outer_iname="ji_s_tile", inner_iname="ji_s_local" + ) + + knl = lp.tag_inames(knl, { + # 2D plane compute storage loops + "ii_s_local": "l.1", + "ji_s_local": "l.0", + + # force the use of registers by unrolling + "kb": "unr" + }) + knl = lp.tag_inames(knl, { # outer block loops - "io" : "g.0", - "jo" : "g.1", - "k" : "g.2", - # "ko" : "g.2", + "io": "g.2", + "jo": "g.1", + "ko": "g.0", # inner tile loops - "ii" : "l.0", - "ji" : "l.1", - # "ki" : "l.2", - - # 2D plane compute storage loops - "ii_s" : "l.0", - "ji_s" : "l.1" + "ii": "l.1", + "ji": "l.0", }) if print_device_code: @@ -118,16 +193,42 @@ def main( if print_kernel: print(knl) + if not run_kernel: + return None + ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) ex = knl.executor(queue) - _, lap_fd = ex(queue, u=f(x, y, z), c=c) + f_vals = f(x, y, z) + + import pyopencl.array as cl_array + f_vals_cl = cl_array.to_device(queue, f_vals) + c_cl = cl_array.to_device(queue, c) + lap_u_cl = cl_array.zeros(queue, (npts,)*3, dtype=f_vals_cl.dtype) + avg_time_per_iter = benchmark_executor( + ex, queue, {"u": f_vals_cl, "c": c_cl, "lap_u": lap_u_cl}, + warmup=warmup, iterations=iterations) + avg_gflops = laplacian_flop_count(npts, stencil_width) / avg_time_per_iter / 1e9 + + _, lap_fd = ex(queue, u=f_vals_cl, c=c_cl, lap_u=lap_u_cl) lap_true = laplacian_f(x, y, z) sl = (slice(r, npts - r),)*3 - print(la.norm(lap_true[sl] - lap_fd[0][sl]) / la.norm(lap_true[sl])) + rel_err = la.norm(lap_true[sl] - lap_fd[0].get()[sl]) / la.norm(lap_true[sl]) + + print(20 * "=", "Finite difference report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Grid points : {npts}^3") + print(f"Stencil width: {stencil_width}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Finite difference report ")) * "=") + + return avg_time_per_iter if __name__ == "__main__": @@ -135,14 +236,62 @@ def main( parser = argparse.ArgumentParser() + _ = parser.add_argument("--npoints", action="store", type=int, default=64) + _ = parser.add_argument("--stencil-width", action="store", type=int, default=5) + + _ = parser.add_argument("--compare", action="store_true") _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", + dest="run_kernel") _ = parser.add_argument("--print-device-code", action="store_true") _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) args = parser.parse_args() - main( - use_compute=args.compute, - print_device_code=args.print_device_code, - print_kernel=args.print_kernel - ) + if args.compare: + print("Running example without compute...") + no_compute_time = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + npts=args.npoints, + stencil_width=args.stencil_width, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/matmul.py b/examples/python/compute-examples/matmul.py new file mode 100644 index 000000000..39acb9ef2 --- /dev/null +++ b/examples/python/compute-examples/matmul.py @@ -0,0 +1,400 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl +import pyopencl.array as cl_array + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def benchmark_kernel( + knl: lp.TranslationUnit, + queue: cl.CommandQueue, + a: np.ndarray, + b: np.ndarray, + nwarmup: int = 5, + niterations: int = 20 +): + ex = knl.executor(queue) + + a_cl = cl_array.to_device(queue, a) + b_cl = cl_array.to_device(queue, b) + c_cl = cl_array.zeros(queue, (a.shape[0], b.shape[1]), dtype=a_cl.dtype) + + start = cl.enqueue_marker(queue) + for _ in range(nwarmup): + ex(queue, a=a_cl, b=b_cl, c=c_cl) + end = cl.enqueue_marker(queue) + end.wait() + start.wait() + + start = cl.enqueue_marker(queue) + for _ in range(niterations): + ex(queue, a=a_cl, b=b_cl, c=c_cl) + end = cl.enqueue_marker(queue) + end.wait() + start.wait() + + total_ns = end.profile.end - start.profile.end + total_elapsed_s = total_ns * 1e-9 + s_per_iter = total_elapsed_s / niterations + + total_flops = 2 * a.shape[0] * a.shape[1] * b.shape[1] + gflops = (total_flops / s_per_iter) * 1e-9 + + c_ref = a @ b + _, c_res = ex(queue, a=a_cl, b=b_cl, c=c_cl) + + error = la.norm(c_res[0].get() - c_ref) / la.norm(c_ref) + + m, k = a.shape + _, n = b.shape + print(f"================= Results =================") + print(f"M = {m}, N = {n}, K = {k}") + print(f" Error = {error:.4}") + print(f" Total time (s): {total_elapsed_s:.4}") + print(f"Time per iter (s): {s_per_iter:.4}") + print(f" GFLOP/s: {gflops}") + print(f"===========================================") + + +def naive_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int + ) -> lp.TranslationUnit: + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + iname_tags = { + "io": "g.1", + "jo": "g.0", + + "ii": "l.1", + "ji": "l.0" + } + + return lp.tag_inames(knl, iname_tags) + + +def shared_memory_tiled_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int + ) -> lp.TranslationUnit: + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [a_ii, io, a_ki, ko, jo] : + is = io * {bm} + a_ii and + ks = ko * {bk} + a_ki + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [b_ki, ko, b_ji, jo, io] : + js = jo * {bn} + b_ji and + ks = ko * {bk} + b_ki + }}""") + + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["a_ii", "a_ki"], + temporal_inames=["io", "ko"], + temporary_name="a_tile", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="a_load" + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["b_ki", "b_ji"], + temporal_inames=["ko", "jo", "io"], + temporary_name="b_tile", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="b_load" + ) + + iname_tags = { + "io": "g.1", + "ii": "l.1", + + "jo": "g.0", + "ji": "l.0", + + "a_ii": "l.1", + "a_ki": "l.0", + + "b_ki": "l.1", + "b_ji": "l.0" + } + + return lp.tag_inames(knl, iname_tags) + + +def register_tiled_matmul( + knl: lp.TranslationUnit, + bm: int, + bn: int, + bk: int, + tm: int, + tn: int + ) -> lp.TranslationUnit: + + # {{{ shared-memory-level split / compute + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [a_ii, io, a_ki, ko, jo] : + is = io * {bm} + a_ii and + ks = ko * {bk} + a_ki + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [b_ki, ko, b_ji, jo, io] : + js = jo * {bn} + b_ji and + ks = ko * {bk} + b_ki + }}""") + + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["a_ii", "a_ki"], + temporal_inames=["io", "ko"], + temporary_name="a_smem", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="a_load" + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["b_ki", "b_ji"], + temporal_inames=["ko", "jo"], + temporary_name="b_smem", + temporary_address_space=lp.AddressSpace.LOCAL, + compute_insn_id="b_load" + ) + + wg_size_i = bm // tm + wg_size_j = bn // tn + knl = lp.split_iname( + knl, + "a_ii", + wg_size_i, + inner_iname="a_local", + outer_iname="a_tile" + ) + + knl = lp.split_iname( + knl, + "b_ji", + wg_size_j, + inner_iname="b_local", + outer_iname="b_tile" + ) + + # }}} + + # {{{ register-level split / compute + + knl = lp.extract_subst( + knl, + "a_smem_", + "a_smem[is, ks]", + parameters="is, ks" + ) + + knl = lp.extract_subst( + knl, + "b_smem_", + "b_smem[ks, js]", + parameters="ks, js" + ) + + knl = lp.split_iname(knl, "ii", tm, + inner_iname="ii_reg", + outer_iname="ii_thr") + + knl = lp.split_iname(knl, "ji", tn, + inner_iname="ji_reg", + outer_iname="ji_thr") + + knl = lp.split_iname(knl, "ki", 8, + inner_iname="dot", + outer_iname="ki_outer") + + a_reg_tile = nisl.make_map(f"""{{ + [is, ks] -> [a_reg_i, ii_thr, ji_thr, ki_outer, dot, io, jo, ko] : + is = ii_thr * {tm} + a_reg_i and + ks = ki_outer * 8 + dot + }}""") + + b_reg_tile = nisl.make_map(f"""{{ + [ks, js] -> [b_reg_j, ki_outer, dot, ii_thr, ji_thr, io, jo, ko] : + ks = ki_outer * 8 + dot and + js = ji_thr * {tn} + b_reg_j + }}""") + + knl = compute( + knl, + "a_smem_", + compute_map=a_reg_tile, + storage_indices=["a_reg_i"], + temporal_inames=["ii_thr", "ji_thr", "ki_outer", "dot", "io", "jo", "ko"], + temporary_name="a_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id="a_reg_load" + ) + + knl = compute( + knl, + "b_smem_", + compute_map=b_reg_tile, + storage_indices=["b_reg_j"], + temporal_inames=["ii_thr", "ji_thr", "ki_outer", "dot", "io", "jo", "ko"], + temporary_name="b_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id="b_reg_load" + ) + + # }}} + + iname_tags = { + # global tiles + "io" : "g.1", + "jo" : "g.0", + + # a local storage axes + "a_local": "l.1", + "a_ki" : "l.0", + + # b local storage axes + "b_local": "l.0", + "b_ki" : "l.1", + + # register tiles + "ii_thr": "l.1", + "ji_thr": "l.0", + + # register storage axes + "a_reg_i": "ilp", + "b_reg_j": "ilp", + + # compute axes + "ii_reg": "ilp", + "ji_reg": "ilp" + } + + return lp.tag_inames(knl, iname_tags) + + +def main( + m: int = 1024, + n: int = 1024, + k: int = 1024, + bm: int = 64, + bn: int = 64, + bk: int = 32, + tm: int = 4, + tn: int = 4, + shared_memory_tiled: bool = False, + register_tiled: bool = False, + dtype=np.float32, + print_kernel: bool = False, + print_device_code: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(m, k), dtype=dtype), + lp.GlobalArg("b", shape=(k, n), dtype=dtype), + lp.GlobalArg("c", shape=(m, n), is_output=True) + ] + ) + + knl = lp.fix_parameters(knl, M=m, N=n, K=k) + + if shared_memory_tiled: + knl = shared_memory_tiled_matmul(knl, bm, bn, bk) + elif register_tiled: + knl = register_tiled_matmul(knl, bm, bn, bk, tm, tn) + else: + knl = naive_matmul(knl, bm, bn, bk) + + ctx = cl.create_some_context() + queue = cl.CommandQueue( + ctx, + properties=cl.command_queue_properties.PROFILING_ENABLE + ) + + a = np.random.randn(m, k).astype(dtype) + b = np.random.randn(k, n).astype(dtype) + + benchmark_kernel(knl, queue, a, b) + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--m", action="store", type=int, default=1024) + _ = parser.add_argument("--n", action="store", type=int, default=1024) + _ = parser.add_argument("--k", action="store", type=int, default=1024) + + _ = parser.add_argument("--bm", action="store", type=int, default=64) + _ = parser.add_argument("--bn", action="store", type=int, default=64) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + + _ = parser.add_argument("--tm", action="store", type=int, default=4) + _ = parser.add_argument("--tn", action="store", type=int, default=4) + + _ = parser.add_argument("--shared-memory-tiled", action="store_true") + _ = parser.add_argument("--register-tiled", action="store_true") + + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + + args = parser.parse_args() + + main( + m=args.m, n=args.n, k=args.k, + bm=args.bm, bn=args.bn, bk=args.bk, + tm=args.tm, tn=args.tn, + shared_memory_tiled=args.shared_memory_tiled, + register_tiled=args.register_tiled, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code + ) diff --git a/examples/python/compute-examples/wave-equation-ring-buffer.py b/examples/python/compute-examples/wave-equation-ring-buffer.py new file mode 100644 index 000000000..501f89502 --- /dev/null +++ b/examples/python/compute-examples/wave-equation-ring-buffer.py @@ -0,0 +1,211 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def wave_flop_count(ntime: int) -> int: + return 5 * (ntime - 2) + + +def main( + ntime: int = 128, + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> float | None: + dtype = np.float64 + + dt = dtype(1 / 512) + omega = dtype(2 * np.pi) + omega2 = dtype(omega**2) + + t = dt * np.arange(ntime, dtype=dtype) + u = np.cos(omega * t).astype(dtype) + + bt = 32 + + knl = lp.make_kernel( + "{ [t] : 1 <= t < ntime - 1 }", + """ + u_hist(ts) := u[ts] + + u_next[t + 1] = ( + 2 * u_hist(t) + - u_hist(t - 1) + - dt2 * omega2 * u_hist(t) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(ntime,)), + lp.GlobalArg("u_next", dtype=dtype, shape=(ntime,), + is_output=True), + lp.ValueArg("dt2", dtype=dtype), + lp.ValueArg("omega2", dtype=dtype), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntime=ntime) + knl = lp.split_iname(knl, "t", bt, inner_iname="ti", outer_iname="to") + + if use_compute: + ring_buffer_map = nisl.make_map(f"""{{ + [ts] -> [to, ti, tb] : + tb = ts - (to * {bt} + ti) + 1 + }}""") + + knl = compute( + knl, + "u_hist", + compute_map=ring_buffer_map, + storage_indices=["tb"], + temporal_inames=["to", "ti"], + + temporary_name="u_time_buf", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + + compute_insn_id="u_time_buf_compute", + inames_to_advance=["ti"], + ) + + knl = lp.tag_inames(knl, {"tb": "unr"}) + + knl = lp.tag_inames(knl, {"to": "g.0"}) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + if not run_kernel: + return None + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + + dt2 = dtype(dt**2) + avg_time_per_iter = benchmark_executor( + ex, queue, {"u": u, "dt2": dt2, "omega2": omega2}, + warmup=warmup, iterations=iterations) + avg_gflops = wave_flop_count(ntime) / avg_time_per_iter / 1e9 + + _, out = ex(queue, u=u, dt2=dt2, omega2=omega2) + + ref = np.zeros_like(u) + for time_idx in range(1, ntime - 1): + ref[time_idx + 1] = ( + 2 * u[time_idx] + - u[time_idx - 1] + - dt2 * omega2 * u[time_idx] + ) + + sl = slice(2, ntime) + rel_err = la.norm(ref[sl] - out[0][sl]) / la.norm(ref[sl]) + + print(20 * "=", "Wave recurrence report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Time steps : {ntime}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Wave recurrence report ")) * "=") + + return avg_time_per_iter + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--ntime", action="store", type=int, default=128) + + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", + dest="run_kernel") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + if args.compare: + print("Running example without compute...") + no_compute_time = main( + ntime=args.ntime, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + ntime=args.ntime, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + ntime=args.ntime, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 1346a9162..6d7536b87 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,13 +1,18 @@ -from collections.abc import Mapping, Sequence, Set -from typing import override -from typing_extensions import TypeAlias -import loopy as lp -from loopy.kernel.tools import DomainChanger -from loopy.types import ToLoopyTypeConvertible, to_loopy_type +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias, override + import namedisl as nisl -from loopy.kernel import LoopKernel -from loopy.kernel.data import AddressSpace +import islpy as isl +import pymbolic.primitives as p +from pymbolic import var +from pymbolic.mapper.dependency import DependencyMapper +from pymbolic.mapper.substitutor import make_subst_func +from pytools.tag import Tag + +import loopy as lp +from loopy.kernel.tools import DomainChanger from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( ExpansionState, @@ -18,33 +23,119 @@ multi_pw_aff_from_exprs, pw_aff_to_expr, ) -from loopy.transform.precompute import ( - contains_a_subst_rule_invocation -) +from loopy.transform.precompute import contains_a_subst_rule_invocation from loopy.translation_unit import for_each_kernel -from pymbolic import var -from pymbolic.mapper.substitutor import make_subst_func +from loopy.types import ToLoopyTypeConvertible, to_loopy_type -import islpy as isl -import pymbolic.primitives as p -from pymbolic.mapper.dependency import DependencyMapper -from pymbolic.typing import Expression -from pytools.tag import Tag +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence, Set + + from pymbolic.typing import Expression -AccessTuple: TypeAlias = tuple[Expression, ...] + from loopy.kernel import LoopKernel + from loopy.kernel.data import AddressSpace + + +AccessTuple: TypeAlias = tuple[str, ...] + + +def _access_key(args: Sequence[Expression]) -> AccessTuple: + return tuple(str(arg) for arg in args) + + +def _base_name(name: str) -> str: + return name.removesuffix("_") + + +def _cur_name(name: str) -> str: + return f"{_base_name(name)}_cur" + + +def _prev_name(name: str) -> str: + return f"{_base_name(name)}_prev" + + +def _basic_set_to_predicates(bset: nisl.BasicSet) -> frozenset[Expression]: + isl_bset = bset._reconstruct_isl_object() + + predicates = [] + for constraint in isl_bset.get_constraints(): + expr = pw_aff_to_expr(constraint.get_aff()) + if constraint.is_equality(): + predicates.append(p.Comparison(expr, "==", 0)) + else: + predicates.append(p.Comparison(expr, ">=", 0)) + + return frozenset(predicates) + + +def _set_to_predicate_options( + set_: nisl.Set | nisl.BasicSet + ) -> Sequence[frozenset[Expression]]: + if isinstance(set_, nisl.BasicSet): + if set_._reconstruct_isl_object().is_empty(): + return [] + return [_basic_set_to_predicates(set_)] + + predicate_options = [] + for bset in set_.get_basic_sets(): + if not bset._reconstruct_isl_object().is_empty(): + predicate_options.append(_basic_set_to_predicates(bset)) + + return predicate_options # helper for gathering names of variables in pymbolic expressions def _gather_vars(expr: Expression) -> set[str]: deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif ( + isinstance(dep, p.Subscript) + and isinstance(dep.aggregate, p.Variable)): + var_names.add(dep.aggregate.name) + + return var_names + + +def _existing_name_mapping( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str] + ) -> Mapping[str, str]: + names = map_.names return { - dep.name - for dep in deps - if isinstance(dep, p.Variable) + source: target + for source, target in name_mapping.items() + if source in names and target in names } +def _normalize_renamed_dims( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], + ) -> nisl.Map | nisl.BasicMap: + map_ = map_.equate_dims(_existing_name_mapping(map_, name_mapping)) + + names = map_.names + project_names = [ + renamed_name + for original_name, renamed_name in name_mapping.items() + if original_name in names and renamed_name in names + ] + map_ = map_.project_out(project_names) + + names = map_.names + rename_mapping = { + renamed_name: original_name + for original_name, renamed_name in name_mapping.items() + if original_name not in names and renamed_name in names + } + return map_.rename_dims(rename_mapping) + + # {{{ gathering usage expressions class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): @@ -72,7 +163,6 @@ def __init__( self.usage_expressions: list[Sequence[Expression]] = [] - @override def map_subst_rule( self, @@ -114,10 +204,10 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Mapping[AccessTuple, nisl.Map], + usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap], storage_indices: Sequence[str], temporary_name: str, - compute_insn_id: str, + compute_insn_ids: str | Sequence[str], footprint: nisl.Set ) -> None: @@ -126,20 +216,23 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Mapping[AccessTuple, nisl.Map] = \ + self.usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = \ usage_descriptors self.storage_indices: Sequence[str] = storage_indices self.footprint: nisl.Set = footprint self.temporary_name: str = temporary_name - self.compute_insn_id: str = compute_insn_id + self.compute_insn_ids: frozenset[str] = ( + frozenset([compute_insn_ids]) + if isinstance(compute_insn_ids, str) + else frozenset(compute_insn_ids) + ) self.replaced_something: bool = False # FIXME: may not always be the case (i.e. global barrier between # compute insn and uses) - self.compute_dep_id: str = compute_insn_id - + self.compute_dep_ids: frozenset[str] = self.compute_insn_ids @override def map_subst_rule( @@ -158,17 +251,20 @@ def map_subst_rule( # {{{ validation checks - if not name == self.subst_name: + if name != self.subst_name: return super().map_subst_rule(name, tags, arguments, expn_state) - if not tuple(args) in self.usage_descriptors: + access_key = _access_key(args) + if access_key not in self.usage_descriptors: return super().map_subst_rule(name, tags, arguments, expn_state) - if not len(arguments) == len(rule.arguments): - raise ValueError("Number of arguments passed to rule {name} ", - "does not match the signature of {name}.") + if len(arguments) != len(rule.arguments): + raise ValueError( + f"Number of arguments passed to rule {name} " + f"does not match the signature of {name}." + ) - local_map = self.usage_descriptors[tuple(args)] + local_map = self.usage_descriptors[access_key] temp_footprint = self.footprint.move_dims( frozenset(self.footprint.names) - frozenset(self.storage_indices), isl.dim_type.param @@ -181,7 +277,7 @@ def map_subst_rule( # {{{ get index expression in terms of global inames - local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() + local_pwmaff = self.usage_descriptors[access_key].as_pw_multi_aff() index_exprs: Sequence[Expression] = [] for dim in range(local_pwmaff.dim(isl.dim_type.out)): @@ -195,7 +291,6 @@ def map_subst_rule( return new_expression - @override def map_kernel( self, @@ -215,13 +310,13 @@ def map_kernel( continue insn = insn.with_transformed_expressions( - lambda expr: self(expr, kernel, insn) + lambda expr, insn=insn: self(expr, kernel, insn) ) if self.replaced_something: insn = insn.copy( depends_on=( - insn.depends_on | frozenset([self.compute_insn_id]) + insn.depends_on | self.compute_dep_ids ) ) @@ -239,9 +334,12 @@ def compute( kernel: LoopKernel, substitution: str, compute_map: nisl.Map, + storage_indices: Sequence[str], + # FIXME: can these two be deduced? temporal_inames: Sequence[str], + inames_to_advance: Sequence[str] | None = None, temporary_name: str | None = None, temporary_address_space: AddressSpace | None = None, @@ -266,7 +364,14 @@ def compute( to create inames for the loops that cover the required set of compute points. """ - # {{{ setup and useful items + name_mapping = { + name: name + "_" + for name in compute_map.output_names + if name not in storage_indices + } + compute_map = compute_map.rename_dims(name_mapping) + + # {{{ setup and useful variables storage_set = frozenset(storage_indices) temporal_set = frozenset(temporal_inames) @@ -295,16 +400,16 @@ def compute( named_domain = nisl.make_basic_set(domain_changer.domain) # restrict domain to used inames - usage_domain = named_domain.project_out_except(usage_inames) + local_domain = named_domain.project_out_except(usage_inames) # FIXME: gross. find a cleaner way to generate a space for an empty map global_usage_map = nisl.make_map_from_domain_and_range( - nisl.make_set(isl.Set.empty(usage_domain.get_space())), + nisl.make_set(isl.Set.empty(local_domain.get_space())), compute_map.domain() ) global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) - usage_substs: Mapping[AccessTuple, nisl.Map] = {} + usage_substs: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = {} for usage in usage_exprs: # {{{ compute local usage map, update global usage map @@ -316,7 +421,7 @@ def compute( local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) - local_usage_map = local_usage_map.intersect_domain(usage_domain) + local_usage_map = local_usage_map.intersect_domain(local_domain) global_usage_map = global_usage_map | local_usage_map # }}} @@ -324,33 +429,33 @@ def compute( # {{{ compute storage map local_storage_map = local_usage_map.apply_range(compute_map) + local_storage_map = _normalize_renamed_dims( + local_storage_map, name_mapping) # check that no restrictions happened during composition (i.e. tile # valid for a single point in the domain) if not local_usage_map.domain() <= local_storage_map.domain(): continue + # clean up names non_param_names = (usage_inames - temporal_set) | storage_set parameter_names = frozenset(local_storage_map.names) - non_param_names - local_storage_map = local_storage_map.move_dims(parameter_names, isl.dim_type.param) # }}} - usage_substs[tuple(usage)] = local_storage_map + usage_substs[_access_key(usage)] = local_storage_map - global_usage_map = global_usage_map.apply_range(compute_map) + storage_map = global_usage_map.apply_range(compute_map) + storage_map = _normalize_renamed_dims(storage_map, name_mapping) # }}} # {{{ compute bounds and update kernel domain - global_usage_map = global_usage_map.move_dims( - temporal_set, - isl.dim_type.param - ) - footprint = global_usage_map.range() + storage_map = storage_map.move_dims(temporal_set, isl.dim_type.param) + footprint = storage_map.range() # clean up ticked duplicate names footprint = footprint.project_out_except(temporal_set | storage_set) @@ -358,7 +463,9 @@ def compute( # {{{ FIXME: use Sets instead of BasicSets when loopy is ready - footprint = nisl.make_set(footprint._reconstruct_isl_object().convex_hull()) + # FIXME: convex hull is not permanent + footprint_isl = footprint._reconstruct_isl_object() + footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) named_domain = named_domain & footprint if len(named_domain.get_basic_sets()) != 1: @@ -374,14 +481,140 @@ def compute( # }}} + if not temporary_name: + temporary_name = substitution + "_temp" + + if not compute_insn_id: + compute_insn_id = substitution + "_compute" + + # {{{ reuse analysis + + update_insns: list[lp.InstructionBase] = [] + update_insn_ids: list[str] = [] + refill_predicate_options: Sequence[frozenset[Expression] | None] = [None] + current_update_deps: frozenset[str] = frozenset() + + if inames_to_advance is not None: + advancing_set = frozenset(inames_to_advance) + + compute_map_cur = compute_map.rename_dims({ + name: _cur_name(name) for name in compute_map.output_names + }) + compute_map_prev = compute_map.rename_dims({ + name: _prev_name(name) for name in compute_map.output_names + }) + + cur_storage = global_usage_map.apply_range(compute_map_cur) + prev_storage = global_usage_map.apply_range(compute_map_prev) + + reuse_map = prev_storage.reverse().apply_range(cur_storage) + reuse_map = reuse_map.add_constraint([ + ( + f"{name}_cur = {name}_prev + 1" + if name in advancing_set + else + f"{name}_cur = {name}_prev" + ) + for name in temporal_inames + ]) + + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + previous_footprint = footprint.rename_dims({ + name: _prev_name(name) for name in footprint.names + }) + + reuse_map = reuse_map.intersect_domain(previous_footprint) + reuse_map = reuse_map.intersect_range(current_footprint) + reuse_map = reuse_map - nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in footprint.names) + + "] -> [" + + ", ".join(_cur_name(name) for name in footprint.names) + + "] : " + + " and ".join( + f"{_cur_name(name)} = {_prev_name(name)}" + for name in storage_indices + ) + + " }" + ) + + reused_current = reuse_map.range() + refill = current_footprint - reused_current + + cur_to_normal = { + _cur_name(name): name + for name in footprint.names + } + reused_current = reused_current.rename_dims(cur_to_normal) + refill = refill.rename_dims(cur_to_normal) + + reused_context = named_domain.project_out_except(reused_current.names) + refill_context = named_domain.project_out_except(refill.names) + + reused_current = reused_current.gist(reused_context) + refill = refill.gist(refill_context) + + refill_predicate_options = _set_to_predicate_options(refill) + + storage_reuse_map = reuse_map.project_out_except( + frozenset(_prev_name(name) for name in storage_indices) + | frozenset(_cur_name(name) for name in storage_indices) + ) + storage_reuse_map = storage_reuse_map.rename_dims({ + _cur_name(name): name + for name in storage_indices + }) + cur_to_prev = storage_reuse_map.reverse() + cur_to_prev_pwma = cur_to_prev.as_pw_multi_aff() + prev_expr_by_name = { + cur_to_prev_pwma.get_dim_name(isl.dim_type.out, dim): + pw_aff_to_expr(cur_to_prev_pwma.get_at(dim)) + for dim in range(cur_to_prev_pwma.dim(isl.dim_type.out)) + } + prev_storage_exprs = [ + prev_expr_by_name[_prev_name(name)] + for name in storage_indices + ] + + shift_assignee = var(temporary_name)[ + tuple(var(idx) for idx in storage_indices) + ] + shift_expression = var(temporary_name)[tuple(prev_storage_exprs)] + + shift_predicate_options = _set_to_predicate_options(reused_current) + for i, predicates in enumerate(shift_predicate_options): + shift_insn_id = ( + f"{compute_insn_id}_shift" + if len(shift_predicate_options) == 1 + else f"{compute_insn_id}_shift_{i}" + ) + update_insns.append(lp.Assignment( + id=shift_insn_id, + assignee=shift_assignee, + expression=shift_expression, + within_inames=frozenset(temporal_inames) | storage_set, + predicates=predicates, + depends_on=current_update_deps, + )) + update_insn_ids.append(shift_insn_id) + current_update_deps = frozenset([shift_insn_id]) + + # }}} + # {{{ create compute instruction in kernel - # FIXME: + # FIXME: maybe just keep original around? + compute_map = compute_map.rename_dims({ + value: key for key, value in name_mapping.items() + }) + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { - compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : + compute_pw_aff.get_dim_name(isl.dim_type.out, dim): pw_aff_to_expr(compute_pw_aff.get_at(dim)) - for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } expr_subst_map = RuleAwareSubstitutionMapper( @@ -391,26 +624,41 @@ def compute( ) subst_expr = kernel.substitutions[substitution].expression - compute_expression = expr_subst_map(subst_expr, kernel, None) - - if not temporary_name: - temporary_name = substitution + "_temp" + compute_expression = expr_subst_map( + subst_expr, + kernel, + None, + ) + compute_dep_ids = frozenset().union(*( + kernel.writer_map().get(var_name, frozenset()) + for var_name in _gather_vars(compute_expression) + )) assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] within_inames = compute_map.output_names - if not compute_insn_id: - compute_insn_id = substitution + "_compute" - compute_insn = lp.Assignment( - id=compute_insn_id, - assignee=assignee, - expression=compute_expression, - within_inames=within_inames - ) - new_insns = list(kernel.instructions) - new_insns.append(compute_insn) + new_insns.extend(update_insns) + + for i, predicates in enumerate(refill_predicate_options): + refill_insn_id = ( + compute_insn_id + if len(refill_predicate_options) == 1 + else f"{compute_insn_id}_refill_{i}" + ) + compute_insn = lp.Assignment( + id=refill_insn_id, + assignee=assignee, + expression=compute_expression, + within_inames=within_inames, + predicates=predicates, + depends_on=current_update_deps | compute_dep_ids, + ) + new_insns.append(compute_insn) + update_insn_ids.append(refill_insn_id) + current_update_deps = frozenset([refill_insn_id]) + kernel = kernel.copy(instructions=new_insns) # }}} @@ -428,7 +676,7 @@ def compute( usage_substs, storage_indices, temporary_name, - compute_insn_id, + update_insn_ids, footprint ) From 33814e244b5eab10fc868a6b1d70947112ee8e1d Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 17 Apr 2026 13:45:30 -0500 Subject: [PATCH 22/27] add more examples --- .../finite-difference-2-5D.py | 3 - .../l2p-3d-tensor-product-compute.py | 489 ++++++++++++++ .../l2p-tiled-basis-compute.py | 351 ++++++++++ .../compute-examples/m2m-sum-factorization.py | 636 ++++++++++++++++++ examples/python/compute-examples/matmul.py | 2 +- .../compute-examples/p2m-basis-compute.py | 369 ++++++++++ loopy/transform/compute.py | 2 +- 7 files changed, 1847 insertions(+), 5 deletions(-) create mode 100644 examples/python/compute-examples/l2p-3d-tensor-product-compute.py create mode 100644 examples/python/compute-examples/l2p-tiled-basis-compute.py create mode 100644 examples/python/compute-examples/m2m-sum-factorization.py create mode 100644 examples/python/compute-examples/p2m-basis-compute.py diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index ff15afb49..43d09309e 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -14,9 +14,6 @@ def centered_second_derivative_coefficients(radius: int, dtype) -> np.ndarray: offsets = np.arange(-radius, radius + 1, dtype=dtype) powers = np.arange(2 * radius + 1) - - # Enforce exactness for monomials through degree 2r. The second derivative - # of x**2 at zero is 2, all other monomial derivatives in this range vanish. vandermonde = offsets[np.newaxis, :] ** powers[:, np.newaxis] rhs = np.zeros(2 * radius + 1, dtype=dtype) rhs[2] = 2 diff --git a/examples/python/compute-examples/l2p-3d-tensor-product-compute.py b/examples/python/compute-examples/l2p-3d-tensor-product-compute.py new file mode 100644 index 000000000..820947ef4 --- /dev/null +++ b/examples/python/compute-examples/l2p-3d-tensor-product-compute.py @@ -0,0 +1,489 @@ +"""Benchmark a 3D Cartesian Taylor L2P microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic local +expansion evaluation in three spatial dimensions. This script does not +directly evaluate the 3D Laplace, Helmholtz, or biharmonic Green's function. +The dense local coefficient tensor ``gamma`` is assumed to have already been +formed by the relevant FMM translation machinery; this benchmark isolates the +target-side monomial contraction. + +The kernel evaluates a dense 3D tensor-product local expansion at many target +points: + + phi[itgt] = sum_{q0,q1,q2} gamma[q0, q1, q2] + * x[itgt]**q0 / q0! + * y[itgt]**q1 / q1! + * z[itgt]**q2 / q2! + +The baseline variants are GPU-parallel kernels over target blocks that expand +the basis substitutions inline. The compute variants use +:func:`loopy.transform.compute.compute` to materialize the x, y, and z basis +values into private temporaries. The script includes both a direct tiled +compute schedule and an optimized register-tiled schedule, following the style +of Loopy's compute matmul example. + +Use ``--compare`` to run the naive parallel baseline and the optimized compute +kernel, validate both against the NumPy reference, and report timing, modeled +GFLOP/s, speedup, and relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_l2p_3d( + gamma: np.ndarray, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = gamma.shape[0] - 1 + result = np.empty_like(x) + + for itgt in range(x.size): + acc = 0 + for q0 in range(order + 1): + x_basis = x[itgt]**q0 * inv_fact[q0] + for q1 in range(order + 1): + y_basis = y[itgt]**q1 * inv_fact[q1] + for q2 in range(order + 1): + z_basis = z[itgt]**q2 * inv_fact[q2] + acc += gamma[q0, q1, q2] * x_basis * y_basis * z_basis + result[itgt] = acc + + return result + + +def make_kernel( + ntargets: int, + order: int, + dtype + ) -> lp.TranslationUnit: + knl = lp.make_kernel( + "{ [itgt, q0, q1, q2] : " + "0 <= itgt < ntargets and 0 <= q0, q1, q2 <= p }", + """ + x_basis_(itgt_arg, q0_arg) := ( + x[itgt_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_basis_(itgt_arg, q1_arg) := ( + y[itgt_arg] ** q1_arg * inv_fact[q1_arg] + ) + + z_basis_(itgt_arg, q2_arg) := ( + z[itgt_arg] ** q2_arg * inv_fact[q2_arg] + ) + + phi[itgt] = sum( + [q0, q1, q2], + gamma[q0, q1, q2] + * x_basis_(itgt, q0) + * y_basis_(itgt, q1) + * z_basis_(itgt, q2) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("y", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("z", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg( + "gamma", + dtype=dtype, + shape=(order + 1, order + 1, order + 1), + ), + lp.GlobalArg("phi", dtype=dtype, shape=(ntargets,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + return lp.fix_parameters(knl, ntargets=ntargets, p=order) + + +def split_targets( + knl: lp.TranslationUnit, + target_block_size: int + ) -> lp.TranslationUnit: + knl = lp.split_iname( + knl, + "itgt", + target_block_size, + inner_iname="itgt_inner", + outer_iname="itgt_block", + ) + return lp.tag_inames(knl, {"itgt_block": "g.0"}) + + +def block_private_l2p_3d( + knl: lp.TranslationUnit, + target_block_size: int, + dtype + ) -> lp.TranslationUnit: + knl = split_targets(knl, target_block_size) + + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_s, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["itgt_s", "q0_s"], + temporal_inames=["itgt_block"], + temporary_name="x_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_s, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["itgt_s", "q1_s"], + temporal_inames=["itgt_block"], + temporary_name="y_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + z_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q2_arg] -> [itgt_block, itgt_s, q2_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_s and + q2_arg = q2_s + }}""") + + return compute( + knl, + "z_basis_", + compute_map=z_basis_map, + storage_indices=["itgt_s", "q2_s"], + temporal_inames=["itgt_block"], + temporary_name="z_basis_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_basis_compute", + ) + + +def register_tiled_l2p_3d( + knl: lp.TranslationUnit, + target_block_size: int, + dtype + ) -> lp.TranslationUnit: + knl = split_targets(knl, target_block_size) + + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_inner, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["q0_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="x_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_inner, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["q1_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="y_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + z_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q2_arg] -> [itgt_block, itgt_inner, q2_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q2_arg = q2_s + }}""") + + knl = compute( + knl, + "z_basis_", + compute_map=z_basis_map, + storage_indices=["q2_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="z_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_basis_compute", + ) + + return lp.tag_inames(knl, { + "itgt_inner": "l.0", + "q0_s": "unr", + "q1_s": "unr", + "q2_s": "unr", + "q0": "unr", + "q1": "unr", + "q2": "unr", + }) + + +def operation_model( + ntargets: int, + order: int, + target_block_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_basis_evals = 3 * ntargets * ncoeff**3 + tiled_compute_basis_evals = 3 * ntargets * ncoeff + return inline_basis_evals, tiled_compute_basis_evals + + +def l2p_3d_flop_count(ntargets: int, order: int, use_compute: bool) -> int: + ncoeff = order + 1 + + contraction_flops = 4 * ntargets * ncoeff**3 + if use_compute: + basis_scale_flops = 3 * ntargets * ncoeff + else: + basis_scale_flops = 3 * ntargets * ncoeff**3 + + return contraction_flops + basis_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + inv_fact: np.ndarray, + gamma: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + z_cl = cl_array.to_device(queue, z) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + gamma_cl = cl_array.to_device(queue, gamma) + phi_cl = cl_array.zeros(queue, x.shape, dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "z": z_cl, + "inv_fact": inv_fact_cl, + "gamma": gamma_cl, + "phi": phi_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, z=z_cl, inv_fact=inv_fact_cl, + gamma=gamma_cl, phi=phi_cl) + return out[0].get(), elapsed + + +def main( + ntargets: int = 256, + order: int = 8, + target_block_size: int = 32, + use_compute: bool = False, + use_block_private_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + if ntargets % target_block_size: + raise ValueError("ntargets must be divisible by target_block_size") + + dtype = np.float64 + rng = np.random.default_rng(22) + x = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + z = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + inv_fact = inv_factorials(order, dtype) + gamma = rng.normal(size=(order + 1, order + 1, order + 1)).astype(dtype) + reference = reference_l2p_3d(gamma, x, y, z, inv_fact) + + inline_evals, compute_evals = operation_model( + ntargets, order, target_block_size) + + if compare: + variants = ["inline", "register-tiled compute"] + elif use_block_private_compute: + variants = ["block-private compute"] + elif use_compute: + variants = ["register-tiled compute"] + else: + variants = ["inline"] + + timings: dict[str, float] = {} + for variant in variants: + knl = make_kernel(ntargets, order, dtype) + + if variant == "inline": + knl = split_targets(knl, target_block_size) + knl = lp.tag_inames(knl, { + "itgt_inner": "l.0", + "q0": "unr", + "q1": "unr", + "q2": "unr", + }) + elif variant == "block-private compute": + knl = block_private_l2p_3d(knl, target_block_size, dtype) + elif variant == "register-tiled compute": + knl = register_tiled_l2p_3d(knl, target_block_size, dtype) + else: + raise ValueError(f"unknown variant '{variant}'") + + variant_uses_compute = variant != "inline" + modeled_flops = l2p_3d_flop_count( + ntargets, order, use_compute=variant_uses_compute) + + print(20 * "=", "3D L2P basis report", 20 * "=") + print(f"Variant : {variant}") + print(f"Targets : {ntargets}") + print(f"Order : {order}") + print(f"Target block: {target_block_size}") + print(f"Inline basis evaluations: {inline_evals}") + print(f"Tiled compute evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, z, inv_fact, gamma, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" 3D L2P basis report ")) * "=") + + if ( + compare + and "inline" in timings + and "register-tiled compute" in timings): + speedup = timings["inline"] / timings["register-tiled compute"] + time_reduction = ( + 1 - timings["register-tiled compute"] / timings["inline"]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--ntargets", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=8) + _ = parser.add_argument("--target-block-size", action="store", + type=int, default=32) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--block-private-compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + ntargets=args.ntargets, + order=args.order, + target_block_size=args.target_block_size, + use_compute=args.compute, + use_block_private_compute=args.block_private_compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/l2p-tiled-basis-compute.py b/examples/python/compute-examples/l2p-tiled-basis-compute.py new file mode 100644 index 000000000..f11a1054a --- /dev/null +++ b/examples/python/compute-examples/l2p-tiled-basis-compute.py @@ -0,0 +1,351 @@ +"""Benchmark a 2D Cartesian Taylor L2P microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic local +expansion evaluation. This script does not evaluate a particular Green's +function such as the 2D Laplace or Helmholtz kernel. The local coefficients +``gamma`` are treated as already available; if they came from a Laplace FMM, +this kernel is the L2P monomial-contraction stage after the Laplace-specific +coefficient/derivative work has already happened. + +The kernel evaluates a tensor-product Taylor-like local expansion at many +target points: + + phi[itgt] = sum_{q0,q1} gamma[q0, q1] + * x[itgt]**q0 / q0! + * y[itgt]**q1 / q1! + +The inline variant is a parallel GPU kernel over target blocks that expands the +two basis substitutions at every use inside the coefficient sum. The compute +variant uses :func:`loopy.transform.compute.compute` to materialize the x and y +basis values into private temporaries for each target, so the powers/factorial +scalings are reused across the inner coefficient loops instead of recomputed. + +Use ``--compare`` to run both GPU-parallel variants, check against the NumPy +reference implementation, and report timing, modeled GFLOP/s, speedup, and +relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_l2p( + gamma: np.ndarray, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = gamma.shape[0] - 1 + result = np.empty_like(x) + + for itgt in range(x.size): + acc = 0 + for q0 in range(order + 1): + x_basis = x[itgt]**q0 * inv_fact[q0] + for q1 in range(order + 1): + y_basis = y[itgt]**q1 * inv_fact[q1] + acc += gamma[q0, q1] * x_basis * y_basis + result[itgt] = acc + + return result + + +def make_kernel( + ntargets: int, + order: int, + target_block_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if ntargets % target_block_size: + raise ValueError("ntargets must be divisible by target_block_size") + + knl = lp.make_kernel( + "{ [itgt, q0, q1] : 0 <= itgt < ntargets and 0 <= q0, q1 <= p }", + """ + x_basis_(itgt_arg, q0_arg) := ( + x[itgt_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_basis_(itgt_arg, q1_arg) := ( + y[itgt_arg] ** q1_arg * inv_fact[q1_arg] + ) + + phi[itgt] = sum( + [q0, q1], + gamma[q0, q1] * x_basis_(itgt, q0) * y_basis_(itgt, q1) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("y", dtype=dtype, shape=(ntargets,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg("gamma", dtype=dtype, shape=(order + 1, order + 1)), + lp.GlobalArg("phi", dtype=dtype, shape=(ntargets,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntargets=ntargets, p=order) + knl = lp.split_iname( + knl, + "itgt", + target_block_size, + inner_iname="itgt_inner", + outer_iname="itgt_block", + ) + + if use_compute: + x_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q0_arg] -> [itgt_block, itgt_inner, q0_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q0_arg = q0_s + }}""") + + knl = compute( + knl, + "x_basis_", + compute_map=x_basis_map, + storage_indices=["q0_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="x_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_basis_compute", + ) + + y_basis_map = nisl.make_map(f"""{{ + [itgt_arg, q1_arg] -> [itgt_block, itgt_inner, q1_s] : + itgt_arg = itgt_block * {target_block_size} + itgt_inner and + q1_arg = q1_s + }}""") + + knl = compute( + knl, + "y_basis_", + compute_map=y_basis_map, + storage_indices=["q1_s"], + temporal_inames=["itgt_block", "itgt_inner"], + temporary_name="y_basis_reg", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_basis_compute", + ) + + iname_tags = { + "itgt_block": "g.0", + "itgt_inner": "l.0", + "q0": "unr", + "q1": "unr", + } + if use_compute: + iname_tags.update({ + "q0_s": "unr", + "q1_s": "unr", + }) + + knl = lp.tag_inames(knl, iname_tags) + return knl + + +def operation_model( + ntargets: int, + order: int, + target_block_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_basis_evals = 2 * ntargets * ncoeff**2 + tiled_compute_basis_evals = 2 * ntargets * ncoeff + return inline_basis_evals, tiled_compute_basis_evals + + +def l2p_flop_count(ntargets: int, order: int, use_compute: bool) -> int: + ncoeff = order + 1 + + contraction_flops = 3 * ntargets * ncoeff**2 + if use_compute: + basis_scale_flops = 2 * ntargets * ncoeff + else: + basis_scale_flops = 2 * ntargets * ncoeff**2 + + return contraction_flops + basis_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray, + gamma: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + gamma_cl = cl_array.to_device(queue, gamma) + phi_cl = cl_array.zeros(queue, x.shape, dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "inv_fact": inv_fact_cl, + "gamma": gamma_cl, + "phi": phi_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, inv_fact=inv_fact_cl, + gamma=gamma_cl, phi=phi_cl) + return out[0].get(), elapsed + + +def main( + ntargets: int = 256, + order: int = 12, + target_block_size: int = 32, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + dtype = np.float64 + rng = np.random.default_rng(14) + x = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=ntargets).astype(dtype) + inv_fact = inv_factorials(order, dtype) + gamma = rng.normal(size=(order + 1, order + 1)).astype(dtype) + reference = reference_l2p(gamma, x, y, inv_fact) + + inline_evals, compute_evals = operation_model( + ntargets, order, target_block_size) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + ntargets, order, target_block_size, dtype, + use_compute=variant_uses_compute) + modeled_flops = l2p_flop_count( + ntargets, order, use_compute=variant_uses_compute) + + print(20 * "=", "L2P basis report", 20 * "=") + print(f"Variant : {'tiled compute' if variant_uses_compute else 'inline'}") + print(f"Targets : {ntargets}") + print(f"Order : {order}") + print(f"Target block: {target_block_size}") + print(f"Inline basis evaluations: {inline_evals}") + print(f"Tiled compute evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, inv_fact, gamma, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" L2P basis report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--ntargets", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=12) + _ = parser.add_argument("--target-block-size", action="store", + type=int, default=32) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + ntargets=args.ntargets, + order=args.order, + target_block_size=args.target_block_size, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/m2m-sum-factorization.py b/examples/python/compute-examples/m2m-sum-factorization.py new file mode 100644 index 000000000..d3618d27d --- /dev/null +++ b/examples/python/compute-examples/m2m-sum-factorization.py @@ -0,0 +1,636 @@ +"""Benchmark compressed Cartesian Taylor M2M sum factorization with compute. + +FMM kernel class: compressed Cartesian Taylor/asymptotic multipole-to-multipole +translation. This is the binomial center-shift part of an FMM translation, not +a direct evaluation of a Green's function such as Laplace, Helmholtz, or +biharmonic. The translation weights are powers of the center displacement +divided by factorials; PDE-specific derivative generation, compression +matrices, and recompression are outside this microbenchmark. + +The stored-index pattern here is intentionally simple. The 2D mode stores the +two coordinate axes, and the 3D mode stores the three coordinate axes. That +captures the sum-factorized structure from Section 4.2.3, but it is not the +full compressed 3D Laplace stored set, which would retain PDE-derived +hyperplane layers with O(p**2) stored coefficients rather than only O(p) axis +coefficients. + +This script models a multipole-to-multipole-like translation in 2D or 3D where +the input expansion is stored only on the coordinate axes. For 2D, the stored +coefficients are ``beta[zeta0, 0]`` and ``beta[0, zeta1]``. For 3D, they are +``beta[zeta0, 0, 0]``, ``beta[0, zeta1, 0]``, and ``beta[0, 0, zeta2]``. The +output still fills the full tensor-product coefficient grid. + +The inline variant is a GPU-parallel kernel over output coefficient indices +that expands the one-dimensional translation sums at each output. The compute +variant uses :func:`loopy.transform.compute.compute` to materialize those axis +sums into private temporaries and reuses them across an ILP tile of the last +output axis. In 2D it tiles ``eta1``; in 3D it tiles ``eta2`` while reusing the +``eta0`` and ``eta1`` axis sums across that register tile. + +Use ``--dimension 2`` or ``--dimension 3`` to choose the kernel. Use +``--compare`` to run both GPU-parallel variants, check against the NumPy +reference implementation, and report timing, modeled GFLOP/s, speedup, and +relative error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la +import pymbolic.primitives as p + +import loopy as lp +import loopy.transform.compute as compute_mod +from loopy.symbolic import DependencyMapper +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def enable_compute_for_reduction_substitutions() -> None: + """Let compute inspect substitution rules whose expressions are reductions.""" + + def gather_vars(expr): + deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif ( + isinstance(dep, p.Subscript) + and isinstance(dep.aggregate, p.Variable)): + var_names.add(dep.aggregate.name) + + return var_names + + compute_mod._gather_vars = gather_vars + + +def translation_weights(h: np.ndarray, order: int) -> np.ndarray: + weights = np.empty((len(h), order + 1), dtype=h.dtype) + weights[:, 0] = 1 + + for axis in range(len(h)): + for n in range(1, order + 1): + weights[axis, n] = weights[axis, n - 1] * h[axis] / n + + return weights + + +def make_axis_compressed_coefficients( + order: int, + dimension: int, + dtype + ) -> np.ndarray: + rng = np.random.default_rng(12) + beta = np.zeros(dimension * (order + 1,), dtype=dtype) + + for axis in range(dimension): + axis_slice = [0] * dimension + axis_slice[axis] = slice(None) + beta[tuple(axis_slice)] = rng.normal(size=order + 1) + + beta[(0,) * dimension] = rng.normal() + + return beta + + +def reference_axis_m2m_2d(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + order = beta.shape[0] - 1 + sigma = np.empty_like(beta) + + for eta0 in range(order + 1): + for eta1 in range(order + 1): + acc = 0 + + for zeta1 in range(eta1 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1 - zeta1] + * beta[0, zeta1] + ) + + for zeta0 in range(eta0 + 1): + acc += ( + weights[0, eta0 - zeta0] + * weights[1, eta1] + * beta[zeta0, 0] + ) + + acc -= weights[0, eta0] * weights[1, eta1] * beta[0, 0] + sigma[eta0, eta1] = acc + + return sigma + + +def reference_axis_m2m_3d(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + order = beta.shape[0] - 1 + sigma = np.empty_like(beta) + + for eta0 in range(order + 1): + for eta1 in range(order + 1): + for eta2 in range(order + 1): + acc = 0 + + for zeta0 in range(eta0 + 1): + acc += ( + weights[0, eta0 - zeta0] + * weights[1, eta1] + * weights[2, eta2] + * beta[zeta0, 0, 0] + ) + + for zeta1 in range(eta1 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1 - zeta1] + * weights[2, eta2] + * beta[0, zeta1, 0] + ) + + for zeta2 in range(eta2 + 1): + acc += ( + weights[0, eta0] + * weights[1, eta1] + * weights[2, eta2 - zeta2] + * beta[0, 0, zeta2] + ) + + acc -= ( + 2 + * weights[0, eta0] + * weights[1, eta1] + * weights[2, eta2] + * beta[0, 0, 0] + ) + sigma[eta0, eta1, eta2] = acc + + return sigma + + +def make_kernel_2d( + order: int, + eta_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % eta_tile_size: + raise ValueError("order + 1 must be divisible by eta_tile_size") + + knl = lp.make_kernel( + "{ [eta0, eta1, zeta0, zeta1] : 0 <= eta0, eta1, zeta0, zeta1 <= p }", + """ + x_axis_sum_(eta0_arg) := sum( + [zeta0], + if( + zeta0 <= eta0_arg, + w[0, eta0_arg - zeta0] * beta[zeta0, 0], + 0 + ) + ) + + y_axis_sum_(eta1_arg) := sum( + [zeta1], + if( + zeta1 <= eta1_arg, + w[1, eta1_arg - zeta1] * beta[0, zeta1], + 0 + ) + ) + + sigma[eta0, eta1] = ( + w[0, eta0] * y_axis_sum_(eta1) + + w[1, eta1] * x_axis_sum_(eta0) + - w[0, eta0] * w[1, eta1] * beta[0, 0] + ) + """, + [ + lp.GlobalArg("beta", dtype=dtype, shape=(order + 1, order + 1)), + lp.GlobalArg("w", dtype=dtype, shape=(2, order + 1)), + lp.GlobalArg("sigma", dtype=dtype, shape=(order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, p=order) + knl = lp.split_iname( + knl, + "eta1", + eta_tile_size, + inner_iname="eta1_inner", + outer_iname="eta1_block", + ) + + if use_compute: + x_axis_sum_map = nisl.make_map(f"""{{ + [eta0_arg] -> [eta0, eta1_block, x_slot] : + eta0_arg = eta0 and x_slot = 0 + }}""") + knl = compute( + knl, + "x_axis_sum_", + compute_map=x_axis_sum_map, + storage_indices=["x_slot"], + temporal_inames=["eta0", "eta1_block"], + temporary_name="x_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_axis_sum_compute", + ) + + y_axis_sum_map = nisl.make_map(f"""{{ + [eta1_arg] -> [eta0, eta1_block, y_slot] : + eta1_arg = eta1_block * {eta_tile_size} + y_slot + }}""") + knl = compute( + knl, + "y_axis_sum_", + compute_map=y_axis_sum_map, + storage_indices=["y_slot"], + temporal_inames=["eta0", "eta1_block"], + temporary_name="y_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_axis_sum_compute", + ) + knl = lp.tag_inames(knl, { + "x_slot": "unr", + "y_slot": "unr", + }) + + knl = lp.tag_inames(knl, { + "eta0": "g.1", + "eta1_block": "g.0", + "eta1_inner": "ilp", + }) + return knl + + +def make_kernel_3d( + order: int, + eta_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % eta_tile_size: + raise ValueError("order + 1 must be divisible by eta_tile_size") + + knl = lp.make_kernel( + """ + { [eta0, eta1, eta2, zeta0, zeta1, zeta2] : + 0 <= eta0, eta1, eta2, zeta0, zeta1, zeta2 <= p } + """, + """ + x_axis_sum_(eta0_arg) := sum( + [zeta0], + if( + zeta0 <= eta0_arg, + w[0, eta0_arg - zeta0] * beta[zeta0, 0, 0], + 0 + ) + ) + + y_axis_sum_(eta1_arg) := sum( + [zeta1], + if( + zeta1 <= eta1_arg, + w[1, eta1_arg - zeta1] * beta[0, zeta1, 0], + 0 + ) + ) + + z_axis_sum_(eta2_arg) := sum( + [zeta2], + if( + zeta2 <= eta2_arg, + w[2, eta2_arg - zeta2] * beta[0, 0, zeta2], + 0 + ) + ) + + sigma[eta0, eta1, eta2] = ( + w[1, eta1] * w[2, eta2] * x_axis_sum_(eta0) + + w[0, eta0] * w[2, eta2] * y_axis_sum_(eta1) + + w[0, eta0] * w[1, eta1] * z_axis_sum_(eta2) + - 2 * w[0, eta0] * w[1, eta1] * w[2, eta2] * beta[0, 0, 0] + ) + """, + [ + lp.GlobalArg( + "beta", dtype=dtype, shape=(order + 1, order + 1, order + 1)), + lp.GlobalArg("w", dtype=dtype, shape=(3, order + 1)), + lp.GlobalArg( + "sigma", dtype=dtype, + shape=(order + 1, order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, p=order) + knl = lp.split_iname( + knl, + "eta2", + eta_tile_size, + inner_iname="eta2_inner", + outer_iname="eta2_block", + ) + + if use_compute: + x_axis_sum_map = nisl.make_map(""" + { + [eta0_arg] -> [eta0, eta1, eta2_block, x_slot] : + eta0_arg = eta0 and x_slot = 0 + } + """) + knl = compute( + knl, + "x_axis_sum_", + compute_map=x_axis_sum_map, + storage_indices=["x_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="x_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_axis_sum_compute", + ) + + y_axis_sum_map = nisl.make_map(""" + { + [eta1_arg] -> [eta0, eta1, eta2_block, y_slot] : + eta1_arg = eta1 and y_slot = 0 + } + """) + knl = compute( + knl, + "y_axis_sum_", + compute_map=y_axis_sum_map, + storage_indices=["y_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="y_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_axis_sum_compute", + ) + + z_axis_sum_map = nisl.make_map(f"""{{ + [eta2_arg] -> [eta0, eta1, eta2_block, z_slot] : + eta2_arg = eta2_block * {eta_tile_size} + z_slot + }}""") + knl = compute( + knl, + "z_axis_sum_", + compute_map=z_axis_sum_map, + storage_indices=["z_slot"], + temporal_inames=["eta0", "eta1", "eta2_block"], + temporary_name="z_axis_sum_vec", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="z_axis_sum_compute", + ) + knl = lp.tag_inames(knl, { + "x_slot": "unr", + "y_slot": "unr", + "z_slot": "unr", + }) + + knl = lp.tag_inames(knl, { + "eta0": "g.2", + "eta1": "g.1", + "eta2_block": "g.0", + "eta2_inner": "ilp", + }) + return knl + + +def make_kernel( + order: int, + eta_tile_size: int, + dimension: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if dimension == 2: + return make_kernel_2d(order, eta_tile_size, dtype, use_compute) + if dimension == 3: + return make_kernel_3d(order, eta_tile_size, dtype, use_compute) + raise ValueError("dimension must be 2 or 3") + + +def reference_axis_m2m(beta: np.ndarray, weights: np.ndarray) -> np.ndarray: + dimension = beta.ndim + if dimension == 2: + return reference_axis_m2m_2d(beta, weights) + if dimension == 3: + return reference_axis_m2m_3d(beta, weights) + raise ValueError("dimension must be 2 or 3") + + +def operation_model( + order: int, + eta_tile_size: int, + dimension: int + ) -> tuple[int, int]: + ncoeff = order + 1 + if dimension == 2: + inline_sum_terms = 2 * ncoeff**3 + tiled_compute_sum_terms = ncoeff**3 + ncoeff**3 // eta_tile_size + elif dimension == 3: + inline_sum_terms = 3 * ncoeff**4 + tiled_compute_sum_terms = ncoeff**4 + 2 * ncoeff**4 // eta_tile_size + else: + raise ValueError("dimension must be 2 or 3") + return inline_sum_terms, tiled_compute_sum_terms + + +def m2m_flop_count( + order: int, + eta_tile_size: int, + dimension: int, + use_compute: bool + ) -> int: + ncoeff = order + 1 + + if dimension == 2: + if use_compute: + sum_flops = 2 * ncoeff**3 + 2 * ncoeff**3 // eta_tile_size + else: + sum_flops = 4 * ncoeff**3 + correction_flops = 3 * ncoeff**2 + elif dimension == 3: + if use_compute: + sum_flops = 2 * ncoeff**4 + 4 * ncoeff**4 // eta_tile_size + else: + sum_flops = 6 * ncoeff**4 + correction_flops = 8 * ncoeff**3 + else: + raise ValueError("dimension must be 2 or 3") + + return sum_flops + correction_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + beta: np.ndarray, + weights: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + beta_cl = cl_array.to_device(queue, beta) + weights_cl = cl_array.to_device(queue, weights) + sigma_cl = cl_array.zeros(queue, beta.shape, dtype=beta.dtype) + + elapsed = benchmark_executor( + ex, + queue, + {"beta": beta_cl, "w": weights_cl, "sigma": sigma_cl}, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex(queue, beta=beta_cl, w=weights_cl, sigma=sigma_cl) + return out[0].get(), elapsed + + +def main( + order: int = 16, + eta_tile_size: int = 8, + dimension: int = 2, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + if order < 0: + raise ValueError("order must be nonnegative") + if dimension not in (2, 3): + raise ValueError("dimension must be 2 or 3") + + dtype = np.float64 + h = np.array([0.25, -0.2, 0.15][:dimension], dtype=dtype) + weights = translation_weights(h, order) + beta = make_axis_compressed_coefficients(order, dimension, dtype) + reference = reference_axis_m2m(beta, weights) + + inline_terms, compute_terms = operation_model( + order, eta_tile_size, dimension) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + order, eta_tile_size, dimension, dtype, + use_compute=variant_uses_compute) + modeled_flops = m2m_flop_count( + order, eta_tile_size, dimension, + use_compute=variant_uses_compute) + + print(20 * "=", "Compressed M2M report", 20 * "=") + print(f"Variant: {'compute sum-factorized' if variant_uses_compute else 'inline'}") + print(f"Dimension: {dimension}D") + print(f"Order : {order}") + print(f"Eta tile: {eta_tile_size}") + if dimension == 2: + print("Stored compressed set: zeta0 = 0 or zeta1 = 0") + else: + print("Stored compressed set: exactly one zeta axis may be nonzero") + print(f"Inline sum terms : {inline_terms}") + print(f"Tiled compute sum terms: {compute_terms}") + print(f"Modeled flop count : {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, beta, weights, warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" Compressed M2M report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--order", action="store", type=int, default=16) + _ = parser.add_argument("--eta-tile-size", action="store", type=int, default=8) + _ = parser.add_argument("--dimension", action="store", type=int, choices=(2, 3), + default=2) + _ = parser.add_argument("--dim", action="store", type=int, choices=(2, 3), + dest="dimension") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + order=args.order, + eta_tile_size=args.eta_tile_size, + dimension=args.dimension, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/matmul.py b/examples/python/compute-examples/matmul.py index 39acb9ef2..45af6f431 100644 --- a/examples/python/compute-examples/matmul.py +++ b/examples/python/compute-examples/matmul.py @@ -318,7 +318,7 @@ def main( tn: int = 4, shared_memory_tiled: bool = False, register_tiled: bool = False, - dtype=np.float32, + dtype: lp.ToLoopyTypeConvertible = np.float32, print_kernel: bool = False, print_device_code: bool = False ) -> None: diff --git a/examples/python/compute-examples/p2m-basis-compute.py b/examples/python/compute-examples/p2m-basis-compute.py new file mode 100644 index 000000000..2b4b1ab82 --- /dev/null +++ b/examples/python/compute-examples/p2m-basis-compute.py @@ -0,0 +1,369 @@ +"""Benchmark a 2D Cartesian Taylor P2M microkernel with Loopy compute. + +FMM kernel class: kernel-independent Cartesian Taylor/asymptotic multipole +moment formation. This script does not evaluate a particular Green's function +such as the 2D Laplace or Helmholtz kernel. It builds the source monomial +moments that a Taylor FMM would later pair with kernel derivatives or translated +coefficients. In a Laplace FMM, the Laplace-specific derivative recurrence or +compressed representation lives outside this benchmark. + +The kernel forms tensor-product source moments from particle strengths: + + beta[q0, q1] = sum_{isrc} strength[isrc] + * x[isrc]**q0 / q0! + * y[isrc]**q1 / q1! + +The inline variant is a GPU-parallel reduction over sources for every output +coefficient. The compute variant splits the source and q1 loops and uses +:func:`loopy.transform.compute.compute` to precompute reusable x and y monomial +basis values in private temporaries. This tests whether compute can expose +source-tile and coefficient-tile reuse in a reduction-heavy P2M-like kernel. + +Use ``--compare`` to run both GPU-parallel variants, compare with the NumPy +reference result, and print timing, modeled GFLOP/s, speedup, and relative +error. +""" + +import os +import time + +os.environ.setdefault("XDG_CACHE_HOME", "/tmp") + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def inv_factorials(order: int, dtype) -> np.ndarray: + result = np.empty(order + 1, dtype=dtype) + result[0] = 1 + for i in range(1, order + 1): + result[i] = result[i - 1] / i + return result + + +def reference_p2m( + strength: np.ndarray, + x: np.ndarray, + y: np.ndarray, + inv_fact: np.ndarray + ) -> np.ndarray: + order = inv_fact.size - 1 + beta = np.empty((order + 1, order + 1), dtype=x.dtype) + + for q0 in range(order + 1): + for q1 in range(order + 1): + acc = 0 + for isrc in range(x.size): + x_monom = x[isrc]**q0 * inv_fact[q0] + y_monom = y[isrc]**q1 * inv_fact[q1] + acc += strength[isrc] * x_monom * y_monom + beta[q0, q1] = acc + + return beta + + +def make_kernel( + nsources: int, + order: int, + q1_tile_size: int, + source_tile_size: int, + dtype, + use_compute: bool = False + ) -> lp.TranslationUnit: + if (order + 1) % q1_tile_size: + raise ValueError("order + 1 must be divisible by q1_tile_size") + if nsources % source_tile_size: + raise ValueError("nsources must be divisible by source_tile_size") + + knl = lp.make_kernel( + "{ [isrc, q0, q1] : 0 <= isrc < nsources and 0 <= q0, q1 <= p }", + """ + x_monom_(isrc_arg, q0_arg) := ( + x[isrc_arg] ** q0_arg * inv_fact[q0_arg] + ) + + y_monom_(isrc_arg, q1_arg) := ( + y[isrc_arg] ** q1_arg * inv_fact[q1_arg] + ) + + beta[q0, q1] = sum( + [isrc], + strength[isrc] * x_monom_(isrc, q0) * y_monom_(isrc, q1) + ) + """, + [ + lp.GlobalArg("x", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("y", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("strength", dtype=dtype, shape=(nsources,)), + lp.GlobalArg("inv_fact", dtype=dtype, shape=(order + 1,)), + lp.GlobalArg("beta", dtype=dtype, shape=(order + 1, order + 1), + is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, nsources=nsources, p=order) + knl = lp.split_iname( + knl, + "q1", + q1_tile_size, + inner_iname="q1_inner", + outer_iname="q1_outer", + ) + knl = lp.split_iname( + knl, + "isrc", + source_tile_size, + inner_iname="isrc_inner", + outer_iname="isrc_outer", + ) + + if use_compute: + x_monom_map = nisl.make_map(f"""{{ + [isrc_arg, q0_arg] -> [q0, q1_outer, isrc_outer, isrc_s] : + isrc_arg = isrc_outer * {source_tile_size} + isrc_s and + q0_arg = q0 + }}""") + + knl = compute( + knl, + "x_monom_", + compute_map=x_monom_map, + storage_indices=["isrc_s"], + temporal_inames=["q0", "q1_outer", "isrc_outer"], + temporary_name="x_monom_for_q1_tile", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="x_monom_compute", + ) + + y_monom_map = nisl.make_map(f"""{{ + [isrc_arg, q1_arg] -> [q0, q1_outer, q1_inner, isrc_outer, isrc_s] : + isrc_arg = isrc_outer * {source_tile_size} + isrc_s and + q1_arg = q1_outer * {q1_tile_size} + q1_inner + }}""") + + knl = compute( + knl, + "y_monom_", + compute_map=y_monom_map, + storage_indices=["isrc_s"], + temporal_inames=["q0", "q1_outer", "q1_inner", "isrc_outer"], + temporary_name="y_monom_for_coeff", + temporary_address_space=lp.AddressSpace.PRIVATE, + temporary_dtype=dtype, + compute_insn_id="y_monom_compute", + ) + + return lp.tag_inames(knl, { + "q0": "g.1", + "q1_outer": "g.0", + "q1_inner": "ilp", + }) + + +def operation_model( + nsources: int, + order: int, + q1_tile_size: int + ) -> tuple[int, int]: + ncoeff = order + 1 + inline_monomial_evals = 2 * nsources * ncoeff**2 + compute_monomial_evals = ( + nsources * ncoeff**2 + + nsources * ncoeff**2 // q1_tile_size + ) + return inline_monomial_evals, compute_monomial_evals + + +def p2m_flop_count( + nsources: int, + order: int, + q1_tile_size: int, + use_compute: bool + ) -> int: + ncoeff = order + 1 + + contraction_flops = 3 * nsources * ncoeff**2 + if use_compute: + monomial_scale_flops = ( + nsources * ncoeff**2 + + nsources * ncoeff**2 // q1_tile_size + ) + else: + monomial_scale_flops = 2 * nsources * ncoeff**2 + + return contraction_flops + monomial_scale_flops + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def run_kernel( + knl: lp.TranslationUnit, + x: np.ndarray, + y: np.ndarray, + strength: np.ndarray, + inv_fact: np.ndarray, + warmup: int, + iterations: int + ) -> tuple[np.ndarray, float]: + import pyopencl as cl + import pyopencl.array as cl_array + + ctx = cl.create_some_context(interactive=False) + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + x_cl = cl_array.to_device(queue, x) + y_cl = cl_array.to_device(queue, y) + strength_cl = cl_array.to_device(queue, strength) + inv_fact_cl = cl_array.to_device(queue, inv_fact) + beta_cl = cl_array.zeros(queue, (inv_fact.size, inv_fact.size), + dtype=x.dtype) + + elapsed = benchmark_executor( + ex, + queue, + { + "x": x_cl, + "y": y_cl, + "strength": strength_cl, + "inv_fact": inv_fact_cl, + "beta": beta_cl, + }, + warmup=warmup, + iterations=iterations, + ) + + _, out = ex( + queue, x=x_cl, y=y_cl, strength=strength_cl, + inv_fact=inv_fact_cl, beta=beta_cl) + return out[0].get(), elapsed + + +def main( + nsources: int = 256, + order: int = 12, + q1_tile_size: int = 13, + source_tile_size: int = 128, + use_compute: bool = False, + compare: bool = False, + print_kernel: bool = False, + print_device_code: bool = False, + run: bool = False, + warmup: int = 3, + iterations: int = 10 + ) -> None: + dtype = np.float64 + rng = np.random.default_rng(18) + x = rng.uniform(-0.25, 0.25, size=nsources).astype(dtype) + y = rng.uniform(-0.25, 0.25, size=nsources).astype(dtype) + strength = rng.normal(size=nsources).astype(dtype) + inv_fact = inv_factorials(order, dtype) + reference = reference_p2m(strength, x, y, inv_fact) + + inline_evals, compute_evals = operation_model( + nsources, order, q1_tile_size) + + variants = [False, True] if compare else [use_compute] + timings: dict[bool, float] = {} + for variant_uses_compute in variants: + knl = make_kernel( + nsources, order, q1_tile_size, source_tile_size, dtype, + use_compute=variant_uses_compute) + modeled_flops = p2m_flop_count( + nsources, order, q1_tile_size, + use_compute=variant_uses_compute) + + print(20 * "=", "P2M basis report", 20 * "=") + print(f"Variant: {'compute' if variant_uses_compute else 'inline'}") + print(f"Sources: {nsources}") + print(f"Order : {order}") + print(f"q1 tile: {q1_tile_size}") + print(f"Source tile: {source_tile_size}") + print(f"Inline monomial evaluations: {inline_evals}") + print(f"Compute monomial evaluations: {compute_evals}") + print(f"Modeled flop count: {modeled_flops}") + + if print_kernel: + print(knl) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if run or compare: + try: + result, elapsed = run_kernel( + knl, x, y, strength, inv_fact, + warmup=warmup, iterations=iterations) + except Exception as exc: + print(f"Runtime execution unavailable: {exc}") + else: + rel_err = la.norm(result - reference) / la.norm(reference) + gflops = modeled_flops / elapsed * 1e-9 + timings[variant_uses_compute] = elapsed + print(f"Average time per iteration: {elapsed:.6e} s") + print(f"Modeled throughput: {gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + + print((40 + len(" P2M basis report ")) * "=") + + if compare and False in timings and True in timings: + speedup = timings[False] / timings[True] + time_reduction = (1 - timings[True] / timings[False]) * 100 + print(f"Speedup: {speedup:.3f}x") + print(f"Relative time reduction: {time_reduction:.2f}%") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _ = parser.add_argument("--nsources", action="store", type=int, default=256) + _ = parser.add_argument("--order", action="store", type=int, default=12) + _ = parser.add_argument("--q1-tile-size", action="store", type=int, default=13) + _ = parser.add_argument("--source-tile-size", action="store", + type=int, default=128) + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + main( + nsources=args.nsources, + order=args.order, + q1_tile_size=args.q1_tile_size, + source_tile_size=args.source_tile_size, + use_compute=args.compute, + compare=args.compare, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 6d7536b87..7be2a9f26 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -7,7 +7,6 @@ import islpy as isl import pymbolic.primitives as p from pymbolic import var -from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.substitutor import make_subst_func from pytools.tag import Tag @@ -15,6 +14,7 @@ from loopy.kernel.tools import DomainChanger from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( + DependencyMapper, ExpansionState, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, From 35b78d6d72acbcfbcc55bb29ea2d19591ec328fa Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 17 Apr 2026 18:25:57 -0500 Subject: [PATCH 23/27] add compute examples driver; entire refactor on compute for clarity --- .../compute-examples/run-compute-examples.sh | 85 ++ loopy/transform/compute-old.py | 716 ++++++++++ loopy/transform/compute.py | 1200 ++++++++++------- test/test_compute.py | 101 ++ 4 files changed, 1637 insertions(+), 465 deletions(-) create mode 100755 examples/python/compute-examples/run-compute-examples.sh create mode 100644 loopy/transform/compute-old.py create mode 100644 test/test_compute.py diff --git a/examples/python/compute-examples/run-compute-examples.sh b/examples/python/compute-examples/run-compute-examples.sh new file mode 100755 index 000000000..d9c8aab93 --- /dev/null +++ b/examples/python/compute-examples/run-compute-examples.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +set -euo pipefail + +PYTHON="${PYTHON:-/home/aj/miniforge3/envs/dev/bin/python}" +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" + +cd "$SCRIPT_DIR" + +run_example() { + echo + echo "===== $* =====" + "$PYTHON" "$@" +} + +run_example finite-difference-2-5D.py \ + --npoints 96 \ + --stencil-width 9 \ + --compute \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example wave-equation-ring-buffer.py \ + --ntime 4096 \ + --compute \ + --run-kernel \ + --warmup 2 \ + --iterations 5 + +run_example matmul.py \ + --m 512 \ + --n 512 \ + --k 512 \ + --bm 32 \ + --bn 32 \ + --bk 16 \ + --shared-memory-tiled + +run_example matmul.py \ + --m 512 \ + --n 512 \ + --k 512 \ + --bm 64 \ + --bn 64 \ + --bk 16 \ + --tm 4 \ + --tn 4 \ + --register-tiled + +run_example l2p-tiled-basis-compute.py \ + --ntargets 512 \ + --order 12 \ + --target-block-size 64 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example p2m-basis-compute.py \ + --nsources 512 \ + --order 12 \ + --q1-tile-size 13 \ + --source-tile-size 128 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example m2m-sum-factorization.py \ + --order 23 \ + --eta-tile-size 8 \ + --dimension 3 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + +run_example l2p-3d-tensor-product-compute.py \ + --ntargets 512 \ + --order 8 \ + --target-block-size 64 \ + --compare \ + --run-kernel \ + --warmup 2 \ + --iterations 3 diff --git a/loopy/transform/compute-old.py b/loopy/transform/compute-old.py new file mode 100644 index 000000000..7be2a9f26 --- /dev/null +++ b/loopy/transform/compute-old.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias, override + +import namedisl as nisl + +import islpy as isl +import pymbolic.primitives as p +from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func +from pytools.tag import Tag + +import loopy as lp +from loopy.kernel.tools import DomainChanger +from loopy.match import StackMatch, parse_stack_match +from loopy.symbolic import ( + DependencyMapper, + ExpansionState, + RuleAwareIdentityMapper, + RuleAwareSubstitutionMapper, + SubstitutionRuleExpander, + SubstitutionRuleMappingContext, + multi_pw_aff_from_exprs, + pw_aff_to_expr, +) +from loopy.transform.precompute import contains_a_subst_rule_invocation +from loopy.translation_unit import for_each_kernel +from loopy.types import ToLoopyTypeConvertible, to_loopy_type + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence, Set + + from pymbolic.typing import Expression + + from loopy.kernel import LoopKernel + from loopy.kernel.data import AddressSpace + + +AccessTuple: TypeAlias = tuple[str, ...] + + +def _access_key(args: Sequence[Expression]) -> AccessTuple: + return tuple(str(arg) for arg in args) + + +def _base_name(name: str) -> str: + return name.removesuffix("_") + + +def _cur_name(name: str) -> str: + return f"{_base_name(name)}_cur" + + +def _prev_name(name: str) -> str: + return f"{_base_name(name)}_prev" + + +def _basic_set_to_predicates(bset: nisl.BasicSet) -> frozenset[Expression]: + isl_bset = bset._reconstruct_isl_object() + + predicates = [] + for constraint in isl_bset.get_constraints(): + expr = pw_aff_to_expr(constraint.get_aff()) + if constraint.is_equality(): + predicates.append(p.Comparison(expr, "==", 0)) + else: + predicates.append(p.Comparison(expr, ">=", 0)) + + return frozenset(predicates) + + +def _set_to_predicate_options( + set_: nisl.Set | nisl.BasicSet + ) -> Sequence[frozenset[Expression]]: + if isinstance(set_, nisl.BasicSet): + if set_._reconstruct_isl_object().is_empty(): + return [] + return [_basic_set_to_predicates(set_)] + + predicate_options = [] + for bset in set_.get_basic_sets(): + if not bset._reconstruct_isl_object().is_empty(): + predicate_options.append(_basic_set_to_predicates(bset)) + + return predicate_options + + +# helper for gathering names of variables in pymbolic expressions +def _gather_vars(expr: Expression) -> set[str]: + deps = DependencyMapper()(expr) + var_names = set() + for dep in deps: + if isinstance(dep, p.Variable): + var_names.add(dep.name) + elif ( + isinstance(dep, p.Subscript) + and isinstance(dep.aggregate, p.Variable)): + var_names.add(dep.aggregate.name) + + return var_names + + +def _existing_name_mapping( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str] + ) -> Mapping[str, str]: + names = map_.names + return { + source: target + for source, target in name_mapping.items() + if source in names and target in names + } + + +def _normalize_renamed_dims( + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], + ) -> nisl.Map | nisl.BasicMap: + map_ = map_.equate_dims(_existing_name_mapping(map_, name_mapping)) + + names = map_.names + project_names = [ + renamed_name + for original_name, renamed_name in name_mapping.items() + if original_name in names and renamed_name in names + ] + map_ = map_.project_out(project_names) + + names = map_.names + rename_mapping = { + renamed_name: original_name + for original_name, renamed_name in name_mapping.items() + if original_name not in names and renamed_name in names + } + return map_.rename_dims(rename_mapping) + + +# {{{ gathering usage expressions + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + """ + Gathers all expressions used as inputs to a particular substitution rule, + identified by name. + """ + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_expander: SubstitutionRuleExpander, + kernel: LoopKernel, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None + ) -> None: + + super().__init__(rule_mapping_ctx) + + self.subst_expander: SubstitutionRuleExpander = subst_expander + self.kernel: LoopKernel = kernel + self.subst_name: str = subst_name + self.subst_tag: Set[Tag] | None = ( + {subst_tag} if isinstance(subst_tag, Tag) else subst_tag + ) + + self.usage_expressions: list[Sequence[Expression]] = [] + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + + if name != self.subst_name: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + if self.subst_tag is not None and self.subst_tag != tags: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.usage_expressions.append([ + arg_ctx[arg_name] for arg_name in rule.arguments + ]) + + return 0 + +# }}} + + +# {{{ substitution rule use replacement + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_ids: str | Sequence[str], + footprint: nisl.Set + ) -> None: + + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: Sequence[Tag] | None = subst_tag + + self.usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = \ + usage_descriptors + self.storage_indices: Sequence[str] = storage_indices + self.footprint: nisl.Set = footprint + + self.temporary_name: str = temporary_name + self.compute_insn_ids: frozenset[str] = ( + frozenset([compute_insn_ids]) + if isinstance(compute_insn_ids, str) + else frozenset(compute_insn_ids) + ) + + self.replaced_something: bool = False + + # FIXME: may not always be the case (i.e. global barrier between + # compute insn and uses) + self.compute_dep_ids: frozenset[str] = self.compute_insn_ids + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState + ) -> Expression: + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + args = [arg_ctx[arg_name] for arg_name in rule.arguments] + + # {{{ validation checks + + if name != self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + access_key = _access_key(args) + if access_key not in self.usage_descriptors: + return super().map_subst_rule(name, tags, arguments, expn_state) + + if len(arguments) != len(rule.arguments): + raise ValueError( + f"Number of arguments passed to rule {name} " + f"does not match the signature of {name}." + ) + + local_map = self.usage_descriptors[access_key] + temp_footprint = self.footprint.move_dims( + frozenset(self.footprint.names) - frozenset(self.storage_indices), + isl.dim_type.param + ) + + if not local_map.range() <= temp_footprint: + return super().map_subst_rule(name, tags, arguments, expn_state) + + # }}} + + # {{{ get index expression in terms of global inames + + local_pwmaff = self.usage_descriptors[access_key].as_pw_multi_aff() + + index_exprs: Sequence[Expression] = [] + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + + new_expression = var(self.temporary_name)[tuple(index_exprs)] + + # }}} + + self.replaced_something = True + + return new_expression + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: + + new_insns: Sequence[lp.InstructionBase] = [] + for insn in kernel.instructions: + self.replaced_something = False + + if (isinstance(insn, lp.MultiAssignmentBase) and not + contains_a_subst_rule_invocation(kernel, insn)): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr, insn=insn: self(expr, kernel, insn) + ) + + if self.replaced_something: + insn = insn.copy( + depends_on=( + insn.depends_on | self.compute_dep_ids + ) + ) + + # FIXME: determine compute insn dependencies + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + +# }}} + + +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: nisl.Map, + + storage_indices: Sequence[str], + + # FIXME: can these two be deduced? + temporal_inames: Sequence[str], + inames_to_advance: Sequence[str] | None = None, + + temporary_name: str | None = None, + temporary_address_space: AddressSpace | None = None, + + temporary_dtype: ToLoopyTypeConvertible = None, + + compute_insn_id: str | None = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + inserted compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". + + :arg storage_indices: An ordered sequence of names of storage indices. Used + to create inames for the loops that cover the required set of compute points. + """ + + name_mapping = { + name: name + "_" + for name in compute_map.output_names + if name not in storage_indices + } + compute_map = compute_map.rename_dims(name_mapping) + + # {{{ setup and useful variables + + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + expander = SubstitutionRuleExpander(kernel.substitutions) + expr_gatherer = UsageSiteExpressionGatherer( + ctx, expander, kernel, substitution, None + ) + + _ = expr_gatherer.map_kernel(kernel) + usage_exprs = expr_gatherer.usage_expressions + + all_exprs = [expr for usage in usage_exprs for expr in usage] + usage_inames: frozenset[str] = frozenset( + set.union(*(_gather_vars(expr) for expr in all_exprs)) + ) + + # }}} + + # {{{ construct necessary pieces; footprint, global usage map + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) + + # restrict domain to used inames + local_domain = named_domain.project_out_except(usage_inames) + + # FIXME: gross. find a cleaner way to generate a space for an empty map + global_usage_map = nisl.make_map_from_domain_and_range( + nisl.make_set(isl.Set.empty(local_domain.get_space())), + compute_map.domain() + ) + global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) + + usage_substs: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = {} + for usage in usage_exprs: + + # {{{ compute local usage map, update global usage map + + local_usage_mpwaff = multi_pw_aff_from_exprs( + usage, + global_usage_map.get_space() + ) + + local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) + + local_usage_map = local_usage_map.intersect_domain(local_domain) + global_usage_map = global_usage_map | local_usage_map + + # }}} + + # {{{ compute storage map + + local_storage_map = local_usage_map.apply_range(compute_map) + local_storage_map = _normalize_renamed_dims( + local_storage_map, name_mapping) + + # check that no restrictions happened during composition (i.e. tile + # valid for a single point in the domain) + if not local_usage_map.domain() <= local_storage_map.domain(): + continue + + # clean up names + non_param_names = (usage_inames - temporal_set) | storage_set + parameter_names = frozenset(local_storage_map.names) - non_param_names + local_storage_map = local_storage_map.move_dims(parameter_names, + isl.dim_type.param) + + # }}} + + usage_substs[_access_key(usage)] = local_storage_map + + storage_map = global_usage_map.apply_range(compute_map) + storage_map = _normalize_renamed_dims(storage_map, name_mapping) + + # }}} + + # {{{ compute bounds and update kernel domain + + storage_map = storage_map.move_dims(temporal_set, isl.dim_type.param) + footprint = storage_map.range() + + # clean up ticked duplicate names + footprint = footprint.project_out_except(temporal_set | storage_set) + footprint = footprint.move_dims(temporal_set, isl.dim_type.set) + + # {{{ FIXME: use Sets instead of BasicSets when loopy is ready + + # FIXME: convex hull is not permanent + footprint_isl = footprint._reconstruct_isl_object() + footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) + named_domain = named_domain & footprint + + if len(named_domain.get_basic_sets()) != 1: + raise ValueError("New domain should be composed of a single basic set") + + # FIXME: use named object once loopy is name-ified + domain = named_domain.get_basic_sets()[0]._reconstruct_isl_object() + new_domains = domain_changer.get_domains_with(domain) + + # }}} + + kernel = kernel.copy(domains=new_domains) + + # }}} + + if not temporary_name: + temporary_name = substitution + "_temp" + + if not compute_insn_id: + compute_insn_id = substitution + "_compute" + + # {{{ reuse analysis + + update_insns: list[lp.InstructionBase] = [] + update_insn_ids: list[str] = [] + refill_predicate_options: Sequence[frozenset[Expression] | None] = [None] + current_update_deps: frozenset[str] = frozenset() + + if inames_to_advance is not None: + advancing_set = frozenset(inames_to_advance) + + compute_map_cur = compute_map.rename_dims({ + name: _cur_name(name) for name in compute_map.output_names + }) + compute_map_prev = compute_map.rename_dims({ + name: _prev_name(name) for name in compute_map.output_names + }) + + cur_storage = global_usage_map.apply_range(compute_map_cur) + prev_storage = global_usage_map.apply_range(compute_map_prev) + + reuse_map = prev_storage.reverse().apply_range(cur_storage) + reuse_map = reuse_map.add_constraint([ + ( + f"{name}_cur = {name}_prev + 1" + if name in advancing_set + else + f"{name}_cur = {name}_prev" + ) + for name in temporal_inames + ]) + + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + previous_footprint = footprint.rename_dims({ + name: _prev_name(name) for name in footprint.names + }) + + reuse_map = reuse_map.intersect_domain(previous_footprint) + reuse_map = reuse_map.intersect_range(current_footprint) + reuse_map = reuse_map - nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in footprint.names) + + "] -> [" + + ", ".join(_cur_name(name) for name in footprint.names) + + "] : " + + " and ".join( + f"{_cur_name(name)} = {_prev_name(name)}" + for name in storage_indices + ) + + " }" + ) + + reused_current = reuse_map.range() + refill = current_footprint - reused_current + + cur_to_normal = { + _cur_name(name): name + for name in footprint.names + } + reused_current = reused_current.rename_dims(cur_to_normal) + refill = refill.rename_dims(cur_to_normal) + + reused_context = named_domain.project_out_except(reused_current.names) + refill_context = named_domain.project_out_except(refill.names) + + reused_current = reused_current.gist(reused_context) + refill = refill.gist(refill_context) + + refill_predicate_options = _set_to_predicate_options(refill) + + storage_reuse_map = reuse_map.project_out_except( + frozenset(_prev_name(name) for name in storage_indices) + | frozenset(_cur_name(name) for name in storage_indices) + ) + storage_reuse_map = storage_reuse_map.rename_dims({ + _cur_name(name): name + for name in storage_indices + }) + cur_to_prev = storage_reuse_map.reverse() + cur_to_prev_pwma = cur_to_prev.as_pw_multi_aff() + prev_expr_by_name = { + cur_to_prev_pwma.get_dim_name(isl.dim_type.out, dim): + pw_aff_to_expr(cur_to_prev_pwma.get_at(dim)) + for dim in range(cur_to_prev_pwma.dim(isl.dim_type.out)) + } + prev_storage_exprs = [ + prev_expr_by_name[_prev_name(name)] + for name in storage_indices + ] + + shift_assignee = var(temporary_name)[ + tuple(var(idx) for idx in storage_indices) + ] + shift_expression = var(temporary_name)[tuple(prev_storage_exprs)] + + shift_predicate_options = _set_to_predicate_options(reused_current) + for i, predicates in enumerate(shift_predicate_options): + shift_insn_id = ( + f"{compute_insn_id}_shift" + if len(shift_predicate_options) == 1 + else f"{compute_insn_id}_shift_{i}" + ) + update_insns.append(lp.Assignment( + id=shift_insn_id, + assignee=shift_assignee, + expression=shift_expression, + within_inames=frozenset(temporal_inames) | storage_set, + predicates=predicates, + depends_on=current_update_deps, + )) + update_insn_ids.append(shift_insn_id) + current_update_deps = frozenset([shift_insn_id]) + + # }}} + + # {{{ create compute instruction in kernel + + # FIXME: maybe just keep original around? + compute_map = compute_map.rename_dims({ + value: key for key, value in name_mapping.items() + }) + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + storage_ax_to_global_expr = { + compute_pw_aff.get_dim_name(isl.dim_type.out, dim): + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + } + + expr_subst_map = RuleAwareSubstitutionMapper( + ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map( + subst_expr, + kernel, + None, + ) + compute_dep_ids = frozenset().union(*( + kernel.writer_map().get(var_name, frozenset()) + for var_name in _gather_vars(compute_expression) + )) + + assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + + within_inames = compute_map.output_names + + new_insns = list(kernel.instructions) + new_insns.extend(update_insns) + + for i, predicates in enumerate(refill_predicate_options): + refill_insn_id = ( + compute_insn_id + if len(refill_predicate_options) == 1 + else f"{compute_insn_id}_refill_{i}" + ) + compute_insn = lp.Assignment( + id=refill_insn_id, + assignee=assignee, + expression=compute_expression, + within_inames=within_inames, + predicates=predicates, + depends_on=current_update_deps | compute_dep_ids, + ) + new_insns.append(compute_insn) + update_insn_ids.append(refill_insn_id) + current_update_deps = frozenset([refill_insn_id]) + + kernel = kernel.copy(instructions=new_insns) + + # }}} + + # {{{ replace invocations with new compute instruction + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + + replacer = RuleInvocationReplacer( + ctx, + substitution, + None, + usage_substs, + storage_indices, + temporary_name, + update_insn_ids, + footprint + ) + + kernel = replacer.map_kernel(kernel) + + # }}} + + # {{{ create temporary variable for result of compute + + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) + + temp_shape = tuple( + pw_aff_to_expr(footprint.dim_max(dim)) + 1 + for dim in storage_indices + ) + + new_temp_vars = dict(kernel.temporary_variables) + + # FIXME: temp_var might already exist, handle the case where it does + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=(0,)*len(temp_shape), + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices) + ) + + new_temp_vars[temporary_name] = temp_var + + kernel = kernel.copy( + temporary_variables=new_temp_vars + ) + + # }}} + + return kernel diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 7be2a9f26..fd839d549 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,47 +1,106 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeAlias, override +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, TypeAlias, cast import namedisl as nisl +from typing_extensions import override import islpy as isl import pymbolic.primitives as p from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func +from pymbolic.typing import Expression from pytools.tag import Tag -import loopy as lp -from loopy.kernel.tools import DomainChanger -from loopy.match import StackMatch, parse_stack_match -from loopy.symbolic import ( +from ..kernel.data import TemporaryVariable +from ..kernel.instruction import Assignment, InstructionBase, MultiAssignmentBase +from ..kernel.tools import DomainChanger +from ..match import StackMatch, parse_stack_match +from ..symbolic import ( DependencyMapper, ExpansionState, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, - SubstitutionRuleExpander, SubstitutionRuleMappingContext, multi_pw_aff_from_exprs, pw_aff_to_expr, ) -from loopy.transform.precompute import contains_a_subst_rule_invocation -from loopy.translation_unit import for_each_kernel -from loopy.types import ToLoopyTypeConvertible, to_loopy_type +from ..translation_unit import for_each_kernel +from ..types import ToLoopyTypeConvertible, to_loopy_type +from .precompute import contains_a_subst_rule_invocation if TYPE_CHECKING: from collections.abc import Mapping, Sequence, Set - from pymbolic.typing import Expression + from ..kernel import LoopKernel + from ..kernel.data import AddressSpace - from loopy.kernel import LoopKernel - from loopy.kernel.data import AddressSpace +UsageKey: TypeAlias = tuple[str, int] +PredicateSet: TypeAlias = frozenset[Expression] +PredicateOptions: TypeAlias = tuple[PredicateSet | None, ...] -AccessTuple: TypeAlias = tuple[str, ...] +@dataclass(frozen=True) +class UsageSite: + insn_id: str + ordinal: int + args: tuple[Expression, ...] + predicates: PredicateSet -def _access_key(args: Sequence[Expression]) -> AccessTuple: - return tuple(str(arg) for arg in args) + @property + def key(self) -> UsageKey: + return self.insn_id, self.ordinal + + @property + def domain_names(self) -> frozenset[str]: + exprs = (*self.args, *self.predicates) + return frozenset(set().union(*(_gather_vars(expr) for expr in exprs))) + + +@dataclass(frozen=True) +class NameState: + internal_compute_map: nisl.Map + renamed_to_original: Mapping[str, str] + original_to_renamed: Mapping[str, str] + + +@dataclass(frozen=True) +class UsageInfo: + global_usage_map: nisl.Map + local_storage_maps: Mapping[UsageKey, nisl.Map | nisl.BasicMap] + + +@dataclass(frozen=True) +class FootprintInfo: + loopy_footprint: nisl.Set + named_domain: nisl.Set + + +@dataclass(frozen=True) +class ReuseRelations: + shift_relation: nisl.Map + reusable_footprint: nisl.Set + refill_footprint: nisl.Set + + +@dataclass(frozen=True) +class ComputePlan: + name_state: NameState + usage_info: UsageInfo + footprint_info: FootprintInfo + storage_indices: tuple[str, ...] + temporal_inames: tuple[str, ...] + reuse_relations: ReuseRelations | None + + +@dataclass(frozen=True) +class ComputeInstructionInfo: + expression: Expression + dependencies: frozenset[str] + within_inames: frozenset[str] def _base_name(name: str) -> str: @@ -56,55 +115,105 @@ def _prev_name(name: str) -> str: return f"{_base_name(name)}_prev" -def _basic_set_to_predicates(bset: nisl.BasicSet) -> frozenset[Expression]: - isl_bset = bset._reconstruct_isl_object() +def _make_name_state( + compute_map: nisl.Map, + storage_indices: Sequence[str], +) -> NameState: + original_to_renamed = { + name: f"{name}_" + for name in compute_map.output_names + if name not in storage_indices + } + renamed_to_original = { + renamed: original for original, renamed in original_to_renamed.items() + } + return NameState( + internal_compute_map=compute_map.rename_dims(original_to_renamed), + renamed_to_original=renamed_to_original, + original_to_renamed=original_to_renamed, + ) + + +def _infer_temporal_inames( + compute_map: nisl.Map, + storage_indices: Sequence[str], +) -> tuple[str, ...]: + storage_set = frozenset(storage_indices) + return tuple(name for name in compute_map.output_names if name not in storage_set) - predicates = [] - for constraint in isl_bset.get_constraints(): - expr = pw_aff_to_expr(constraint.get_aff()) - if constraint.is_equality(): - predicates.append(p.Comparison(expr, "==", 0)) - else: - predicates.append(p.Comparison(expr, ">=", 0)) - return frozenset(predicates) +def _basic_set_to_predicates(bset: nisl.BasicSet) -> PredicateSet: + return frozenset( + p.Comparison( + pw_aff_to_expr(constraint.get_aff()), + "==" if constraint.is_equality() else ">=", + 0, + ) + for constraint in bset._reconstruct_isl_object().get_constraints() + ) def _set_to_predicate_options( - set_: nisl.Set | nisl.BasicSet - ) -> Sequence[frozenset[Expression]]: + set_: nisl.Set | nisl.BasicSet, +) -> tuple[PredicateSet, ...]: if isinstance(set_, nisl.BasicSet): if set_._reconstruct_isl_object().is_empty(): - return [] - return [_basic_set_to_predicates(set_)] + return () + return (_basic_set_to_predicates(set_),) - predicate_options = [] - for bset in set_.get_basic_sets(): - if not bset._reconstruct_isl_object().is_empty(): - predicate_options.append(_basic_set_to_predicates(bset)) - - return predicate_options + return tuple( + _basic_set_to_predicates(bset) + for bset in set_.get_basic_sets() + if not bset._reconstruct_isl_object().is_empty() + ) -# helper for gathering names of variables in pymbolic expressions def _gather_vars(expr: Expression) -> set[str]: deps = DependencyMapper()(expr) var_names = set() for dep in deps: if isinstance(dep, p.Variable): var_names.add(dep.name) - elif ( - isinstance(dep, p.Subscript) - and isinstance(dep.aggregate, p.Variable)): + elif isinstance(dep, p.Subscript) and isinstance(dep.aggregate, p.Variable): var_names.add(dep.aggregate.name) return var_names +def _gather_usage_inames(sites: Sequence[UsageSite]) -> frozenset[str]: + return frozenset(set().union(*(site.domain_names for site in sites))) + + +def _next_ordinal(counters: dict[str, int], insn_id: str) -> int: + ordinal = counters.get(insn_id, 0) + counters[insn_id] = ordinal + 1 + return ordinal + + +def _normalize_subst_tag( + tag: Set[Tag] | Sequence[Tag] | Tag | None, +) -> frozenset[Tag] | None: + if tag is None: + return None + if isinstance(tag, Tag): + return frozenset([tag]) + return frozenset(tag) + + +def _add_predicates_to_domain( + domain: nisl.BasicSet, + predicates: PredicateSet, +) -> nisl.BasicSet: + predicate_constraints = [str(predicate) for predicate in predicates] + if not predicate_constraints: + return domain + return domain.add_constraint(predicate_constraints) + + def _existing_name_mapping( - map_: nisl.Map | nisl.BasicMap, - name_mapping: Mapping[str, str] - ) -> Mapping[str, str]: + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], +) -> Mapping[str, str]: names = map_.names return { source: target @@ -114,198 +223,181 @@ def _existing_name_mapping( def _normalize_renamed_dims( - map_: nisl.Map | nisl.BasicMap, - name_mapping: Mapping[str, str], - ) -> nisl.Map | nisl.BasicMap: + map_: nisl.Map | nisl.BasicMap, + name_mapping: Mapping[str, str], +) -> nisl.Map | nisl.BasicMap: map_ = map_.equate_dims(_existing_name_mapping(map_, name_mapping)) names = map_.names - project_names = [ + map_ = map_.project_out([ renamed_name for original_name, renamed_name in name_mapping.items() if original_name in names and renamed_name in names - ] - map_ = map_.project_out(project_names) + ]) names = map_.names - rename_mapping = { + return map_.rename_dims({ renamed_name: original_name for original_name, renamed_name in name_mapping.items() if original_name not in names and renamed_name in names - } - return map_.rename_dims(rename_mapping) + }) -# {{{ gathering usage expressions +def _empty_usage_map(local_domain: nisl.BasicSet, range_: nisl.Set) -> nisl.Map: + map_space = nisl.make_map_from_domain_and_range( + nisl.make_set(isl.Set.empty(local_domain.get_space())), + range_, + ).get_space() + return nisl.make_map(isl.Map.empty(map_space)) + + +def _map_to_output_exprs(map_: nisl.Map | nisl.BasicMap) -> tuple[Expression, ...]: + pwmaff = map_.as_pw_multi_aff() + return tuple( + pw_aff_to_expr(pwmaff.get_at(dim)) + for dim in range(pwmaff.dim(isl.dim_type.out)) + ) + + +def _map_to_named_output_exprs( + map_: nisl.Map | nisl.BasicMap, +) -> Mapping[str, Expression]: + pwmaff = map_.as_pw_multi_aff() + return { + pwmaff.get_dim_name(isl.dim_type.out, dim): pw_aff_to_expr(pwmaff.get_at(dim)) + for dim in range(pwmaff.dim(isl.dim_type.out)) + } + class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): - """ - Gathers all expressions used as inputs to a particular substitution rule, - identified by name. - """ def __init__( - self, - rule_mapping_ctx: SubstitutionRuleMappingContext, - subst_expander: SubstitutionRuleExpander, - kernel: LoopKernel, - subst_name: str, - subst_tag: Set[Tag] | Tag | None = None - ) -> None: - + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None, + ) -> None: super().__init__(rule_mapping_ctx) - self.subst_expander: SubstitutionRuleExpander = subst_expander - self.kernel: LoopKernel = kernel self.subst_name: str = subst_name - self.subst_tag: Set[Tag] | None = ( - {subst_tag} if isinstance(subst_tag, Tag) else subst_tag - ) - - self.usage_expressions: list[Sequence[Expression]] = [] + self.subst_tag: frozenset[Tag] | None = _normalize_subst_tag(subst_tag) + self.sites: list[UsageSite] = [] + self._next_ordinal_by_insn: dict[str, int] = {} @override def map_subst_rule( - self, - name: str, - tags: Set[Tag] | None, - arguments: Sequence[Expression], - expn_state: ExpansionState, - ) -> Expression: - + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: if name != self.subst_name: - return super().map_subst_rule( - name, tags, arguments, expn_state - ) + return super().map_subst_rule(name, tags, arguments, expn_state) - if self.subst_tag is not None and self.subst_tag != tags: - return super().map_subst_rule( - name, tags, arguments, expn_state - ) + if self.subst_tag is not None and self.subst_tag != frozenset(tags or ()): + return super().map_subst_rule(name, tags, arguments, expn_state) rule = self.rule_mapping_context.old_subst_rules[name] arg_ctx = self.make_new_arg_context( name, rule.arguments, arguments, expn_state.arg_context ) - self.usage_expressions.append([ - arg_ctx[arg_name] for arg_name in rule.arguments - ]) + self.sites.append( + UsageSite( + insn_id=expn_state.insn_id, + ordinal=_next_ordinal(self._next_ordinal_by_insn, expn_state.insn_id), + args=tuple(arg_ctx[arg_name] for arg_name in rule.arguments), + predicates=frozenset(expn_state.instruction.predicates), + ) + ) return 0 -# }}} - - -# {{{ substitution rule use replacement class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): def __init__( - self, - ctx: SubstitutionRuleMappingContext, - subst_name: str, - subst_tag: Sequence[Tag] | None, - usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap], - storage_indices: Sequence[str], - temporary_name: str, - compute_insn_ids: str | Sequence[str], - footprint: nisl.Set - ) -> None: - + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Mapping[UsageKey, nisl.Map | nisl.BasicMap], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_ids: str | Sequence[str], + footprint: nisl.Set, + ) -> None: super().__init__(ctx) self.subst_name: str = subst_name - self.subst_tag: Sequence[Tag] | None = subst_tag - - self.usage_descriptors: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = \ + self.subst_tag: frozenset[Tag] | None = _normalize_subst_tag(subst_tag) + self.usage_descriptors: Mapping[UsageKey, nisl.Map | nisl.BasicMap] = ( usage_descriptors - self.storage_indices: Sequence[str] = storage_indices - self.footprint: nisl.Set = footprint - + ) + self.storage_indices: tuple[str, ...] = tuple(storage_indices) + self.temp_footprint: nisl.Set = footprint.move_dims( + frozenset(footprint.names) - frozenset(self.storage_indices), + isl.dim_type.param, + ) self.temporary_name: str = temporary_name - self.compute_insn_ids: frozenset[str] = ( + self.compute_dep_ids: frozenset[str] = ( frozenset([compute_insn_ids]) if isinstance(compute_insn_ids, str) else frozenset(compute_insn_ids) ) - self.replaced_something: bool = False - - # FIXME: may not always be the case (i.e. global barrier between - # compute insn and uses) - self.compute_dep_ids: frozenset[str] = self.compute_insn_ids + self._next_ordinal_by_insn: dict[str, int] = {} @override def map_subst_rule( - self, - name: str, - tags: Set[Tag] | None, - arguments: Sequence[Expression], - expn_state: ExpansionState - ) -> Expression: - - rule = self.rule_mapping_context.old_subst_rules[name] - arg_ctx = self.make_new_arg_context( - name, rule.arguments, arguments, expn_state.arg_context - ) - args = [arg_ctx[arg_name] for arg_name in rule.arguments] - - # {{{ validation checks - + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: if name != self.subst_name: return super().map_subst_rule(name, tags, arguments, expn_state) - access_key = _access_key(args) - if access_key not in self.usage_descriptors: + if self.subst_tag is not None and self.subst_tag != frozenset(tags or ()): return super().map_subst_rule(name, tags, arguments, expn_state) + rule = self.rule_mapping_context.old_subst_rules[name] if len(arguments) != len(rule.arguments): raise ValueError( f"Number of arguments passed to rule {name} " f"does not match the signature of {name}." ) - local_map = self.usage_descriptors[access_key] - temp_footprint = self.footprint.move_dims( - frozenset(self.footprint.names) - frozenset(self.storage_indices), - isl.dim_type.param + access_key = ( + expn_state.insn_id, + _next_ordinal(self._next_ordinal_by_insn, expn_state.insn_id), ) - if not local_map.range() <= temp_footprint: + local_map = self.usage_descriptors.get(access_key) + if local_map is None: return super().map_subst_rule(name, tags, arguments, expn_state) - # }}} - - # {{{ get index expression in terms of global inames - - local_pwmaff = self.usage_descriptors[access_key].as_pw_multi_aff() - - index_exprs: Sequence[Expression] = [] - for dim in range(local_pwmaff.dim(isl.dim_type.out)): - index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) - - new_expression = var(self.temporary_name)[tuple(index_exprs)] - - # }}} + if not local_map.range() <= self.temp_footprint: + return super().map_subst_rule(name, tags, arguments, expn_state) self.replaced_something = True - - return new_expression + return var(self.temporary_name)[_map_to_output_exprs(local_map)] @override def map_kernel( - self, - kernel: LoopKernel, - within: StackMatch = lambda knl, insn, stack: True, - map_args: bool = True, - map_tvs: bool = True - ) -> LoopKernel: - - new_insns: Sequence[lp.InstructionBase] = [] + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True, + ) -> LoopKernel: + new_insns = [] for insn in kernel.instructions: self.replaced_something = False - if (isinstance(insn, lp.MultiAssignmentBase) and not - contains_a_subst_rule_invocation(kernel, insn)): + if isinstance( + insn, MultiAssignmentBase + ) and not contains_a_subst_rule_invocation(kernel, insn): new_insns.append(insn) continue @@ -314,403 +406,581 @@ def map_kernel( ) if self.replaced_something: - insn = insn.copy( - depends_on=( - insn.depends_on | self.compute_dep_ids - ) - ) - - # FIXME: determine compute insn dependencies + insn = insn.copy(depends_on=insn.depends_on | self.compute_dep_ids) new_insns.append(insn) return kernel.copy(instructions=new_insns) -# }}} - - -@for_each_kernel -def compute( - kernel: LoopKernel, - substitution: str, - compute_map: nisl.Map, - - storage_indices: Sequence[str], - - # FIXME: can these two be deduced? - temporal_inames: Sequence[str], - inames_to_advance: Sequence[str] | None = None, - - temporary_name: str | None = None, - temporary_address_space: AddressSpace | None = None, - - temporary_dtype: ToLoopyTypeConvertible = None, - - compute_insn_id: str | None = None - ) -> LoopKernel: - """ - Inserts an instruction to compute an expression given by :arg:`substitution` - and replaces all invocations of :arg:`substitution` with the result of the - inserted compute instruction. - - :arg substitution: The substitution rule for which the compute - transform should be applied. - - :arg compute_map: An :class:`isl.Map` representing a relation between - substitution rule indices and tuples `(a, l)`, where `a` is a vector of - storage indices and `l` is a vector of "timestamps". - - :arg storage_indices: An ordered sequence of names of storage indices. Used - to create inames for the loops that cover the required set of compute points. - """ - - name_mapping = { - name: name + "_" - for name in compute_map.output_names - if name not in storage_indices - } - compute_map = compute_map.rename_dims(name_mapping) - - # {{{ setup and useful variables - - storage_set = frozenset(storage_indices) - temporal_set = frozenset(temporal_inames) +def _gather_usage_sites( + kernel: LoopKernel, + substitution: str, +) -> tuple[UsageSite, ...]: ctx = SubstitutionRuleMappingContext( - kernel.substitutions, kernel.get_var_name_generator()) - expander = SubstitutionRuleExpander(kernel.substitutions) - expr_gatherer = UsageSiteExpressionGatherer( - ctx, expander, kernel, substitution, None - ) - - _ = expr_gatherer.map_kernel(kernel) - usage_exprs = expr_gatherer.usage_expressions - - all_exprs = [expr for usage in usage_exprs for expr in usage] - usage_inames: frozenset[str] = frozenset( - set.union(*(_gather_vars(expr) for expr in all_exprs)) + kernel.substitutions, kernel.get_var_name_generator() ) + gatherer = UsageSiteExpressionGatherer(ctx, substitution) + + _ = gatherer.map_kernel(kernel) + return tuple(gatherer.sites) + + +def _build_usage_info( + named_domain: nisl.BasicSet, + name_state: NameState, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + sites: Sequence[UsageSite], +) -> UsageInfo: + if not sites: + raise ValueError( + "compute() did not find any invocation of the requested substitution rule." + ) - # }}} - - # {{{ construct necessary pieces; footprint, global usage map - - # add compute inames to domain / kernel - domain_changer = DomainChanger(kernel, kernel.all_inames()) - named_domain = nisl.make_basic_set(domain_changer.domain) - - # restrict domain to used inames + usage_inames = _gather_usage_inames(sites) local_domain = named_domain.project_out_except(usage_inames) - - # FIXME: gross. find a cleaner way to generate a space for an empty map - global_usage_map = nisl.make_map_from_domain_and_range( - nisl.make_set(isl.Set.empty(local_domain.get_space())), - compute_map.domain() + global_usage_map = _empty_usage_map( + local_domain, name_state.internal_compute_map.domain() ) - global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space())) - usage_substs: Mapping[AccessTuple, nisl.Map | nisl.BasicMap] = {} - for usage in usage_exprs: - - # {{{ compute local usage map, update global usage map + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) + usage_descriptors: dict[UsageKey, nisl.Map | nisl.BasicMap] = {} - local_usage_mpwaff = multi_pw_aff_from_exprs( - usage, - global_usage_map.get_space() + for site in sites: + local_domain = _add_predicates_to_domain( + named_domain.project_out_except(site.domain_names), + site.predicates, + ) + usage_mpwaff = multi_pw_aff_from_exprs(site.args, global_usage_map.get_space()) + local_usage_map = nisl.make_map(usage_mpwaff.as_map()).intersect_domain( + local_domain ) - - local_usage_map = nisl.make_map(local_usage_mpwaff.as_map()) - - local_usage_map = local_usage_map.intersect_domain(local_domain) global_usage_map = global_usage_map | local_usage_map - # }}} - - # {{{ compute storage map - - local_storage_map = local_usage_map.apply_range(compute_map) + local_storage_map = local_usage_map.apply_range(name_state.internal_compute_map) local_storage_map = _normalize_renamed_dims( - local_storage_map, name_mapping) - - # check that no restrictions happened during composition (i.e. tile - # valid for a single point in the domain) + local_storage_map, name_state.original_to_renamed + ) if not local_usage_map.domain() <= local_storage_map.domain(): continue - # clean up names non_param_names = (usage_inames - temporal_set) | storage_set - parameter_names = frozenset(local_storage_map.names) - non_param_names - local_storage_map = local_storage_map.move_dims(parameter_names, - isl.dim_type.param) - - # }}} + usage_descriptors[site.key] = local_storage_map.move_dims( + frozenset(local_storage_map.names) - non_param_names, + isl.dim_type.param, + ) - usage_substs[_access_key(usage)] = local_storage_map + return UsageInfo( + global_usage_map=global_usage_map, + local_storage_maps=usage_descriptors, + ) - storage_map = global_usage_map.apply_range(compute_map) - storage_map = _normalize_renamed_dims(storage_map, name_mapping) - # }}} +def _build_footprint_info( + named_domain: nisl.BasicSet, + name_state: NameState, + usage_info: UsageInfo, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], +) -> FootprintInfo: + storage_set = frozenset(storage_indices) + temporal_set = frozenset(temporal_inames) - # {{{ compute bounds and update kernel domain + storage_map = usage_info.global_usage_map.apply_range( + name_state.internal_compute_map + ) + storage_map = _normalize_renamed_dims(storage_map, name_state.original_to_renamed) storage_map = storage_map.move_dims(temporal_set, isl.dim_type.param) - footprint = storage_map.range() + exact_footprint = storage_map.range() + exact_footprint = exact_footprint.project_out_except(temporal_set | storage_set) + exact_footprint = exact_footprint.move_dims(temporal_set, isl.dim_type.set) - # clean up ticked duplicate names - footprint = footprint.project_out_except(temporal_set | storage_set) - footprint = footprint.move_dims(temporal_set, isl.dim_type.set) + # Loopy domains are still restricted to a single BasicSet in this path. + footprint_isl = exact_footprint._reconstruct_isl_object() + loopy_footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) - # {{{ FIXME: use Sets instead of BasicSets when loopy is ready + loopy_domain = named_domain & loopy_footprint + if len(loopy_domain.get_basic_sets()) != 1: + raise ValueError("New domain should be composed of a single basic set") - # FIXME: convex hull is not permanent - footprint_isl = footprint._reconstruct_isl_object() - footprint = nisl.make_set(isl.Set.from_basic_set(footprint_isl.convex_hull())) - named_domain = named_domain & footprint + return FootprintInfo(loopy_footprint=loopy_footprint, named_domain=loopy_domain) + + +def _build_compute_plan( + compute_map: nisl.Map, + named_domain: nisl.BasicSet, + sites: Sequence[UsageSite], + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + inames_to_advance: Sequence[str] | Literal["auto"] | None, +) -> ComputePlan: + name_state = _make_name_state(compute_map, storage_indices) + usage_info = _build_usage_info( + named_domain, + name_state, + storage_indices, + temporal_inames, + sites, + ) + footprint_info = _build_footprint_info( + named_domain, + name_state, + usage_info, + storage_indices, + temporal_inames, + ) - if len(named_domain.get_basic_sets()) != 1: - raise ValueError("New domain should be composed of a single basic set") + if inames_to_advance == "auto": + inames_to_advance = _detect_inames_to_advance( + name_state.internal_compute_map, + usage_info.global_usage_map, + footprint_info.loopy_footprint, + storage_indices, + temporal_inames, + ) - # FIXME: use named object once loopy is name-ified - domain = named_domain.get_basic_sets()[0]._reconstruct_isl_object() - new_domains = domain_changer.get_domains_with(domain) + reuse_relations = ( + None + if inames_to_advance is None + else _build_reuse_relations( + name_state.internal_compute_map, + usage_info.global_usage_map, + footprint_info.loopy_footprint, + footprint_info.named_domain, + storage_indices, + temporal_inames, + frozenset(inames_to_advance), + ) + ) - # }}} + return ComputePlan( + name_state=name_state, + usage_info=usage_info, + footprint_info=footprint_info, + storage_indices=tuple(storage_indices), + temporal_inames=tuple(temporal_inames), + reuse_relations=reuse_relations, + ) - kernel = kernel.copy(domains=new_domains) - # }}} +def _build_reuse_relations( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + named_domain: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + advancing_set: frozenset[str], +) -> ReuseRelations: + predecessor_context = _build_predecessor_context(temporal_inames, advancing_set) + shift_relation = _build_shift_relation( + compute_map, + global_usage_map, + footprint, + storage_indices, + predecessor_context, + ) + reusable_footprint = shift_relation.range() + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + refill_footprint = current_footprint - reusable_footprint - if not temporary_name: - temporary_name = substitution + "_temp" + normal_names = {_cur_name(name): name for name in footprint.names} + reusable_footprint = reusable_footprint.rename_dims(normal_names) + refill_footprint = refill_footprint.rename_dims(normal_names) - if not compute_insn_id: - compute_insn_id = substitution + "_compute" + reusable_footprint = reusable_footprint.gist( + named_domain.project_out_except(reusable_footprint.names) + ) + refill_footprint = refill_footprint.gist( + named_domain.project_out_except(refill_footprint.names) + ) - # {{{ reuse analysis + return ReuseRelations( + shift_relation=shift_relation, + reusable_footprint=reusable_footprint, + refill_footprint=refill_footprint, + ) - update_insns: list[lp.InstructionBase] = [] - update_insn_ids: list[str] = [] - refill_predicate_options: Sequence[frozenset[Expression] | None] = [None] - current_update_deps: frozenset[str] = frozenset() - if inames_to_advance is not None: - advancing_set = frozenset(inames_to_advance) +def _build_predecessor_context( + temporal_inames: Sequence[str], + advancing_set: frozenset[str], +) -> nisl.Map: + constraints = [ + ( + f"{_cur_name(name)} = {_prev_name(name)} + 1" + if name in advancing_set + else f"{_cur_name(name)} = {_prev_name(name)}" + ) + for name in temporal_inames + ] - compute_map_cur = compute_map.rename_dims({ - name: _cur_name(name) for name in compute_map.output_names - }) - compute_map_prev = compute_map.rename_dims({ - name: _prev_name(name) for name in compute_map.output_names - }) + return nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in temporal_inames) + + "] -> [" + + ", ".join(_cur_name(name) for name in temporal_inames) + + "]" + + (f" : {' and '.join(constraints)}" if constraints else "") + + " }" + ) - cur_storage = global_usage_map.apply_range(compute_map_cur) - prev_storage = global_usage_map.apply_range(compute_map_prev) - reuse_map = prev_storage.reverse().apply_range(cur_storage) - reuse_map = reuse_map.add_constraint([ - ( - f"{name}_cur = {name}_prev + 1" - if name in advancing_set - else - f"{name}_cur = {name}_prev" - ) - for name in temporal_inames - ]) - - current_footprint = footprint.rename_dims({ - name: _cur_name(name) for name in footprint.names - }) - previous_footprint = footprint.rename_dims({ - name: _prev_name(name) for name in footprint.names - }) - - reuse_map = reuse_map.intersect_domain(previous_footprint) - reuse_map = reuse_map.intersect_range(current_footprint) - reuse_map = reuse_map - nisl.make_map( - "{ [" - + ", ".join(_prev_name(name) for name in footprint.names) - + "] -> [" - + ", ".join(_cur_name(name) for name in footprint.names) - + "] : " - + " and ".join( - f"{_cur_name(name)} = {_prev_name(name)}" - for name in storage_indices - ) - + " }" +def _build_shift_relation( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + storage_indices: Sequence[str], + predecessor_context: nisl.Map, +) -> nisl.Map: + compute_map_cur = compute_map.rename_dims({ + name: _cur_name(name) for name in compute_map.output_names + }) + compute_map_prev = compute_map.rename_dims({ + name: _prev_name(name) for name in compute_map.output_names + }) + + reuse_map = ( + global_usage_map + .apply_range(compute_map_prev) + .reverse() + .apply_range(global_usage_map.apply_range(compute_map_cur)) + ) + reuse_map = reuse_map & predecessor_context + + current_footprint = footprint.rename_dims({ + name: _cur_name(name) for name in footprint.names + }) + previous_footprint = footprint.rename_dims({ + name: _prev_name(name) for name in footprint.names + }) + + reuse_map = reuse_map.intersect_domain(previous_footprint) + reuse_map = reuse_map.intersect_range(current_footprint) + + return reuse_map - _identity_storage_map(footprint, storage_indices) + + +def _detect_inames_to_advance( + compute_map: nisl.Map, + global_usage_map: nisl.Map, + footprint: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], +) -> tuple[str, ...]: + candidates = [] + for name in temporal_inames: + shift_relation = _build_shift_relation( + compute_map, + global_usage_map, + footprint, + storage_indices, + _build_predecessor_context(temporal_inames, frozenset([name])), ) + if not shift_relation._reconstruct_isl_object().is_empty(): + candidates.append(name) - reused_current = reuse_map.range() - refill = current_footprint - reused_current + if len(candidates) > 1: + raise ValueError( + "Could not infer a unique advancing iname. " + f"Candidates are {candidates}; pass inames_to_advance explicitly." + ) + + return tuple(candidates) + + +def _identity_storage_map( + footprint: nisl.Set, + storage_indices: Sequence[str], +) -> nisl.Map: + return nisl.make_map( + "{ [" + + ", ".join(_prev_name(name) for name in footprint.names) + + "] -> [" + + ", ".join(_cur_name(name) for name in footprint.names) + + "] : " + + " and ".join( + f"{_cur_name(name)} = {_prev_name(name)}" for name in storage_indices + ) + + " }" + ) - cur_to_normal = { - _cur_name(name): name - for name in footprint.names - } - reused_current = reused_current.rename_dims(cur_to_normal) - refill = refill.rename_dims(cur_to_normal) - reused_context = named_domain.project_out_except(reused_current.names) - refill_context = named_domain.project_out_except(refill.names) +def _make_shift_instructions( + reuse_map: nisl.Map, + reused_current: nisl.Set, + storage_indices: Sequence[str], + temporal_inames: Sequence[str], + temporary_name: str, + compute_insn_id: str, +) -> tuple[tuple[InstructionBase, ...], tuple[str, ...], frozenset[str]]: + storage_reuse_map = reuse_map.project_out_except( + frozenset(_prev_name(name) for name in storage_indices) + | frozenset(_cur_name(name) for name in storage_indices) + ) + storage_reuse_map = storage_reuse_map.rename_dims({ + _cur_name(name): name for name in storage_indices + }) + + cur_to_prev_exprs = _map_to_named_output_exprs(storage_reuse_map.reverse()) + prev_storage_exprs = tuple( + cur_to_prev_exprs[_prev_name(name)] for name in storage_indices + ) - reused_current = reused_current.gist(reused_context) - refill = refill.gist(refill_context) + shift_assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + shift_expression = var(temporary_name)[prev_storage_exprs] - refill_predicate_options = _set_to_predicate_options(refill) + update_insns = [] + update_ids = [] + current_deps: frozenset[str] = frozenset() + shift_predicate_options = _set_to_predicate_options(reused_current) - storage_reuse_map = reuse_map.project_out_except( - frozenset(_prev_name(name) for name in storage_indices) - | frozenset(_cur_name(name) for name in storage_indices) + for i, predicates in enumerate(shift_predicate_options): + shift_insn_id = ( + f"{compute_insn_id}_shift" + if len(shift_predicate_options) == 1 + else f"{compute_insn_id}_shift_{i}" ) - storage_reuse_map = storage_reuse_map.rename_dims({ - _cur_name(name): name - for name in storage_indices - }) - cur_to_prev = storage_reuse_map.reverse() - cur_to_prev_pwma = cur_to_prev.as_pw_multi_aff() - prev_expr_by_name = { - cur_to_prev_pwma.get_dim_name(isl.dim_type.out, dim): - pw_aff_to_expr(cur_to_prev_pwma.get_at(dim)) - for dim in range(cur_to_prev_pwma.dim(isl.dim_type.out)) - } - prev_storage_exprs = [ - prev_expr_by_name[_prev_name(name)] - for name in storage_indices - ] - - shift_assignee = var(temporary_name)[ - tuple(var(idx) for idx in storage_indices) - ] - shift_expression = var(temporary_name)[tuple(prev_storage_exprs)] - - shift_predicate_options = _set_to_predicate_options(reused_current) - for i, predicates in enumerate(shift_predicate_options): - shift_insn_id = ( - f"{compute_insn_id}_shift" - if len(shift_predicate_options) == 1 - else f"{compute_insn_id}_shift_{i}" - ) - update_insns.append(lp.Assignment( + update_insns.append( + Assignment( id=shift_insn_id, assignee=shift_assignee, expression=shift_expression, - within_inames=frozenset(temporal_inames) | storage_set, + within_inames=frozenset(temporal_inames) | frozenset(storage_indices), predicates=predicates, - depends_on=current_update_deps, - )) - update_insn_ids.append(shift_insn_id) - current_update_deps = frozenset([shift_insn_id]) - - # }}} + depends_on=current_deps, + ) + ) + update_ids.append(shift_insn_id) + current_deps = frozenset([shift_insn_id]) - # {{{ create compute instruction in kernel + return tuple(update_insns), tuple(update_ids), current_deps - # FIXME: maybe just keep original around? - compute_map = compute_map.rename_dims({ - value: key for key, value in name_mapping.items() - }) +def _build_compute_instruction_info( + kernel: LoopKernel, + substitution: str, + name_state: NameState, +) -> ComputeInstructionInfo: + compute_map = name_state.internal_compute_map.rename_dims( + name_state.renamed_to_original + ) compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - storage_ax_to_global_expr = { - compute_pw_aff.get_dim_name(isl.dim_type.out, dim): - pw_aff_to_expr(compute_pw_aff.get_at(dim)) + storage_axis_to_global_expr = { + compute_pw_aff.get_dim_name(isl.dim_type.out, dim): pw_aff_to_expr( + compute_pw_aff.get_at(dim) + ) for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) expr_subst_map = RuleAwareSubstitutionMapper( ctx, - make_subst_func(storage_ax_to_global_expr), - within=parse_stack_match(None) + make_subst_func(storage_axis_to_global_expr), + within=parse_stack_match(None), ) - subst_expr = kernel.substitutions[substitution].expression compute_expression = expr_subst_map( - subst_expr, + kernel.substitutions[substitution].expression, kernel, - None, + cast("InstructionBase", cast("object", None)), ) - compute_dep_ids = frozenset().union(*( - kernel.writer_map().get(var_name, frozenset()) - for var_name in _gather_vars(compute_expression) - )) - assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] + dependencies = frozenset().union( + *( + kernel.writer_map().get(var_name, frozenset()) + for var_name in _gather_vars(compute_expression) + ) + ) + + return ComputeInstructionInfo( + expression=compute_expression, + dependencies=dependencies, + within_inames=frozenset(compute_map.output_names), + ) - within_inames = compute_map.output_names - new_insns = list(kernel.instructions) - new_insns.extend(update_insns) +def _add_update_and_compute_instructions( + kernel: LoopKernel, + update_insns: Sequence[InstructionBase], + update_ids: Sequence[str], + refill_options: PredicateOptions, + final_deps: frozenset[str], + compute_info: ComputeInstructionInfo, + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_id: str, +) -> tuple[LoopKernel, tuple[str, ...]]: + new_insns = [*kernel.instructions, *update_insns] + update_ids = list(update_ids) + current_deps = final_deps + assignee = var(temporary_name)[tuple(var(idx) for idx in storage_indices)] - for i, predicates in enumerate(refill_predicate_options): + for i, predicates in enumerate(refill_options): refill_insn_id = ( compute_insn_id - if len(refill_predicate_options) == 1 + if len(refill_options) == 1 else f"{compute_insn_id}_refill_{i}" ) - compute_insn = lp.Assignment( - id=refill_insn_id, - assignee=assignee, - expression=compute_expression, - within_inames=within_inames, - predicates=predicates, - depends_on=current_update_deps | compute_dep_ids, + new_insns.append( + Assignment( + id=refill_insn_id, + assignee=assignee, + expression=compute_info.expression, + within_inames=compute_info.within_inames, + predicates=predicates, + depends_on=current_deps | compute_info.dependencies, + ) ) - new_insns.append(compute_insn) - update_insn_ids.append(refill_insn_id) - current_update_deps = frozenset([refill_insn_id]) + update_ids.append(refill_insn_id) + current_deps = frozenset([refill_insn_id]) - kernel = kernel.copy(instructions=new_insns) + return kernel.copy(instructions=new_insns), tuple(update_ids) - # }}} - # {{{ replace invocations with new compute instruction +def _add_temporary( + kernel: LoopKernel, + footprint: nisl.Set, + storage_indices: Sequence[str], + temporary_name: str, + temporary_address_space: AddressSpace | None, + temporary_dtype: ToLoopyTypeConvertible, +) -> LoopKernel: + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) + bounds = tuple( + (pw_aff_to_expr(footprint.dim_min(dim)), pw_aff_to_expr(footprint.dim_max(dim))) + for dim in storage_indices + ) + base_indices = tuple(lower for lower, _upper in bounds) + temp_shape = tuple(upper - lower + 1 for lower, upper in bounds) + + new_temp_vars = dict(kernel.temporary_variables) + new_temp_vars[temporary_name] = TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=base_indices, + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices), + ) + return kernel.copy(temporary_variables=new_temp_vars) + + +def _lower_compute_plan( + kernel: LoopKernel, + substitution: str, + plan: ComputePlan, + domain_changer: DomainChanger, + temporary_name: str, + temporary_address_space: AddressSpace | None, + temporary_dtype: ToLoopyTypeConvertible, + compute_insn_id: str, +) -> LoopKernel: + domain = plan.footprint_info.named_domain.get_basic_sets()[ + 0 + ]._reconstruct_isl_object() + kernel = kernel.copy(domains=domain_changer.get_domains_with(domain)) + + update_insns: tuple[InstructionBase, ...] = () + update_insn_ids: tuple[str, ...] = () + refill_options: PredicateOptions = (None,) + final_deps: frozenset[str] = frozenset() + if plan.reuse_relations is not None: + update_insns, update_insn_ids, final_deps = _make_shift_instructions( + plan.reuse_relations.shift_relation, + plan.reuse_relations.reusable_footprint, + plan.storage_indices, + plan.temporal_inames, + temporary_name, + compute_insn_id, + ) + refill_options = _set_to_predicate_options( + plan.reuse_relations.refill_footprint + ) + + compute_info = _build_compute_instruction_info( + kernel, substitution, plan.name_state + ) + kernel, update_insn_ids = _add_update_and_compute_instructions( + kernel, + update_insns, + update_insn_ids, + refill_options, + final_deps, + compute_info, + plan.storage_indices, + temporary_name, + compute_insn_id, + ) ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator() ) - - replacer = RuleInvocationReplacer( + kernel = RuleInvocationReplacer( ctx, substitution, None, - usage_substs, - storage_indices, + plan.usage_info.local_storage_maps, + plan.storage_indices, temporary_name, update_insn_ids, - footprint - ) - - kernel = replacer.map_kernel(kernel) - - # }}} - - # {{{ create temporary variable for result of compute + plan.footprint_info.loopy_footprint, + ).map_kernel(kernel) - loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - - temp_shape = tuple( - pw_aff_to_expr(footprint.dim_max(dim)) + 1 - for dim in storage_indices + return _add_temporary( + kernel, + plan.footprint_info.loopy_footprint, + plan.storage_indices, + temporary_name, + temporary_address_space, + temporary_dtype, ) - new_temp_vars = dict(kernel.temporary_variables) - # FIXME: temp_var might already exist, handle the case where it does - temp_var = lp.TemporaryVariable( - name=temporary_name, - dtype=loopy_type, - base_indices=(0,)*len(temp_shape), - shape=temp_shape, - address_space=temporary_address_space, - dim_names=tuple(storage_indices) - ) +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: nisl.Map, + storage_indices: Sequence[str], + temporal_inames: Sequence[str] | None = None, + inames_to_advance: Sequence[str] | Literal["auto"] | None = None, + temporary_name: str | None = None, + temporary_address_space: AddressSpace | None = None, + temporary_dtype: ToLoopyTypeConvertible = None, + compute_insn_id: str | None = None, +) -> LoopKernel: + """Compute a substitution into a temporary and replace covered uses.""" + temporary_name = temporary_name or f"{substitution}_temp" + compute_insn_id = compute_insn_id or f"{substitution}_compute" + if temporal_inames is None: + temporal_inames = _infer_temporal_inames(compute_map, storage_indices) - new_temp_vars[temporary_name] = temp_var + domain_changer = DomainChanger(kernel, kernel.all_inames()) + named_domain = nisl.make_basic_set(domain_changer.domain) - kernel = kernel.copy( - temporary_variables=new_temp_vars + plan = _build_compute_plan( + compute_map, + named_domain, + _gather_usage_sites(kernel, substitution), + tuple(storage_indices), + temporal_inames, + inames_to_advance, + ) + return _lower_compute_plan( + kernel, + substitution, + plan, + domain_changer, + temporary_name, + temporary_address_space, + temporary_dtype, + compute_insn_id, ) - - # }}} - - return kernel diff --git a/test/test_compute.py b/test/test_compute.py new file mode 100644 index 000000000..b5442133d --- /dev/null +++ b/test/test_compute.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import namedisl as nisl +import numpy as np + +import loopy as lp +from loopy.transform.compute_stub import _gather_usage_sites, compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def test_compute_stub_simple_substitution_codegen() -> None: + knl = lp.make_kernel( + "{ [i] : 0 <= i < n }", + """ + u_(is) := u[is] + out[i] = u_(i) + """, + [ + lp.GlobalArg("u", shape=(16,), dtype=np.float32), + lp.GlobalArg("out", shape=(16,), dtype=np.float32, is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, n=16) + + knl = compute( + knl, + "u_", + compute_map=nisl.make_map("{ [is] -> [i_s] : is = i_s }"), + storage_indices=["i_s"], + temporal_inames=[], + temporary_name="u_tmp", + temporary_dtype=np.float32, + ) + + code = lp.generate_code_v2(knl).device_code() + assert "float u_tmp[16]" in code + assert "u_tmp[i_s] = u[i_s]" in code + assert "out[i] = u_tmp[i]" in code + + +def test_compute_stub_repeated_substitution_uses_are_unique() -> None: + knl = lp.make_kernel( + "{ [i] : 0 <= i < n }", + """ + u_(is) := u[is] + out[i] = u_(i) + u_(i + 1) {id=write_out} + """, + [ + lp.GlobalArg("u", shape=(16,), dtype=np.float32), + lp.GlobalArg("out", shape=(16,), dtype=np.float32, is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, n=16) + + sites = _gather_usage_sites(knl["loopy_kernel"], "u_") + + assert [site.key for site in sites] == [("write_out", 0), ("write_out", 1)] + assert sites[0].args != sites[1].args + + +def test_compute_stub_ring_buffer_codegen() -> None: + ntime = 128 + block_size = 32 + knl = lp.make_kernel( + "{ [t] : 1 <= t < ntime - 1 }", + """ + u_hist(ts) := u[ts] + u_next[t + 1] = 2*u_hist(t) - u_hist(t - 1) + """, + [ + lp.GlobalArg("u", dtype=np.float64, shape=(ntime,)), + lp.GlobalArg("u_next", dtype=np.float64, shape=(ntime,), is_output=True), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + knl = lp.fix_parameters(knl, ntime=ntime) + knl = lp.split_iname( + knl, + "t", + block_size, + inner_iname="ti", + outer_iname="to", + ) + + knl = compute( + knl, + "u_hist", + compute_map=nisl.make_map("{ [ts] -> [to, ti, tb] : tb = 32*to + ti - ts }"), + storage_indices=["tb"], + inames_to_advance="auto", + temporary_name="u_time_buf", + temporary_dtype=np.float64, + ) + + code = lp.generate_code_v2(knl).device_code() + assert "double u_time_buf[2]" in code + assert "u_time_buf[tb] = u_time_buf[0]" in code + assert "u_time_buf[tb] = u[-1 * tb + ti + 32 * to]" in code + assert "u_next[1 + ti + 32 * to]" in code From 10317a109ba114d9e0e864033699138761cf760e Mon Sep 17 00:00:00 2001 From: Addison Date: Fri, 17 Apr 2026 18:29:00 -0500 Subject: [PATCH 24/27] change examples runner to use which python instead of hardcoded python path --- examples/python/compute-examples/run-compute-examples.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python/compute-examples/run-compute-examples.sh b/examples/python/compute-examples/run-compute-examples.sh index d9c8aab93..096a6d6f3 100755 --- a/examples/python/compute-examples/run-compute-examples.sh +++ b/examples/python/compute-examples/run-compute-examples.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -euo pipefail -PYTHON="${PYTHON:-/home/aj/miniforge3/envs/dev/bin/python}" +PYTHON="$(which python)" SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR" From 1fa9526fb3d19cd3872dc0e69ee6b68e8e0de7c1 Mon Sep 17 00:00:00 2001 From: Addison Date: Sat, 18 Apr 2026 13:35:02 -0500 Subject: [PATCH 25/27] add diamond tiling example --- .codex | 0 .../finite-difference-diamond.py | 342 ++++++++++++++++++ .../compute-examples/run-compute-examples.sh | 10 + 3 files changed, 352 insertions(+) create mode 100644 .codex create mode 100644 examples/python/compute-examples/finite-difference-diamond.py diff --git a/.codex b/.codex new file mode 100644 index 000000000..e69de29bb diff --git a/examples/python/compute-examples/finite-difference-diamond.py b/examples/python/compute-examples/finite-difference-diamond.py new file mode 100644 index 000000000..94baf9183 --- /dev/null +++ b/examples/python/compute-examples/finite-difference-diamond.py @@ -0,0 +1,342 @@ +import time + +import namedisl as nisl +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + +import loopy as lp +from loopy.transform.compute import compute +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 + + +def centered_second_derivative_coefficients(radius: int, dtype) -> np.ndarray: + offsets = np.arange(-radius, radius + 1, dtype=dtype) + powers = np.arange(2 * radius + 1) + vandermonde = offsets[np.newaxis, :] ** powers[:, np.newaxis] + rhs = np.zeros(2 * radius + 1, dtype=dtype) + rhs[2] = 2 + + return np.linalg.solve(vandermonde, rhs).astype(dtype) + + +def benchmark_executor(ex, queue, args, warmup: int, iterations: int) -> float: + if iterations <= 0: + raise ValueError("iterations must be positive") + + evt = None + for _ in range(warmup): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + + start = time.perf_counter() + for _ in range(iterations): + evt, _ = ex(queue, **args) + if evt is not None: + evt.wait() + end = time.perf_counter() + + return (end - start) / iterations + + +def fd_flop_count(ntime: int, nspace: int, stencil_width: int) -> int: + radius = stencil_width // 2 + output_points = (ntime - 1) * (nspace - 2 * radius) + return 2 * stencil_width * output_points + + +def make_initial_condition(nspace: int, dtype) -> np.ndarray: + x = np.linspace(-1, 1, num=nspace, endpoint=True, dtype=dtype) + wave_number = dtype(2 * np.pi) + return np.sin(wave_number * x).astype(dtype) + + +def reference_time_stepper( + u0: np.ndarray, + coeffs: np.ndarray, + ntime: int, + radius: int, +) -> np.ndarray: + result = np.zeros((ntime, u0.size), dtype=u0.dtype) + result[0] = u0 + result[1:, :radius] = u0[:radius] + result[1:, u0.size - radius :] = u0[u0.size - radius :] + + for t in range(ntime - 1): + for i in range(radius, u0.size - radius): + result[t + 1, i] = sum( + coeffs[ell + radius] * result[t, i - ell] + for ell in range(-radius, radius + 1) + ) + + return result + + +def offset_name(ell: int) -> str: + return f"u_p{ell}" if ell >= 0 else f"u_m{-ell}" + + +def main( + ntime: int = 128, + nspace: int = 4096, + stencil_width: int = 9, + time_block_size: int = 8, + space_block_size: int = 128, + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False, + run_kernel: bool = False, + warmup: int = 3, + iterations: int = 10, +) -> float | None: + if stencil_width <= 0 or stencil_width % 2 == 0: + raise ValueError("stencil_width must be a positive odd integer") + if ntime <= stencil_width: + raise ValueError("ntime must be larger than stencil_width") + if nspace <= 2 * stencil_width: + raise ValueError("nspace must be larger than twice stencil_width") + + dtype = np.float64 + r = stencil_width // 2 + + u0 = make_initial_condition(nspace, dtype) + u_hist = np.zeros((ntime, nspace), dtype=dtype) + u_hist[0] = u0 + u_hist[1:, :r] = u0[:r] + u_hist[1:, nspace - r :] = u0[nspace - r :] + h = dtype(2 / (nspace - 1)) + dt = dtype(0.05 * h**2) + lap_coeffs = centered_second_derivative_coefficients(r, dtype) / h**2 + coeffs = (dt * lap_coeffs).astype(dtype) + coeffs[r] += 1 + + bt = time_block_size + bx = space_block_size + subst_rules = "\n".join( + f"{offset_name(ell)}(ts, is) := u_hist[ts, is " + f"{'+' if -ell >= 0 else '-'} {abs(ell)}]" + for ell in range(-r, r + 1) + ) + stencil_expr = " + ".join( + f"c[{ell + r}] * {offset_name(ell)}(t, i)" for ell in range(-r, r + 1) + ) + + knl = lp.make_kernel( + "{ [t, i] : 0 <= t < ntime - 1 and r <= i < nspace - r }", + f""" + {subst_rules} + + u_hist[t + 1, i] = {stencil_expr} {{id=step}} + """, + [ + lp.GlobalArg( + "u_hist", + dtype=dtype, + shape=(ntime, nspace), + is_input=True, + is_output=True, + ), + lp.GlobalArg("c", dtype=dtype, shape=(stencil_width,)), + ], + lang_version=LOOPY_USE_LANGUAGE_VERSION_2018_2, + ) + + knl = lp.fix_parameters(knl, ntime=ntime, nspace=nspace, r=r) + knl = lp.split_iname(knl, "t", bt, inner_iname="ti", outer_iname="to") + knl = lp.split_iname(knl, "i", bx, inner_iname="xi", outer_iname="xo") + + if use_compute: + raise NotImplementedError( + "The recurrent diamond time-stepper cannot currently be lowered " + "through compute(): Loopy represents the instance-wise dependence " + "between compute loads from u_hist[t] and writes to u_hist[t+1] as " + "an instruction-level dependency cycle." + ) + compute_insn_ids = [] + for ell in range(-r, r + 1): + suffix = offset_name(ell) + compute_insn_id = f"{suffix}_diamond_compute" + compute_insn_ids.append(compute_insn_id) + storage_axis = f"xi_s_{suffix}" + diamond_map = nisl.make_map(f"""{{ + [ts, is] -> [to, xo, ti, {storage_axis}] : + ts = to * {bt} + ti and + is = xo * {bx} + {storage_axis} + ti - {bt - 1} + }}""") + + knl = compute( + knl, + suffix, + compute_map=diamond_map, + storage_indices=[storage_axis], + temporal_inames=["to", "xo", "ti"], + temporary_name=f"{suffix}_diamond", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=dtype, + compute_insn_id=compute_insn_id, + ) + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + f"id:{suffix}_diamond_compute", + lambda insn: insn.copy(depends_on=frozenset()), + ) + ) + + knl = lp.split_iname( + knl, + storage_axis, + 128, + outer_iname=f"{storage_axis}_tile", + inner_iname=f"{storage_axis}_local", + ) + knl = lp.tag_inames(knl, {f"{storage_axis}_local": "l.0"}) + + no_sync_with_computes = frozenset( + (compute_insn_id, "global") for compute_insn_id in compute_insn_ids + ) + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + "id:step", + lambda insn: insn.copy( + no_sync_with=insn.no_sync_with | no_sync_with_computes + ), + ) + ) + for compute_insn_id in compute_insn_ids: + knl = knl.with_kernel( + lp.map_instructions( + knl.default_entrypoint, + f"id:{compute_insn_id}", + lambda insn: insn.copy( + no_sync_with=insn.no_sync_with | frozenset([("step", "global")]) + ), + ) + ) + + knl = lp.tag_inames(knl, {"xi": "l.0"}) + knl = lp.prioritize_loops(knl, "to,ti,xo,xi") + knl = lp.set_options(knl, insert_gbarriers=True) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + if not run_kernel: + return None + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + ex = knl.executor(queue) + + import pyopencl.array as cl_array + + u_hist_cl = cl_array.to_device(queue, u_hist) + coeffs_cl = cl_array.to_device(queue, coeffs) + + args = {"c": coeffs_cl, "u_hist": u_hist_cl} + avg_time_per_iter = benchmark_executor( + ex, queue, args, warmup=warmup, iterations=iterations + ) + avg_gflops = fd_flop_count(ntime, nspace, stencil_width) / avg_time_per_iter / 1e9 + + _, out = ex(queue, **args) + reference = reference_time_stepper(u0, coeffs, ntime, r) + sl = (slice(None), slice(r, nspace - r)) + rel_err = la.norm(reference[sl] - out[0].get()[sl]) / la.norm(reference[sl]) + + print(20 * "=", "Diamond finite difference report", 20 * "=") + print(f"Variant : {'compute' if use_compute else 'baseline'}") + print(f"Time steps : {ntime}") + print(f"Space points : {nspace}") + print(f"Stencil width: {stencil_width}") + print(f"Tile shape : bt = {bt}, bx = {bx}") + print(f"Iterations : warmup = {warmup}, timed = {iterations}") + print(f"Average time per iteration: {avg_time_per_iter:.6e} s") + print(f"Average throughput: {avg_gflops:.3f} GFLOP/s") + print(f"Relative error: {rel_err:.3e}") + print((40 + len(" Diamond finite difference report ")) * "=") + + return avg_time_per_iter + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--ntime", action="store", type=int, default=128) + _ = parser.add_argument("--nspace", action="store", type=int, default=4096) + _ = parser.add_argument("--stencil-width", action="store", type=int, default=9) + _ = parser.add_argument("--time-block-size", action="store", type=int, default=8) + _ = parser.add_argument("--space-block-size", action="store", type=int, default=128) + + _ = parser.add_argument("--compare", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--no-run-kernel", action="store_false", dest="run_kernel") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--warmup", action="store", type=int, default=3) + _ = parser.add_argument("--iterations", action="store", type=int, default=10) + + args = parser.parse_args() + + if args.compare: + print("Running example without compute...") + no_compute_time = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=False, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + print("Running example with compute...") + compute_time = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=True, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=True, + warmup=args.warmup, + iterations=args.iterations, + ) + print(50 * "=", "\n") + + assert no_compute_time is not None + assert compute_time is not None + speedup = no_compute_time / compute_time + print(f"Speedup: {speedup:.3f}x") + time_reduction = (1 - compute_time / no_compute_time) * 100 + print(f"Relative time reduction: {time_reduction:.2f}%") + else: + _ = main( + ntime=args.ntime, + nspace=args.nspace, + stencil_width=args.stencil_width, + time_block_size=args.time_block_size, + space_block_size=args.space_block_size, + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel, + run_kernel=args.run_kernel, + warmup=args.warmup, + iterations=args.iterations, + ) diff --git a/examples/python/compute-examples/run-compute-examples.sh b/examples/python/compute-examples/run-compute-examples.sh index 096a6d6f3..1121bc0c3 100755 --- a/examples/python/compute-examples/run-compute-examples.sh +++ b/examples/python/compute-examples/run-compute-examples.sh @@ -20,6 +20,16 @@ run_example finite-difference-2-5D.py \ --warmup 2 \ --iterations 3 +run_example finite-difference-diamond.py \ + --ntime 96 \ + --nspace 4096 \ + --stencil-width 9 \ + --time-block-size 8 \ + --space-block-size 128 \ + --run-kernel \ + --warmup 2 \ + --iterations 3 + run_example wave-equation-ring-buffer.py \ --ntime 4096 \ --compute \ From bd8d2409dd3ca5c411d57edff3c35e96a3746a8d Mon Sep 17 00:00:00 2001 From: Addison Date: Sat, 18 Apr 2026 13:37:14 -0500 Subject: [PATCH 26/27] add .codex to gitignore --- .codex | 0 .gitignore | 2 + loopy/kernel/dependency.py | 222 +++++++++++++++++++++++++++++++++++++ 3 files changed, 224 insertions(+) delete mode 100644 .codex create mode 100644 loopy/kernel/dependency.py diff --git a/.codex b/.codex deleted file mode 100644 index e69de29bb..000000000 diff --git a/.gitignore b/.gitignore index 4378c7122..f03dc4d4c 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ virtualenv-[0-9]*[0-9] # Files used by run-pylint.sh .pylintrc.yml .run-pylint.py + +.codex diff --git a/loopy/kernel/dependency.py b/loopy/kernel/dependency.py new file mode 100644 index 000000000..5512c62a6 --- /dev/null +++ b/loopy/kernel/dependency.py @@ -0,0 +1,222 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2025 Addison Alvey-Blanco" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Mapping + +from constantdict import constantdict + +import islpy as isl +from islpy import dim_type + +from loopy import HappensAfter, LoopKernel, for_each_kernel +from loopy.kernel.instruction import ( + InstructionBase, + VariableSpecificHappensAfter, +) +from loopy.transform.dependency import AccessMapFinder + + +def _add_lexicographic_happens_after_inner(knl, after_insn, before_insn): + domain_before = knl.get_inames_domain(before_insn.within_inames) + domain_after = knl.get_inames_domain(after_insn.within_inames) + + happens_after = isl.Map.from_domain_and_range(domain_after, + domain_before) + for idim in range(happens_after.dim(dim_type.out)): + happens_after = happens_after.set_dim_name( + dim_type.out, + idim, + happens_after.get_dim_name(dim_type.out, idim) + "'" + ) + + shared_inames = before_insn.within_inames & after_insn.within_inames + + shared_inames_order_before = [ + domain_before.get_dim_name(dim_type.out, idim) + for idim in range(domain_before.dim(dim_type.out)) + if domain_before.get_dim_name(dim_type.out, idim) + in shared_inames + ] + + shared_inames_order_after = [ + domain_after.get_dim_name(dim_type.out, idim) + for idim in range(domain_after.dim(dim_type.out)) + if domain_after.get_dim_name(dim_type.out, idim) + in shared_inames + ] + + assert shared_inames_order_after == shared_inames_order_before + shared_inames_order = list(shared_inames_order_after) + + affs_in = isl.affs_from_space(happens_after.domain().space) + affs_out = isl.affs_from_space(happens_after.range().space) + + lex_map = isl.Map.empty(happens_after.space) + for iinnermost, innermost_iname in enumerate(shared_inames_order): + innermost_map = affs_in[innermost_iname].gt_map( + affs_out[innermost_iname + "'"] + ) + + for outer_iname in shared_inames_order[:iinnermost]: + innermost_map = innermost_map & ( + affs_in[outer_iname].eq_map( + affs_out[outer_iname + "'"] + ) + ) + + if before_insn != after_insn: + innermost_map = innermost_map | ( + affs_in[shared_inames_order[iinnermost]].eq_map( + affs_out[shared_inames_order[iinnermost] + "'"] + ) + ) + + lex_map = lex_map | innermost_map + + return happens_after & lex_map + + +@for_each_kernel +def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel: + """ + Impose a sequential, top-down execution order to instructions in a program. + It is expected that this strict order will be relaxed with + :func:`reduce_strict_ordering_with_dependencies` using data dependencies. + """ + + rmap = knl.reader_map() + wmap_r: dict[str, set[str]] = {} + for var, insns in knl.writer_map().items(): + for insn in insns: + wmap_r.setdefault(insn, set()) + wmap_r[insn].add(var) + + new_insns = [] + for iafter, after_insn in enumerate(knl.instructions): + assert after_insn.id is not None + + new_happens_after = {} + + # check for self dependencies + for var in wmap_r[after_insn.id]: + if rmap.get(var) and after_insn.id in rmap[var]: + self_happens_after = _add_lexicographic_happens_after_inner( + knl, after_insn, after_insn + ) + new_happens_after[after_insn.id] = HappensAfter( + self_happens_after + ) + + if iafter != 0: + before_insn = knl.instructions[iafter - 1] + happens_after = _add_lexicographic_happens_after_inner( + knl, after_insn, before_insn + ) + new_happens_after[before_insn.id] = HappensAfter(happens_after) + + new_insns.append(after_insn.copy(happens_after=new_happens_after)) + + return knl.copy(instructions=new_insns) + + +@for_each_kernel +def reduce_strict_ordering(knl) -> LoopKernel: + def narrow_dependencies( + after: InstructionBase, + before: InstructionBase, + remaining_instances: isl.Set, # type: ignore + happens_afters: Mapping[str, VariableSpecificHappensAfter] = {}, + happens_after_map: isl.Map | None = None, # type: ignore + ) -> Mapping[str, VariableSpecificHappensAfter]: + # FIXME: can we get rid of all the "assert x is not None" stuff? + + assert isinstance(after.id, str) + assert isinstance(before.id, str) + + if remaining_instances.is_empty(): + return happens_afters + + for insn, happens_after in before.happens_after.items(): + if happens_after_map is None: + happens_after_map = happens_after.instances_rel + else: + assert happens_after.instances_rel is not None + happens_after_map = happens_after_map.apply_range( + happens_after.instances_rel) + + source_vars = access_mapper.get_accessed_variables(after.id) + common_vars = wmap_r[insn] & source_vars # type: ignore + for var in common_vars: + write_map = access_mapper.get_map(insn, var) + source_map = access_mapper.get_map(after.id, var) + + assert write_map is not None + assert source_map is not None + + source_to_writer = source_map.apply_range(write_map.reverse()) + dependency_map = source_to_writer & happens_after_map + remaining_instances = remaining_instances - dependency_map.domain() + if dependency_map is not None and not dependency_map.is_empty(): + happens_after_obj = VariableSpecificHappensAfter( + dependency_map, var + ) + + happens_afters = constantdict( + dict(happens_afters) | {insn: happens_after_obj}) + + if insn != after.id: + happens_afters = constantdict( + dict(happens_afters) | dict(narrow_dependencies( + after, + knl.id_to_insn[insn], + remaining_instances, + happens_afters, + happens_after_map, + )) + ) + + return happens_afters + + access_mapper = AccessMapFinder(knl) + for insn in knl.instructions: + access_mapper(insn.expression, insn.id) + access_mapper(insn.assignee, insn.id) + + wmap_r: dict[str, set[str]] = {} + for var, insns in knl.writer_map().items(): + for insn in insns: + wmap_r.setdefault(insn, set()) + wmap_r[insn].add(var) + + new_insns = [] + for insn in knl.instructions[::-1]: + new_insns.append( + insn.copy(happens_after=narrow_dependencies( + after=insn, + before=insn, + remaining_instances=knl.get_inames_domain(insn.within_inames))) + ) + + return knl.copy(instructions=new_insns[::-1]) From 70c15c88f63019dbf13e3688db1ea2f28dca597b Mon Sep 17 00:00:00 2001 From: Addison Date: Sat, 18 Apr 2026 13:37:48 -0500 Subject: [PATCH 27/27] remove dependency.py --- loopy/kernel/dependency.py | 222 ------------------------------------- 1 file changed, 222 deletions(-) delete mode 100644 loopy/kernel/dependency.py diff --git a/loopy/kernel/dependency.py b/loopy/kernel/dependency.py deleted file mode 100644 index 5512c62a6..000000000 --- a/loopy/kernel/dependency.py +++ /dev/null @@ -1,222 +0,0 @@ -from __future__ import annotations - - -__copyright__ = "Copyright (C) 2025 Addison Alvey-Blanco" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from typing import Mapping - -from constantdict import constantdict - -import islpy as isl -from islpy import dim_type - -from loopy import HappensAfter, LoopKernel, for_each_kernel -from loopy.kernel.instruction import ( - InstructionBase, - VariableSpecificHappensAfter, -) -from loopy.transform.dependency import AccessMapFinder - - -def _add_lexicographic_happens_after_inner(knl, after_insn, before_insn): - domain_before = knl.get_inames_domain(before_insn.within_inames) - domain_after = knl.get_inames_domain(after_insn.within_inames) - - happens_after = isl.Map.from_domain_and_range(domain_after, - domain_before) - for idim in range(happens_after.dim(dim_type.out)): - happens_after = happens_after.set_dim_name( - dim_type.out, - idim, - happens_after.get_dim_name(dim_type.out, idim) + "'" - ) - - shared_inames = before_insn.within_inames & after_insn.within_inames - - shared_inames_order_before = [ - domain_before.get_dim_name(dim_type.out, idim) - for idim in range(domain_before.dim(dim_type.out)) - if domain_before.get_dim_name(dim_type.out, idim) - in shared_inames - ] - - shared_inames_order_after = [ - domain_after.get_dim_name(dim_type.out, idim) - for idim in range(domain_after.dim(dim_type.out)) - if domain_after.get_dim_name(dim_type.out, idim) - in shared_inames - ] - - assert shared_inames_order_after == shared_inames_order_before - shared_inames_order = list(shared_inames_order_after) - - affs_in = isl.affs_from_space(happens_after.domain().space) - affs_out = isl.affs_from_space(happens_after.range().space) - - lex_map = isl.Map.empty(happens_after.space) - for iinnermost, innermost_iname in enumerate(shared_inames_order): - innermost_map = affs_in[innermost_iname].gt_map( - affs_out[innermost_iname + "'"] - ) - - for outer_iname in shared_inames_order[:iinnermost]: - innermost_map = innermost_map & ( - affs_in[outer_iname].eq_map( - affs_out[outer_iname + "'"] - ) - ) - - if before_insn != after_insn: - innermost_map = innermost_map | ( - affs_in[shared_inames_order[iinnermost]].eq_map( - affs_out[shared_inames_order[iinnermost] + "'"] - ) - ) - - lex_map = lex_map | innermost_map - - return happens_after & lex_map - - -@for_each_kernel -def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel: - """ - Impose a sequential, top-down execution order to instructions in a program. - It is expected that this strict order will be relaxed with - :func:`reduce_strict_ordering_with_dependencies` using data dependencies. - """ - - rmap = knl.reader_map() - wmap_r: dict[str, set[str]] = {} - for var, insns in knl.writer_map().items(): - for insn in insns: - wmap_r.setdefault(insn, set()) - wmap_r[insn].add(var) - - new_insns = [] - for iafter, after_insn in enumerate(knl.instructions): - assert after_insn.id is not None - - new_happens_after = {} - - # check for self dependencies - for var in wmap_r[after_insn.id]: - if rmap.get(var) and after_insn.id in rmap[var]: - self_happens_after = _add_lexicographic_happens_after_inner( - knl, after_insn, after_insn - ) - new_happens_after[after_insn.id] = HappensAfter( - self_happens_after - ) - - if iafter != 0: - before_insn = knl.instructions[iafter - 1] - happens_after = _add_lexicographic_happens_after_inner( - knl, after_insn, before_insn - ) - new_happens_after[before_insn.id] = HappensAfter(happens_after) - - new_insns.append(after_insn.copy(happens_after=new_happens_after)) - - return knl.copy(instructions=new_insns) - - -@for_each_kernel -def reduce_strict_ordering(knl) -> LoopKernel: - def narrow_dependencies( - after: InstructionBase, - before: InstructionBase, - remaining_instances: isl.Set, # type: ignore - happens_afters: Mapping[str, VariableSpecificHappensAfter] = {}, - happens_after_map: isl.Map | None = None, # type: ignore - ) -> Mapping[str, VariableSpecificHappensAfter]: - # FIXME: can we get rid of all the "assert x is not None" stuff? - - assert isinstance(after.id, str) - assert isinstance(before.id, str) - - if remaining_instances.is_empty(): - return happens_afters - - for insn, happens_after in before.happens_after.items(): - if happens_after_map is None: - happens_after_map = happens_after.instances_rel - else: - assert happens_after.instances_rel is not None - happens_after_map = happens_after_map.apply_range( - happens_after.instances_rel) - - source_vars = access_mapper.get_accessed_variables(after.id) - common_vars = wmap_r[insn] & source_vars # type: ignore - for var in common_vars: - write_map = access_mapper.get_map(insn, var) - source_map = access_mapper.get_map(after.id, var) - - assert write_map is not None - assert source_map is not None - - source_to_writer = source_map.apply_range(write_map.reverse()) - dependency_map = source_to_writer & happens_after_map - remaining_instances = remaining_instances - dependency_map.domain() - if dependency_map is not None and not dependency_map.is_empty(): - happens_after_obj = VariableSpecificHappensAfter( - dependency_map, var - ) - - happens_afters = constantdict( - dict(happens_afters) | {insn: happens_after_obj}) - - if insn != after.id: - happens_afters = constantdict( - dict(happens_afters) | dict(narrow_dependencies( - after, - knl.id_to_insn[insn], - remaining_instances, - happens_afters, - happens_after_map, - )) - ) - - return happens_afters - - access_mapper = AccessMapFinder(knl) - for insn in knl.instructions: - access_mapper(insn.expression, insn.id) - access_mapper(insn.assignee, insn.id) - - wmap_r: dict[str, set[str]] = {} - for var, insns in knl.writer_map().items(): - for insn in insns: - wmap_r.setdefault(insn, set()) - wmap_r[insn].add(var) - - new_insns = [] - for insn in knl.instructions[::-1]: - new_insns.append( - insn.copy(happens_after=narrow_dependencies( - after=insn, - before=insn, - remaining_instances=knl.get_inames_domain(insn.within_inames))) - ) - - return knl.copy(instructions=new_insns[::-1])