@@ -86,6 +86,9 @@ def __init__(self, ctx, target_kernels, exclude_self, strength_usage=None,
8686 source_kernels = source_kernels , strength_usage = strength_usage ,
8787 value_dtypes = value_dtypes , name = name , device = device )
8888
89+ import pyopencl as cl
90+ self .is_gpu = not (self .device .type & cl .device_type .CPU )
91+
8992 self .exclude_self = exclude_self
9093
9194 self .dim = single_valued (knl .dim for knl in
@@ -444,7 +447,7 @@ def default_name(self):
444447 return "p2p_from_csr"
445448
446449 def get_kernel (self , max_nsources_in_one_box , max_ntargets_in_one_box ,
447- gpu = False , nsplit = 32 ):
450+ work_items_per_group = 32 ):
448451 loopy_insns , result_names = self .get_loopy_insns_and_result_names ()
449452 arguments = self .get_default_src_tgt_arguments () \
450453 + [
@@ -473,13 +476,11 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
473476 "{[iknl]: 0 <= iknl < noutputs}" ,
474477 "{[isrc_box]: isrc_box_start <= isrc_box < isrc_box_end}" ,
475478 "{[idim]: 0 <= idim < dim}" ,
476- "{[isrc]: isrc_start <= isrc < isrc_end}"
477479 ]
478480
479- src_outer_limit = (max_nsources_in_one_box - 1 ) // nsplit
480- tgt_outer_limit = (max_ntargets_in_one_box - 1 ) // nsplit
481+ tgt_outer_limit = (max_ntargets_in_one_box - 1 ) // work_items_per_group
481482
482- if gpu :
483+ if self . is_gpu :
483484 arguments += [
484485 lp .TemporaryVariable ("local_isrc" ,
485486 shape = (self .dim , max_nsources_in_one_box )),
@@ -488,79 +489,90 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
488489 ]
489490 domains += [
490491 "{[istrength]: 0 <= istrength < nstrengths}" ,
491- "{[inner]: 0 <= inner < nsplit }" ,
492+ "{[inner]: 0 <= inner < work_items_per_group }" ,
492493 "{[itgt_offset_outer]: 0 <= itgt_offset_outer <= tgt_outer_limit}" ,
493- "{[isrc_offset_outer]: 0 <= isrc_offset_outer <= src_outer_limit}" ,
494+ "{[isrc_prefetch]: 0 <= isrc_prefetch < max_nsources_in_one_box}" ,
495+ "{[isrc_offset]: 0 <= isrc_offset < max_nsources_in_one_box"
496+ " and isrc_offset < isrc_end - isrc_start}" ,
494497 ]
495498 else :
496499 domains += [
497500 "{[itgt]: itgt_start <= itgt < itgt_end}" ,
501+ "{[isrc]: isrc_start <= isrc < isrc_end}"
498502 ]
499503
500504 # There are two algorithms here because pocl-pthread 1.9 miscompiles
501505 # the "gpu" kernel with prefetching.
502- if gpu :
506+ if self . is_gpu :
503507 instructions = (self .get_kernel_scaling_assignments ()
504508 + ["""
505509 for itgt_box
506- <> tgt_ibox = target_boxes[itgt_box]
507- <> itgt_start = box_target_starts[tgt_ibox]
508- <> itgt_end = itgt_start + box_target_counts_nonchild[tgt_ibox]
509-
510- <> isrc_box_start = source_box_starts[itgt_box]
511- <> isrc_box_end = source_box_starts[itgt_box+1]
510+ <> tgt_ibox = target_boxes[itgt_box] {id=init_0}
511+ <> itgt_start = box_target_starts[tgt_ibox] {id=init_1}
512+ <> itgt_end = itgt_start + box_target_counts_nonchild[tgt_ibox] \
513+ {id=init_2}
514+ <> isrc_box_start = source_box_starts[itgt_box] {id=init_3}
515+ <> isrc_box_end = source_box_starts[itgt_box+1] {id=init_4}
512516
513517 for itgt_offset_outer
514- <> itgt_offset = itgt_offset_outer * nsplit + inner
515- <> itgt = itgt_offset + itgt_start
516- <> cond_itgt = itgt < itgt_end
517- <> acc[iknl] = 0 {id=init_acc}
518- if cond_itgt
519- tgt_center[idim] = targets[idim, itgt] {id=prefetch_tgt,dup=idim}
518+ for inner
519+ <> itgt_offset = itgt_offset_outer * work_items_per_group + inner
520+ <> itgt = itgt_offset + itgt_start
521+ <> cond_itgt = itgt < itgt_end
522+ <> acc[iknl] = 0 {id=init_acc}
523+ if cond_itgt
524+ tgt_center[idim] = targets[idim, itgt] {id=set_tgt,dup=idim}
525+ end
520526 end
521527 for isrc_box
522528 <> src_ibox = source_box_lists[isrc_box] {id=src_box_insn_0}
523529 <> isrc_start = box_source_starts[src_ibox] {id=src_box_insn_1}
524530 <> isrc_end = isrc_start + box_source_counts_nonchild[src_ibox] \
525531 {id=src_box_insn_2}
526- for isrc_offset_outer
527- <> isrc_offset = isrc_offset_outer * nsplit + inner
528- <> cond_isrc = isrc_offset < isrc_end - isrc_start
529- if cond_isrc
530- local_isrc[idim, isrc_offset ] = sources[idim,
531- isrc_offset + isrc_start] {id=prefetch_src, dup=idim}
532- local_isrc_strength[istrength, isrc_offset ] = strength[
533- istrength, isrc_offset + isrc_start] {id=prefetch_charge}
532+ for isrc_prefetch
533+ <> cond_isrc_prefetch = isrc_prefetch < isrc_end - isrc_start \
534+ {id=cond_isrc_prefetch}
535+ if cond_isrc_prefetch
536+ local_isrc[idim, isrc_prefetch ] = sources[idim,
537+ isrc_prefetch + isrc_start] {id=prefetch_src, dup=idim}
538+ local_isrc_strength[istrength, isrc_prefetch ] = strength[
539+ istrength, isrc_prefetch + isrc_start] {id=prefetch_charge}
534540 end
535541 end
536- if cond_itgt
537- for isrc
538- <> d[idim] = (tgt_center[idim] - local_isrc[idim,
539- isrc - isrc_start]) {dep=prefetch_src:prefetch_tgt}
542+ for inner
543+ if cond_itgt
544+ for isrc_offset
545+ <> isrc = isrc_offset + isrc_start
546+ <> d[idim] = (tgt_center[idim] - local_isrc[idim,
547+ isrc_offset]) \
548+ {id=set_d,dep=prefetch_src:set_tgt}
540549 """ ] + ["""
541- <> is_self = (isrc == target_to_source[itgt])
550+ <> is_self = (isrc == target_to_source[itgt])
542551 """ if self .exclude_self else "" ]
543552 + [f"""
544- <> strength_{ i } = local_isrc_strength[{ i } , isrc - isrc_start ] \
545- {{ dep=prefetch_charge}}
553+ <> strength_{ i } = local_isrc_strength[{ i } , isrc_offset ] \
554+ {{id=set_strength { i } , dep=prefetch_charge}}
546555 """ for
547556 i in set (self .strength_usage )]
548557 + loopy_insns
549558 + [f"""
550- acc[{ iknl } ] = acc[{ iknl } ] + \
551- pair_result_{ iknl } \
552- {{id=update_acc_{ iknl } , dep=init_acc}}
559+ acc[{ iknl } ] = acc[{ iknl } ] + \
560+ pair_result_{ iknl } \
561+ {{id=update_acc_{ iknl } , dep=init_acc}}
553562 """ for iknl in range (len (self .target_kernels ))]
554563 + ["""
564+ end
555565 end
556566 end
557567 end
558568 """ ]
559569 + [f"""
570+ for inner
560571 if cond_itgt
561572 result[{ iknl } , itgt] = knl_{ iknl } _scaling * acc[{ iknl } ] \
562573 {{id_prefix=write_csr,dep=update_acc_{ iknl } }}
563574 end
575+ end
564576 """ for iknl in range (len (self .target_kernels ))]
565577 + ["""
566578 end
@@ -623,8 +635,9 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
623635 fixed_parameters = {
624636 "dim" : self .dim ,
625637 "nstrengths" : self .strength_count ,
626- "nsplit" : nsplit ,
627- "src_outer_limit" : src_outer_limit ,
638+ "max_nsources_in_one_box" : max_nsources_in_one_box ,
639+ "max_ntargets_in_one_box" : max_ntargets_in_one_box ,
640+ "work_items_per_group" : work_items_per_group ,
628641 "tgt_outer_limit" : tgt_outer_limit ,
629642 "noutputs" : len (self .target_kernels )},
630643 lang_version = MOST_RECENT_LANGUAGE_VERSION )
@@ -643,16 +656,23 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
643656 return loopy_knl
644657
645658 def get_optimized_kernel (self , max_nsources_in_one_box ,
646- max_ntargets_in_one_box ):
647- import pyopencl as cl
648- dev = self .context .devices [0 ]
649- if dev .type & cl .device_type .CPU :
659+ max_ntargets_in_one_box , dtype_size ):
660+ if not self .is_gpu :
650661 knl = self .get_kernel (max_nsources_in_one_box ,
651- max_ntargets_in_one_box , gpu = False )
662+ max_ntargets_in_one_box )
652663 knl = lp .split_iname (knl , "itgt_box" , 4 , outer_tag = "g.0" )
664+ knl = self ._allow_redundant_execution_of_knl_scaling (knl )
653665 else :
666+ work_items_per_group = min (256 , max_ntargets_in_one_box )
667+ total_local_mem = max_nsources_in_one_box * \
668+ (self .dim + self .strength_count ) * dtype_size
669+ # multiplying by 2 here to make sure at least 2 work groups
670+ # can be scheduled at the same time for latency hiding
671+ nprefetch = (2 * total_local_mem - 1 ) // self .device .local_mem_size + 1
672+
654673 knl = self .get_kernel (max_nsources_in_one_box ,
655- max_ntargets_in_one_box , gpu = True , nsplit = 32 )
674+ max_ntargets_in_one_box ,
675+ work_items_per_group = work_items_per_group )
656676 knl = lp .tag_inames (knl , {"itgt_box" : "g.0" , "inner" : "l.0" })
657677 knl = lp .set_temporary_address_space (knl ,
658678 ["local_isrc" , "local_isrc_strength" ], lp .AddressSpace .LOCAL )
@@ -670,10 +690,27 @@ def get_optimized_kernel(self, max_nsources_in_one_box,
670690 if count in [2 , 3 , 4 , 8 , 16 ]:
671691 knl = lp .tag_array_axes (knl , "local_isrc" , "vec,C" )
672692
673- knl = lp .add_inames_for_unused_hw_axes (knl )
693+ # We need to split isrc_prefetch and isrc_offset into chunks.
694+ nsources = (max_nsources_in_one_box + nprefetch - 1 ) // nprefetch
695+ knl = lp .split_array_axis (knl , "local_isrc" , 1 , nsources )
696+ knl = lp .split_iname (knl , "isrc_prefetch" , nsources ,
697+ outer_iname = "iprefetch" )
698+ knl = lp .split_iname (knl , "isrc_prefetch_inner" , work_items_per_group )
699+ knl = lp .tag_inames (knl , {"isrc_prefetch_inner_inner" : "l.0" })
700+ knl = lp .split_iname (knl , "isrc_offset" , nsources ,
701+ outer_iname = "iprefetch" )
702+
703+ # After splitting, the temporary array local_isrc need not
704+ # be as large as before. Need to simplify before unprivatizing
705+ knl = lp .simplify_indices (knl )
706+ knl = lp .unprivatize_temporaries_with_inames (knl ,
707+ "iprefetch" , only_var_names = "local_isrc" )
708+
709+ knl = lp .add_inames_to_insn (knl ,
710+ "inner" , "id:init_* or id:*_scaling or id:src_box_insn_*" )
711+ knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:*_scaling" )
674712 # knl = lp.set_options(knl, write_code=True)
675713
676- knl = self ._allow_redundant_execution_of_knl_scaling (knl )
677714 knl = lp .set_options (knl ,
678715 enforce_variable_access_ordered = "no_check" )
679716
@@ -682,9 +719,17 @@ def get_optimized_kernel(self, max_nsources_in_one_box,
682719 def __call__ (self , queue , ** kwargs ):
683720 max_nsources_in_one_box = kwargs .pop ("max_nsources_in_one_box" )
684721 max_ntargets_in_one_box = kwargs .pop ("max_ntargets_in_one_box" )
722+
723+ if self .is_gpu :
724+ dtype_size = kwargs .get ("sources" )[0 ].dtype .alignment
725+ else :
726+ dtype_size = None
727+
685728 knl = self .get_cached_optimized_kernel (
686729 max_nsources_in_one_box = max_nsources_in_one_box ,
687- max_ntargets_in_one_box = max_ntargets_in_one_box )
730+ max_ntargets_in_one_box = max_ntargets_in_one_box ,
731+ dtype_size = dtype_size ,
732+ )
688733
689734 return knl (queue , ** kwargs )
690735
0 commit comments