@@ -399,7 +399,7 @@ def bounds(exprs):
399399 return dict ((op , (names [op ], dep )) for op , dep in deps .items ())
400400
401401
402- def generate (builder , wrapper_name = None ):
402+ def generate (builder , wrapper_name = None , include_math = True , include_petsc = True , include_complex = True ):
403403 if builder .layer_index is not None :
404404 outer_inames = frozenset ([builder ._loop_index .name ,
405405 builder .layer_index .name ])
@@ -546,7 +546,16 @@ def renamer(expr):
546546 # register kernel
547547 kernel = builder .kernel
548548 headers = set (kernel .headers )
549- headers = headers | set (["#include <math.h>" , "#include <complex.h>" , "#include <petsc.h>" ])
549+
550+ if include_math :
551+ headers .add ("#include <math.h>" )
552+
553+ if include_petsc :
554+ headers .add ("#include <petsc.h>" )
555+
556+ if include_complex :
557+ headers .add ("#include <complex.h>" )
558+
550559 if PETSc .Log .isActive ():
551560 headers = headers | set (["#include <petsclog.h>" ])
552561 preamble = "\n " .join (sorted (headers ))
@@ -621,15 +630,22 @@ def statement_assign(expr, context):
621630 if isinstance (lvalue , Indexed ):
622631 context .index_ordering .append (tuple (i .name for i in lvalue .index_ordering ()))
623632 lvalue , rvalue = tuple (expression (c , context .parameters ) for c in expr .children )
624- within_inames = context .within_inames [expr ]
633+ if isinstance (expr .label , (PreUnpackInst , UnpackInst )):
634+ tag = "scatter"
635+ elif isinstance (expr .label , PackInst ):
636+ tag = "gather"
637+ else :
638+ raise NotImplementedError ()
625639
640+ within_inames = context .within_inames [expr ]
626641 id , depends_on = context .instruction_dependencies [expr ]
627642 predicates = frozenset (context .conditions )
628643 return loopy .Assignment (lvalue , rvalue , within_inames = within_inames ,
629644 within_inames_is_final = True ,
630645 predicates = predicates ,
631646 id = id ,
632- depends_on = depends_on , depends_on_is_final = True )
647+ depends_on = depends_on , depends_on_is_final = True ,
648+ tags = frozenset ([tag ]))
633649
634650
635651@statement .register (FunctionCall )
0 commit comments