@@ -280,15 +280,15 @@ def rec_function_definition(
280280 self , expr : FunctionDefinition , * args : P .args , ** kwargs : P .kwargs
281281 ) -> FunctionResultT :
282282 """Call the mapper method of *expr* and return the result."""
283- method : Callable [..., FunctionResultT ] | None
284-
285283 try :
286- method = self .map_function_definition # type: ignore[attr-defined]
284+ method_name = "map_function_definition"
285+ method : Callable [..., FunctionResultT ] = cast (
286+ "Callable[..., FunctionResultT]" ,
287+ getattr (self , method_name ))
287288 except AttributeError :
288289 raise ValueError (
289290 f"{ type (self ).__name__ } lacks a mapper method for functions." ) from None
290291
291- assert method is not None
292292 return method (expr , * args , ** kwargs )
293293
294294 @overload
@@ -523,10 +523,11 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
523523 try :
524524 return self ._cache_retrieve (inputs )
525525 except KeyError :
526- # Intentionally going to Mapper instead of super() to avoid
527- # double caching when subclasses of CachedMapper override rec,
528- # see https://github.com/inducer/pytato/pull/585
529- return self ._cache_add (inputs , Mapper .rec (self , expr , * args , ** kwargs ))
526+ # Reminder: If overriding this in a subclass and reimplementing the cache
527+ # lookup logic there, must use super(CachedMapper, self) instead of
528+ # super() below to avoid double caching,
529+ # see https://github.com/inducer/pytato/pull/585.
530+ return self ._cache_add (inputs , super ().rec (expr , * args , ** kwargs ))
530531
531532 def rec_function_definition (
532533 self , expr : FunctionDefinition , * args : P .args , ** kwargs : P .kwargs
@@ -535,11 +536,12 @@ def rec_function_definition(
535536 try :
536537 return self ._function_cache_retrieve (inputs )
537538 except KeyError :
539+ # Reminder: If overriding this in a subclass and reimplementing the cache
540+ # lookup logic there, must use super(CachedMapper, self) instead of
541+ # super() below to avoid double caching,
542+ # see https://github.com/inducer/pytato/pull/585.
538543 return self ._function_cache_add (
539- # Intentionally going to Mapper instead of super() to avoid
540- # double caching when subclasses of CachedMapper override rec,
541- # see https://github.com/inducer/pytato/pull/585
542- inputs , Mapper .rec_function_definition (self , expr , * args , ** kwargs ))
544+ inputs , super ().rec_function_definition (expr , * args , ** kwargs ))
543545
544546 def clone_for_callee (
545547 self , function : FunctionDefinition ) -> Self :
@@ -1970,10 +1972,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
19701972 try :
19711973 return self ._cache_retrieve (inputs )
19721974 except KeyError :
1973- # Intentionally going to Mapper instead of super() to avoid
1974- # double caching when subclasses of CachedMapper override rec,
1975- # see https://github.com/inducer/pytato/pull/585
1976- return self . _cache_add ( inputs , Mapper .rec (self , self .map_fn (expr )))
1975+ # Using super(CachedMapper, self) instead of super() to bypass
1976+ # CachedMapper.rec and avoid double caching
1977+ return self . _cache_add ( inputs ,
1978+ super ( CachedMapper , self ) .rec (self .map_fn (expr )))
19771979
19781980# }}}
19791981
0 commit comments