Skip to content

Commit 4ac3457

Browse files
committed
Improve narrowing with numeric types
I added tests for this change in #20709 so it's easier to see the diff. Like a few others, this was factored out of #20660 to make that one easier to land. The change in bytes narrowing is also desirable (but unfortunately only applies with --no-strict-bytes). We can figure out something else for that case
1 parent 79e9c78 commit 4ac3457

2 files changed

Lines changed: 38 additions & 12 deletions

File tree

mypy/checker.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6687,7 +6687,8 @@ def narrow_type_by_identity_equality(
66876687
target = TypeRange(target_type, is_upper_bound=False)
66886688

66896689
if_map, else_map = conditional_types_to_typemaps(
6690-
operands[i], *conditional_types(expr_type, [target])
6690+
operands[i],
6691+
*conditional_types(expr_type, [target], consider_promotion_overlap=True),
66916692
)
66926693
if is_target_for_value_narrowing(get_proper_type(target_type)):
66936694
all_if_maps.append(if_map)
@@ -6724,7 +6725,10 @@ def narrow_type_by_identity_equality(
67246725
target = TypeRange(target_type, is_upper_bound=False)
67256726
if is_target_for_value_narrowing(get_proper_type(target_type)):
67266727
if_map, else_map = conditional_types_to_typemaps(
6727-
operands[i], *conditional_types(expr_type, [target])
6728+
operands[i],
6729+
*conditional_types(
6730+
expr_type, [target], consider_promotion_overlap=True
6731+
),
67286732
)
67296733
all_else_maps.append(else_map)
67306734
continue
@@ -6754,7 +6758,10 @@ def narrow_type_by_identity_equality(
67546758
target = TypeRange(target_type, is_upper_bound=False)
67556759

67566760
if_map, else_map = conditional_types_to_typemaps(
6757-
operands[i], *conditional_types(expr_type, [target], default=expr_type)
6761+
operands[i],
6762+
*conditional_types(
6763+
expr_type, [target], default=expr_type, consider_promotion_overlap=True
6764+
),
67586765
)
67596766
or_if_maps.append(if_map)
67606767
if is_target_for_value_narrowing(get_proper_type(target_type)):
@@ -8256,6 +8263,7 @@ def conditional_types(
82568263
default: None = None,
82578264
*,
82588265
consider_runtime_isinstance: bool = True,
8266+
consider_promotion_overlap: bool = False,
82598267
) -> tuple[Type | None, Type | None]: ...
82608268

82618269

@@ -8266,6 +8274,7 @@ def conditional_types(
82668274
default: Type,
82678275
*,
82688276
consider_runtime_isinstance: bool = True,
8277+
consider_promotion_overlap: bool = False,
82698278
) -> tuple[Type, Type]: ...
82708279

82718280

@@ -8275,6 +8284,7 @@ def conditional_types(
82758284
default: Type | None = None,
82768285
*,
82778286
consider_runtime_isinstance: bool = True,
8287+
consider_promotion_overlap: bool = False,
82788288
) -> tuple[Type | None, Type | None]:
82798289
"""Takes in the current type and a proposed type of an expression.
82808290
@@ -8319,6 +8329,7 @@ def conditional_types(
83198329
proposed_type_ranges,
83208330
default=union_item,
83218331
consider_runtime_isinstance=consider_runtime_isinstance,
8332+
consider_promotion_overlap=consider_promotion_overlap,
83228333
)
83238334
yes_items.append(yes_type)
83248335
no_items.append(no_type)
@@ -8354,9 +8365,17 @@ def conditional_types(
83548365
consider_runtime_isinstance=consider_runtime_isinstance,
83558366
)
83568367
return default, remainder
8357-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8368+
if not is_overlapping_types(
8369+
current_type, proposed_type, ignore_promotions=not consider_promotion_overlap
8370+
):
83588371
# Expression is never of any type in proposed_type_ranges
83598372
return UninhabitedType(), default
8373+
if consider_promotion_overlap and not is_overlapping_types(
8374+
current_type, proposed_type, ignore_promotions=True
8375+
):
8376+
# We set consider_promotion_overlap when comparing equality. This is one of the places
8377+
# at runtime where subtyping with promotion does happen to match runtime semantics
8378+
return default, default
83608379
# we can only restrict when the type is precise, not bounded
83618380
proposed_precise_type = UnionType.make_union(
83628381
[type_range.item for type_range in proposed_type_ranges if not type_range.is_upper_bound]

test-data/unit/check-narrowing.test

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,30 +3567,37 @@ def convert_type(target_type: Type[TargetType]) -> TargetType:
35673567
from __future__ import annotations
35683568
from typing import Literal
35693569

3570-
# TODO: the behaviour on some of these test cases is incorrect
35713570
def f1(number: float, i: int):
35723571
if number == i:
35733572
reveal_type(number) # N: Revealed type is "builtins.float"
35743573
reveal_type(i) # N: Revealed type is "builtins.int"
35753574

35763575
def f2(number: float, five: Literal[5]):
35773576
if number == five:
3578-
reveal_type(number) # E: Statement is unreachable
3579-
reveal_type(five)
3577+
reveal_type(number) # N: Revealed type is "builtins.float"
3578+
reveal_type(five) # N: Revealed type is "Literal[5]"
35803579

35813580
def f3(number: float | int, five: Literal[5]):
35823581
if number == five:
3583-
reveal_type(number) # N: Revealed type is "Literal[5]"
3582+
reveal_type(number) # N: Revealed type is "builtins.float | Literal[5]"
3583+
reveal_type(five) # N: Revealed type is "Literal[5]"
3584+
3585+
def f8(number: float | Literal[5], five: Literal[5]):
3586+
if number == five:
3587+
reveal_type(number) # N: Revealed type is "builtins.float | Literal[5]"
3588+
reveal_type(five) # N: Revealed type is "Literal[5]"
3589+
else:
3590+
reveal_type(number) # N: Revealed type is "builtins.float"
35843591
reveal_type(five) # N: Revealed type is "Literal[5]"
35853592

35863593
def f4(number: float | None, i: int):
35873594
if number == i:
3588-
reveal_type(number) # N: Revealed type is "builtins.float | None"
3595+
reveal_type(number) # N: Revealed type is "builtins.float"
35893596
reveal_type(i) # N: Revealed type is "builtins.int"
35903597

35913598
def f5(number: float | int, i: int):
35923599
if number == i:
3593-
reveal_type(number) # N: Revealed type is "builtins.int"
3600+
reveal_type(number) # N: Revealed type is "builtins.float | builtins.int"
35943601
reveal_type(i) # N: Revealed type is "builtins.int"
35953602

35963603
def f6(number: float | complex, i: int):
@@ -3604,7 +3611,7 @@ class Custom:
36043611
def f7(number: float, x: Custom | int):
36053612
if number == x:
36063613
reveal_type(number) # N: Revealed type is "builtins.float"
3607-
reveal_type(x) # N: Revealed type is "__main__.Custom"
3614+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.int"
36083615
[builtins fixtures/primitives.pyi]
36093616

36103617
[case testNarrowingAnyNegativeIntersection-xfail]
@@ -3694,7 +3701,7 @@ def main(
36943701
reveal_type(v_memoryview) # N: Revealed type is "builtins.memoryview"
36953702

36963703
if v_all == v_bytes:
3697-
reveal_type(v_all) # N: Revealed type is "builtins.bytes"
3704+
reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview"
36983705
reveal_type(v_bytes) # N: Revealed type is "builtins.bytes"
36993706
if v_all == v_bytearray:
37003707
reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview"

0 commit comments

Comments
 (0)