Skip to content

Commit 72427a9

Browse files
committed
[mypyc] Add inline primitives for bytes.__getitem__
1 parent 1c8c009 commit 72427a9

4 files changed

Lines changed: 139 additions & 1 deletion

File tree

mypyc/irbuild/specialize.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,13 @@
9494
join_formatted_strings,
9595
tokenizer_format_call,
9696
)
97-
from mypyc.primitives.bytes_ops import isinstance_bytearray, isinstance_bytes
97+
from mypyc.primitives.bytes_ops import (
98+
bytes_adjust_index_op,
99+
bytes_get_item_unsafe_op,
100+
bytes_range_check_op,
101+
isinstance_bytearray,
102+
isinstance_bytes,
103+
)
98104
from mypyc.primitives.dict_ops import (
99105
dict_items_op,
100106
dict_keys_op,
@@ -1299,3 +1305,54 @@ def translate_bytes_writer_set_item(
12991305
)
13001306

13011307
return builder.none()
1308+
1309+
1310+
@specialize_dunder("__getitem__", bytes_rprimitive)
1311+
def translate_bytes_get_item(
1312+
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1313+
) -> Value | None:
1314+
"""Optimized bytes.__getitem__ implementation with bounds checking."""
1315+
# Check that we have exactly one argument
1316+
if len(args) != 1:
1317+
return None
1318+
1319+
# Get the bytes object
1320+
obj = builder.accept(base_expr)
1321+
1322+
# Get the index argument
1323+
index = builder.accept(args[0])
1324+
1325+
# Only use the optimized version for i64 index (requires experimental mode)
1326+
if not is_int64_rprimitive(index.type):
1327+
return None
1328+
1329+
# Adjust the index (handle negative indices)
1330+
adjusted_index = builder.primitive_op(
1331+
bytes_adjust_index_op, [obj, index], ctx_expr.line
1332+
)
1333+
1334+
# Check if the adjusted index is in valid range
1335+
range_check = builder.primitive_op(
1336+
bytes_range_check_op, [obj, adjusted_index], ctx_expr.line
1337+
)
1338+
1339+
# Create blocks for branching
1340+
valid_block = BasicBlock()
1341+
invalid_block = BasicBlock()
1342+
1343+
builder.add_bool_branch(range_check, valid_block, invalid_block)
1344+
1345+
# Handle invalid index - raise IndexError
1346+
builder.activate_block(invalid_block)
1347+
builder.add(
1348+
RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line)
1349+
)
1350+
builder.add(Unreachable())
1351+
1352+
# Handle valid index - get the item
1353+
builder.activate_block(valid_block)
1354+
result = builder.primitive_op(
1355+
bytes_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
1356+
)
1357+
1358+
return result

mypyc/lib-rt/bytes_extra_ops.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,30 @@
22
#define MYPYC_BYTES_EXTRA_OPS_H
33

44
#include <Python.h>
5+
#include <stdint.h>
56
#include "CPy.h"
67

78
// Optimized bytes translate operation
89
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);
910

11+
// Optimized bytes.__getitem__ operations
12+
13+
// If index is negative, convert to non-negative index (no range checking)
14+
static inline int64_t CPyBytes_AdjustIndex(PyObject *obj, int64_t index) {
15+
if (index < 0) {
16+
return index + Py_SIZE(obj);
17+
}
18+
return index;
19+
}
20+
21+
// Check if index is in valid range [0, len)
22+
static inline bool CPyBytes_RangeCheck(PyObject *obj, int64_t index) {
23+
return index >= 0 && index < Py_SIZE(obj);
24+
}
25+
26+
// Get byte at index (no bounds checking) - returns as CPyTagged
27+
static inline CPyTagged CPyBytes_GetItemUnsafe(PyObject *obj, int64_t index) {
28+
return ((CPyTagged)(uint8_t)(PyBytes_AS_STRING(obj))[index]) << 1;
29+
}
30+
1031
#endif

mypyc/primitives/bytes_ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@
1212
c_int_rprimitive,
1313
c_pyssize_t_rprimitive,
1414
dict_rprimitive,
15+
int64_rprimitive,
1516
int_rprimitive,
1617
list_rprimitive,
1718
object_rprimitive,
1819
str_rprimitive,
20+
uint8_rprimitive,
1921
)
2022
from mypyc.primitives.registry import (
2123
ERR_NEG_INT,
2224
binary_op,
2325
custom_op,
26+
custom_primitive_op,
2427
function_op,
2528
load_address_op,
2629
method_op,
@@ -166,3 +169,38 @@
166169
c_function_name="CPyBytes_Ord",
167170
error_kind=ERR_MAGIC,
168171
)
172+
173+
# Optimized bytes.__getitem__ operations
174+
175+
# bytes index adjustment - convert negative index to positive
176+
bytes_adjust_index_op = custom_primitive_op(
177+
name="bytes_adjust_index",
178+
arg_types=[bytes_rprimitive, int64_rprimitive],
179+
return_type=int64_rprimitive,
180+
c_function_name="CPyBytes_AdjustIndex",
181+
error_kind=ERR_NEVER,
182+
experimental=True,
183+
dependencies=[BYTES_EXTRA_OPS],
184+
)
185+
186+
# bytes range check - check if index is in valid range
187+
bytes_range_check_op = custom_primitive_op(
188+
name="bytes_range_check",
189+
arg_types=[bytes_rprimitive, int64_rprimitive],
190+
return_type=bool_rprimitive,
191+
c_function_name="CPyBytes_RangeCheck",
192+
error_kind=ERR_NEVER,
193+
experimental=True,
194+
dependencies=[BYTES_EXTRA_OPS],
195+
)
196+
197+
# bytes.__getitem__() - get byte at index (no bounds checking)
198+
bytes_get_item_unsafe_op = custom_primitive_op(
199+
name="bytes_get_item_unsafe",
200+
arg_types=[bytes_rprimitive, int64_rprimitive],
201+
return_type=int_rprimitive,
202+
c_function_name="CPyBytes_GetItemUnsafe",
203+
error_kind=ERR_NEVER,
204+
experimental=True,
205+
dependencies=[BYTES_EXTRA_OPS],
206+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,28 @@ L0:
113113
r0 = CPyBytes_GetItem(a, i)
114114
return r0
115115

116+
[case testBytesIndex_experimental_64bit]
117+
from mypy_extensions import i64
118+
119+
def f(a: bytes, i: i64) -> int:
120+
return a[i]
121+
[out]
122+
def f(a, i):
123+
a :: bytes
124+
i, r0 :: i64
125+
r1, r2 :: bool
126+
r3 :: int
127+
L0:
128+
r0 = CPyBytes_AdjustIndex(a, i)
129+
r1 = CPyBytes_RangeCheck(a, r0)
130+
if r1 goto L2 else goto L1 :: bool
131+
L1:
132+
r2 = raise IndexError('index out of range')
133+
unreachable
134+
L2:
135+
r3 = CPyBytes_GetItemUnsafe(a, r0)
136+
return r3
137+
116138
[case testBytesConcat]
117139
def f(a: bytes, b: bytes) -> bytes:
118140
return a + b

0 commit comments

Comments
 (0)