Skip to content

Commit 59257a2

Browse files
authored
[mypyc] Speed up ord(str[n]) by inlining (#20578)
The approach is similar to #20552, which added a fast inlined implementation of bytes get item. However, we do it for `ord(str[n])` instead of just `str[n]`, since the latter produces a reference-counted string of length 1, which is often too slow for performance-critical code. Later on the idea is to add a string writer class that supports quickly appending unicode code points represented as integers. This makes a micro-benchmark that finds the highest unicode code point in a string about 18x faster.
1 parent ae16cff commit 59257a2

9 files changed

Lines changed: 269 additions & 3 deletions

File tree

mypyc/doc/str_operations.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Operators
1919
---------
2020

2121
* Concatenation (``s1 + s2``)
22-
* Indexing (``s[n]``)
22+
* Indexing (``s[n]``; also ``ord(s[n])``, which avoids the temporary length-1 string)
2323
* Slicing (``s[n:m]``, ``s[n:]``, ``s[:m]``)
2424
* Comparisons (``==``, ``!=``)
2525
* Augmented assignment (``s1 += s2``)

mypyc/ir/deps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ def get_header(self) -> str:
5252
BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
5353
BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c")
5454
BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c")
55+
STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c")

mypyc/irbuild/specialize.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DictExpr,
2525
Expression,
2626
GeneratorExpr,
27+
IndexExpr,
2728
IntExpr,
2829
ListExpr,
2930
MemberExpr,
@@ -72,6 +73,8 @@
7273
is_int_rprimitive,
7374
is_list_rprimitive,
7475
is_sequence_rprimitive,
76+
is_str_rprimitive,
77+
is_tagged,
7578
is_uint8_rprimitive,
7679
list_rprimitive,
7780
object_rprimitive,
@@ -125,9 +128,12 @@
125128
bytes_decode_latin1_strict,
126129
bytes_decode_utf8_strict,
127130
isinstance_str,
131+
str_adjust_index_op,
128132
str_encode_ascii_strict,
129133
str_encode_latin1_strict,
130134
str_encode_utf8_strict,
135+
str_get_item_unsafe_as_int_op,
136+
str_range_check_op,
131137
)
132138
from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op
133139

@@ -1126,9 +1132,33 @@ def translate_float(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Valu
11261132
def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
11271133
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:
11281134
return None
1129-
arg = constant_fold_expr(builder, expr.args[0])
1135+
arg_expr = expr.args[0]
1136+
arg = constant_fold_expr(builder, arg_expr)
11301137
if isinstance(arg, (str, bytes)) and len(arg) == 1:
11311138
return Integer(ord(arg))
1139+
1140+
# Check for ord(s[i]) where s is str and i is an integer
1141+
if isinstance(arg_expr, IndexExpr):
1142+
# Check base type
1143+
base_type = builder.node_type(arg_expr.base)
1144+
if is_str_rprimitive(base_type):
1145+
# Check index type
1146+
index_expr = arg_expr.index
1147+
index_type = builder.node_type(index_expr)
1148+
if is_tagged(index_type) or is_fixed_width_rtype(index_type):
1149+
# This is ord(s[i]) where s is str and i is an integer.
1150+
# Generate specialized inline code using the helper.
1151+
result = translate_getitem_with_bounds_check(
1152+
builder,
1153+
arg_expr.base,
1154+
[arg_expr.index],
1155+
expr,
1156+
str_adjust_index_op,
1157+
str_range_check_op,
1158+
str_get_item_unsafe_as_int_op,
1159+
)
1160+
return result
1161+
11321162
return None
11331163

11341164

mypyc/lib-rt/str_extra_ops.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "str_extra_ops.h"
2+
3+
// All str extra ops are inline functions in str_extra_ops.h
4+
// This file exists to satisfy the SourceDep requirements

mypyc/lib-rt/str_extra_ops.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef MYPYC_STR_EXTRA_OPS_H
2+
#define MYPYC_STR_EXTRA_OPS_H
3+
4+
#include <Python.h>
5+
#include <stdint.h>
6+
#include "CPy.h"
7+
8+
// Optimized str indexing for ord(s[i])
9+
10+
// If index is negative, convert to non-negative index (no range checking)
11+
static inline int64_t CPyStr_AdjustIndex(PyObject *obj, int64_t index) {
12+
if (index < 0) {
13+
return index + PyUnicode_GET_LENGTH(obj);
14+
}
15+
return index;
16+
}
17+
18+
// Check if index is in valid range [0, len)
19+
static inline bool CPyStr_RangeCheck(PyObject *obj, int64_t index) {
20+
return index >= 0 && index < PyUnicode_GET_LENGTH(obj);
21+
}
22+
23+
// Get character at index as int (ord value) - no bounds checking, returns as CPyTagged
24+
static inline CPyTagged CPyStr_GetItemUnsafeAsInt(PyObject *obj, int64_t index) {
25+
int kind = PyUnicode_KIND(obj);
26+
return PyUnicode_READ(kind, PyUnicode_DATA(obj), index) << 1;
27+
}
28+
29+
#endif

mypyc/primitives/str_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from mypyc.ir.deps import STR_EXTRA_OPS
56
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
67
from mypyc.ir.rtypes import (
78
RType,
@@ -10,10 +11,12 @@
1011
bytes_rprimitive,
1112
c_int_rprimitive,
1213
c_pyssize_t_rprimitive,
14+
int64_rprimitive,
1315
int_rprimitive,
1416
list_rprimitive,
1517
object_rprimitive,
1618
pointer_rprimitive,
19+
short_int_rprimitive,
1720
str_rprimitive,
1821
tuple_rprimitive,
1922
)
@@ -507,3 +510,35 @@
507510
c_function_name="CPyStr_Ord",
508511
error_kind=ERR_MAGIC,
509512
)
513+
514+
# Optimized str indexing for ord(s[i])
515+
516+
# str index adjustment - convert negative index to positive
517+
str_adjust_index_op = custom_primitive_op(
518+
name="str_adjust_index",
519+
arg_types=[str_rprimitive, int64_rprimitive],
520+
return_type=int64_rprimitive,
521+
c_function_name="CPyStr_AdjustIndex",
522+
error_kind=ERR_NEVER,
523+
dependencies=[STR_EXTRA_OPS],
524+
)
525+
526+
# str range check - check if index is in valid range
527+
str_range_check_op = custom_primitive_op(
528+
name="str_range_check",
529+
arg_types=[str_rprimitive, int64_rprimitive],
530+
return_type=bool_rprimitive,
531+
c_function_name="CPyStr_RangeCheck",
532+
error_kind=ERR_NEVER,
533+
dependencies=[STR_EXTRA_OPS],
534+
)
535+
536+
# str.__getitem__() as int - get character at index as int (ord value) - no bounds checking
537+
str_get_item_unsafe_as_int_op = custom_primitive_op(
538+
name="str_get_item_unsafe_as_int",
539+
arg_types=[str_rprimitive, int64_rprimitive],
540+
return_type=short_int_rprimitive,
541+
c_function_name="CPyStr_GetItemUnsafeAsInt",
542+
error_kind=ERR_NEVER,
543+
dependencies=[STR_EXTRA_OPS],
544+
)

