Skip to content

Commit ef3824c

Browse files
committed
Narrow types based on collection containment
1 parent 66f83ea commit ef3824c

5 files changed

Lines changed: 159 additions & 18 deletions

File tree

mypy/checker.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,6 +6583,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65836583

65846584
partial_type_maps = []
65856585
for operator, expr_indices in simplified_operator_list:
6586+
if_map: TypeMap
6587+
else_map: TypeMap
6588+
65866589
if operator in {"is", "is not", "==", "!="}:
65876590
if_map, else_map = self.equality_type_narrowing_helper(
65886591
node,
@@ -6598,14 +6601,24 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65986601
item_type = operand_types[left_index]
65996602
iterable_type = operand_types[right_index]
66006603

6601-
if_map, else_map = {}, {}
6604+
if_map = {}
6605+
else_map = {}
66026606

66036607
if left_index in narrowable_operand_index_to_hash:
6604-
# We only try and narrow away 'None' for now
6605-
if is_overlapping_none(item_type):
6606-
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6608+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6609+
if collection_item_type is not None:
6610+
if_map, else_map = self.narrow_type_by_equality(
6611+
"==",
6612+
operands=[operands[left_index], operands[right_index]],
6613+
operand_types=[item_type, collection_item_type],
6614+
expr_indices=[left_index, right_index],
6615+
narrowable_indices={0},
6616+
)
6617+
6618+
# We only try and narrow away 'None' for now
66076619
if (
6608-
collection_item_type is not None
6620+
if_map is not None
6621+
and is_overlapping_none(item_type)
66096622
and not is_overlapping_none(collection_item_type)
66106623
and not (
66116624
isinstance(collection_item_type, Instance)
@@ -6622,11 +6635,11 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66226635
expr = operands[right_index]
66236636
if if_type is None:
66246637
if_map = None
6625-
else:
6638+
elif if_map is not None:
66266639
if_map[expr] = if_type
66276640
if else_type is None:
66286641
else_map = None
6629-
else:
6642+
elif else_map is not None:
66306643
else_map[expr] = else_type
66316644

66326645
else:

mypy/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def infer_constraints_for_callable(
124124
param_spec = callee.param_spec()
125125
param_spec_arg_types = []
126126
param_spec_arg_names = []
127-
param_spec_arg_kinds = []
127+
param_spec_arg_kinds: list[ArgKind] = []
128128

129129
incomplete_star_mapping = False
130130
for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`?

mypyc/irbuild/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Final, Literal, TypedDict, cast
5+
from typing import Any, Final, Literal, TypedDict
66
from typing_extensions import NotRequired
77

88
from mypy.nodes import (
@@ -138,7 +138,6 @@ def get_mypyc_attrs(
138138

139139
def set_mypyc_attr(key: str, value: Any, line: int) -> None:
140140
if key in MYPYC_ATTRS:
141-
key = cast(MypycAttr, key)
142141
attrs[key] = value
143142
lines[key] = line
144143
else:

test-data/unit/check-narrowing.test

Lines changed: 129 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,13 +1373,13 @@ else:
13731373
reveal_type(val) # N: Revealed type is "None"
13741374

13751375
if val in (None,):
1376-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1376+
reveal_type(val) # N: Revealed type is "None"
13771377
else:
1378-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1378+
reveal_type(val) # N: Revealed type is "__main__.A"
13791379
if val not in (None,):
1380-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1380+
reveal_type(val) # N: Revealed type is "__main__.A"
13811381
else:
1382-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1382+
reveal_type(val) # N: Revealed type is "None"
13831383

13841384
class Hmm:
13851385
def __eq__(self, other) -> bool: ...
@@ -2294,9 +2294,8 @@ def f(x: str | int) -> None:
22942294
y = x
22952295

22962296
if x in ["x"]:
2297-
# TODO: we should fix this reveal https://github.com/python/mypy/issues/3229
2298-
reveal_type(x) # N: Revealed type is "builtins.str | builtins.int"
2299-
y = x # E: Incompatible types in assignment (expression has type "str | int", variable has type "str")
2297+
reveal_type(x) # N: Revealed type is "builtins.str"
2298+
y = x
23002299
z = x
23012300
z = y
23022301
[builtins fixtures/primitives.pyi]
@@ -2806,3 +2805,126 @@ class X:
28062805
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
28072806
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
28082807
[builtins fixtures/dict.pyi]
2808+
2809+
2810+
[case testTypeNarrowingStringInLiteralUnion]
2811+
from typing import Literal, Tuple
2812+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2813+
x: str = "hi!"
2814+
if x in typ:
2815+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2816+
else:
2817+
reveal_type(x) # N: Revealed type is "builtins.str"
2818+
[builtins fixtures/tuple.pyi]
2819+
[typing fixtures/typing-medium.pyi]
2820+
2821+
[case testTypeNarrowingStringInLiteralUnionSubset]
2822+
from typing import Literal, Tuple
2823+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
2824+
strIn: str = "b"
2825+
strOut: str = "c"
2826+
if strIn in typeAlpha:
2827+
reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
2828+
else:
2829+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2830+
if strOut in typeAlpha:
2831+
reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
2832+
else:
2833+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2834+
[builtins fixtures/primitives.pyi]
2835+
[typing fixtures/typing-medium.pyi]
2836+
2837+
[case testNarrowingStringNotInLiteralUnion]
2838+
from typing import Literal, Tuple
2839+
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
2840+
strIn: str = "c"
2841+
strOut: str = "d"
2842+
if strIn not in typeAlpha:
2843+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2844+
else:
2845+
reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
2846+
if strOut in typeAlpha:
2847+
reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
2848+
else:
2849+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2850+
[builtins fixtures/primitives.pyi]
2851+
[typing fixtures/typing-medium.pyi]
2852+
2853+
[case testNarrowingStringInLiteralUnionDontExpand]
2854+
from typing import Literal, Tuple
2855+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
2856+
strIn: Literal['c'] = "c"
2857+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2858+
#Check we don't expand a Literal into the Union type
2859+
if strIn not in typeAlpha:
2860+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2861+
else:
2862+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2863+
[builtins fixtures/primitives.pyi]
2864+
[typing fixtures/typing-medium.pyi]
2865+
2866+
[case testTypeNarrowingStringInMixedUnion]
2867+
from typing import Literal, Tuple
2868+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2869+
x: str = "hi!"
2870+
if x in typ:
2871+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2872+
else:
2873+
reveal_type(x) # N: Revealed type is "builtins.str"
2874+
[builtins fixtures/tuple.pyi]
2875+
[typing fixtures/typing-medium.pyi]
2876+
2877+
[case testTypeNarrowingStringInSet]
2878+
from typing import Literal, Set
2879+
typ: Set[Literal['a', 'b']] = {'a', 'b'}
2880+
x: str = "hi!"
2881+
if x in typ:
2882+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2883+
else:
2884+
reveal_type(x) # N: Revealed type is "builtins.str"
2885+
if x not in typ:
2886+
reveal_type(x) # N: Revealed type is "builtins.str"
2887+
else:
2888+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2889+
[builtins fixtures/narrowing.pyi]
2890+
[typing fixtures/typing-medium.pyi]
2891+
2892+
[case testTypeNarrowingStringInList]
2893+
from typing import Literal, List
2894+
typ: List[Literal['a', 'b']] = ['a', 'b']
2895+
x: str = "hi!"
2896+
if x in typ:
2897+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2898+
else:
2899+
reveal_type(x) # N: Revealed type is "builtins.str"
2900+
if x not in typ:
2901+
reveal_type(x) # N: Revealed type is "builtins.str"
2902+
else:
2903+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2904+
[builtins fixtures/narrowing.pyi]
2905+
[typing fixtures/typing-medium.pyi]
2906+
2907+
[case testTypeNarrowingUnionStringFloat]
2908+
from typing import Union
2909+
def foobar(foo: Union[str, float]):
2910+
if foo in ['a', 'b']:
2911+
reveal_type(foo) # N: Revealed type is "builtins.str"
2912+
else:
2913+
reveal_type(foo) # N: Revealed type is "builtins.str | builtins.float"
2914+
[builtins fixtures/primitives.pyi]
2915+
[typing fixtures/typing-medium.pyi]
2916+
2917+
[case testNarrowAnyWithEqualityOrContainment]
2918+
# https://github.com/python/mypy/issues/17841
2919+
from typing import Any
2920+
2921+
def f1(x: Any) -> None:
2922+
if x is not None and x not in ["x"]:
2923+
return
2924+
reveal_type(x) # N: Revealed type is "Any"
2925+
2926+
def f2(x: Any) -> None:
2927+
if x is not None and x != "x":
2928+
return
2929+
reveal_type(x) # N: Revealed type is "Any"
2930+
[builtins fixtures/tuple.pyi]

test-data/unit/fixtures/narrowing.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Builtins stub used in check-narrowing test cases.
2-
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
2+
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable
33

44

55
Tco = TypeVar('Tco', covariant=True)
@@ -15,6 +15,13 @@ class function: pass
1515
class ellipsis: pass
1616
class int: pass
1717
class str: pass
18+
class float: pass
1819
class dict(Generic[KT, VT]): pass
1920

2021
def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass
22+
23+
class list(Sequence[Tco]):
24+
def __contains__(self, other: object) -> bool: pass
25+
class set(Iterable[Tco], Generic[Tco]):
26+
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
27+
def __contains__(self, item: object) -> bool: pass

0 commit comments

Comments
 (0)