|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import sys |
| 6 | +# sys.exit(1) |
| 7 | + |
5 | 8 | import enum |
6 | 9 | import itertools |
7 | 10 | import time |
|
32 | 35 | from mypy.literals import literal |
33 | 36 | from mypy.maptype import map_instance_to_supertype |
34 | 37 | from mypy.meet import is_overlapping_types, narrow_declared_type |
| 38 | +from mypy.subtypes import is_subtype |
35 | 39 | from mypy.message_registry import ErrorMessage |
36 | 40 | from mypy.messages import MessageBuilder, format_type |
37 | 41 | from mypy.nodes import ( |
@@ -1749,7 +1753,41 @@ def check_callable_call( |
1749 | 1753 | callee = callee.copy_modified(ret_type=fresh_ret_type) |
1750 | 1754 |
|
1751 | 1755 | if callee.is_generic(): |
| 1756 | + # import sys |
| 1757 | + # sys.stderr.write(f"DEBUG: Checking generic callee: {callee}\n") |
1752 | 1758 | callee = freshen_function_type_vars(callee) |
| 1759 | + ctx = self.type_context[-1] |
| 1760 | + if ctx: |
| 1761 | + p_ctx = get_proper_type(ctx) |
| 1762 | + if isinstance(p_ctx, UnionType): |
| 1763 | + # sys.stderr.write(f"DEBUG: Union Context found: {ctx}\n") |
| 1764 | + candidates = [] |
| 1765 | + for item in p_ctx.items: |
| 1766 | + candidate = self.infer_function_type_arguments_using_context( |
| 1767 | + callee, context, type_context=item |
| 1768 | + ) |
| 1769 | + # Filter out candidates that did not respect the context (e.g. remained generic |
| 1770 | + # or inferred something incompatible). |
| 1771 | + if is_subtype(candidate.ret_type, item): |
| 1772 | + candidates.append(candidate) |
| 1773 | + |
| 1774 | + if candidates: |
| 1775 | + # We use 'None' context to prevent infinite recursion when checking overloads |
| 1776 | + # provided one of the candidates remains generic. |
| 1777 | + self.type_context.append(None) |
| 1778 | + try: |
| 1779 | + return self.check_overload_call( |
| 1780 | + Overloaded(candidates), |
| 1781 | + args, |
| 1782 | + arg_kinds, |
| 1783 | + arg_names, |
| 1784 | + callable_name, |
| 1785 | + object_type, |
| 1786 | + context, |
| 1787 | + ) |
| 1788 | + finally: |
| 1789 | + self.type_context.pop() |
| 1790 | + |
1753 | 1791 | callee = self.infer_function_type_arguments_using_context(callee, context) |
1754 | 1792 |
|
1755 | 1793 | formal_to_actual = map_actuals_to_formals( |
@@ -1982,19 +2020,23 @@ def infer_arg_types_in_context( |
1982 | 2020 | return cast(list[Type], res) |
1983 | 2021 |
|
1984 | 2022 | def infer_function_type_arguments_using_context( |
1985 | | - self, callable: CallableType, error_context: Context |
| 2023 | + self, callable: CallableType, error_context: Context, type_context: Type | None = None |
1986 | 2024 | ) -> CallableType: |
1987 | 2025 | """Unify callable return type to type context to infer type vars. |
1988 | 2026 |
|
1989 | 2027 | For example, if the return type is set[t] where 't' is a type variable |
1990 | 2028 | of callable, and if the context is set[int], return callable modified |
1991 | 2029 | by substituting 't' with 'int'. |
1992 | 2030 | """ |
1993 | | - ctx = self.type_context[-1] |
| 2031 | + ctx: Type | None |
| 2032 | + if type_context: |
| 2033 | + ctx = type_context |
| 2034 | + else: |
| 2035 | + ctx = self.type_context[-1] |
1994 | 2036 | if not ctx: |
1995 | 2037 | return callable |
1996 | 2038 | # The return type may have references to type metavariables that |
1997 | | - # we are inferring right now. We must consider them as indeterminate |
| 2039 | + # we are inferred right now. We must consider them as indeterminate |
1998 | 2040 | # and they are not potential results; thus we replace them with the |
1999 | 2041 | # special ErasedType type. On the other hand, class type variables are |
2000 | 2042 | # valid results. |
@@ -3143,13 +3185,23 @@ def type_overrides_set( |
3143 | 3185 | ) -> Iterator[None]: |
3144 | 3186 | """Set _temporary_ type overrides for given expressions.""" |
3145 | 3187 | assert len(exprs) == len(overrides) |
| 3188 | + # Use a dict to store original values. This handles duplicates in exprs automatically |
| 3189 | + # by only storing the original value for the first occurrence (since we iterate and |
| 3190 | + # populate if not present). |
| 3191 | + original_values: dict[Expression, Type | None] = {} |
3146 | 3192 | for expr, typ in zip(exprs, overrides): |
| 3193 | + if expr not in original_values: |
| 3194 | + original_values[expr] = self.type_overrides.get(expr) |
3147 | 3195 | self.type_overrides[expr] = typ |
3148 | 3196 | try: |
3149 | 3197 | yield |
3150 | 3198 | finally: |
3151 | | - for expr in exprs: |
3152 | | - del self.type_overrides[expr] |
| 3199 | + for expr, prev in original_values.items(): |
| 3200 | + if prev is None: |
| 3201 | + if expr in self.type_overrides: |
| 3202 | + del self.type_overrides[expr] |
| 3203 | + else: |
| 3204 | + self.type_overrides[expr] = prev |
3153 | 3205 |
|
3154 | 3206 | def combine_function_signatures(self, types: list[ProperType]) -> AnyType | CallableType: |
3155 | 3207 | """Accepts a list of function signatures and attempts to combine them together into a |
|
0 commit comments