1818from devito .arch .compiler import AOMPCompiler
1919from devito .symbolics .inspection import has_integer_args , sympy_dtype
2020from devito .symbolics .queries import q_leaf
21- from devito .tools import ctypes_to_cstr , ctypes_vector_mapper , dtype_to_ctype
21+ from devito .tools import (
22+ ctypes_to_cstr , ctypes_vector_mapper , dtype_to_ctype , memoized_func
23+ )
2224from devito .types .basic import AbstractFunction
2325
2426__all__ = ['BasePrinter' , 'ccode' , 'get_printer' ]
2527
26- _preset_dtypes = (np .float32 , np .float64 , np .complex64 , np .complex128 )
27- _printer_registry = {}
28-
2928
3029class BasePrinter (CodePrinter ):
3130
@@ -452,22 +451,9 @@ def _print_Fallback(self, expr):
452451 sympy .printing .str .StrPrinter ._print_Add = BasePrinter ._print_Add
453452
454453
455- def get_printer (printer , dtype = None ):
456- try :
457- registry = _printer_registry [printer ]
458- except KeyError :
459- default = printer ()
460- registry = {None : default , default .dtype : default }
461- for i in _preset_dtypes :
462- registry .setdefault (i , printer (settings = {'dtype' : i }))
463- _printer_registry [printer ] = registry
464-
465- try :
466- return registry [dtype ]
467- except KeyError :
468- handle = printer (settings = {'dtype' : dtype })
469- registry [dtype ] = handle
470- return handle
454+ @memoized_func
455+ def get_printer (printer , dtype ):
456+ return printer (settings = {'dtype' : dtype })
471457
472458
473459def ccode (expr , printer = None , dtype = None ):
@@ -489,4 +475,5 @@ def ccode(expr, printer=None, dtype=None):
489475 if printer is None :
490476 from devito .passes .iet .languages .C import CPrinter
491477 printer = CPrinter
478+ dtype = printer ._default_settings ['dtype' ] if dtype is None else dtype
492479 return get_printer (printer , dtype ).doprint (expr , None )
0 commit comments