mypyc/test-data/irbuild-str.test

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,65 @@ L0:
507507
r6 = unbox(int, r5)
508508
return r6
509509

510+
[case testOrdOfStrIndex_64bit]
511+
from mypy_extensions import i64
512+
def ord_str_index(s: str, i: int) -> int:
513+
return ord(s[i])
514+
def ord_str_index_i64(s: str, i: i64) -> int:
515+
return ord(s[i])
516+
[typing fixtures/typing-full.pyi]
517+
[out]
518+
def ord_str_index(s, i):
519+
s :: str
520+
i :: int
521+
r0 :: native_int
522+
r1 :: bit
523+
r2, r3 :: i64
524+
r4 :: ptr
525+
r5 :: c_ptr
526+
r6, r7 :: i64
527+
r8, r9 :: bool
528+
r10 :: short_int
529+
L0:
530+
r0 = i & 1
531+
r1 = r0 == 0
532+
if r1 goto L1 else goto L2 :: bool
533+
L1:
534+
r2 = i >> 1
535+
r3 = r2
536+
goto L3
537+
L2:
538+
r4 = i ^ 1
539+
r5 = r4
540+
r6 = CPyLong_AsInt64(r5)
541+
r3 = r6
542+
keep_alive i
543+
L3:
544+
r7 = CPyStr_AdjustIndex(s, r3)
545+
r8 = CPyStr_RangeCheck(s, r7)
546+
if r8 goto L5 else goto L4 :: bool
547+
L4:
548+
r9 = raise IndexError('index out of range')
549+
unreachable
550+
L5:
551+
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
552+
return r10
553+
def ord_str_index_i64(s, i):
554+
s :: str
555+
i, r0 :: i64
556+
r1, r2 :: bool
557+
r3 :: short_int
558+
L0:
559+
r0 = CPyStr_AdjustIndex(s, i)
560+
r1 = CPyStr_RangeCheck(s, r0)
561+
if r1 goto L2 else goto L1 :: bool
562+
L1:
563+
r2 = raise IndexError('index out of range')
564+
unreachable
565+
L2:
566+
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
567+
return r3
568+
510569
[case testStrip]
511570
from typing import NewType, Union
512571
NewStr = NewType("NewStr", str)

mypyc/test-data/run-strings.test

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ def test_chr() -> None:
808808

809809
[case testOrd]
810810
from testutil import assertRaises
811+
from mypy_extensions import i64, i32, i16
811812

812813
def test_ord() -> None:
813814
assert ord(' ') == 32
@@ -834,6 +835,77 @@ def test_ord() -> None:
834835
with assertRaises(TypeError):
835836
ord('')
836837

