Skip to content

Commit d87132f

Browse files
authored
[mypyc] Fix coercion from short tagged int to fixed-width int (#20587)
This was unimplemented. Use `ord(s[i])` in the test cases, since this produces a short int. I used LLM assist but reviewed all the output manually.
1 parent d308c0b commit d87132f

File tree

3 files changed

+272
-39
lines changed

3 files changed

+272
-39
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 109 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,6 @@ def coerce_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -
459459
assert is_fixed_width_rtype(target_type), target_type
460460
assert isinstance(target_type, RPrimitive), target_type
461461

462-
res = Register(target_type)
463-
464462
fast, slow, end = BasicBlock(), BasicBlock(), BasicBlock()
465463

466464
check = self.check_tagged_short_int(src, line)
@@ -471,37 +469,20 @@ def coerce_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -
471469
size = target_type.size
472470
if size < int_rprimitive.size:
473471
# Add a range check when the target type is smaller than the source type
474-
fast2, fast3 = BasicBlock(), BasicBlock()
475-
upper_bound = 1 << (size * 8 - 1)
476-
if not target_type.is_signed:
477-
upper_bound *= 2
478-
check2 = self.add(ComparisonOp(src, Integer(upper_bound, src.type), ComparisonOp.SLT))
479-
self.add(Branch(check2, fast2, slow, Branch.BOOL))
480-
self.activate_block(fast2)
481-
if target_type.is_signed:
482-
lower_bound = -upper_bound
483-
else:
484-
lower_bound = 0
485-
check3 = self.add(ComparisonOp(src, Integer(lower_bound, src.type), ComparisonOp.SGE))
486-
self.add(Branch(check3, fast3, slow, Branch.BOOL))
487-
self.activate_block(fast3)
488-
tmp = self.int_op(
489-
c_pyssize_t_rprimitive,
490-
src,
491-
Integer(1, c_pyssize_t_rprimitive),
492-
IntOp.RIGHT_SHIFT,
493-
line,
472+
# Use helper method to generate range checking and conversion
473+
res = self.coerce_tagged_to_fixed_width_with_range_check(
474+
src, target_type, int_rprimitive.size, slow, end, line
494475
)
495-
tmp = self.add(Truncate(tmp, target_type))
496476
else:
477+
# No range check needed when target is same size or larger
497478
if size > int_rprimitive.size:
498479
tmp = self.add(Extend(src, target_type, signed=True))
499480
else:
500481
tmp = src
501482
tmp = self.int_op(target_type, tmp, Integer(1, target_type), IntOp.RIGHT_SHIFT, line)
502-
503-
self.add(Assign(res, tmp))
504-
self.goto(end)
483+
res = Register(target_type)
484+
self.add(Assign(res, tmp))
485+
self.goto(end)
505486

506487
self.activate_block(slow)
507488
if is_int64_rprimitive(target_type) or (
@@ -521,31 +502,122 @@ def coerce_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -
521502
self.add(Assign(res, tmp))
522503
self.add(KeepAlive([src]))
523504
self.goto(end)
524-
elif is_int32_rprimitive(target_type):
505+
else:
525506
# Slow path just always generates an OverflowError
507+
self.emit_fixed_width_overflow_error(target_type, line)
508+
509+
self.activate_block(end)
510+
return res
511+
512+
def coerce_tagged_to_fixed_width_with_range_check(
513+
self,
514+
src: Value,
515+
target_type: RType,
516+
source_size: int,
517+
overflow_block: BasicBlock,
518+
success_block: BasicBlock,
519+
line: int,
520+
) -> Register:
521+
"""Helper to convert a tagged value to a smaller fixed-width type with range checking.
522+
523+
This method generates IR for converting a tagged integer (like short_int or the fast
524+
path of int) to a smaller fixed-width type (i32, i16, uint8) with overflow detection.
525+
526+
The method performs range checks and branches to overflow_block on failure, or
527+
success_block on success (after assigning the result to a register).
528+
529+
Args:
530+
src: Tagged source value (with tag bit set)
531+
target_type: Target fixed-width type (must be smaller than source_size)
532+
source_size: Size in bytes of the source type
533+
overflow_block: Block to branch to on overflow
534+
success_block: Block to goto after successful conversion
535+
line: Line number
536+
537+
Returns:
538+
Result register containing the converted value (valid in success_block)
539+
"""
540+
assert is_fixed_width_rtype(target_type), target_type
541+
assert isinstance(target_type, RPrimitive), target_type
542+
size = target_type.size
543+
assert size < source_size, (target_type, size, source_size)
544+
545+
res = Register(target_type)
546+
in_range, in_range2 = BasicBlock(), BasicBlock()
547+
548+
# Calculate bounds for the target type (in tagged representation)
549+
upper_bound = 1 << (size * 8 - 1)
550+
if not target_type.is_signed:
551+
upper_bound *= 2
552+
553+
# Check if value < upper_bound
554+
check_upper = self.add(ComparisonOp(src, Integer(upper_bound, src.type), ComparisonOp.SLT))
555+
self.add(Branch(check_upper, in_range, overflow_block, Branch.BOOL))
556+
557+
self.activate_block(in_range)
558+
559+
# Check if value >= lower_bound
560+
if target_type.is_signed:
561+
lower_bound = -upper_bound
562+
else:
563+
lower_bound = 0
564+
check_lower = self.add(ComparisonOp(src, Integer(lower_bound, src.type), ComparisonOp.SGE))
565+
self.add(Branch(check_lower, in_range2, overflow_block, Branch.BOOL))
566+
567+
self.activate_block(in_range2)
568+
569+
# Value is in range - shift right to remove tag, then truncate
570+
shifted = self.int_op(
571+
c_pyssize_t_rprimitive,
572+
src,
573+
Integer(1, c_pyssize_t_rprimitive),
574+
IntOp.RIGHT_SHIFT,
575+
line,
576+
)
577+
tmp = self.add(Truncate(shifted, target_type))
578+
self.add(Assign(res, tmp))
579+
self.goto(success_block)
580+
581+
return res
582+
583+
def emit_fixed_width_overflow_error(self, target_type: RType, line: int) -> None:
584+
"""Emit overflow error for fixed-width type conversion."""
585+
if is_int32_rprimitive(target_type):
526586
self.call_c(int32_overflow, [], line)
527-
self.add(Unreachable())
528587
elif is_int16_rprimitive(target_type):
529-
# Slow path just always generates an OverflowError
530588
self.call_c(int16_overflow, [], line)
531-
self.add(Unreachable())
532589
elif is_uint8_rprimitive(target_type):
533-
# Slow path just always generates an OverflowError
534590
self.call_c(uint8_overflow, [], line)
535-
self.add(Unreachable())
536591
else:
537592
assert False, target_type
538-
539-
self.activate_block(end)
540-
return res
593+
self.add(Unreachable())
541594

542595
def coerce_short_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -> Value:
596+
# short_int (CPyTagged) is guaranteed to be a tagged value, never a pointer,
597+
# so we don't need the fast/slow path split like coerce_int_to_fixed_width.
598+
# However, we still need range checking when target type is smaller than source.
599+
assert is_fixed_width_rtype(target_type), target_type
600+
assert isinstance(target_type, RPrimitive), target_type
601+
543602
if is_int64_rprimitive(target_type) or (
544603
PLATFORM_SIZE == 4 and is_int32_rprimitive(target_type)
545604
):
605+
# No range check needed - target is same size or larger than source
546606
return self.int_op(target_type, src, Integer(1, target_type), IntOp.RIGHT_SHIFT, line)
547-
# TODO: i32 on 64-bit platform
548-
assert False, (src.type, target_type, PLATFORM_SIZE)
607+
608+
# Target is smaller than source - need range checking
609+
# Use helper method to generate range checking and conversion
610+
overflow, end = BasicBlock(), BasicBlock()
611+
res = self.coerce_tagged_to_fixed_width_with_range_check(
612+
src, target_type, short_int_rprimitive.size, overflow, end, line
613+
)
614+
615+
# Handle overflow case
616+
self.activate_block(overflow)
617+
self.emit_fixed_width_overflow_error(target_type, line)
618+
619+
self.activate_block(end)
620+
return res
549621

550622
def coerce_fixed_width_to_int(self, src: Value, line: int) -> Value:
551623
if (

mypyc/test-data/irbuild-str.test

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,19 @@ L0:
508508
return r6
509509

510510
[case testOrdOfStrIndex_64bit]
511-
from mypy_extensions import i64
511+
from mypy_extensions import i64, i32, i16, u8
512512
def ord_str_index(s: str, i: int) -> int:
513513
return ord(s[i])
514514
def ord_str_index_i64(s: str, i: i64) -> int:
515515
return ord(s[i])
516+
def ord_str_index_to_i32(s: str, i: i64) -> i32:
517+
return ord(s[i])
518+
def ord_str_index_to_i16(s: str, i: i64) -> i16:
519+
return ord(s[i])
520+
def ord_str_index_to_u8(s: str, i: i64) -> u8:
521+
return ord(s[i])
522+
def ord_str_index_to_i64(s: str, i: i64) -> i64:
523+
return ord(s[i])
516524
[typing fixtures/typing-full.pyi]
517525
[out]
518526
def ord_str_index(s, i):
@@ -565,6 +573,119 @@ L1:
565573
L2:
566574
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
567575
return r3
576+
def ord_str_index_to_i32(s, i):
577+
s :: str
578+
i, r0 :: i64
579+
r1, r2 :: bool
580+
r3 :: short_int
581+
r4, r5 :: bit
582+
r6 :: native_int
583+
r7, r8 :: i32
584+
L0:
585+
r0 = CPyStr_AdjustIndex(s, i)
586+
r1 = CPyStr_RangeCheck(s, r0)
587+
if r1 goto L2 else goto L1 :: bool
588+
L1:
589+
r2 = raise IndexError('index out of range')
590+
unreachable
591+
L2:
592+
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
593+
r4 = r3 < 4294967296 :: signed
594+
if r4 goto L3 else goto L5 :: bool
595+
L3:
596+
r5 = r3 >= -4294967296 :: signed
597+
if r5 goto L4 else goto L5 :: bool
598+
L4:
599+
r6 = r3 >> 1
600+
r7 = truncate r6: native_int to i32
601+
r8 = r7
602+
goto L6
603+
L5:
604+
CPyInt32_Overflow()
605+
unreachable
606+
L6:
607+
return r8
608+
def ord_str_index_to_i16(s, i):
609+
s :: str
610+
i, r0 :: i64
611+
r1, r2 :: bool
612+
r3 :: short_int
613+
r4, r5 :: bit
614+
r6 :: native_int
615+
r7, r8 :: i16
616+
L0:
617+
r0 = CPyStr_AdjustIndex(s, i)
618+
r1 = CPyStr_RangeCheck(s, r0)
619+
if r1 goto L2 else goto L1 :: bool
620+
L1:
621+
r2 = raise IndexError('index out of range')
622+
unreachable
623+
L2:
624+
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
625+
r4 = r3 < 65536 :: signed
626+
if r4 goto L3 else goto L5 :: bool
627+
L3:
628+
r5 = r3 >= -65536 :: signed
629+
if r5 goto L4 else goto L5 :: bool
630+
L4:
631+
r6 = r3 >> 1
632+
r7 = truncate r6: native_int to i16
633+
r8 = r7
634+
goto L6
635+
L5:
636+
CPyInt16_Overflow()
637+
unreachable
638+
L6:
639+
return r8
640+
def ord_str_index_to_u8(s, i):
641+
s :: str
642+
i, r0 :: i64
643+
r1, r2 :: bool
644+
r3 :: short_int
645+
r4, r5 :: bit
646+
r6 :: native_int
647+
r7, r8 :: u8
648+
L0:
649+
r0 = CPyStr_AdjustIndex(s, i)
650+
r1 = CPyStr_RangeCheck(s, r0)
651+
if r1 goto L2 else goto L1 :: bool
652+
L1:
653+
r2 = raise IndexError('index out of range')
654+
unreachable
655+
L2:
656+
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
657+
r4 = r3 < 512 :: signed
658+
if r4 goto L3 else goto L5 :: bool
659+
L3:
660+
r5 = r3 >= 0 :: signed
661+
if r5 goto L4 else goto L5 :: bool
662+
L4:
663+
r6 = r3 >> 1
664+
r7 = truncate r6: native_int to u8
665+
r8 = r7
666+
goto L6
667+
L5:
668+
CPyUInt8_Overflow()
669+
unreachable
670+
L6:
671+
return r8
672+
def ord_str_index_to_i64(s, i):
673+
s :: str
674+
i, r0 :: i64
675+
r1, r2 :: bool
676+
r3 :: short_int
677+
r4 :: i64
678+
L0:
679+
r0 = CPyStr_AdjustIndex(s, i)
680+
r1 = CPyStr_RangeCheck(s, r0)
681+
if r1 goto L2 else goto L1 :: bool
682+
L1:
683+
r2 = raise IndexError('index out of range')
684+
unreachable
685+
L2:
686+
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
687+
r4 = r3 >> 1
688+
return r4
568689

569690
[case testStrip]
570691
from typing import NewType, Union

mypyc/test-data/run-strings.test

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def test_chr() -> None:
808808

809809
[case testOrd]
810810
from testutil import assertRaises
811-
from mypy_extensions import i64, i32, i16
811+
from mypy_extensions import i64, i32, i16, u8
812812

813813
def test_ord() -> None:
814814
assert ord(' ') == 32
@@ -906,6 +906,46 @@ def test_ord_str_index_unicode_mix() -> None:
906906
assert ord(s[2 + int()]) == 20320 # 3-byte
907907
assert ord(s[3 + int()]) == 128512 # 4-byte
908908

909+
def test_ord_str_index_to_fixed_width() -> None:
910+
# Test i32 coercion (signed 32-bit: -2^31 to 2^31-1)
911+
# All valid Unicode code points fit in i32, so test min/max boundaries
912+
s_i32_min = chr(0)
913+
assert i32(ord(s_i32_min[0 + int()])) == 0
914+
915+
s_i32_max = chr(0x10FFFF) # Max Unicode code point
916+
assert i32(ord(s_i32_max[0 + int()])) == 0x10FFFF
917+
918+
# Test i16 coercion (signed 16-bit: -2^15 to 2^15-1, i.e., -32768 to 32767)
919+
# ord() returns non-negative, so test 0 to 32767
920+
s_i16_min = chr(0)
921+
assert i16(ord(s_i16_min[0 + int()])) == 0
922+
923+
s_i16_max = chr(32767) # 2^15 - 1
924+
assert i16(ord(s_i16_max[0 + int()])) == 32767
925+
926+
s_i16_overflow = chr(32768) # 2^15
927+
with assertRaises(ValueError, "int too large to convert to i16"):
928+
i16(ord(s_i16_overflow[0 + int()]))
929+
930+
s_i16_overflow2 = chr(32769) # 2^15 + 1
931+
with assertRaises(ValueError, "int too large to convert to i16"):
932+
i16(ord(s_i16_overflow2[0 + int()]))
933+
934+
# Test u8 coercion (unsigned 8-bit: 0 to 2^8-1, i.e., 0 to 255)
935+
s_u8_min = chr(0)
936+
assert u8(ord(s_u8_min[0 + int()])) == 0
937+
938+
s_u8_max = chr(255)
939+
assert u8(ord(s_u8_max[0 + int()])) == 255
940+
941+
s_u8_overflow = chr(256)
942+
with assertRaises(ValueError, "int too large or small to convert to u8"):
943+
u8(ord(s_u8_overflow[0 + int()]))
944+
945+
s_u8_overflow2 = chr(257)
946+
with assertRaises(ValueError, "int too large or small to convert to u8"):
947+
u8(ord(s_u8_overflow2[0 + int()]))
948+
909949
[case testDecode]
910950
from testutil import assertRaises
911951

0 commit comments

Comments
 (0)