Skip to content

Commit d2f9c63

Browse files
committed
Refactor __getitem__ specializations
1 parent 72427a9 commit d2f9c63

1 file changed

Lines changed: 63 additions & 57 deletions

File tree

mypyc/irbuild/specialize.py

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,30 +1212,56 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr
12121212
return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line)
12131213

12141214

1215-
@specialize_dunder("__getitem__", bytes_writer_rprimitive)
1216-
def translate_bytes_writer_get_item(
1217-
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1215+
def translate_getitem_with_bounds_check(
1216+
builder: IRBuilder,
1217+
base_expr: Expression,
1218+
args: list[Expression],
1219+
ctx_expr: Expression,
1220+
adjust_index_op: PrimitiveDescription,
1221+
range_check_op: PrimitiveDescription,
1222+
get_item_unsafe_op: PrimitiveDescription,
1223+
require_i64_index: bool = False,
12181224
) -> Value | None:
1219-
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
1225+
"""Shared helper for optimized __getitem__ with bounds checking.
1226+
1227+
This implements the common pattern of:
1228+
1. Adjusting negative indices
1229+
2. Checking if index is in valid range
1230+
3. Raising IndexError if out of range
1231+
4. Getting the item if in range
1232+
1233+
Args:
1234+
builder: The IR builder
1235+
base_expr: The base object expression
1236+
args: The arguments to __getitem__ (should be length 1)
1237+
ctx_expr: The context expression for line numbers
1238+
adjust_index_op: Primitive op to adjust negative indices
1239+
range_check_op: Primitive op to check if index is in valid range
1240+
get_item_unsafe_op: Primitive op to get item (no bounds checking)
1241+
require_i64_index: If True, only use optimization for i64 indices
1242+
1243+
Returns:
1244+
The result value, or None if optimization doesn't apply
1245+
"""
12201246
# Check that we have exactly one argument
12211247
if len(args) != 1:
12221248
return None
12231249

1224-
# Get the BytesWriter object
1250+
# Get the object
12251251
obj = builder.accept(base_expr)
12261252

12271253
# Get the index argument
12281254
index = builder.accept(args[0])
12291255

1256+
# If required, check that index is i64 (for experimental mode)
1257+
if require_i64_index and not is_int64_rprimitive(index.type):
1258+
return None
1259+
12301260
# Adjust the index (handle negative indices)
1231-
adjusted_index = builder.primitive_op(
1232-
bytes_writer_adjust_index_op, [obj, index], ctx_expr.line
1233-
)
1261+
adjusted_index = builder.primitive_op(adjust_index_op, [obj, index], ctx_expr.line)
12341262

12351263
# Check if the adjusted index is in valid range
1236-
range_check = builder.primitive_op(
1237-
bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line
1238-
)
1264+
range_check = builder.primitive_op(range_check_op, [obj, adjusted_index], ctx_expr.line)
12391265

12401266
# Create blocks for branching
12411267
valid_block = BasicBlock()
@@ -1252,13 +1278,27 @@ def translate_bytes_writer_get_item(
12521278

12531279
# Handle valid index - get the item
12541280
builder.activate_block(valid_block)
1255-
result = builder.primitive_op(
1256-
bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
1257-
)
1281+
result = builder.primitive_op(get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line)
12581282

12591283
return result
12601284

12611285

1286+
@specialize_dunder("__getitem__", bytes_writer_rprimitive)
1287+
def translate_bytes_writer_get_item(
1288+
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1289+
) -> Value | None:
1290+
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
1291+
return translate_getitem_with_bounds_check(
1292+
builder,
1293+
base_expr,
1294+
args,
1295+
ctx_expr,
1296+
bytes_writer_adjust_index_op,
1297+
bytes_writer_range_check_op,
1298+
bytes_writer_get_item_unsafe_op,
1299+
)
1300+
1301+
12621302
@specialize_dunder("__setitem__", bytes_writer_rprimitive)
12631303
def translate_bytes_writer_set_item(
12641304
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
@@ -1312,47 +1352,13 @@ def translate_bytes_get_item(
13121352
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
13131353
) -> Value | None:
13141354
"""Optimized bytes.__getitem__ implementation with bounds checking."""
1315-
# Check that we have exactly one argument
1316-
if len(args) != 1:
1317-
return None
1318-
1319-
# Get the bytes object
1320-
obj = builder.accept(base_expr)
1321-
1322-
# Get the index argument
1323-
index = builder.accept(args[0])
1324-
1325-
# Only use the optimized version for i64 index (requires experimental mode)
1326-
if not is_int64_rprimitive(index.type):
1327-
return None
1328-
1329-
# Adjust the index (handle negative indices)
1330-
adjusted_index = builder.primitive_op(
1331-
bytes_adjust_index_op, [obj, index], ctx_expr.line
1332-
)
1333-
1334-
# Check if the adjusted index is in valid range
1335-
range_check = builder.primitive_op(
1336-
bytes_range_check_op, [obj, adjusted_index], ctx_expr.line
1337-
)
1338-
1339-
# Create blocks for branching
1340-
valid_block = BasicBlock()
1341-
invalid_block = BasicBlock()
1342-
1343-
builder.add_bool_branch(range_check, valid_block, invalid_block)
1344-
1345-
# Handle invalid index - raise IndexError
1346-
builder.activate_block(invalid_block)
1347-
builder.add(
1348-
RaiseStandardError(RaiseStandardError.INDEX_ERROR, "index out of range", ctx_expr.line)
1349-
)
1350-
builder.add(Unreachable())
1351-
1352-
# Handle valid index - get the item
1353-
builder.activate_block(valid_block)
1354-
result = builder.primitive_op(
1355-
bytes_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
1355+
return translate_getitem_with_bounds_check(
1356+
builder,
1357+
base_expr,
1358+
args,
1359+
ctx_expr,
1360+
bytes_adjust_index_op,
1361+
bytes_range_check_op,
1362+
bytes_get_item_unsafe_op,
1363+
require_i64_index=True,
13561364
)
1357-
1358-
return result

0 commit comments

Comments
 (0)