Skip to content

Commit 55401f2

Browse files
authored
Better match narrowing for type objects (#20872)
This is the more general fix I alluded to in #20367 (comment) The tests I added do not yet have perfect behaviour, e.g. the type object unions should be allowed and should not have the branch marked as unreachable. I will open a PR for that next (it is easy, but the diff is thrashier so better for a separate PR) Fixes #18470 , fixes comment in #17133 (comment)
1 parent c09b174 commit 55401f2

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

mypy/checkpattern.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
560560
return self.early_non_match()
561561
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
562562
typ = fill_typevars_with_any(p_typ.type_object())
563+
type_range = TypeRange(typ, is_upper_bound=False)
563564
elif (
564565
isinstance(type_info, Var)
565566
and type_info.type is not None
@@ -569,8 +570,10 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
569570
fallback = self.chk.named_type("builtins.function")
570571
any_type = AnyType(TypeOfAny.unannotated)
571572
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
572-
elif isinstance(p_typ, TypeType) and isinstance(p_typ.item, NoneType):
573+
type_range = TypeRange(typ, is_upper_bound=False)
574+
elif isinstance(p_typ, TypeType):
573575
typ = p_typ.item
576+
type_range = TypeRange(p_typ.item, is_upper_bound=True)
574577
elif not isinstance(p_typ, AnyType):
575578
self.msg.fail(
576579
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
@@ -579,9 +582,11 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
579582
o,
580583
)
581584
return self.early_non_match()
585+
else:
586+
type_range = get_type_range(typ)
582587

583588
new_type, rest_type = self.chk.conditional_types_with_intersection(
584-
current_type, [get_type_range(typ)], o, default=current_type
589+
current_type, [type_range], o, default=current_type
585590
)
586591
if is_uninhabited(new_type):
587592
return self.early_non_match()

test-data/unit/check-python310.test

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ match x:
344344
case [str()]:
345345
pass
346346

347-
[case testMatchSequencePatternWithInvalidClassPattern]
347+
[case testMatchSequencePatternWithTypeObject]
348+
# flags: --strict-equality --warn-unreachable
348349
class Example:
349350
__match_args__ = ("value",)
350351
def __init__(self, value: str) -> None:
@@ -353,10 +354,10 @@ class Example:
353354
SubClass: type[Example]
354355

355356
match [SubClass("a"), SubClass("b")]:
356-
case [SubClass(value), *rest]: # E: Expected type in class pattern; found "type[__main__.Example]"
357-
reveal_type(value) # E: Cannot determine type of "value" \
358-
# N: Revealed type is "Any"
357+
case [SubClass(value), *rest]:
358+
reveal_type(value) # N: Revealed type is "builtins.str"
359359
reveal_type(rest) # N: Revealed type is "builtins.list[__main__.Example]"
360+
360361
[builtins fixtures/tuple.pyi]
361362

362363
# Narrowing union-based values via a literal pattern on an indexed/attribute subject
@@ -1257,6 +1258,84 @@ reveal_type(y) # N: Revealed type is "builtins.int"
12571258
reveal_type(z) # N: Revealed type is "builtins.int"
12581259
[builtins fixtures/dict-full.pyi]
12591260

1261+
[case testMatchClassPatternTypeObject]
1262+
# flags: --strict-equality --warn-unreachable
1263+
class Example:
1264+
__match_args__ = ("value",)
1265+
def __init__(self, value: str) -> None:
1266+
self.value = value
1267+
1268+
def f1(subclass: type[Example]) -> None:
1269+
match subclass("a"):
1270+
case Example(value):
1271+
reveal_type(value) # N: Revealed type is "builtins.str"
1272+
case anything:
1273+
reveal_type(anything) # E: Statement is unreachable
1274+
1275+
def f2(subclass: type[Example]) -> None:
1276+
match Example("a"):
1277+
case subclass(value):
1278+
reveal_type(value) # N: Revealed type is "builtins.str"
1279+
case anything:
1280+
reveal_type(anything) # N: Revealed type is "__main__.Example"
1281+
1282+
def f3(subclass: type[Example]) -> None:
1283+
match subclass("a"):
1284+
case subclass(value):
1285+
reveal_type(value) # N: Revealed type is "builtins.str"
1286+
case anything:
1287+
reveal_type(anything) # N: Revealed type is "__main__.Example"
1288+
1289+
class Example2:
1290+
__match_args__ = ("value",)
1291+
def __init__(self, value: str) -> None:
1292+
self.value = value
1293+
1294+
def f4(T: type[Example | Example2]) -> None:
1295+
match T("a"):
1296+
case Example(value):
1297+
reveal_type(value) # N: Revealed type is "builtins.str"
1298+
case anything:
1299+
reveal_type(anything) # N: Revealed type is "__main__.Example2"
1300+
1301+
def f5(T: type[Example | Example2]) -> None:
1302+
match Example("a"):
1303+
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
1304+
reveal_type(value) # E: Statement is unreachable
1305+
case anything:
1306+
reveal_type(anything) # N: Revealed type is "__main__.Example"
1307+
1308+
def f6(T: type[Example | Example2]) -> None:
1309+
match T("a"):
1310+
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
1311+
reveal_type(value) # E: Statement is unreachable
1312+
case anything:
1313+
reveal_type(anything) # N: Revealed type is "__main__.Example | __main__.Example2"
1314+
1315+
def f7(m: object, t: type[object]) -> None:
1316+
match m:
1317+
case t():
1318+
reveal_type(m) # N: Revealed type is "builtins.object"
1319+
case _:
1320+
reveal_type(m) # N: Revealed type is "builtins.object"
1321+
[builtins fixtures/tuple.pyi]
1322+
1323+
[case testMatchClassPatternTypeObjectGeneric]
1324+
# flags: --strict-equality --warn-unreachable
1325+
from typing import TypeVar
1326+
T = TypeVar("T")
1327+
1328+
def print_test(m: object, typ: type[T]) -> T:
1329+
match m:
1330+
case typ():
1331+
reveal_type(m) # N: Revealed type is "T`-1"
1332+
return m
1333+
case str():
1334+
reveal_type(m) # N: Revealed type is "builtins.str"
1335+
case _:
1336+
reveal_type(m) # N: Revealed type is "builtins.object"
1337+
raise
1338+
12601339
[case testMatchNonFinalMatchArgs]
12611340
class A:
12621341
__match_args__ = ("a", "b")

0 commit comments

Comments
 (0)