Skip to content

Commit ad30be8

Browse files
authored
Rework narrowing logic for equality and identity (#20492)
Mypy does not narrow as much as it could, which results in false positives. We would also like to narrow based on containment. The PR for that was previously reverted due to inconsistencies between narrowing via equality and via containment. This fixes the inconsistency on the equality side and paves the road for adding narrowing via containment. That is, we lay groundwork for fixing #17864 and fixing #17841 Fixes #18524 Fixes #20041 Fixes #17162 Fixes #16830 Fixes #13704 Fixes #7642 Fixes #3964
1 parent 3cae656 commit ad30be8

13 files changed

Lines changed: 341 additions & 225 deletions

mypy/checker.py

Lines changed: 138 additions & 152 deletions
Large diffs are not rendered by default.

mypy/test/testargs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import sys
12+
from typing import Any, cast
1213

1314
from mypy.main import infer_python_executable, process_options
1415
from mypy.options import Options
@@ -63,7 +64,7 @@ def test_executable_inference(self) -> None:
6364

6465
# first test inferring executable from version
6566
options = Options()
66-
options.python_executable = None
67+
options.python_executable = cast(Any, None)
6768
options.python_version = sys.version_info[:2]
6869
infer_python_executable(options, special_opts)
6970
assert options.python_version == sys.version_info[:2]

mypy/test/testtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def test_generic_function_type(self) -> None:
194194

195195
def test_type_alias_expand_once(self) -> None:
196196
A, target = self.fx.def_alias_1(self.fx.a)
197-
assert get_proper_type(A) == target
198197
assert get_proper_type(target) == target
198+
assert get_proper_type(A) == target
199199

200200
A, target = self.fx.def_alias_2(self.fx.a)
201-
assert get_proper_type(A) == target
202201
assert get_proper_type(target) == target
202+
assert get_proper_type(A) == target
203203

204204
def test_recursive_nested_in_non_recursive(self) -> None:
205205
A, _ = self.fx.def_alias_1(self.fx.a)

mypy/typeops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ def is_singleton_type(typ: Type) -> bool:
10051005
return typ.is_singleton_type()
10061006

10071007

1008-
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> Type:
1008+
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str | None) -> Type:
10091009
"""Attempts to recursively expand any enum Instances with the given target_fullname
10101010
into a Union of all of its component LiteralTypes.
10111011
@@ -1034,7 +1034,9 @@ class Status(Enum):
10341034
]
10351035
return UnionType.make_union(items)
10361036

1037-
if isinstance(typ, Instance) and typ.type.fullname == target_fullname:
1037+
if isinstance(typ, Instance) and (
1038+
target_fullname is None or typ.type.fullname == target_fullname
1039+
):
10381040
if typ.type.fullname == "builtins.bool":
10391041
return UnionType([LiteralType(True, typ), LiteralType(False, typ)])
10401042

mypyc/test-data/run-classes.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4920,6 +4920,7 @@ def test_setattr() -> None:
49204920
assert i.one == 1
49214921
assert i.two == None
49224922
assert i.const == 42
4923+
i = i
49234924

49244925
i.__setattr__("two", "2")
49254926
assert i.two == "2"
@@ -4957,6 +4958,7 @@ def test_setattr_inherited() -> None:
49574958
assert i.one == 1
49584959
assert i.two == None
49594960
assert i.const == 42
4961+
i = i
49604962

49614963
i.__setattr__("two", "2")
49624964
assert i.two == "2"
@@ -4996,6 +4998,7 @@ def test_setattr_overridden() -> None:
49964998
assert i.one == 1
49974999
assert i.two == None
49985000
assert i.const == 42
5001+
i = i
49995002

50005003
i.__setattr__("two", "2")
50015004
assert i.two == "2"
@@ -5064,6 +5067,7 @@ def test_setattr_nonnative() -> None:
50645067
assert i.one == 1
50655068
assert i.two == None
50665069
assert i.const == 42
5070+
i = i
50675071

50685072
i.__setattr__("two", "2")
50695073
assert i.two == "2"
@@ -5134,6 +5138,8 @@ def test_no_setattr_nonnative() -> None:
51345138
object.__setattr__(i, "three", 102)
51355139
assert i.three == 102
51365140

5141+
i = i
5142+
51375143
del i.three
51385144
assert i.three == None
51395145

test-data/unit/check-enum.test

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ else:
10321032
reveal_type(z) # No output: this branch is unreachable
10331033
[builtins fixtures/bool.pyi]
10341034

1035-
[case testEnumReachabilityNoNarrowingForUnionMessiness]
1035+
[case testEnumReachabilityNarrowingForUnionMessiness]
10361036
from enum import Enum
10371037
from typing import Literal
10381038

@@ -1045,17 +1045,16 @@ x: Foo
10451045
y: Literal[Foo.A, Foo.B]
10461046
z: Literal[Foo.B, Foo.C]
10471047

1048-
# For the sake of simplicity, no narrowing is done when the narrower type is a Union.
10491048
if x is y:
1050-
reveal_type(x) # N: Revealed type is "__main__.Foo"
1049+
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.A] | Literal[__main__.Foo.B]"
10511050
reveal_type(y) # N: Revealed type is "Literal[__main__.Foo.A] | Literal[__main__.Foo.B]"
10521051
else:
10531052
reveal_type(x) # N: Revealed type is "__main__.Foo"
10541053
reveal_type(y) # N: Revealed type is "Literal[__main__.Foo.A] | Literal[__main__.Foo.B]"
10551054

10561055
if y is z:
1057-
reveal_type(y) # N: Revealed type is "Literal[__main__.Foo.A] | Literal[__main__.Foo.B]"
1058-
reveal_type(z) # N: Revealed type is "Literal[__main__.Foo.B] | Literal[__main__.Foo.C]"
1056+
reveal_type(y) # N: Revealed type is "Literal[__main__.Foo.B]"
1057+
reveal_type(z) # N: Revealed type is "Literal[__main__.Foo.B]"
10591058
else:
10601059
reveal_type(y) # N: Revealed type is "Literal[__main__.Foo.A] | Literal[__main__.Foo.B]"
10611060
reveal_type(z) # N: Revealed type is "Literal[__main__.Foo.B] | Literal[__main__.Foo.C]"

test-data/unit/check-flags.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2372,7 +2372,7 @@ x: int = "" # E: Incompatible types in assignment (expression has type "str", v
23722372
x: int = "" # E: Incompatible types in assignment (expression has type "str", variable has type "int")
23732373

23742374
[case testDisableBytearrayPromotion]
2375-
# flags: --disable-bytearray-promotion --strict-equality
2375+
# flags: --disable-bytearray-promotion --strict-equality --warn-unreachable
23762376
def f(x: bytes) -> None: ...
23772377
f(bytearray(b"asdf")) # E: Argument 1 to "f" has incompatible type "bytearray"; expected "bytes"
23782378
f(memoryview(b"asdf"))

test-data/unit/check-isinstance.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2729,7 +2729,7 @@ from typing import Union
27292729

27302730
y: str
27312731
if type(y) is int: # E: Subclass of "str" and "int" cannot exist: would have incompatible method signatures
2732-
y # E: Statement is unreachable
2732+
y # E: Statement is unreachable
27332733
else:
27342734
reveal_type(y) # N: Revealed type is "builtins.str"
27352735
[builtins fixtures/isinstance.pyi]
@@ -2760,6 +2760,7 @@ else:
27602760
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str"
27612761

27622762
[case testTypeEqualsMultipleTypesShouldntNarrow]
2763+
# flags: --warn-unreachable
27632764
# make sure we don't do any narrowing if there are multiple types being compared
27642765

27652766
from typing import Union

test-data/unit/check-narrowing.test

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,13 +1364,13 @@ class A: ...
13641364
val: Optional[A]
13651365

13661366
if val == None:
1367-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1367+
reveal_type(val) # N: Revealed type is "None"
13681368
else:
13691369
reveal_type(val) # N: Revealed type is "__main__.A"
13701370
if val != None:
13711371
reveal_type(val) # N: Revealed type is "__main__.A"
13721372
else:
1373-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1373+
reveal_type(val) # N: Revealed type is "None"
13741374

13751375
if val in (None,):
13761376
reveal_type(val) # N: Revealed type is "__main__.A | None"
@@ -1380,6 +1380,19 @@ if val not in (None,):
13801380
reveal_type(val) # N: Revealed type is "__main__.A | None"
13811381
else:
13821382
reveal_type(val) # N: Revealed type is "__main__.A | None"
1383+
1384+
class Hmm:
1385+
def __eq__(self, other) -> bool: ...
1386+
1387+
hmm: Optional[Hmm]
1388+
if hmm == None:
1389+
reveal_type(hmm) # N: Revealed type is "__main__.Hmm | None"
1390+
else:
1391+
reveal_type(hmm) # N: Revealed type is "__main__.Hmm"
1392+
if hmm != None:
1393+
reveal_type(hmm) # N: Revealed type is "__main__.Hmm"
1394+
else:
1395+
reveal_type(hmm) # N: Revealed type is "__main__.Hmm | None"
13831396
[builtins fixtures/primitives.pyi]
13841397

13851398
[case testNarrowingWithTupleOfTypes]
@@ -2277,13 +2290,13 @@ def f4(x: SE) -> None:
22772290
# https://github.com/python/mypy/issues/17864
22782291
def f(x: str | int) -> None:
22792292
if x == "x":
2280-
reveal_type(x) # N: Revealed type is "builtins.str | builtins.int"
2293+
reveal_type(x) # N: Revealed type is "builtins.str"
22812294
y = x
22822295

22832296
if x in ["x"]:
22842297
# TODO: we should fix this reveal https://github.com/python/mypy/issues/3229
22852298
reveal_type(x) # N: Revealed type is "builtins.str | builtins.int"
2286-
y = x
2299+
y = x # E: Incompatible types in assignment (expression has type "str | int", variable has type "str")
22872300
z = x
22882301
z = y
22892302
[builtins fixtures/primitives.pyi]
@@ -2699,3 +2712,97 @@ reveal_type(t.foo) # N: Revealed type is "__main__.D"
26992712
t.foo = C1()
27002713
reveal_type(t.foo) # N: Revealed type is "__main__.C"
27012714
[builtins fixtures/property.pyi]
2715+
2716+
[case testNarrowingNotImplemented]
2717+
from __future__ import annotations
2718+
from typing_extensions import Self
2719+
2720+
class X:
2721+
def __divmod__(self, other: Self | int) -> tuple[Self, Self]: ...
2722+
2723+
def __floordiv__(self, other: Self | int) -> Self:
2724+
qr = self.__divmod__(other)
2725+
if qr is NotImplemented:
2726+
return NotImplemented
2727+
return qr[0]
2728+
[builtins fixtures/notimplemented.pyi]
2729+
2730+
2731+
[case testNarrowingBooleans]
2732+
# flags: --warn-return-any
2733+
from typing import Any
2734+
2735+
def foo(x: dict[str, Any]) -> bool:
2736+
if x.get("event") is False:
2737+
return False
2738+
if x.get("event") is True:
2739+
return True
2740+
raise
2741+
[builtins fixtures/dict.pyi]
2742+
2743+
2744+
[case testNarrowingTypeObjects]
2745+
from __future__ import annotations
2746+
from typing import Callable, Any, TypeVar, Generic, Protocol
2747+
_T_co = TypeVar('_T_co', covariant=True)
2748+
2749+
class Boxxy(Protocol[_T_co]):
2750+
def get_box(self) -> _T_co: ...
2751+
2752+
class TupleLike(Generic[_T_co]):
2753+
def __init__(self, iterable: Boxxy[_T_co], /) -> None:
2754+
raise
2755+
2756+
class Box1(Generic[_T_co]):
2757+
def __init__(self, content: _T_co, /) -> None: ...
2758+
def get_box(self) -> _T_co: raise
2759+
2760+
class Box2(Generic[_T_co]):
2761+
def __init__(self, content: _T_co, /) -> None: ...
2762+
def get_box(self) -> _T_co: raise
2763+
2764+
def get_type(setting_name: str) -> Callable[[Box1], Any] | type[Any]:
2765+
raise
2766+
2767+
def main(key: str):
2768+
existing_value_type = get_type(key)
2769+
if existing_value_type is TupleLike:
2770+
reveal_type(TupleLike) # N: Revealed type is "def [_T_co] (__main__.Boxxy[_T_co`1]) -> __main__.TupleLike[_T_co`1]"
2771+
TupleLike(Box2("str"))
2772+
[builtins fixtures/tuple.pyi]
2773+
2774+
[case testNarrowingCollections]
2775+
# flags: --warn-unreachable
2776+
from typing import cast
2777+
2778+
class X:
2779+
def __init__(self) -> None:
2780+
self.x: dict[str, str] = {}
2781+
self.y: list[str] = []
2782+
2783+
def xxx(self) -> None:
2784+
if self.x == {}:
2785+
reveal_type(self.x) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
2786+
self.x["asdf"]
2787+
2788+
if self.x == dict():
2789+
reveal_type(self.x) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
2790+
self.x["asdf"]
2791+
2792+
if self.x == cast(dict[int, int], {}):
2793+
reveal_type(self.x) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
2794+
self.x["asdf"]
2795+
2796+
def yyy(self) -> None:
2797+
if self.y == []:
2798+
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
2799+
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
2800+
2801+
if self.y == list():
2802+
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
2803+
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
2804+
2805+
if self.y == cast(list[int], []):
2806+
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
2807+
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
2808+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)