Skip to content

Commit 82d6d43

Browse files
authored
Optimize preprocess multipole and postprocess local (#156)
1 parent e7479ff commit 82d6d43

4 files changed

Lines changed: 220 additions & 43 deletions

File tree

sumpy/e2e.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,20 @@ class M2LPreprocessMultipole(E2EBase):
668668
def default_name(self):
669669
return "m2l_preprocess_multipole"
670670

671+
@memoize_method
672+
def get_inner_knl_and_optimizations(self, result_dtype):
673+
m2l_translation = self.tgt_expansion.m2l_translation
674+
return m2l_translation.preprocess_multipole_loopy_knl(
675+
self.tgt_expansion, self.src_expansion, result_dtype)
676+
671677
def get_kernel(self, result_dtype):
672678
m2l_translation = self.tgt_expansion.m2l_translation
673679
nsrc_coeffs = len(self.src_expansion)
674680
npreprocessed_src_coeffs = \
675681
m2l_translation.preprocess_multipole_nexprs(self.tgt_expansion,
676682
self.src_expansion)
677-
single_box_preprocess_knl = m2l_translation.preprocess_multipole_loopy_knl(
678-
self.tgt_expansion, self.src_expansion, result_dtype)
683+
single_box_preprocess_knl, _ = self.get_inner_knl_and_optimizations(
684+
result_dtype)
679685

680686
from sumpy.tools import gather_loopy_arguments
681687
loopy_knl = lp.make_kernel(
@@ -721,11 +727,11 @@ def get_kernel(self, result_dtype):
721727
return loopy_knl
722728

723729
def get_optimized_kernel(self, result_dtype):
724-
# FIXME
725730
knl = self.get_kernel(result_dtype)
726-
knl = lp.split_iname(knl, "isrc_box", 64, outer_tag="g.0",
727-
within=f"in_kernel:{self.name}")
728-
knl = lp.add_inames_for_unused_hw_axes(knl)
731+
knl = lp.tag_inames(knl, "isrc_box:g.0")
732+
_, optimizations = self.get_inner_knl_and_optimizations(result_dtype)
733+
for optimization in optimizations:
734+
knl = optimization(knl)
729735
return knl
730736

731737
def __call__(self, queue, **kwargs):
@@ -752,15 +758,21 @@ class M2LPostprocessLocal(E2EBase):
752758
def default_name(self):
753759
return "m2l_postprocess_local"
754760

761+
@memoize_method
762+
def get_inner_knl_and_optimizations(self, result_dtype):
763+
m2l_translation = self.tgt_expansion.m2l_translation
764+
return m2l_translation.postprocess_local_loopy_knl(
765+
self.tgt_expansion, self.src_expansion, result_dtype)
766+
755767
def get_kernel(self, result_dtype):
756768
m2l_translation = self.tgt_expansion.m2l_translation
757769
ntgt_coeffs = len(self.tgt_expansion)
758770
ntgt_coeffs_before_postprocessing = \
759771
m2l_translation.postprocess_local_nexprs(self.tgt_expansion,
760772
self.src_expansion)
761773

762-
single_box_postprocess_knl = m2l_translation.postprocess_local_loopy_knl(
763-
self.tgt_expansion, self.src_expansion, result_dtype)
774+
single_box_postprocess_knl, _ = self.get_inner_knl_and_optimizations(
775+
result_dtype)
764776

765777
from sumpy.tools import gather_loopy_arguments
766778
loopy_knl = lp.make_kernel(
@@ -813,9 +825,12 @@ def get_kernel(self, result_dtype):
813825
return loopy_knl
814826

815827
def get_optimized_kernel(self, result_dtype):
816-
# FIXME
817828
knl = self.get_kernel(result_dtype)
818-
knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0")
829+
knl = lp.tag_inames(knl, "itgt_box:g.0")
830+
_, optimizations = self.get_inner_knl_and_optimizations(result_dtype)
831+
for optimization in optimizations:
832+
knl = optimization(knl)
833+
knl = lp.add_inames_for_unused_hw_axes(knl)
819834
return knl
820835

821836
def __call__(self, queue, **kwargs):

sumpy/expansion/__init__.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import sumpy.symbolic as sym
3131
from sumpy.kernel import Kernel
3232
from sumpy.tools import add_mi
33+
import pymbolic.primitives as prim
3334

3435
import logging
3536
logger = logging.getLogger(__name__)
@@ -377,6 +378,18 @@ def _split_coeffs_into_hyperplanes(
377378

378379

379380
class FullExpansionTermsWrangler(ExpansionTermsWrangler):
381+
382+
def get_storage_index(self, mi, order=None):
383+
if not order:
384+
order = sum(mi)
385+
if self.dim == 3:
386+
return (order*(order + 1)*(order + 2))//6 + \
387+
(order + 2)*mi[2] - (mi[2]*(mi[2] + 1))//2 + mi[1]
388+
elif self.dim == 2:
389+
return (order*(order + 1))//2 + mi[1]
390+
else:
391+
raise NotImplementedError
392+
380393
def get_coefficient_identifiers(self):
381394
return super().get_full_coefficient_identifiers()
382395

@@ -389,6 +402,26 @@ def get_stored_mpole_coefficients_from_full(self,
389402
return self.get_full_kernel_derivatives_from_stored(
390403
full_mpole_coefficients, rscale, sac=sac)
391404

405+
@memoize_method
406+
def _get_mi_ordering_key_and_axis_permutation(self):
407+
"""
408+
Returns a degree lexicographic order as a callable that can be used as a
409+
``sort`` key on multi-indices and a permutation of the axis ordered
410+
from the slowest varying axis to the fastest varying axis of the
411+
multi-indices when sorted.
412+
"""
413+
from sumpy.expansion.diff_op import DerivativeIdentifier
414+
415+
axis_permutation = list(reversed(list(range(self.dim))))
416+
417+
def mi_key(ident):
418+
if isinstance(ident, DerivativeIdentifier):
419+
mi = ident.mi
420+
else:
421+
mi = ident
422+
return tuple([sum(mi)] + list(reversed(mi)))
423+
424+
return mi_key, axis_permutation
392425
# }}}
393426

394427

@@ -520,7 +553,7 @@ def stored_identifiers(self):
520553
# the axes so that the axis with the on-axis coefficient comes first in the
521554
# multi-index tuple.
522555
@memoize_method
523-
def _get_mi_ordering_key(self):
556+
def _get_mi_ordering_key_and_axis_permutation(self):
524557
"""
525558
A degree lexicographic order with the slowest varying index depending on
526559
the PDE is used, returned as a callable that can be used as a
@@ -529,6 +562,9 @@ def _get_mi_ordering_key(self):
529562
multipole-to-multipole translation to get lower error bounds.
530563
The slowest varying index is chosen such that the multipole-to-local
531564
translation cost is optimized.
565+
566+
Also returns a permutation of the axis ordered from the slowest varying
567+
axis to the fastest varying axis of the multi-indices when sorted.
532568
"""
533569
dim = self.dim
534570
deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs
@@ -554,7 +590,7 @@ def mi_key(ident):
554590
key.append(mi[axis_permutation[i]])
555591
return tuple(key)
556592

557-
return mi_key
593+
return mi_key, axis_permutation
558594

559595
def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]:
560596
mis = self.get_full_coefficient_identifiers()
@@ -570,8 +606,8 @@ def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]:
570606
else:
571607
# Calculate the multi-index that appears last in in the PDE in
572608
# the degree lexicographic order given by
573-
# _get_mi_ordering_key.
574-
ordering_key = self._get_mi_ordering_key()
609+
# _get_mi_ordering_key_and_axis_permutation.
610+
ordering_key, _ = self._get_mi_ordering_key_and_axis_permutation()
575611
max_mi = max(deriv_id_to_coeff, key=ordering_key).mi
576612
hyperplanes = [(d, const)
577613
for d in range(self.dim)
@@ -581,9 +617,54 @@ def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]:
581617

582618
def get_full_coefficient_identifiers(self):
583619
identifiers = super().get_full_coefficient_identifiers()
584-
key = self._get_mi_ordering_key()
620+
key, _ = self._get_mi_ordering_key_and_axis_permutation()
585621
return sorted(identifiers, key=key)
586622

623+
def get_storage_index(self, mi, order=None):
624+
if not order:
625+
order = sum(mi)
626+
627+
ordering_key, axis_permutation = \
628+
self._get_mi_ordering_key_and_axis_permutation()
629+
deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs
630+
max_mi = max(deriv_id_to_coeff, key=ordering_key).mi
631+
632+
if all(m != 0 for m in max_mi):
633+
raise NotImplementedError("non-elliptic PDEs")
634+
635+
c = max_mi[axis_permutation[0]]
636+
637+
mi = list(mi)
638+
mi[axis_permutation[0]], mi[0] = mi[0], mi[axis_permutation[0]]
639+
640+
if self.dim == 3:
641+
if all(isinstance(axis, int) for axis in mi):
642+
if order < c - 1:
643+
return (order*(order + 1)*(order + 2))//6 + \
644+
(order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1]
645+
else:
646+
return (c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
647+
+ mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
648+
else:
649+
return prim.If(prim.Comparison(order, "<", c - 1),
650+
(order*(order + 1)*(order + 2))//6
651+
+ (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1],
652+
(c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
653+
+ mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
654+
)
655+
elif self.dim == 2:
656+
if all(isinstance(axis, int) for axis in mi):
657+
if order < c - 1:
658+
return (order*(order + 1))//2 + mi[0]
659+
else:
660+
return (c*(c-1))//2 + c*(order - c + 1) + mi[0]
661+
else:
662+
return prim.If(prim.Comparison(order, "<", c - 1),
663+
(order*(order + 1))//2 + mi[0],
664+
(c*(c-1))//2 + c*(order - c + 1) + mi[0])
665+
else:
666+
raise NotImplementedError
667+
587668
@memoize_method
588669
def get_stored_ids_and_unscaled_projection_matrix(self):
589670
from pytools import ProcessLogger
@@ -608,7 +689,7 @@ def get_stored_ids_and_unscaled_projection_matrix(self):
608689
from_output_coeffs_by_row, shape)
609690
return mis, op
610691

611-
ordering_key = self._get_mi_ordering_key()
692+
ordering_key, _ = self._get_mi_ordering_key_and_axis_permutation()
612693
max_mi = max((ident for ident in mi_to_coeff.keys()), key=ordering_key)
613694
max_mi_coeff = mi_to_coeff[max_mi]
614695
max_mi_mult = -1/sym.sympify(max_mi_coeff)

sumpy/expansion/m2l.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -440,64 +440,98 @@ def preprocess_multipole_nexprs(self, tgt_expansion, src_expansion):
440440
def preprocess_multipole_loopy_knl(self, tgt_expansion, src_expansion,
441441
result_dtype):
442442

443-
circulant_matrix_mis, _, _ = \
443+
circulant_matrix_mis, _, max_mi = \
444444
self._translation_classes_dependent_data_mis(tgt_expansion,
445445
src_expansion)
446-
circulant_matrix_ident_to_index = {
447-
ident: i for i, ident in enumerate(circulant_matrix_mis)}
448446

449447
ncoeff_src = len(src_expansion.get_coefficient_identifiers())
450448
ncoeff_preprocessed = self.preprocess_multipole_nexprs(tgt_expansion,
451449
src_expansion)
450+
order = src_expansion.order
452451

453452
output_coeffs = pymbolic.var("output_coeffs")
454453
input_coeffs = pymbolic.var("input_coeffs")
455-
srcidx_sym = pymbolic.var("srcidx")
456454
output_icoeff = pymbolic.var("output_icoeff")
457455
input_icoeff = pymbolic.var("input_icoeff")
456+
input_coeffs_copy = pymbolic.var("input_coeffs_copy")
457+
458+
dim = tgt_expansion.dim
459+
v = [pymbolic.var(f"x{i}") for i in range(dim)]
460+
461+
wrangler = src_expansion.expansion_terms_wrangler
462+
_, axis_permutation = wrangler._get_mi_ordering_key_and_axis_permutation()
463+
slowest_idx = axis_permutation[0]
464+
# max_mi[slowest_idx] = 2*(c - 1)
465+
c = max_mi[slowest_idx] // 2 + 1
466+
noutput_coeffs = c * (2*order + 1) ** (dim - 1)
458467

459468
domains = [
460469
"{[output_icoeff]: 0<=output_icoeff<noutput_coeffs}",
470+
"{[input_icoeff]: 0<=input_icoeff<ninput_coeffs}",
461471
]
472+
462473
insns = [
463474
lp.Assignment(
464-
assignee=input_icoeff,
465-
expression=srcidx_sym[output_icoeff],
466-
id="input_icoeff",
475+
assignee=input_coeffs_copy[input_icoeff],
476+
expression=input_coeffs[input_icoeff],
477+
id="input_copy",
478+
temp_var_type=lp.Optional(None),
467479
),
480+
]
481+
482+
idx = output_icoeff
483+
for i in range(dim - 1, -1, -1):
484+
new_idx = idx % (max_mi[i] + 1) if i > 0 else idx
485+
insns.append(lp.Assignment(
486+
assignee=v[i],
487+
expression=new_idx,
488+
id=f"set_x{i}",
489+
temp_var_type=lp.Optional(None),
490+
))
491+
idx = idx // (max_mi[i] + 1)
492+
493+
input_idx = wrangler.get_storage_index(v)
494+
output_idx = 0
495+
mult = 1
496+
for i in range(dim - 1, -1, -1):
497+
output_idx += mult*v[i]
498+
mult *= (max_mi[i] + 1)
499+
500+
insns += [
468501
lp.Assignment(
469502
assignee=output_coeffs[output_icoeff],
470-
expression=pymbolic.primitives.If(
471-
pymbolic.primitives.Comparison(input_icoeff, ">=", 0),
472-
input_coeffs[input_icoeff],
473-
0,
474-
),
475-
depends_on=frozenset(["input_icoeff"]),
503+
expression=input_coeffs_copy[input_idx],
504+
predicates=frozenset([
505+
pymbolic.primitives.Comparison(sum(v), "<=", order),
506+
pymbolic.primitives.Comparison(v[slowest_idx], "<", c),
507+
]),
508+
depends_on=frozenset([f"set_x{i}" for i in range(dim)]
509+
+ ["input_copy"]),
476510
)
477511
]
478512

479-
srcidx = np.full(ncoeff_preprocessed, -1, dtype=np.int32)
480-
for icoeff_src, term in enumerate(
481-
src_expansion.get_coefficient_identifiers()):
482-
new_icoeff_src = circulant_matrix_ident_to_index[term]
483-
srcidx[new_icoeff_src] = icoeff_src
484-
485-
return lp.make_function(domains, insns,
513+
knl = lp.make_function(domains, insns,
486514
kernel_data=[
487515
lp.ValueArg("src_rscale", None),
488516
lp.GlobalArg("output_coeffs", None, shape=ncoeff_preprocessed,
489517
is_input=False, is_output=True),
490518
lp.GlobalArg("input_coeffs", None, shape=ncoeff_src),
491-
lp.TemporaryVariable(input_icoeff.name, dtype=np.int32),
492-
lp.TemporaryVariable(
493-
srcidx_sym.name, initializer=srcidx,
494-
address_space=lp.AddressSpace.GLOBAL, read_only=True),
495519
...],
496520
name="m2l_preprocess_inner",
497521
lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
498-
fixed_parameters={"noutput_coeffs": ncoeff_preprocessed},
522+
fixed_parameters={"noutput_coeffs": noutput_coeffs,
523+
"ninput_coeffs": ncoeff_src},
499524
)
500525

526+
optimizations = [
527+
lambda knl: lp.split_iname(knl, "m2l__input_icoeff",
528+
32, inner_tag="l.0"),
529+
lambda knl: lp.split_iname(knl, "m2l__output_icoeff",
530+
32, inner_tag="l.0"),
531+
]
532+
533+
return (knl, optimizations)
534+
501535
def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result,
502536
src_rscale, tgt_rscale, sac):
503537
circulant_matrix_mis, _, _ = \
@@ -607,7 +641,12 @@ def result_func(x):
607641
"{[output_icoeff]: 0<=output_icoeff<ncoeff_tgt}"
608642
]
609643

610-
return lp.make_function(domains, insns,
644+
optimizations = [
645+
lambda knl: lp.split_iname(knl, "m2l__output_icoeff",
646+
32, inner_tag="l.0")
647+
]
648+
649+
return (lp.make_function(domains, insns,
611650
kernel_data=[
612651
lp.ValueArg("src_rscale", None),
613652
lp.ValueArg("tgt_rscale", None),
@@ -630,7 +669,7 @@ def result_func(x):
630669
name="m2l_postprocess_inner",
631670
lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
632671
fixed_parameters=fixed_parameters,
633-
)
672+
), optimizations)
634673

635674
# }}} VolumeTaylorM2LTranslation
636675

0 commit comments

Comments
 (0)