Skip to content

Commit edfda9b

Browse files
committed
Add functools.lru_cache plugin support
- Add lru_cache callback to functools plugin for type validation - Register callbacks in default plugin for decorator and wrapper calls - Support different lru_cache patterns: @lru_cache, @lru_cache(), @lru_cache(maxsize=N) Fixes issue #16261
1 parent b69309b commit edfda9b

4 files changed

Lines changed: 421 additions & 0 deletions

File tree

mypy/plugins/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
)
5151
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
5252
from mypy.plugins.functools import (
53+
functools_lru_cache_callback,
5354
functools_total_ordering_maker_callback,
5455
functools_total_ordering_makers,
56+
lru_cache_wrapper_call_callback,
5557
partial_call_callback,
5658
partial_new_callback,
5759
)
@@ -103,6 +105,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
103105
return create_singledispatch_function_callback
104106
elif fullname == "functools.partial":
105107
return partial_new_callback
108+
elif fullname == "functools.lru_cache":
109+
return functools_lru_cache_callback
106110
elif fullname == "enum.member":
107111
return enum_member_callback
108112
return None
@@ -162,6 +166,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
162166
return call_singledispatch_function_after_register_argument
163167
elif fullname == "functools.partial.__call__":
164168
return partial_call_callback
169+
elif fullname == "functools._lru_cache_wrapper.__call__":
170+
return lru_cache_wrapper_call_callback
165171
return None
166172

167173
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

mypy/plugins/functools.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
ArgKind,
1717
Argument,
1818
CallExpr,
19+
Decorator,
20+
MemberExpr,
1921
NameExpr,
2022
Var,
2123
)
@@ -25,6 +27,8 @@
2527
AnyType,
2628
CallableType,
2729
Instance,
30+
LiteralType,
31+
NoneType,
2832
Overloaded,
2933
ParamSpecFlavor,
3034
ParamSpecType,
@@ -36,11 +40,13 @@
3640
get_proper_type,
3741
)
3842

43+
3944
functools_total_ordering_makers: Final = {"functools.total_ordering"}
4045

4146
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
4247

4348
PARTIAL: Final = "functools.partial"
49+
LRU_CACHE: Final = "functools.lru_cache"
4450

4551

4652
class _MethodInfo(NamedTuple):
@@ -393,3 +399,157 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
393399
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
394400

395401
return result
402+
403+
404+
def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
405+
"""Infer a more precise return type for functools.lru_cache decorator"""
406+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
407+
return ctx.default_return_type
408+
409+
# Only handle the very specific case: @lru_cache (without parentheses)
410+
# where a single function is passed directly as the only argument
411+
if (
412+
len(ctx.arg_types) == 1
413+
and len(ctx.arg_types[0]) == 1
414+
and len(ctx.args) == 1
415+
and len(ctx.args[0]) == 1
416+
):
417+
418+
first_arg_type = ctx.arg_types[0][0]
419+
420+
proper_first_arg_type = get_proper_type(first_arg_type)
421+
if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)):
422+
return ctx.default_return_type
423+
424+
# Try to extract callable type
425+
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
426+
if fn_type is not None:
427+
# This is the @lru_cache case (function passed directly)
428+
return fn_type
429+
430+
# For all other cases (parameterized, multiple args, etc.), don't interfere
431+
return ctx.default_return_type
432+
433+
434+
def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
435+
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
436+
if not isinstance(ctx.api, mypy.checker.TypeChecker):
437+
return ctx.default_return_type
438+
439+
# Safety check: ensure we have the required context
440+
if not ctx.context or not ctx.args or not ctx.arg_types:
441+
return ctx.default_return_type
442+
443+
# Try to find the original function signature using AST/symbol table analysis
444+
original_signature = _find_original_function_signature(ctx)
445+
446+
if original_signature is not None:
447+
# Validate the call against the original function signature
448+
actual_args = []
449+
actual_arg_kinds = []
450+
actual_arg_names = []
451+
seen_args = set()
452+
453+
for i, param in enumerate(ctx.args):
454+
for j, a in enumerate(param):
455+
if a in seen_args:
456+
continue
457+
seen_args.add(a)
458+
actual_args.append(a)
459+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
460+
actual_arg_names.append(ctx.arg_names[i][j])
461+
462+
# Check the call against the original signature
463+
result, _ = ctx.api.expr_checker.check_call(
464+
callee=original_signature,
465+
args=actual_args,
466+
arg_kinds=actual_arg_kinds,
467+
arg_names=actual_arg_names,
468+
context=ctx.context,
469+
)
470+
return result
471+
472+
return ctx.default_return_type
473+
474+
475+
def _get_callable_from_decorator(decorator_node: Decorator) -> CallableType | None:
476+
"""Extract the CallableType from a Decorator node if available."""
477+
if decorator_node.func is None:
478+
return None
479+
func_def = decorator_node.func
480+
if isinstance(func_def.type, CallableType):
481+
return func_def.type
482+
return None
483+
484+
485+
def _bind_method(method_type: CallableType, decorator_node: Decorator) -> CallableType:
486+
"""
487+
Bind a method by removing the self parameter for instance methods.
488+
489+
Static and class methods are returned unchanged.
490+
"""
491+
func_def = decorator_node.func
492+
if func_def is None:
493+
return method_type
494+
495+
# For instance methods, bind self by removing the first parameter
496+
if not func_def.is_static and not func_def.is_class and method_type.arg_types:
497+
return method_type.copy_modified(
498+
arg_types=method_type.arg_types[1:],
499+
arg_kinds=method_type.arg_kinds[1:],
500+
arg_names=method_type.arg_names[1:],
501+
)
502+
return method_type
503+
504+
505+
def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
506+
"""
507+
Find the original function signature from an lru_cache decorated function call.
508+
509+
Returns the CallableType of the original function if found, None otherwise.
510+
"""
511+
if not isinstance(ctx.context, CallExpr):
512+
return None
513+
514+
callee = ctx.context.callee
515+
516+
# Handle method calls (obj.method() or Class.method())
517+
if isinstance(callee, MemberExpr):
518+
method_name = callee.name
519+
if not method_name:
520+
return None
521+
522+
# Get the type of the object or class being accessed
523+
member_type = ctx.api.expr_checker.accept(callee.expr)
524+
proper_type = get_proper_type(member_type)
525+
526+
if not isinstance(proper_type, Instance):
527+
return None
528+
529+
# Look up the method in the class
530+
class_info = proper_type.type
531+
if method_name not in class_info.names:
532+
return None
533+
534+
symbol = class_info.names[method_name]
535+
if not isinstance(symbol.node, Decorator):
536+
return None
537+
538+
method_type = _get_callable_from_decorator(symbol.node)
539+
if method_type is None:
540+
return None
541+
542+
return _bind_method(method_type, symbol.node)
543+
544+
# Handle module-level function calls
545+
if isinstance(callee, NameExpr) and callee.name:
546+
if callee.name not in ctx.api.globals:
547+
return None
548+
549+
symbol = ctx.api.globals[callee.name]
550+
if not isinstance(symbol.node, Decorator):
551+
return None
552+
553+
return _get_callable_from_decorator(symbol.node)
554+
555+
return None

0 commit comments

Comments
 (0)