Skip to content

Commit c270674

Browse files
committed
Refactor distributed eval mapper
1 parent d7425ab commit c270674

1 file changed

Lines changed: 33 additions & 32 deletions

File tree

pytential/symbolic/execution.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -381,36 +381,37 @@ def exec_assign(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate):
381381
def exec_compute_potential_insn(
382382
self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate):
383383
from pytential.qbx.distributed import DistributedQBXLayerPotentialSource
384-
return_timing_data = self.timing_data is not None
385384

386-
is_distributed_fmm = None
385+
mpi_rank = self.comm.Get_rank()
387386
use_target_specific_qbx = None
388387
fmm_backend = None
389388
qbx_order = None
390389
fmm_level_to_order = None
391390
expansion_factory = None
392391

393-
if self.comm.Get_rank() == 0:
394-
source = bound_expr.places.get_geometry(insn.source.geometry)
395-
is_distributed_fmm = isinstance(
396-
source, DistributedQBXLayerPotentialSource)
397-
if is_distributed_fmm:
398-
use_target_specific_qbx = source._use_target_specific_qbx
399-
fmm_backend = source.fmm_backend
400-
qbx_order = source.qbx_order
401-
fmm_level_to_order = source.fmm_level_to_order
402-
expansion_factory = source.expansion_factory
403-
404-
is_distributed_fmm = self.comm.bcast(is_distributed_fmm, root=0)
405-
if is_distributed_fmm:
406-
use_target_specific_qbx = self.comm.bcast(
407-
use_target_specific_qbx, root=0)
408-
fmm_backend = self.comm.bcast(fmm_backend, root=0)
409-
qbx_order = self.comm.bcast(qbx_order, root=0)
410-
fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0)
411-
expansion_factory = self.comm.bcast(expansion_factory, root=0)
412-
413-
if is_distributed_fmm and self.comm.Get_rank() != 0:
392+
if mpi_rank == 0:
393+
source: DistributedQBXLayerPotentialSource = \
394+
bound_expr.places.get_geometry(insn.source.geometry)
395+
if not isinstance(source, DistributedQBXLayerPotentialSource):
396+
raise TypeError("Distributed execution mapper can only process"
397+
"distributed layer potential source")
398+
399+
use_target_specific_qbx = source._use_target_specific_qbx
400+
fmm_backend = source.fmm_backend
401+
qbx_order = source.qbx_order
402+
fmm_level_to_order = source.fmm_level_to_order
403+
expansion_factory = source.expansion_factory
404+
405+
use_target_specific_qbx = self.comm.bcast(
406+
use_target_specific_qbx, root=0)
407+
fmm_backend = self.comm.bcast(fmm_backend, root=0)
408+
qbx_order = self.comm.bcast(qbx_order, root=0)
409+
fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0)
410+
expansion_factory = self.comm.bcast(expansion_factory, root=0)
411+
412+
assert isinstance(fmm_backend, str)
413+
414+
if mpi_rank != 0:
414415
source = DistributedQBXLayerPotentialSource(
415416
self.comm,
416417
actx.context,
@@ -420,18 +421,18 @@ def exec_compute_potential_insn(
420421
fmm_backend=fmm_backend,
421422
expansion_factory=expansion_factory)
422423

423-
if self.comm.Get_rank() == 0 or is_distributed_fmm:
424-
result, timing_data = (
425-
source.exec_compute_potential_insn(
426-
actx, insn, bound_expr, evaluate, return_timing_data))
424+
return_timing_data = self.timing_data is not None
425+
result, timing_data = (
426+
source.exec_compute_potential_insn(
427+
actx, insn, bound_expr, evaluate, return_timing_data))
427428

428-
if return_timing_data:
429-
# The compiler ensures this.
430-
assert insn not in self.timing_data
429+
if return_timing_data:
430+
# The compiler ensures this.
431+
assert insn not in self.timing_data
431432

432-
self.timing_data[insn] = timing_data
433+
self.timing_data[insn] = timing_data
433434

434-
return result
435+
return result
435436

436437
def __call__(self, expr, *args, **kwargs):
437438
if self.comm.Get_rank() == 0:

0 commit comments

Comments
 (0)