Skip to content

Commit a43fcd3

Browse files
authored
Merge pull request #639 from OP2/connorjward/refactor-base.py
Major refactor of base.py etc
2 parents 2d77e8f + e4a9de6 commit a43fcd3

22 files changed

Lines changed: 4775 additions & 4856 deletions

pyop2/base.py

Lines changed: 0 additions & 3911 deletions
This file was deleted.

pyop2/codegen/rep2loopy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pyop2.codegen.node import traversal, Node, Memoizer, reuse_if_untouched
2121

22-
from pyop2.base import READ, WRITE
22+
from pyop2.types.access import READ, WRITE
2323
from pyop2.datatypes import as_ctypes
2424

2525
from pyop2.codegen.optimise import index_merger, rename_nodes

pyop2/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from pyop2.configuration import configuration
4949
from pyop2.logger import debug, progress, INFO
5050
from pyop2.exceptions import CompilationError
51-
from pyop2.base import JITModule
5251

5352

5453
def _check_hashes(x, y, datatype):
@@ -466,6 +465,7 @@ def load(jitmodule, extension, fn_name, cppargs=[], ldargs=[],
466465
:kwarg comm: Optional communicator to compile the code on (only
467466
rank 0 compiles code) (defaults to COMM_WORLD).
468467
"""
468+
from pyop2.parloop import JITModule
469469
if isinstance(jitmodule, str):
470470
class StrCode(object):
471471
def __init__(self, code, argtypes):

pyop2/kernel.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import hashlib
2+
3+
import coffee
4+
import loopy as lp
5+
6+
from . import caching, configuration as conf, datatypes, exceptions as ex, utils, version
7+
8+
9+
class Kernel(caching.Cached):
10+
11+
"""OP2 kernel type.
12+
13+
:param code: kernel function definition, including signature; either a
14+
string or an AST :class:`.Node`
15+
:param name: kernel function name; must match the name of the kernel
16+
function given in `code`
17+
:param opts: options dictionary for :doc:`PyOP2 IR optimisations <ir>`
18+
(optional, ignored if `code` is a string)
19+
:param include_dirs: list of additional include directories to be searched
20+
when compiling the kernel (optional, defaults to empty)
21+
:param headers: list of system headers to include when compiling the kernel
22+
in the form ``#include <header.h>`` (optional, defaults to empty)
23+
:param user_code: code snippet to be executed once at the very start of
24+
the generated kernel wrapper code (optional, defaults to
25+
empty)
26+
:param ldargs: A list of arguments to pass to the linker when
27+
compiling this Kernel.
28+
:param requires_zeroed_output_arguments: Does this kernel require the
29+
output arguments to be zeroed on entry when called? (default no)
30+
:param cpp: Is the kernel actually C++ rather than C? If yes,
31+
then compile with the C++ compiler (kernel is wrapped in
32+
extern C for linkage reasons).
33+
34+
Consider the case of initialising a :class:`~pyop2.Dat` with seeded random
35+
values in the interval 0 to 1. The corresponding :class:`~pyop2.Kernel` is
36+
constructed as follows: ::
37+
38+
op2.Kernel("void setrand(double *x) { x[0] = (double)random()/RAND_MAX); }",
39+
name="setrand",
40+
headers=["#include <stdlib.h>"], user_code="srandom(10001);")
41+
42+
.. note::
43+
When running in parallel with MPI the generated code must be the same
44+
on all ranks.
45+
"""
46+
47+
_cache = {}
48+
49+
@classmethod
50+
@utils.validate_type(('name', str, ex.NameTypeError))
51+
def _cache_key(cls, code, name, opts={}, include_dirs=[], headers=[],
52+
user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False,
53+
flop_count=None):
54+
# Both code and name are relevant since there might be multiple kernels
55+
# extracting different functions from the same code
56+
# Also include the PyOP2 version, since the Kernel class might change
57+
58+
if isinstance(code, coffee.base.Node):
59+
code = code.gencode()
60+
if isinstance(code, lp.TranslationUnit):
61+
from loopy.tools import LoopyKeyBuilder
62+
from hashlib import sha256
63+
key_hash = sha256()
64+
code.update_persistent_hash(key_hash, LoopyKeyBuilder())
65+
code = key_hash.hexdigest()
66+
hashee = (str(code) + name + str(sorted(opts.items())) + str(include_dirs)
67+
+ str(headers) + version.__version__ + str(ldargs) + str(cpp) + str(requires_zeroed_output_arguments))
68+
return hashlib.md5(hashee.encode()).hexdigest()
69+
70+
@utils.cached_property
71+
def _wrapper_cache_key_(self):
72+
return (self._key, )
73+
74+
def __init__(self, code, name, opts={}, include_dirs=[], headers=[],
75+
user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False,
76+
flop_count=None):
77+
# Protect against re-initialization when retrieved from cache
78+
if self._initialized:
79+
return
80+
self._name = name
81+
self._cpp = cpp
82+
# Record used optimisations
83+
self._opts = opts
84+
self._include_dirs = include_dirs
85+
self._ldargs = ldargs if ldargs is not None else []
86+
self._headers = headers
87+
self._user_code = user_code
88+
assert isinstance(code, (str, coffee.base.Node, lp.Program, lp.LoopKernel, lp.TranslationUnit))
89+
self._code = code
90+
self._initialized = True
91+
self.requires_zeroed_output_arguments = requires_zeroed_output_arguments
92+
self.flop_count = flop_count
93+
94+
@property
95+
def name(self):
96+
"""Kernel name, must match the kernel function name in the code."""
97+
return self._name
98+
99+
@property
100+
def code(self):
101+
return self._code
102+
103+
@utils.cached_property
104+
def num_flops(self):
105+
if self.flop_count is not None:
106+
return self.flop_count
107+
if not conf.configuration["compute_kernel_flops"]:
108+
return 0
109+
if isinstance(self.code, coffee.base.Node):
110+
v = coffee.visitors.EstimateFlops()
111+
return v.visit(self.code)
112+
elif isinstance(self.code, lp.TranslationUnit):
113+
op_map = lp.get_op_map(
114+
self.code.copy(options=lp.Options(ignore_boostable_into=True),
115+
silenced_warnings=['insn_count_subgroups_upper_bound',
116+
'get_x_map_guessing_subgroup_size',
117+
'summing_if_branches_ops']),
118+
subgroup_size='guess')
119+
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'], dtype=[datatypes.ScalarType]).eval_and_sum({})
120+
else:
121+
return 0
122+
123+
def __str__(self):
124+
return "OP2 Kernel: %s" % self._name
125+
126+
def __repr__(self):
127+
return 'Kernel("""%s""", %r)' % (self._code, self._name)
128+
129+
def __eq__(self, other):
130+
return self.cache_key == other.cache_key
131+
132+
133+
class PyKernel(Kernel):
134+
@classmethod
135+
def _cache_key(cls, *args, **kwargs):
136+
return None
137+
138+
def __init__(self, code, name=None, **kwargs):
139+
self._func = code
140+
self._name = name
141+
142+
def __getattr__(self, attr):
143+
"""Return None on unrecognised attributes"""
144+
return None
145+
146+
def __call__(self, *args):
147+
return self._func(*args)
148+
149+
def __repr__(self):
150+
return 'Kernel("""%s""", %r)' % (self._func, self._name)

pyop2/op2.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,18 @@
3939
from pyop2.logger import debug, info, warning, error, critical, set_log_level
4040
from pyop2.mpi import MPI, COMM_WORLD, collective
4141

42-
from pyop2.sequential import par_loop, Kernel # noqa: F401
43-
from pyop2.sequential import READ, WRITE, RW, INC, MIN, MAX # noqa: F401
44-
from pyop2.base import ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL # noqa: F401
45-
from pyop2.sequential import Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet # noqa: F401
46-
from pyop2.sequential import Map, MixedMap, PermutedMap, Sparsity, Halo # noqa: F401
47-
from pyop2.sequential import Global, GlobalDataSet # noqa: F401
48-
from pyop2.sequential import Dat, MixedDat, DatView, Mat # noqa: F401
49-
from pyop2.sequential import ParLoop as SeqParLoop
50-
from pyop2.pyparloop import ParLoop as PyParLoop
42+
from .types import (
43+
Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet,
44+
Map, MixedMap, PermutedMap, Sparsity, Halo,
45+
Global, GlobalDataSet,
46+
Dat, MixedDat, DatView, Mat
47+
)
48+
from .types.access import READ, WRITE, RW, INC, MIN, MAX
49+
50+
from pyop2.parloop import par_loop, ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL
51+
from pyop2.kernel import Kernel
52+
53+
from pyop2.parloop import ParLoop as SeqParLoop, PyParLoop
5154

5255
import types
5356
import loopy

0 commit comments

Comments
 (0)