Skip to content

Commit 8f67a54

Browse files
committed
Merge remote-tracking branch 'origin/master' into narrow57
2 parents 48fa581 + e858d5d commit 8f67a54

14 files changed

Lines changed: 253 additions & 66 deletions

mypy/checker.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6494,7 +6494,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
64946494
and not self.is_literal_enum(expr)
64956495
# CallableType type objects are usually already maximally specific
64966496
and not (
6497-
isinstance(p_expr := get_proper_type(expr_type), CallableType)
6497+
isinstance(p_expr := get_proper_type(expr_type), FunctionLike)
64986498
and p_expr.is_type_obj()
64996499
)
65006500
# This is a little ad hoc, in the absence of intersection types
@@ -6679,7 +6679,8 @@ def narrow_type_by_identity_equality(
66796679
else:
66806680
raise AssertionError
66816681

6682-
partial_type_maps = []
6682+
all_if_maps: list[TypeMap] = []
6683+
all_else_maps: list[TypeMap] = []
66836684

66846685
# For each narrowable index, we see what we can narrow based on each relevant target
66856686
for i in expr_indices:
@@ -6690,10 +6691,8 @@ def narrow_type_by_identity_equality(
66906691
continue
66916692

66926693
expr_type = operand_types[i]
6693-
expanded_expr_type = try_expanding_sum_type_to_union(
6694-
coerce_to_literal(expr_type), None
6695-
)
66966694
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
6695+
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
66976696
for j in expr_indices:
66986697
if i == j:
66996698
continue
@@ -6703,11 +6702,6 @@ def narrow_type_by_identity_equality(
67036702
continue
67046703
target_type = operand_types[j]
67056704
if should_coerce_literals:
6706-
# TODO: doing this prevents narrowing a single-member Enum to literal
6707-
# of its member, because we expand it here and then refuse to add equal
6708-
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6709-
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6710-
# See testMatchEnumSingleChoice
67116705
target_type = coerce_to_literal(target_type)
67126706

67136707
if (
@@ -6718,24 +6712,21 @@ def narrow_type_by_identity_equality(
67186712
continue
67196713

67206714
target = TypeRange(target_type, is_upper_bound=False)
6721-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
67226715

6723-
if is_value_target:
6724-
if_map, else_map = conditional_types_to_typemaps(
6725-
operands[i], *conditional_types(expanded_expr_type, [target])
6726-
)
6727-
partial_type_maps.append((if_map, else_map))
6716+
if_map, else_map = conditional_types_to_typemaps(
6717+
operands[i], *conditional_types(expr_type, [target])
6718+
)
6719+
if is_target_for_value_narrowing(get_proper_type(target_type)):
6720+
all_if_maps.append(if_map)
6721+
all_else_maps.append(else_map)
67286722
else:
6729-
if_map, else_map = conditional_types_to_typemaps(
6730-
operands[i], *conditional_types(expr_type, [target])
6731-
)
67326723
# For value targets, it is safe to narrow in the negative case.
67336724
# e.g. if (x: Literal[5] | None) != (y: Literal[5]), we can narrow x to None
67346725
# However, for non-value targets, we cannot do this narrowing,
67356726
# and so we ignore else_map
67366727
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6737-
if if_map:
6738-
partial_type_maps.append((if_map, {}))
6728+
if if_map is not None: # TODO: this gate is incorrect and should be removed
6729+
all_if_maps.append(if_map)
67396730

67406731
# Handle narrowing for operands with custom __eq__ methods specially
67416732
# In most cases, we won't be able to do any narrowing
@@ -6757,14 +6748,12 @@ def narrow_type_by_identity_equality(
67576748
if should_coerce_literals:
67586749
target_type = coerce_to_literal(target_type)
67596750
target = TypeRange(target_type, is_upper_bound=False)
6760-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
6761-
6762-
if is_value_target:
6751+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67636752
if_map, else_map = conditional_types_to_typemaps(
67646753
operands[i], *conditional_types(expr_type, [target])
67656754
)
67666755
if else_map:
6767-
partial_type_maps.append(({}, else_map))
6756+
all_else_maps.append(else_map)
67686757
continue
67696758

67706759
# If our operand with custom __eq__ is a union, where only some members of the union
@@ -6778,37 +6767,24 @@ def narrow_type_by_identity_equality(
67786767
# we narrow to in the if_map
67796768
or_if_maps.append({operands[i]: expr_type})
67806769

6770+
expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None))
67816771
for j in expr_indices:
67826772
if j in custom_eq_indices:
67836773
continue
67846774
target_type = operand_types[j]
67856775
if should_coerce_literals:
67866776
target_type = coerce_to_literal(target_type)
67876777
target = TypeRange(target_type, is_upper_bound=False)
6788-
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))
67896778

6790-
if is_value_target:
6791-
expr_type = coerce_to_literal(expr_type)
6792-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
67936779
if_map, else_map = conditional_types_to_typemaps(
67946780
operands[i], *conditional_types(expr_type, [target], default=expr_type)
67956781
)
67966782
or_if_maps.append(if_map)
6797-
if is_value_target:
6783+
if is_target_for_value_narrowing(get_proper_type(target_type)):
67986784
or_else_maps.append(else_map)
67996785

6800-
final_if_map: TypeMap = {}
6801-
final_else_map: TypeMap = {}
6802-
if or_if_maps:
6803-
final_if_map = or_if_maps[0]
6804-
for if_map in or_if_maps[1:]:
6805-
final_if_map = or_conditional_maps(final_if_map, if_map)
6806-
if or_else_maps:
6807-
final_else_map = or_else_maps[0]
6808-
for else_map in or_else_maps[1:]:
6809-
final_else_map = or_conditional_maps(final_else_map, else_map)
6810-
6811-
partial_type_maps.append((final_if_map, final_else_map))
6786+
all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))
6787+
all_else_maps.append(reduce_or_conditional_type_maps(or_else_maps))
68126788

68136789
# Handle narrowing for comparisons that produce additional narrowing, like
68146790
# `type(x) == T` or `x.__class__ is T`
@@ -6851,13 +6827,16 @@ def narrow_type_by_identity_equality(
68516827
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo)
68526828
else False
68536829
)
6854-
if not is_final:
6855-
else_map = {}
6856-
partial_type_maps.append((if_map, else_map))
6830+
all_if_maps.append(if_map)
6831+
if is_final:
6832+
# We can only narrow `type(x) == T` in the negative case if T is final
6833+
all_else_maps.append(else_map)
68576834

