Skip to content

Commit 58562f6

Browse files
committed
Cache the distributed geometry
1 parent c270674 commit 58562f6

2 files changed

Lines changed: 95 additions & 89 deletions

File tree

pytential/qbx/distributed.py

Lines changed: 94 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,98 @@ def compute_local_geometry_data(
479479
qbx_center_to_target_box_source_level))
480480

481481

482+
def distribute_geo_data(comm, actx, insn, bound_expr, evaluate,
483+
global_geo_data_device):
484+
geo_data_cache = bound_expr._geo_data_cache
485+
486+
if insn in geo_data_cache:
487+
return geo_data_cache[insn]
488+
489+
boxes_time = None
490+
global_geo_data = None
491+
492+
if comm.Get_rank() == 0:
493+
# Use the cost model to estimate execution time for partitioning
494+
from pytential.qbx.cost import AbstractQBXCostModel, QBXCostModel
495+
496+
# FIXME: If the expansion wrangler is not FMMLib, the argument
497+
# 'uses_pde_expansions' might be different
498+
cost_model = QBXCostModel()
499+
500+
import warnings
501+
warnings.warn(
502+
"Kernel-specific calibration parameters are not supplied when"
503+
"using distributed FMM.")
504+
# TODO: supply better default calibration parameters
505+
calibration_params = AbstractQBXCostModel.get_unit_calibration_params()
506+
507+
kernel_args = {}
508+
for arg_name, arg_expr in insn.kernel_arguments.items():
509+
kernel_args[arg_name] = evaluate(arg_expr)
510+
511+
boxes_time, _ = cost_model.qbx_cost_per_box(
512+
actx.queue, global_geo_data_device, insn.target_kernels[0],
513+
kernel_args, calibration_params)
514+
boxes_time = boxes_time.get()
515+
516+
from pytential.qbx.utils import ToHostTransferredGeoDataWrapper
517+
global_geo_data = ToHostTransferredGeoDataWrapper(global_geo_data_device)
518+
519+
# {{{ Construct a traversal builder
520+
521+
# NOTE: The distributed implementation relies on building the same traversal
522+
# objects as the one on the root rank. This means here the traversal builder
523+
# should use the same parameters as `QBXFMMGeometryData.traversal`. To make
524+
# it consistent across ranks, we broadcast the parameters here.
525+
526+
trav_param = None
527+
if comm.Get_rank() == 0:
528+
trav_param = {
529+
"well_sep_is_n_away":
530+
global_geo_data.geo_data.code_getter.build_traversal
531+
.well_sep_is_n_away,
532+
"from_sep_smaller_crit":
533+
global_geo_data.geo_data.code_getter.build_traversal.
534+
from_sep_smaller_crit,
535+
"_from_sep_smaller_min_nsources_cumul":
536+
global_geo_data.geo_data.lpot_source.
537+
_from_sep_smaller_min_nsources_cumul}
538+
trav_param = comm.bcast(trav_param, root=0)
539+
540+
traversal_builder = QBXFMMGeometryDataTraversalBuilder(
541+
actx.context,
542+
well_sep_is_n_away=trav_param["well_sep_is_n_away"],
543+
from_sep_smaller_crit=trav_param["from_sep_smaller_crit"],
544+
_from_sep_smaller_min_nsources_cumul=trav_param[
545+
"_from_sep_smaller_min_nsources_cumul"])
546+
547+
# }}}
548+
549+
# {{{ Broadcast the subset of the global geometry data to worker ranks
550+
551+
global_geo_data = broadcast_global_geometry_data(
552+
comm, actx, traversal_builder, global_geo_data)
553+
554+
# }}}
555+
556+
# {{{ Compute the local geometry data from the global geometry data
557+
558+
if comm.Get_rank() != 0:
559+
boxes_time = np.empty(
560+
global_geo_data.global_traversal.tree.nboxes, dtype=np.float64)
561+
562+
comm.Bcast(boxes_time, root=0)
563+
564+
local_geo_data = compute_local_geometry_data(
565+
actx, comm, global_geo_data, boxes_time, traversal_builder)
566+
567+
# }}}
568+
569+
geo_data_cache[insn] = (global_geo_data, local_geo_data)
570+
571+
return global_geo_data, local_geo_data
572+
573+
482574
class DistributedQBXLayerPotentialSource(QBXLayerPotentialSource):
483575
def __init__(self, comm, cl_context, *args,
484576
_use_target_specific_qbx: Optional[bool] = None,
@@ -613,12 +705,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
613705
from pytential.qbx import get_flat_strengths_from_densities
614706
from meshmode.discretization import Discretization
615707

616-
target_name_and_side_to_number = None
617-
target_discrs_and_qbx_sides = None
618708
global_geo_data_device = None
619-
global_geo_data = None
620-
local_geo_data = None
621-
boxes_time = None
622709
output_and_expansion_dtype = None
623710
flat_strengths = None
624711

@@ -631,90 +718,8 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
631718
insn.source.geometry,
632719
target_discrs_and_qbx_sides)
633720

