Skip to content

Commit 3ef962d

Browse files
committed
Add support for literal addition
1 parent 8ed16d1 commit 3ef962d

5 files changed

Lines changed: 150 additions & 11 deletions

File tree

mypy/checkexpr.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3520,6 +3520,9 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35203520
items=proper_left_type.items + [UnpackType(mapped)]
35213521
)
35223522

3523+
if e.op == "+" and (result := self.literal_expression_addition(e, left_type)):
3524+
return result
3525+
35233526
use_reverse: UseReverse = USE_REVERSE_DEFAULT
35243527
if e.op == "|":
35253528
if is_named_instance(proper_left_type, "builtins.dict"):
@@ -3580,6 +3583,57 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35803583
else:
35813584
raise RuntimeError(f"Unknown operator {e.op}")
35823585

3586+
def literal_value_from_expr(
3587+
self, expr: Expression, typ: Type | None = None
3588+
) -> tuple[list[str | int], str] | None:
3589+
if isinstance(expr, StrExpr):
3590+
return [expr.value], "builtins.str"
3591+
if isinstance(expr, IntExpr):
3592+
return [expr.value], "builtins.int"
3593+
if isinstance(expr, BytesExpr):
3594+
return [expr.value], "builtins.bytes"
3595+
3596+
typ = typ or self.accept(expr)
3597+
ptype = get_proper_type(typ)
3598+
3599+
if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)):
3600+
return [ptype.value], ptype.fallback.type.fullname
3601+
3602+
if isinstance(ptype, UnionType):
3603+
fallback: str | None = None
3604+
values: list[str | int] = []
3605+
for item in ptype.items:
3606+
pitem = get_proper_type(item)
3607+
if not isinstance(pitem, LiteralType) or isinstance(pitem.value, (float, bool)):
3608+
break
3609+
if fallback is None:
3610+
fallback = pitem.fallback.type.fullname
3611+
if fallback != pitem.fallback.type.fullname:
3612+
break
3613+
values.append(pitem.value)
3614+
else:
3615+
assert fallback is not None
3616+
return values, fallback
3617+
return None
3618+
3619+
def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None:
3620+
"""Check if literal values can be combined with addition."""
3621+
assert e.op == "+"
3622+
if not (lvalue := self.literal_value_from_expr(e.left, left_type)):
3623+
return None
3624+
if not (rvalue := self.literal_value_from_expr(e.right)) or lvalue[1] != rvalue[1]:
3625+
return None
3626+
3627+
values: list[int | str] = sorted(
3628+
{
3629+
val[0] + val[1] # type: ignore[operator]
3630+
for val in itertools.product(lvalue[0], rvalue[0])
3631+
}
3632+
)
3633+
if len(values) == 1:
3634+
return LiteralType(values[0], self.named_type(lvalue[1]))
3635+
return UnionType([LiteralType(val, self.named_type(lvalue[1])) for val in values])
3636+
35833637
def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
35843638
"""Type check a comparison expression.
35853639

test-data/unit/check-literal.test

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,19 +1407,19 @@ c: Literal[4]
14071407
d: Literal['foo']
14081408
e: str
14091409

1410-
reveal_type(a + a) # N: Revealed type is "builtins.int"
1410+
reveal_type(a + a) # N: Revealed type is "Literal[6]"
14111411
reveal_type(a + b) # N: Revealed type is "builtins.int"
14121412
reveal_type(b + a) # N: Revealed type is "builtins.int"
1413-
reveal_type(a + 1) # N: Revealed type is "builtins.int"
1414-
reveal_type(1 + a) # N: Revealed type is "builtins.int"
1415-
reveal_type(a + c) # N: Revealed type is "builtins.int"
1416-
reveal_type(c + a) # N: Revealed type is "builtins.int"
1413+
reveal_type(a + 1) # N: Revealed type is "Literal[4]"
1414+
reveal_type(1 + a) # N: Revealed type is "Literal[4]"
1415+
reveal_type(a + c) # N: Revealed type is "Literal[7]"
1416+
reveal_type(c + a) # N: Revealed type is "Literal[7]"
14171417

1418-
reveal_type(d + d) # N: Revealed type is "builtins.str"
1418+
reveal_type(d + d) # N: Revealed type is "Literal['foofoo']"
14191419
reveal_type(d + e) # N: Revealed type is "builtins.str"
14201420
reveal_type(e + d) # N: Revealed type is "builtins.str"
1421-
reveal_type(d + 'foo') # N: Revealed type is "builtins.str"
1422-
reveal_type('foo' + d) # N: Revealed type is "builtins.str"
1421+
reveal_type(d + 'bar') # N: Revealed type is "Literal['foobar']"
1422+
reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']"
14231423

14241424
reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int"
14251425
reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int"
@@ -2976,3 +2976,87 @@ x: Type[Literal[1]] # E: Type[...] can't contain "Literal[...]"
29762976
y: Type[Union[Literal[1], Literal[2]]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]"
29772977
z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]"
29782978
[builtins fixtures/tuple.pyi]
2979+
2980+
[case testLiteralAddition]
2981+
from typing import Union
2982+
from typing_extensions import Literal
2983+
2984+
str_a: Literal["a"]
2985+
str_b: Literal["b"]
2986+
str_union_1: Literal["a", "b"]
2987+
str_union_2: Literal["c", "d"]
2988+
s: str
2989+
int_1: Literal[1]
2990+
int_2: Literal[2]
2991+
int_union_1: Literal[1, 2]
2992+
int_union_2: Literal[3, 4]
2993+
i: int
2994+
bytes_a: Literal[b"a"]
2995+
bytes_b: Literal[b"b"]
2996+
bytes_union_1: Literal[b"a", b"b"]
2997+
bytes_union_2: Literal[b"c", b"d"]
2998+
b: bytes
2999+
3000+
misc_union: Literal["a", 1]
3001+
3002+
reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']"
3003+
reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']"
3004+
reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']"
3005+
reveal_type(str_union_1 + "b") # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
3006+
reveal_type(str_union_1 + str_b) # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
3007+
reveal_type("a" + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
3008+
reveal_type(str_a + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
3009+
reveal_type(str_union_1 + str_union_2) # N: Revealed type is "Union[Literal['ac'], Literal['ad'], Literal['bc'], Literal['bd']]"
3010+
reveal_type(str_a + s) # N: Revealed type is "builtins.str"
3011+
reveal_type(s + str_a) # N: Revealed type is "builtins.str"
3012+
reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str"
3013+
reveal_type(s + str_union_1) # N: Revealed type is "builtins.str"
3014+
3015+
reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]"
3016+
reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]"
3017+
reveal_type(1 + int_1) # N: Revealed type is "Literal[2]"
3018+
reveal_type(int_union_1 + 1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3019+
reveal_type(int_union_1 + int_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3020+
reveal_type(1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3021+
reveal_type(int_1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3022+
reveal_type(int_union_1 + int_union_2) # N: Revealed type is "Union[Literal[4], Literal[5], Literal[6]]"
3023+
reveal_type(int_1 + i) # N: Revealed type is "builtins.int"
3024+
reveal_type(i + int_1) # N: Revealed type is "builtins.int"
3025+
reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int"
3026+
reveal_type(i + int_union_1) # N: Revealed type is "builtins.int"
3027+
3028+
reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']"
3029+
reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']"
3030+
reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']"
3031+
reveal_type(bytes_union_1 + b"b") # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
3032+
reveal_type(bytes_union_1 + bytes_b) # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
3033+
reveal_type(b"a" + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
3034+
reveal_type(bytes_a + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
3035+
reveal_type(bytes_union_1 + bytes_union_2) # N: Revealed type is "Union[Literal[b'ac'], Literal[b'ad'], Literal[b'bc'], Literal[b'bd']]"
3036+
reveal_type(bytes_a + b) # N: Revealed type is "builtins.bytes"
3037+
reveal_type(b + bytes_a) # N: Revealed type is "builtins.bytes"
3038+
reveal_type(bytes_union_1 + b) # N: Revealed type is "builtins.bytes"
3039+
reveal_type(b + bytes_union_1) # N: Revealed type is "builtins.bytes"
3040+
3041+
reveal_type(misc_union + "a") # N: Revealed type is "Union[builtins.str, builtins.int]" \
3042+
# E: Unsupported operand types for + ("Literal[1]" and "str") \
3043+
# N: Left operand is of type "Literal['a', 1]"
3044+
reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and "Literal[1]") \
3045+
# N: Right operand is of type "Literal['a', 1]" \
3046+
# N: Revealed type is "builtins.str"
3047+
[builtins fixtures/primitives.pyi]
3048+
3049+
[case testLiteralAdditionTypedDict]
3050+
from typing import TypedDict
3051+
from typing_extensions import Literal
3052+
3053+
class LookupDict(TypedDict):
3054+
top_var: str
3055+
bottom_var: str
3056+
var: str
3057+
3058+
def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str:
3059+
return d[pos + "var"]
3060+
3061+
[builtins fixtures/dict.pyi]
3062+
[typing fixtures/typing-typeddict.pyi]

test-data/unit/cmdline.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,8 @@ test_between(1 + 1)
912912
tabs.py:2: error: Incompatible return value type (got "None", expected "str")
913913
return None
914914
^~~~
915-
tabs.py:4: error: Argument 1 to "test_between" has incompatible type "int";
916-
expected "str"
915+
tabs.py:4: error: Argument 1 to "test_between" has incompatible type
916+
"Literal[2]"; expected "str"
917917
test_between(1 + 1)
918918
^~~~~~~~~~~~
919919

test-data/unit/fixtures/primitives.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class str(Sequence[str]):
3434
def __getitem__(self, item: int) -> str: pass
3535
def format(self, *args: object, **kwargs: object) -> str: pass
3636
class bytes(Sequence[int]):
37+
def __add__(self, x: bytes) -> bytes: pass
3738
def __iter__(self) -> Iterator[int]: pass
3839
def __contains__(self, other: object) -> bool: pass
3940
def __getitem__(self, item: int) -> int: pass

test-data/unit/typexport-basic.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class str: pass
142142
class list: pass
143143
class dict: pass
144144
[out]
145-
OpExpr(3) : builtins.int
145+
OpExpr(3) : Literal[3]
146146
OpExpr(4) : builtins.float
147147
OpExpr(5) : builtins.float
148148
OpExpr(6) : builtins.float

0 commit comments

Comments
 (0)