@@ -479,6 +479,98 @@ def compute_local_geometry_data(
479479 qbx_center_to_target_box_source_level ))
480480
481481
482+ def distribute_geo_data (comm , actx , insn , bound_expr , evaluate ,
483+ global_geo_data_device ):
484+ geo_data_cache = bound_expr ._geo_data_cache
485+
486+ if insn in geo_data_cache :
487+ return geo_data_cache [insn ]
488+
489+ boxes_time = None
490+ global_geo_data = None
491+
492+ if comm .Get_rank () == 0 :
493+ # Use the cost model to estimate execution time for partitioning
494+ from pytential .qbx .cost import AbstractQBXCostModel , QBXCostModel
495+
496+ # FIXME: If the expansion wrangler is not FMMLib, the argument
497+ # 'uses_pde_expansions' might be different
498+ cost_model = QBXCostModel ()
499+
500+ import warnings
501+ warnings .warn (
502+ "Kernel-specific calibration parameters are not supplied when"
503+ "using distributed FMM." )
504+ # TODO: supply better default calibration parameters
505+ calibration_params = AbstractQBXCostModel .get_unit_calibration_params ()
506+
507+ kernel_args = {}
508+ for arg_name , arg_expr in insn .kernel_arguments .items ():
509+ kernel_args [arg_name ] = evaluate (arg_expr )
510+
511+ boxes_time , _ = cost_model .qbx_cost_per_box (
512+ actx .queue , global_geo_data_device , insn .target_kernels [0 ],
513+ kernel_args , calibration_params )
514+ boxes_time = boxes_time .get ()
515+
516+ from pytential .qbx .utils import ToHostTransferredGeoDataWrapper
517+ global_geo_data = ToHostTransferredGeoDataWrapper (global_geo_data_device )
518+
519+ # {{{ Construct a traversal builder
520+
521+ # NOTE: The distributed implementation relies on building the same traversal
522+ # objects as the one on the root rank. This means here the traversal builder
523+ # should use the same parameters as `QBXFMMGeometryData.traversal`. To make
524+ # it consistent across ranks, we broadcast the parameters here.
525+
526+ trav_param = None
527+ if comm .Get_rank () == 0 :
528+ trav_param = {
529+ "well_sep_is_n_away" :
530+ global_geo_data .geo_data .code_getter .build_traversal
531+ .well_sep_is_n_away ,
532+ "from_sep_smaller_crit" :
533+ global_geo_data .geo_data .code_getter .build_traversal .
534+ from_sep_smaller_crit ,
535+ "_from_sep_smaller_min_nsources_cumul" :
536+ global_geo_data .geo_data .lpot_source .
537+ _from_sep_smaller_min_nsources_cumul }
538+ trav_param = comm .bcast (trav_param , root = 0 )
539+
540+ traversal_builder = QBXFMMGeometryDataTraversalBuilder (
541+ actx .context ,
542+ well_sep_is_n_away = trav_param ["well_sep_is_n_away" ],
543+ from_sep_smaller_crit = trav_param ["from_sep_smaller_crit" ],
544+ _from_sep_smaller_min_nsources_cumul = trav_param [
545+ "_from_sep_smaller_min_nsources_cumul" ])
546+
547+ # }}}
548+
549+ # {{{ Broadcast the subset of the global geometry data to worker ranks
550+
551+ global_geo_data = broadcast_global_geometry_data (
552+ comm , actx , traversal_builder , global_geo_data )
553+
554+ # }}}
555+
556+ # {{{ Compute the local geometry data from the global geometry data
557+
558+ if comm .Get_rank () != 0 :
559+ boxes_time = np .empty (
560+ global_geo_data .global_traversal .tree .nboxes , dtype = np .float64 )
561+
562+ comm .Bcast (boxes_time , root = 0 )
563+
564+ local_geo_data = compute_local_geometry_data (
565+ actx , comm , global_geo_data , boxes_time , traversal_builder )
566+
567+ # }}}
568+
569+ geo_data_cache [insn ] = (global_geo_data , local_geo_data )
570+
571+ return global_geo_data , local_geo_data
572+
573+
482574class DistributedQBXLayerPotentialSource (QBXLayerPotentialSource ):
483575 def __init__ (self , comm , cl_context , * args ,
484576 _use_target_specific_qbx : Optional [bool ] = None ,
@@ -613,12 +705,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
613705 from pytential .qbx import get_flat_strengths_from_densities
614706 from meshmode .discretization import Discretization
615707
616- target_name_and_side_to_number = None
617- target_discrs_and_qbx_sides = None
618708 global_geo_data_device = None
619- global_geo_data = None
620- local_geo_data = None
621- boxes_time = None
622709 output_and_expansion_dtype = None
623710 flat_strengths = None
624711
@@ -631,90 +718,8 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
631718 insn .source .geometry ,
632719 target_discrs_and_qbx_sides )
633720
634- # Use the cost model to estimate execution time for partitioning
635- from pytential .qbx .cost import AbstractQBXCostModel , QBXCostModel
636-
637- # FIXME: If the expansion wrangler is not FMMLib, the argument
638- # 'uses_pde_expansions' might be different
639- cost_model = QBXCostModel ()
640-
641- import warnings
642- warnings .warn (
643- "Kernel-specific calibration parameters are not supplied when"
644- "using distributed FMM." )
645- # TODO: supply better default calibration parameters
646- calibration_params = AbstractQBXCostModel .get_unit_calibration_params ()
647-
648- kernel_args = {}
649- for arg_name , arg_expr in insn .kernel_arguments .items ():
650- kernel_args [arg_name ] = evaluate (arg_expr )
651-
652- boxes_time , _ = cost_model .qbx_cost_per_box (
653- actx .queue , global_geo_data_device , insn .target_kernels [0 ],
654- kernel_args , calibration_params )
655- boxes_time = boxes_time .get ()
656-
657- from pytential .qbx .utils import ToHostTransferredGeoDataWrapper
658- global_geo_data = ToHostTransferredGeoDataWrapper (global_geo_data_device )
659-
660- # FIXME Exert more positive control over geo_data attribute lifetimes using
661- # geo_data.<method>.clear_cache(geo_data).
662-
663- # FIXME Synthesize "bad centers" around corners and edges that have
664- # inadequate QBX coverage.
665-
666- # FIXME don't compute *all* output kernels on all targets--respect that
667- # some target discretizations may only be asking for derivatives (e.g.)
668-
669- # {{{ Construct a traversal builder
670-
671- # NOTE: The distributed implementation relies on building the same traversal
672- # objects as the one on the root rank. This means here the traversal builder
673- # should use the same parameters as `QBXFMMGeometryData.traversal`. To make
674- # it consistent across ranks, we broadcast the parameters here.
675-
676- trav_param = None
677- if self .comm .Get_rank () == 0 :
678- trav_param = {
679- "well_sep_is_n_away" :
680- global_geo_data .geo_data .code_getter .build_traversal
681- .well_sep_is_n_away ,
682- "from_sep_smaller_crit" :
683- global_geo_data .geo_data .code_getter .build_traversal .
684- from_sep_smaller_crit ,
685- "_from_sep_smaller_min_nsources_cumul" :
686- global_geo_data .geo_data .lpot_source .
687- _from_sep_smaller_min_nsources_cumul }
688- trav_param = self .comm .bcast (trav_param , root = 0 )
689-
690- traversal_builder = QBXFMMGeometryDataTraversalBuilder (
691- actx .context ,
692- well_sep_is_n_away = trav_param ["well_sep_is_n_away" ],
693- from_sep_smaller_crit = trav_param ["from_sep_smaller_crit" ],
694- _from_sep_smaller_min_nsources_cumul = trav_param [
695- "_from_sep_smaller_min_nsources_cumul" ])
696-
697- # }}}
698-
699- # {{{ Broadcast the subset of the global geometry data to worker ranks
700-
701- global_geo_data = broadcast_global_geometry_data (
702- self .comm , actx , traversal_builder , global_geo_data )
703-
704- # }}}
705-
706- # {{{ Compute the local geometry data from the global geometry data
707-
708- if self .comm .Get_rank () != 0 :
709- boxes_time = np .empty (
710- global_geo_data .global_traversal .tree .nboxes , dtype = np .float64 )
711-
712- self .comm .Bcast (boxes_time , root = 0 )
713-
714- local_geo_data = compute_local_geometry_data (
715- actx , self .comm , global_geo_data , boxes_time , traversal_builder )
716-
717- # }}}
721+ global_geo_data , local_geo_data = distribute_geo_data (
722+ self .comm , actx , insn , bound_expr , evaluate , global_geo_data_device )
718723
719724 tree_indep = self ._tree_indep_data_for_wrangler (
720725 target_kernels = insn .target_kernels ,
0 commit comments