634-
# Use the cost model to estimate execution time for partitioning
635-
from pytential.qbx.cost import AbstractQBXCostModel, QBXCostModel
636-
637-
# FIXME: If the expansion wrangler is not FMMLib, the argument
638-
# 'uses_pde_expansions' might be different
639-
cost_model = QBXCostModel()
640-
641-
import warnings
642-
warnings.warn(
643-
"Kernel-specific calibration parameters are not supplied when"
644-
"using distributed FMM.")
645-
# TODO: supply better default calibration parameters
646-
calibration_params = AbstractQBXCostModel.get_unit_calibration_params()
647-
648-
kernel_args = {}
649-
for arg_name, arg_expr in insn.kernel_arguments.items():
650-
kernel_args[arg_name] = evaluate(arg_expr)
651-
652-
boxes_time, _ = cost_model.qbx_cost_per_box(
653-
actx.queue, global_geo_data_device, insn.target_kernels[0],
654-
kernel_args, calibration_params)
655-
boxes_time = boxes_time.get()
656-
657-
from pytential.qbx.utils import ToHostTransferredGeoDataWrapper
658-
global_geo_data = ToHostTransferredGeoDataWrapper(global_geo_data_device)
659-
660-
# FIXME Exert more positive control over geo_data attribute lifetimes using
661-
# geo_data.<method>.clear_cache(geo_data).
662-
663-
# FIXME Synthesize "bad centers" around corners and edges that have
664-
# inadequate QBX coverage.
665-
666-
# FIXME don't compute *all* output kernels on all targets--respect that
667-
# some target discretizations may only be asking for derivatives (e.g.)
668-
669-
# {{{ Construct a traversal builder
670-
671-
# NOTE: The distributed implementation relies on building the same traversal
672-
# objects as the one on the root rank. This means here the traversal builder
673-
# should use the same parameters as `QBXFMMGeometryData.traversal`. To make
674-
# it consistent across ranks, we broadcast the parameters here.
675-
676-
trav_param = None
677-
if self.comm.Get_rank() == 0:
678-
trav_param = {
679-
"well_sep_is_n_away":
680-
global_geo_data.geo_data.code_getter.build_traversal
681-
.well_sep_is_n_away,
682-
"from_sep_smaller_crit":
683-
global_geo_data.geo_data.code_getter.build_traversal.
684-
from_sep_smaller_crit,
685-
"_from_sep_smaller_min_nsources_cumul":
686-
global_geo_data.geo_data.lpot_source.
687-
_from_sep_smaller_min_nsources_cumul}
688-
trav_param = self.comm.bcast(trav_param, root=0)
689-
690-
traversal_builder = QBXFMMGeometryDataTraversalBuilder(
691-
actx.context,
692-
well_sep_is_n_away=trav_param["well_sep_is_n_away"],
693-
from_sep_smaller_crit=trav_param["from_sep_smaller_crit"],
694-
_from_sep_smaller_min_nsources_cumul=trav_param[
695-
"_from_sep_smaller_min_nsources_cumul"])
696-
697-
# }}}
698-
699-
# {{{ Broadcast the subset of the global geometry data to worker ranks
700-
701-
global_geo_data = broadcast_global_geometry_data(
702-
self.comm, actx, traversal_builder, global_geo_data)
703-
704-
# }}}
705-
706-
# {{{ Compute the local geometry data from the global geometry data
707-
708-
if self.comm.Get_rank() != 0:
709-
boxes_time = np.empty(
710-
global_geo_data.global_traversal.tree.nboxes, dtype=np.float64)
711-
712-
self.comm.Bcast(boxes_time, root=0)
713-
714-
local_geo_data = compute_local_geometry_data(
715-
actx, self.comm, global_geo_data, boxes_time, traversal_builder)
716-
717-
# }}}
721+
global_geo_data, local_geo_data = distribute_geo_data(
722+
self.comm, actx, insn, bound_expr, evaluate, global_geo_data_device)
718723

719724
tree_indep = self._tree_indep_data_for_wrangler(
720725
target_kernels=insn.target_kernels,

pytential/symbolic/execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ class DistributedBoundExpression(BoundExpression):
967967
def __init__(self, comm, places, sym_op_expr):
968968
self.comm = comm
969969
self._code = None
970+
self._geo_data_cache = {}
970971

971972
if self.comm.Get_rank() == 0:
972973
super().__init__(places, sym_op_expr)

0 commit comments

Comments
 (0)