Skip to content

Commit 9f186c6

Browse files
committed
implements CPU backend
1 parent fe1688d commit 9f186c6

2 files changed

Lines changed: 164 additions & 3 deletions

File tree

pyop2/backends/cpu.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from pyop2.types.dat import Dat as BaseDat, MixedDat, DatView
2+
from pyop2.types.set import Set, ExtrudedSet, Subset, MixedSet
3+
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
4+
from pyop2.types.map import Map, MixedMap
5+
from pyop2.parloop import AbstractParloop
6+
from pyop2.global_kernel import AbstractGlobalKernel
7+
from pyop2.types.access import INC, MIN, MAX
8+
from pyop2.types.mat import Mat
9+
from pyop2.types.glob import Global
10+
from pyop2.backends import AbstractComputeBackend
11+
from petsc4py import PETSc
12+
from pyop2 import (
13+
compilation,
14+
mpi,
15+
utils
16+
)
17+
18+
import ctypes
19+
import os
20+
import loopy as lp
21+
22+
23+
class Dat(BaseDat):
24+
@utils.cached_property
25+
def _vec(self):
26+
assert self.dtype == PETSc.ScalarType, \
27+
"Can't create Vec with type %s, must be %s" % (self.dtype,
28+
PETSc.ScalarType)
29+
# Can't duplicate layout_vec of dataset, because we then
30+
# carry around extra unnecessary data.
31+
# But use getSizes to save an Allreduce in computing the
32+
# global size.
33+
size = self.dataset.layout_vec.getSizes()
34+
data = self._data[:size[0]]
35+
vec = PETSc.Vec().createWithArray(data, size=size,
36+
bsize=self.cdim, comm=self.comm)
37+
return vec
38+
39+
40+
class GlobalKernel(AbstractGlobalKernel):
41+
42+
@utils.cached_property
43+
def code_to_compile(self):
44+
"""Return the C/C++ source code as a string."""
45+
from pyop2.codegen.rep2loopy import generate
46+
47+
wrapper = generate(self.builder)
48+
code = lp.generate_code_v2(wrapper)
49+
50+
if self.local_kernel.cpp:
51+
from loopy.codegen.result import process_preambles
52+
preamble = "".join(
53+
process_preambles(getattr(code, "device_preambles", [])))
54+
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
55+
return preamble + '\nextern "C" {\n' + device_code + "\n}\n"
56+
return code.device_code()
57+
58+
@PETSc.Log.EventDecorator()
59+
@mpi.collective
60+
def compile(self, comm):
61+
"""Compile the kernel.
62+
63+
:arg comm: The communicator the compilation is collective over.
64+
:returns: A ctypes function pointer for the compiled function.
65+
"""
66+
extension = "cpp" if self.local_kernel.cpp else "c"
67+
cppargs = (
68+
tuple("-I%s/include" % d for d in utils.get_petsc_dir())
69+
+ tuple("-I%s" % d for d in self.local_kernel.include_dirs)
70+
+ ("-I%s" % os.path.abspath(os.path.dirname(__file__)),)
71+
)
72+
ldargs = (
73+
tuple("-L%s/lib" % d for d in utils.get_petsc_dir())
74+
+ tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir())
75+
+ ("-lpetsc", "-lm")
76+
+ tuple(self.local_kernel.ldargs)
77+
)
78+
79+
return compilation.load(self, extension, self.name,
80+
cppargs=cppargs,
81+
ldargs=ldargs,
82+
restype=ctypes.c_int,
83+
comm=comm)
84+
85+
86+
class Parloop(AbstractParloop):
87+
88+
@PETSc.Log.EventDecorator("ParLoopRednBegin")
89+
@mpi.collective
90+
def reduction_begin(self):
91+
"""Begin reductions."""
92+
requests = []
93+
for idx in self._reduction_idxs:
94+
glob = self.arguments[idx].data
95+
mpi_op = {INC: mpi.MPI.SUM,
96+
MIN: mpi.MPI.MIN,
97+
MAX: mpi.MPI.MAX}.get(self.accesses[idx])
98+
99+
if mpi.MPI.VERSION >= 3:
100+
requests.append(self.comm.Iallreduce(glob._data,
101+
glob._buf,
102+
op=mpi_op))
103+
else:
104+
self.comm.Allreduce(glob._data, glob._buf, op=mpi_op)
105+
return tuple(requests)
106+
107+
@PETSc.Log.EventDecorator("ParLoopRednEnd")
108+
@mpi.collective
109+
def reduction_end(self, requests):
110+
"""Finish reductions."""
111+
if mpi.MPI.VERSION >= 3:
112+
for idx, req in zip(self._reduction_idxs, requests):
113+
req.Wait()
114+
glob = self.arguments[idx].data
115+
glob._data[:] = glob._buf
116+
else:
117+
assert len(requests) == 0
118+
119+
for idx in self._reduction_idxs:
120+
glob = self.arguments[idx].data
121+
glob._data[:] = glob._buf
122+
123+
124+
class CPUBackend(AbstractComputeBackend):
125+
GlobalKernel = GlobalKernel
126+
Parloop = Parloop
127+
Set = Set
128+
ExtrudedSet = ExtrudedSet
129+
MixedSet = MixedSet
130+
Subset = Subset
131+
DataSet = DataSet
132+
MixedDataSet = MixedDataSet
133+
Map = Map
134+
MixedMap = MixedMap
135+
Dat = Dat
136+
MixedDat = MixedDat
137+
DatView = DatView
138+
Mat = Mat
139+
Global = Global
140+
GlobalDataSet = GlobalDataSet
141+
PETScVecType = "standard"
142+
143+
def turn_on_offloading(self):
144+
pass
145+
146+
def turn_off_offloading(self):
147+
pass
148+
149+
@property
150+
def cache_key(self):
151+
return (type(self),)
152+
153+
154+
cpu_backend = CPUBackend()

pyop2/op2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,16 @@
5050

5151
from pyop2.local_kernel import CStringLocalKernel, LoopyLocalKernel, CoffeeLocalKernel, Kernel # noqa: F401
5252
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401
53-
MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel)
53+
MatKernelArg, MixedMatKernelArg, MapKernelArg,
54+
AbstractGlobalKernel as GlobalKernel)
5455
from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401
55-
MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop)
56+
MatParloopArg, MixedMatParloopArg, AbstractParloop,
57+
parloop, par_loop)
5658
from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401
5759
MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop)
5860

61+
from pyop2.backends.cpu import cpu_backend
62+
5963
import loopy
6064

6165
__all__ = ['configuration', 'READ', 'WRITE', 'RW', 'INC', 'MIN', 'MAX',
@@ -64,7 +68,7 @@
6468
'set_log_level', 'MPI', 'init', 'exit', 'Kernel', 'Set', 'ExtrudedSet',
6569
'MixedSet', 'Subset', 'DataSet', 'GlobalDataSet', 'MixedDataSet',
6670
'Halo', 'Dat', 'MixedDat', 'Mat', 'Global', 'Map', 'MixedMap',
67-
'Sparsity', 'parloop', 'Parloop', 'ParLoop', 'par_loop',
71+
'Sparsity', 'parloop', 'AbstractParloop', 'ParLoop', 'par_loop',
6872
'DatView', 'PermutedMap']
6973

7074

@@ -121,3 +125,6 @@ def exit():
121125
configuration.reset()
122126
global _initialised
123127
_initialised = False
128+
129+
130+
compute_backend = cpu_backend

0 commit comments

Comments
 (0)