@@ -381,36 +381,37 @@ def exec_assign(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate):
381381 def exec_compute_potential_insn (
382382 self , actx : PyOpenCLArrayContext , insn , bound_expr , evaluate ):
383383 from pytential .qbx .distributed import DistributedQBXLayerPotentialSource
384- return_timing_data = self .timing_data is not None
385384
386- is_distributed_fmm = None
385+ mpi_rank = self . comm . Get_rank ()
387386 use_target_specific_qbx = None
388387 fmm_backend = None
389388 qbx_order = None
390389 fmm_level_to_order = None
391390 expansion_factory = None
392391
393- if self .comm .Get_rank () == 0 :
394- source = bound_expr .places .get_geometry (insn .source .geometry )
395- is_distributed_fmm = isinstance (
396- source , DistributedQBXLayerPotentialSource )
397- if is_distributed_fmm :
398- use_target_specific_qbx = source ._use_target_specific_qbx
399- fmm_backend = source .fmm_backend
400- qbx_order = source .qbx_order
401- fmm_level_to_order = source .fmm_level_to_order
402- expansion_factory = source .expansion_factory
403-
404- is_distributed_fmm = self .comm .bcast (is_distributed_fmm , root = 0 )
405- if is_distributed_fmm :
406- use_target_specific_qbx = self .comm .bcast (
407- use_target_specific_qbx , root = 0 )
408- fmm_backend = self .comm .bcast (fmm_backend , root = 0 )
409- qbx_order = self .comm .bcast (qbx_order , root = 0 )
410- fmm_level_to_order = self .comm .bcast (fmm_level_to_order , root = 0 )
411- expansion_factory = self .comm .bcast (expansion_factory , root = 0 )
412-
413- if is_distributed_fmm and self .comm .Get_rank () != 0 :
392+ if mpi_rank == 0 :
393+ source : DistributedQBXLayerPotentialSource = \
394+ bound_expr .places .get_geometry (insn .source .geometry )
395+ if not isinstance (source , DistributedQBXLayerPotentialSource ):
396+ raise TypeError ("Distributed execution mapper can only process"
397+ "distributed layer potential source" )
398+
399+ use_target_specific_qbx = source ._use_target_specific_qbx
400+ fmm_backend = source .fmm_backend
401+ qbx_order = source .qbx_order
402+ fmm_level_to_order = source .fmm_level_to_order
403+ expansion_factory = source .expansion_factory
404+
405+ use_target_specific_qbx = self .comm .bcast (
406+ use_target_specific_qbx , root = 0 )
407+ fmm_backend = self .comm .bcast (fmm_backend , root = 0 )
408+ qbx_order = self .comm .bcast (qbx_order , root = 0 )
409+ fmm_level_to_order = self .comm .bcast (fmm_level_to_order , root = 0 )
410+ expansion_factory = self .comm .bcast (expansion_factory , root = 0 )
411+
412+ assert isinstance (fmm_backend , str )
413+
414+ if mpi_rank != 0 :
414415 source = DistributedQBXLayerPotentialSource (
415416 self .comm ,
416417 actx .context ,
@@ -420,18 +421,18 @@ def exec_compute_potential_insn(
420421 fmm_backend = fmm_backend ,
421422 expansion_factory = expansion_factory )
422423
423- if self .comm . Get_rank () == 0 or is_distributed_fmm :
424- result , timing_data = (
425- source .exec_compute_potential_insn (
426- actx , insn , bound_expr , evaluate , return_timing_data ))
424+ return_timing_data = self .timing_data is not None
425+ result , timing_data = (
426+ source .exec_compute_potential_insn (
427+ actx , insn , bound_expr , evaluate , return_timing_data ))
427428
428- if return_timing_data :
429- # The compiler ensures this.
430- assert insn not in self .timing_data
429+ if return_timing_data :
430+ # The compiler ensures this.
431+ assert insn not in self .timing_data
431432
432- self .timing_data [insn ] = timing_data
433+ self .timing_data [insn ] = timing_data
433434
434- return result
435+ return result
435436
436437 def __call__ (self , expr , * args , ** kwargs ):
437438 if self .comm .Get_rank () == 0 :
0 commit comments