Skip to content

Commit 683a23c

Browse files
committed
compiler: Simplify get_printer
1 parent 1632ebe commit 683a23c

3 files changed

Lines changed: 9 additions & 22 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from devito.arch.compiler import AOMPCompiler
1919
from devito.symbolics.inspection import has_integer_args, sympy_dtype
2020
from devito.symbolics.queries import q_leaf
21-
from devito.tools import ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype
21+
from devito.tools import (
22+
ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype, memoized_func
23+
)
2224
from devito.types.basic import AbstractFunction
2325

2426
__all__ = ['BasePrinter', 'ccode', 'get_printer']
2527

26-
_preset_dtypes = (np.float32, np.float64, np.complex64, np.complex128)
27-
_printer_registry = {}
28-
2928

3029
class BasePrinter(CodePrinter):
3130

@@ -452,22 +451,9 @@ def _print_Fallback(self, expr):
452451
sympy.printing.str.StrPrinter._print_Add = BasePrinter._print_Add
453452

454453

455-
def get_printer(printer, dtype=None):
456-
try:
457-
registry = _printer_registry[printer]
458-
except KeyError:
459-
default = printer()
460-
registry = {None: default, default.dtype: default}
461-
for i in _preset_dtypes:
462-
registry.setdefault(i, printer(settings={'dtype': i}))
463-
_printer_registry[printer] = registry
464-
465-
try:
466-
return registry[dtype]
467-
except KeyError:
468-
handle = printer(settings={'dtype': dtype})
469-
registry[dtype] = handle
470-
return handle
454+
@memoized_func
455+
def get_printer(printer, dtype):
456+
return printer(settings={'dtype': dtype})
471457

472458

473459
def ccode(expr, printer=None, dtype=None):
@@ -489,4 +475,5 @@ def ccode(expr, printer=None, dtype=None):
489475
if printer is None:
490476
from devito.passes.iet.languages.C import CPrinter
491477
printer = CPrinter
478+
dtype = printer._default_settings['dtype'] if dtype is None else dtype
492479
return get_printer(printer, dtype).doprint(expr, None)

devito/ir/iet/visitors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def __init__(self, *args, printer=None, **kwargs):
257257
self.printer = printer
258258

259259
def ccode(self, expr, dtype=None):
260+
dtype = self.printer._default_settings['dtype'] if dtype is None else dtype
260261
return get_printer(self.printer, dtype).doprint(expr, None)
261262

262263
@property

tests/test_dtypes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,8 @@ def test_math_functions(dtype: np.dtype[np.inexact],
205205

206206

207207
def test_printer_registry() -> None:
208-
default = get_printer(CPrinter)
208+
default = get_printer(CPrinter, np.float32)
209209

210-
assert get_printer(CPrinter) is default
211210
assert get_printer(CPrinter, np.float32) is default
212211

213212
float64 = get_printer(CPrinter, np.float64)

0 commit comments

Comments
 (0)