Skip to content

Commit 8ea1cb4

Browse files
committed
generalize loopy codegen to allow OpenCL/CUDA targets
1 parent 8b6ec76 commit 8ea1cb4

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

pyop2/codegen/rep2loopy.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def bounds(exprs):
399399
return dict((op, (names[op], dep)) for op, dep in deps.items())
400400

401401

402-
def generate(builder, wrapper_name=None):
402+
def generate(builder, wrapper_name=None, include_math=True, include_petsc=True, include_complex=True):
403403
if builder.layer_index is not None:
404404
outer_inames = frozenset([builder._loop_index.name,
405405
builder.layer_index.name])
@@ -546,7 +546,16 @@ def renamer(expr):
546546
# register kernel
547547
kernel = builder.kernel
548548
headers = set(kernel.headers)
549-
headers = headers | set(["#include <math.h>", "#include <complex.h>", "#include <petsc.h>"])
549+
550+
if include_math:
551+
headers.add("#include <math.h>")
552+
553+
if include_petsc:
554+
headers.add("#include <petsc.h>")
555+
556+
if include_complex:
557+
headers.add("#include <complex.h>")
558+
550559
if PETSc.Log.isActive():
551560
headers = headers | set(["#include <petsclog.h>"])
552561
preamble = "\n".join(sorted(headers))
@@ -621,15 +630,22 @@ def statement_assign(expr, context):
621630
if isinstance(lvalue, Indexed):
622631
context.index_ordering.append(tuple(i.name for i in lvalue.index_ordering()))
623632
lvalue, rvalue = tuple(expression(c, context.parameters) for c in expr.children)
624-
within_inames = context.within_inames[expr]
633+
if isinstance(expr.label, (PreUnpackInst, UnpackInst)):
634+
tag = "scatter"
635+
elif isinstance(expr.label, PackInst):
636+
tag = "gather"
637+
else:
638+
raise NotImplementedError()
625639

640+
within_inames = context.within_inames[expr]
626641
id, depends_on = context.instruction_dependencies[expr]
627642
predicates = frozenset(context.conditions)
628643
return loopy.Assignment(lvalue, rvalue, within_inames=within_inames,
629644
within_inames_is_final=True,
630645
predicates=predicates,
631646
id=id,
632-
depends_on=depends_on, depends_on_is_final=True)
647+
depends_on=depends_on, depends_on_is_final=True,
648+
tags=frozenset([tag]))
633649

634650

635651
@statement.register(FunctionCall)

0 commit comments

Comments
 (0)