68586835
# We will not have duplicate entries in our type maps if we only have two operands,
68596836
# so we can skip running meets on the intersections
6860-
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)
6837+
if_map = reduce_and_conditional_type_maps(all_if_maps, use_meet=len(operands) > 2)
6838+
else_map = reduce_or_conditional_type_maps(all_else_maps)
6839+
return if_map, else_map
68616840

68626841
def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
68636842
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -8529,7 +8508,7 @@ def builtin_item_type(tp: Type) -> Type | None:
85298508
return None
85308509

85318510

8532-
def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> TypeMap:
8511+
def and_conditional_maps(m1: TypeMap, m2: TypeMap, *, use_meet: bool = False) -> TypeMap:
85338512
"""Calculate what information we can learn from the truth of (e1 and e2)
85348513
in terms of the information that we can learn from the truth of e1 and
85358514
the truth of e2.
@@ -8562,7 +8541,7 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty
85628541
return result
85638542

85648543

8565-
def or_conditional_maps(m1: TypeMap, m2: TypeMap, coalesce_any: bool = False) -> TypeMap:
8544+
def or_conditional_maps(m1: TypeMap, m2: TypeMap, *, coalesce_any: bool = False) -> TypeMap:
85668545
"""Calculate what information we can learn from the truth of (e1 or e2)
85678546
in terms of the information that we can learn from the truth of e1 and
85688547
the truth of e2. If coalesce_any is True, consider Any a supertype when
@@ -8627,6 +8606,30 @@ def reduce_conditional_maps(
86278606
return final_if_map, final_else_map
86288607

86298608

8609+
def reduce_or_conditional_type_maps(ms: list[TypeMap]) -> TypeMap:
8610+
"""Reduces a list of TypeMaps into a single TypeMap by "or"-ing them together."""
8611+
if len(ms) == 0:
8612+
return {}
8613+
if len(ms) == 1:
8614+
return ms[0]
8615+
result = ms[0]
8616+
for m in ms[1:]:
8617+
result = or_conditional_maps(result, m)
8618+
return result
8619+
8620+
8621+
def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> TypeMap:
8622+
"""Reduces a list of TypeMaps into a single TypeMap by "and"-ing them together."""
8623+
if len(ms) == 0:
8624+
return {}
8625+
if len(ms) == 1:
8626+
return ms[0]
8627+
result = ms[0]
8628+
for m in ms[1:]:
8629+
result = and_conditional_maps(result, m, use_meet=use_meet)
8630+
return result
8631+
8632+
86308633
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
86318634
"builtins.bytes",
86328635
"builtins.bytearray",
@@ -8681,12 +8684,13 @@ def flatten(t: Expression) -> list[Expression]:
86818684
def flatten_types_if_tuple(t: Type) -> list[Type]:
86828685
"""Flatten a nested sequence of tuples into one list of nodes."""
86838686
t = get_proper_type(t)
8687+
if isinstance(t, UnionType):
8688+
return [b for a in t.items for b in flatten_types(a)]
86848689
if isinstance(t, TupleType):
86858690
return [b for a in t.items for b in flatten_types_if_tuple(a)]
86868691
elif is_named_instance(t, "builtins.tuple"):
86878692
return [t.args[0]]
8688-
else:
8689-
return [t]
8693+
return [t]
86908694

86918695

86928696
def expand_func(defn: FuncItem, map: dict[TypeVarId, Type]) -> FuncItem:

mypy/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2688,7 +2688,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
26882688
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
26892689
self.fail("Class has two incompatible bases derived from tuple", defn)
26902690
defn.has_incompatible_baseclass = True
2691-
if info.special_alias and has_placeholder(info.special_alias.target):
2691+
if has_placeholder(base):
26922692
self.process_placeholder(
26932693
None, "tuple base", defn, force_progress=base != info.tuple_type
26942694
)

mypy/semanal_namedtuple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def build_namedtuple_typeinfo(
527527
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
528528
info.is_named_tuple = True
529529
tuple_base = TupleType(types, fallback)
530-
if info.special_alias and has_placeholder(info.special_alias.target):
530+
if has_placeholder(tuple_base):
531531
self.api.process_placeholder(
532532
None, "NamedTuple item", info, force_progress=tuple_base != info.tuple_type
533533
)

mypy/semanal_typeddict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def build_typeddict_typeinfo(
607607
assert fallback is not None
608608
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
609609
typeddict_type = TypedDictType(item_types, required_keys, readonly_keys, fallback)
610-
if info.special_alias and has_placeholder(info.special_alias.target):
610+
if has_placeholder(typeddict_type):
611611
self.api.process_placeholder(
612612
None, "TypedDict item", info, force_progress=typeddict_type != info.typeddict_type
613613
)

mypy/typeops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def is_singleton_identity_type(typ: ProperType) -> bool:
10061006
return typ.is_enum_literal() or isinstance(typ.value, bool)
10071007
if isinstance(typ, TypeType) and isinstance(typ.item, Instance) and typ.item.type.is_final:
10081008
return True
1009-
if isinstance(typ, CallableType) and typ.is_type_obj() and typ.type_object().is_final:
1009+
if isinstance(typ, FunctionLike) and typ.is_type_obj() and typ.type_object().is_final:
10101010
return True
10111011
return False
10121012

mypyc/analysis/ircheck.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def expect_non_float(self, op: Op, v: Value) -> None:
241241
if is_float_rprimitive(v.type):
242242
self.fail(op, "Float not expected")
243243

244+
def expect_primitive_type(self, op: Op, v: Value) -> None:
245+
if not isinstance(v.type, RPrimitive):
246+
self.fail(op, f"RPrimitive expected, got {type(v.type).__name__}")
247+
244248
def visit_goto(self, op: Goto) -> None:
245249
self.check_control_op_targets(op)
246250

@@ -397,13 +401,36 @@ def visit_load_global(self, op: LoadGlobal) -> None:
397401
pass
398402

399403
def visit_int_op(self, op: IntOp) -> None:
404+
self.expect_primitive_type(op, op.lhs)
405+
self.expect_primitive_type(op, op.rhs)
400406
self.expect_non_float(op, op.lhs)
401407
self.expect_non_float(op, op.rhs)
408+
left = op.lhs.type
409+
right = op.rhs.type
410+
op_str = op.op_str[op.op]
411+
if (
412+
isinstance(left, RPrimitive)
413+
and isinstance(right, RPrimitive)
414+
and left.is_signed != right.is_signed
415+
and (
416+
op_str in ("+", "-", "*", "/", "%")
417+
or (op_str not in ("<<", ">>") and left.size != right.size)
418+
)
419+
):
420+
self.fail(op, f"Operand types have incompatible signs: {left}, {right}")
402421

403422
def visit_comparison_op(self, op: ComparisonOp) -> None:
404423
self.check_compatibility(op, op.lhs.type, op.rhs.type)
405424
self.expect_non_float(op, op.lhs)
406425
self.expect_non_float(op, op.rhs)
426+
left = op.lhs.type
427+
right = op.rhs.type
428+
if (
429+
isinstance(left, RPrimitive)
430+
and isinstance(right, RPrimitive)
431+
and left.is_signed != right.is_signed
432+
):
433+
self.fail(op, f"Operand types have incompatible signs: {left}, {right}")
407434

408435
def visit_float_op(self, op: FloatOp) -> None:
409436
self.expect_float(op, op.lhs)

mypyc/ir/pprint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def format_blocks(
414414
lines.append("L%d:%s" % (block.label, handler_msg))
415415
if block in source_to_error:
416416
for error in source_to_error[block]:
417-
lines.append(f" ERR: {error}")
417+
lines.append(f" ERROR: {error}")
418418
ops = block.ops
419419
if (
420420
isinstance(ops[-1], Goto)
@@ -429,8 +429,11 @@ def format_blocks(
429429
line = " " + op.accept(visitor)
430430
lines.append(line)
431431
if op in source_to_error:
432+
first = len(lines) - 1
433+
# Use emojis to highlight the error
432434
for error in source_to_error[op]:
433-
lines.append(f" ERR: {error}")
435+
lines.append(f" \U0001f446 ERROR: {error}")
436+
lines[first] = " \U0000274c " + lines[first][4:]
434437

435438
if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
436439
# Each basic block needs to exit somewhere.

mypyc/irbuild/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,11 @@ def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget:
852852
If the function was not already defined somewhere, then define it
853853
and add it to the current environment.
854854
"""
855-
if fdef.original_def:
855+
if orig := fdef.original_def:
856+
if isinstance(orig, Decorator):
857+
orig = orig.func
856858
# Get the target associated with the previously defined FuncDef.
857-
return builder.lookup(fdef.original_def)
859+
return builder.lookup(orig)
858860

859861
if builder.fn_info.is_generator or builder.fn_info.add_nested_funcs_to_env:
860862
return builder.lookup(fdef)

mypyc/test-data/run-async.test

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,7 +1809,7 @@ from typing import Any, Callable, TypeVar, cast
18091809
F = TypeVar("F", bound=Callable[..., Any])
18101810

18111811

1812-
def mult(x: int) -> Callable[[F], F]:
1812+
def mult_different_wrapper_names(x: int) -> Callable[[F], F]:
18131813
def decorate(fn: F) -> F:
18141814
def get_multiplier() -> int:
18151815
return x
@@ -1829,18 +1829,47 @@ def mult(x: int) -> Callable[[F], F]:
18291829

18301830
return decorate
18311831

1832-
@mult(3)
1832+
def mult_same_wrapper_names(x: int) -> Callable[[F], F]:
1833+
def decorate(fn: F) -> F:
1834+
def get_multiplier() -> int:
1835+
return x
1836+
1837+
if inspect.iscoroutinefunction(fn):
1838+
@functools.wraps(fn)
1839+
async def wrapper(*args, **kwargs) -> Any:
1840+
return get_multiplier() * await fn(*args, **kwargs)
1841+
else:
1842+
@functools.wraps(fn)
1843+
def wrapper(*args, **kwargs) -> Any:
1844+
return get_multiplier() * fn(*args, **kwargs)
1845+
1846+
return cast(F, wrapper)
1847+
1848+
return decorate
1849+
1850+
@mult_different_wrapper_names(3)
18331851
def identity(x: int):
18341852
return x
18351853

1836-
@mult(5)
1854+
@mult_different_wrapper_names(5)
18371855
async def async_identity(x: int):
18381856
return x
18391857

1858+
@mult_same_wrapper_names(2)
1859+
def times_two(x: int):
1860+
return x * 2
1861+
1862+
@mult_same_wrapper_names(4)
1863+
async def async_times_two(x: int):
1864+
return x * 2
1865+
18401866
def test_nested_coroutine_calls_another_nested_function():
18411867
assert identity(1) == 3
18421868
assert asyncio.run(async_identity(2)) == 10
18431869

1870+
assert times_two(3) == 12
1871+
assert asyncio.run(async_times_two(4)) == 32
1872+
18441873
[file asyncio/__init__.pyi]
18451874
from typing import Any, Generator
18461875

0 commit comments

Comments
 (0)