838+
def test_ord_str_index() -> None:
839+
# ASCII
840+
s1 = "hello"
841+
assert ord(s1[0 + int()]) == 104 # 'h'
842+
assert ord(s1[1 + int()]) == 101 # 'e'
843+
assert ord(s1[4 + int()]) == 111 # 'o'
844+
assert ord(s1[-1 + int()]) == 111 # 'o'
845+
assert ord(s1[-5 + int()]) == 104 # 'h'
846+
847+
# Latin-1 (8 bits)
848+
s2 = "café"
849+
assert ord(s2[0 + int()]) == 99 # 'c'
850+
assert ord(s2[3 + int()]) == 233 # 'é' (U+00E9)
851+
assert ord(s2[-1 + int()]) == 233
852+
853+
# 16-bit unicode
854+
s3 = "你好" # Chinese
855+
assert ord(s3[0 + int()]) == 20320 # '你' (U+4F60)
856+
assert ord(s3[1 + int()]) == 22909 # '好' (U+597D)
857+
assert ord(s3[-1 + int()]) == 22909
858+
assert ord(s3[-2 + int()]) == 20320
859+
860+
# 4-byte unicode
861+
s5 = "a😀b" # Emoji between ASCII chars
862+
assert ord(s5[0 + int()]) == 97 # 'a'
863+
assert ord(s5[1 + int()]) == 128512 # '😀' (U+1F600)
864+
assert ord(s5[2 + int()]) == 98 # 'b'
865+
assert ord(s5[-1 + int()]) == 98
866+
assert ord(s5[-2 + int()]) == 128512
867+
assert ord(s5[-3 + int()]) == 97
868+
869+
with assertRaises(IndexError, "index out of range"):
870+
ord(s1[5 + int()])
871+
with assertRaises(IndexError, "index out of range"):
872+
ord(s1[100 + int()])
873+
with assertRaises(IndexError, "index out of range"):
874+
ord(s1[-6 + int()])
875+
with assertRaises(IndexError, "index out of range"):
876+
ord(s1[-100 + int()])
877+
878+
s_empty = ""
879+
with assertRaises(IndexError, "index out of range"):
880+
ord(s_empty[0 + int()])
881+
with assertRaises(IndexError, "index out of range"):
882+
ord(s_empty[-1 + int()])
883+
884+
def test_ord_str_index_i64() -> None:
885+
s = "hello"
886+
887+
idx_i64: i64 = 2 + int()
888+
assert ord(s[idx_i64]) == 108 # 'l'
889+
890+
idx_i64_neg: i64 = -1 + int()
891+
assert ord(s[idx_i64_neg]) == 111 # 'o'
892+
893+
idx_overflow: i64 = 10 + int()
894+
with assertRaises(IndexError, "index out of range"):
895+
ord(s[idx_overflow])
896+
897+
idx_underflow: i64 = -10 + int()
898+
with assertRaises(IndexError, "index out of range"):
899+
ord(s[idx_underflow])
900+
901+
def test_ord_str_index_unicode_mix() -> None:
902+
# Mix of 1-byte, 2-byte, 3-byte, and 4-byte characters
903+
s = "a\u00e9\u4f60😀" # 'a', 'é', '你', '😀'
904+
assert ord(s[0 + int()]) == 97 # 1-byte
905+
assert ord(s[1 + int()]) == 233 # 2-byte
906+
assert ord(s[2 + int()]) == 20320 # 3-byte
907+
assert ord(s[3 + int()]) == 128512 # 4-byte
908+
837909
[case testDecode]
838910
from testutil import assertRaises
839911

mypyc/test/test_cheader.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,23 @@
88
import unittest
99

1010
from mypyc.ir.deps import SourceDep
11-
from mypyc.primitives import registry
11+
from mypyc.ir.ops import PrimitiveDescription
12+
from mypyc.primitives import (
13+
bytearray_ops,
14+
bytes_ops,
15+
dict_ops,
16+
exc_ops,
17+
float_ops,
18+
generic_ops,
19+
int_ops,
20+
list_ops,
21+
misc_ops,
22+
registry,
23+
set_ops,
24+
str_ops,
25+
tuple_ops,
26+
weakref_ops,
27+
)
1228

1329

1430
class TestHeaderInclusion(unittest.TestCase):
@@ -35,6 +51,26 @@ def check_name(name: str) -> None:
3551
for ops in values:
3652
all_ops.extend(ops)
3753

54+
for module in [
55+
bytes_ops,
56+
str_ops,
57+
dict_ops,
58+
list_ops,
59+
bytearray_ops,
60+
generic_ops,
61+
int_ops,
62+
misc_ops,
63+
tuple_ops,
64+
exc_ops,
65+
float_ops,
66+
set_ops,
67+
weakref_ops,
68+
]:
69+
for name in dir(module):
70+
val = getattr(module, name, None)
71+
if isinstance(val, PrimitiveDescription):
72+
all_ops.append(val)
73+
3874
# Find additional headers via extra C source file dependencies.
3975
for op in all_ops:
4076
if op.dependencies:

0 commit comments

Comments
 (0)