@@ -440,64 +440,98 @@ def preprocess_multipole_nexprs(self, tgt_expansion, src_expansion):
440440 def preprocess_multipole_loopy_knl (self , tgt_expansion , src_expansion ,
441441 result_dtype ):
442442
443- circulant_matrix_mis , _ , _ = \
443+ circulant_matrix_mis , _ , max_mi = \
444444 self ._translation_classes_dependent_data_mis (tgt_expansion ,
445445 src_expansion )
446- circulant_matrix_ident_to_index = {
447- ident : i for i , ident in enumerate (circulant_matrix_mis )}
448446
449447 ncoeff_src = len (src_expansion .get_coefficient_identifiers ())
450448 ncoeff_preprocessed = self .preprocess_multipole_nexprs (tgt_expansion ,
451449 src_expansion )
450+ order = src_expansion .order
452451
453452 output_coeffs = pymbolic .var ("output_coeffs" )
454453 input_coeffs = pymbolic .var ("input_coeffs" )
455- srcidx_sym = pymbolic .var ("srcidx" )
456454 output_icoeff = pymbolic .var ("output_icoeff" )
457455 input_icoeff = pymbolic .var ("input_icoeff" )
456+ input_coeffs_copy = pymbolic .var ("input_coeffs_copy" )
457+
458+ dim = tgt_expansion .dim
459+ v = [pymbolic .var (f"x{ i } " ) for i in range (dim )]
460+
461+ wrangler = src_expansion .expansion_terms_wrangler
462+ _ , axis_permutation = wrangler ._get_mi_ordering_key_and_axis_permutation ()
463+ slowest_idx = axis_permutation [0 ]
464+ # max_mi[slowest_idx] = 2*(c - 1)
465+ c = max_mi [slowest_idx ] // 2 + 1
466+ noutput_coeffs = c * (2 * order + 1 ) ** (dim - 1 )
458467
459468 domains = [
460469 "{[output_icoeff]: 0<=output_icoeff<noutput_coeffs}" ,
470+ "{[input_icoeff]: 0<=input_icoeff<ninput_coeffs}" ,
461471 ]
472+
462473 insns = [
463474 lp .Assignment (
464- assignee = input_icoeff ,
465- expression = srcidx_sym [output_icoeff ],
466- id = "input_icoeff" ,
475+ assignee = input_coeffs_copy [input_icoeff ],
476+ expression = input_coeffs [input_icoeff ],
477+ id = "input_copy" ,
478+ temp_var_type = lp .Optional (None ),
467479 ),
480+ ]
481+
482+ idx = output_icoeff
483+ for i in range (dim - 1 , - 1 , - 1 ):
484+ new_idx = idx % (max_mi [i ] + 1 ) if i > 0 else idx
485+ insns .append (lp .Assignment (
486+ assignee = v [i ],
487+ expression = new_idx ,
488+ id = f"set_x{ i } " ,
489+ temp_var_type = lp .Optional (None ),
490+ ))
491+ idx = idx // (max_mi [i ] + 1 )
492+
493+ input_idx = wrangler .get_storage_index (v )
494+ output_idx = 0
495+ mult = 1
496+ for i in range (dim - 1 , - 1 , - 1 ):
497+ output_idx += mult * v [i ]
498+ mult *= (max_mi [i ] + 1 )
499+
500+ insns += [
468501 lp .Assignment (
469502 assignee = output_coeffs [output_icoeff ],
470- expression = pymbolic .primitives .If (
471- pymbolic .primitives .Comparison (input_icoeff , ">=" , 0 ),
472- input_coeffs [input_icoeff ],
473- 0 ,
474- ),
475- depends_on = frozenset (["input_icoeff" ]),
503+ expression = input_coeffs_copy [input_idx ],
504+ predicates = frozenset ([
505+ pymbolic .primitives .Comparison (sum (v ), "<=" , order ),
506+ pymbolic .primitives .Comparison (v [slowest_idx ], "<" , c ),
507+ ]),
508+ depends_on = frozenset ([f"set_x{ i } " for i in range (dim )]
509+ + ["input_copy" ]),
476510 )
477511 ]
478512
479- srcidx = np .full (ncoeff_preprocessed , - 1 , dtype = np .int32 )
480- for icoeff_src , term in enumerate (
481- src_expansion .get_coefficient_identifiers ()):
482- new_icoeff_src = circulant_matrix_ident_to_index [term ]
483- srcidx [new_icoeff_src ] = icoeff_src
484-
485- return lp .make_function (domains , insns ,
513+ knl = lp .make_function (domains , insns ,
486514 kernel_data = [
487515 lp .ValueArg ("src_rscale" , None ),
488516 lp .GlobalArg ("output_coeffs" , None , shape = ncoeff_preprocessed ,
489517 is_input = False , is_output = True ),
490518 lp .GlobalArg ("input_coeffs" , None , shape = ncoeff_src ),
491- lp .TemporaryVariable (input_icoeff .name , dtype = np .int32 ),
492- lp .TemporaryVariable (
493- srcidx_sym .name , initializer = srcidx ,
494- address_space = lp .AddressSpace .GLOBAL , read_only = True ),
495519 ...],
496520 name = "m2l_preprocess_inner" ,
497521 lang_version = lp .MOST_RECENT_LANGUAGE_VERSION ,
498- fixed_parameters = {"noutput_coeffs" : ncoeff_preprocessed },
522+ fixed_parameters = {"noutput_coeffs" : noutput_coeffs ,
523+ "ninput_coeffs" : ncoeff_src },
499524 )
500525
526+ optimizations = [
527+ lambda knl : lp .split_iname (knl , "m2l__input_icoeff" ,
528+ 32 , inner_tag = "l.0" ),
529+ lambda knl : lp .split_iname (knl , "m2l__output_icoeff" ,
530+ 32 , inner_tag = "l.0" ),
531+ ]
532+
533+ return (knl , optimizations )
534+
501535 def postprocess_local_exprs (self , tgt_expansion , src_expansion , m2l_result ,
502536 src_rscale , tgt_rscale , sac ):
503537 circulant_matrix_mis , _ , _ = \
@@ -607,7 +641,12 @@ def result_func(x):
607641 "{[output_icoeff]: 0<=output_icoeff<ncoeff_tgt}"
608642 ]
609643
610- return lp .make_function (domains , insns ,
644+ optimizations = [
645+ lambda knl : lp .split_iname (knl , "m2l__output_icoeff" ,
646+ 32 , inner_tag = "l.0" )
647+ ]
648+
649+ return (lp .make_function (domains , insns ,
611650 kernel_data = [
612651 lp .ValueArg ("src_rscale" , None ),
613652 lp .ValueArg ("tgt_rscale" , None ),
@@ -630,7 +669,7 @@ def result_func(x):
630669 name = "m2l_postprocess_inner" ,
631670 lang_version = lp .MOST_RECENT_LANGUAGE_VERSION ,
632671 fixed_parameters = fixed_parameters ,
633- )
672+ ), optimizations )
634673
635674# }}} VolumeTaylorM2LTranslation
636675
0 commit comments