Skip to content

Commit 659b8a7

Browse files
[mypyc] feat: extend get_expr_length for enumerate, map, zip, range, list, tuple, sorted, and reversed CallExpr [3/4] (#19927)
This PR is pretty simple, I just extended get_expr_length to work for a few more obvious cases: - `builtins.enumerate` - `builtins.map` - `builtins.zip` - `builtins.range` - `builtins.list` - `builtins.tuple` - `builtins.sorted` - `builtins.reversed` This PR is ready for review. Are you going to want tests for all of these? I didn't want to spend time now until I know for sure. All of the `get_expr_length` PRs are entirely independent and can be reviewed/merged in any order. --------- Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
1 parent 7fee02c commit 659b8a7

File tree

3 files changed

+167
-1
lines changed

3 files changed

+167
-1
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
StarExpr,
2626
TupleExpr,
2727
TypeAlias,
28+
Var,
2829
)
2930
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
3031
from mypyc.ir.ops import (
@@ -1235,10 +1236,50 @@ def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
12351236
return other + sum(stars) # type: ignore [arg-type]
12361237
elif isinstance(expr, StarExpr):
12371238
return get_expr_length(builder, expr.expr)
1239+
elif (
1240+
isinstance(expr, RefExpr)
1241+
and isinstance(expr.node, Var)
1242+
and expr.node.is_final
1243+
and isinstance(expr.node.final_value, str)
1244+
and expr.node.has_explicit_value
1245+
):
1246+
return len(expr.node.final_value)
1247+
elif (
1248+
isinstance(expr, CallExpr)
1249+
and isinstance(callee := expr.callee, NameExpr)
1250+
and all(kind == ARG_POS for kind in expr.arg_kinds)
1251+
):
1252+
fullname = callee.fullname
1253+
if (
1254+
fullname
1255+
in (
1256+
"builtins.list",
1257+
"builtins.tuple",
1258+
"builtins.enumerate",
1259+
"builtins.sorted",
1260+
"builtins.reversed",
1261+
)
1262+
and len(expr.args) == 1
1263+
):
1264+
return get_expr_length(builder, expr.args[0])
1265+
elif fullname == "builtins.map" and len(expr.args) == 2:
1266+
return get_expr_length(builder, expr.args[1])
1267+
elif fullname == "builtins.zip" and expr.args:
1268+
arg_lengths = [get_expr_length(builder, arg) for arg in expr.args]
1269+
if all(arg is not None for arg in arg_lengths):
1270+
return min(arg_lengths) # type: ignore [type-var]
1271+
elif fullname == "builtins.range" and len(expr.args) <= 3:
1272+
folded_args = [constant_fold_expr(builder, arg) for arg in expr.args]
1273+
if all(isinstance(arg, int) for arg in folded_args):
1274+
try:
1275+
return len(range(*cast(list[int], folded_args)))
1276+
except ValueError: # prevent crash if invalid args
1277+
pass
1278+
12381279
# TODO: extend this, passing length of listcomp and genexp should have worthwhile
12391280
# performance boost and can be (sometimes) figured out pretty easily. set and dict
12401281
# comps *can* be done as well but will need special logic to consider the possibility
1241-
# of key conflicts. Range, enumerate, zip are all simple logic.
1282+
# of key conflicts.
12421283

12431284
# we might still be able to get the length directly from the type
12441285
rtype = builder.node_type(expr)

mypyc/test-data/fixtures/ir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ def __iter__(self) -> Iterator[int]: pass
314314
def __len__(self) -> int: pass
315315
def __next__(self) -> int: pass
316316

317+
class map(Iterator[_S]):
318+
def __init__(self, func: Callable[[_T], _S], iterable: Iterable[_T]) -> None: pass
319+
def __iter__(self) -> Self: pass
320+
def __next__(self) -> _S: pass
321+
317322
class property:
318323
def __init__(self, fget: Optional[Callable[[Any], Any]] = ...,
319324
fset: Optional[Callable[[Any, Any], None]] = ...,

mypyc/test-data/irbuild-tuple.test

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,126 @@ L4:
898898
a = r1
899899
return 1
900900

901+
[case testTupleBuiltFromLengthCheckable]
902+
from typing import Tuple
903+
904+
def f(val: bool) -> bool:
905+
return not val
906+
907+
def test() -> None:
908+
# this tuple is created from a very complex genexp but we can still compute the length and preallocate the tuple
909+
a = tuple(
910+
x
911+
for x
912+
in zip(
913+
map(str, range(5)),
914+
enumerate(sorted(reversed(tuple("abcdefg"))))
915+
)
916+
)
917+
[out]
918+
def f(val):
919+
val, r0 :: bool
920+
L0:
921+
r0 = val ^ 1
922+
return r0
923+
def test():
924+
r0 :: list
925+
r1, r2, r3 :: object
926+
r4 :: object[1]
927+
r5 :: object_ptr
928+
r6 :: object
929+
r7 :: range
930+
r8 :: object
931+
r9 :: str
932+
r10 :: object
933+
r11 :: object[2]
934+
r12 :: object_ptr
935+
r13 :: object
936+
r14 :: str
937+
r15 :: tuple
938+
r16 :: object
939+
r17 :: str
940+
r18 :: object
941+
r19 :: object[1]
942+
r20 :: object_ptr
943+
r21 :: object
944+
r22 :: list
945+
r23 :: object
946+
r24 :: str
947+
r25 :: object
948+
r26 :: object[1]
949+
r27 :: object_ptr
950+
r28, r29 :: object
951+
r30 :: str
952+
r31 :: object
953+
r32 :: object[2]
954+
r33 :: object_ptr
955+
r34, r35, r36 :: object
956+
r37, x :: tuple[str, tuple[int, str]]
957+
r38 :: object
958+
r39 :: i32
959+
r40, r41 :: bit
960+
r42, a :: tuple
961+
L0:
962+
r0 = PyList_New(0)
963+
r1 = load_address PyUnicode_Type
964+
r2 = load_address PyRange_Type
965+
r3 = object 5
966+
r4 = [r3]
967+
r5 = load_address r4
968+
r6 = PyObject_Vectorcall(r2, r5, 1, 0)
969+
keep_alive r3
970+
r7 = cast(range, r6)
971+
r8 = builtins :: module
972+
r9 = 'map'
973+
r10 = CPyObject_GetAttr(r8, r9)
974+
r11 = [r1, r7]
975+
r12 = load_address r11
976+
r13 = PyObject_Vectorcall(r10, r12, 2, 0)
977+
keep_alive r1, r7
978+
r14 = 'abcdefg'
979+
r15 = PySequence_Tuple(r14)
980+
r16 = builtins :: module
981+
r17 = 'reversed'
982+
r18 = CPyObject_GetAttr(r16, r17)
983+
r19 = [r15]
984+
r20 = load_address r19
985+
r21 = PyObject_Vectorcall(r18, r20, 1, 0)
986+
keep_alive r15
987+
r22 = CPySequence_Sort(r21)
988+
r23 = builtins :: module
989+
r24 = 'enumerate'
990+
r25 = CPyObject_GetAttr(r23, r24)
991+
r26 = [r22]
992+
r27 = load_address r26
993+
r28 = PyObject_Vectorcall(r25, r27, 1, 0)
994+
keep_alive r22
995+
r29 = builtins :: module
996+
r30 = 'zip'
997+
r31 = CPyObject_GetAttr(r29, r30)
998+
r32 = [r13, r28]
999+
r33 = load_address r32
1000+
r34 = PyObject_Vectorcall(r31, r33, 2, 0)
1001+
keep_alive r13, r28
1002+
r35 = PyObject_GetIter(r34)
1003+
L1:
1004+
r36 = PyIter_Next(r35)
1005+
if is_error(r36) goto L4 else goto L2
1006+
L2:
1007+
r37 = unbox(tuple[str, tuple[int, str]], r36)
1008+
x = r37
1009+
r38 = box(tuple[str, tuple[int, str]], x)
1010+
r39 = PyList_Append(r0, r38)
1011+
r40 = r39 >= 0 :: signed
1012+
L3:
1013+
goto L1
1014+
L4:
1015+
r41 = CPy_NoErrOccurred()
1016+
L5:
1017+
r42 = PyList_AsTuple(r0)
1018+
a = r42
1019+
return 1
1020+
9011021
[case testTupleBuiltFromStars]
9021022
from typing import Final
9031023

0 commit comments

Comments
 (0)