Skip to content

Commit dd55b66

Browse files
[mypyc] feat: new primitive for int.to_bytes (#19674)
This PR adds a new primitive for all arg combinations of `int.to_bytes` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cf508c2 commit dd55b66

7 files changed

Lines changed: 222 additions & 1 deletion

File tree

mypyc/irbuild/specialize.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,12 @@
113113
)
114114
from mypyc.primitives.float_ops import isinstance_float
115115
from mypyc.primitives.generic_ops import generic_setattr, setup_object
116-
from mypyc.primitives.int_ops import isinstance_int
116+
from mypyc.primitives.int_ops import (
117+
int_to_big_endian_op,
118+
int_to_bytes_op,
119+
int_to_little_endian_op,
120+
isinstance_int,
121+
)
117122
from mypyc.primitives.librt_strings_ops import (
118123
bytes_writer_adjust_index_op,
119124
bytes_writer_get_item_unsafe_op,
@@ -1242,6 +1247,77 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr
12421247
return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line)
12431248

12441249

1250+
@specialize_function("to_bytes", int_rprimitive)
1251+
def specialize_int_to_bytes(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
1252+
# int.to_bytes(length, byteorder, signed=False)
1253+
if any(kind not in (ARG_POS, ARG_NAMED) for kind in expr.arg_kinds):
1254+
return None
1255+
if not isinstance(callee, MemberExpr):
1256+
return None
1257+
length_expr: Expression | None = None
1258+
byteorder_expr: Expression | None = None
1259+
signed_expr: Expression | None = None
1260+
positional_index = 0
1261+
for name, arg in zip(expr.arg_names, expr.args):
1262+
if name is None:
1263+
if positional_index == 0:
1264+
length_expr = arg
1265+
elif positional_index == 1:
1266+
byteorder_expr = arg
1267+
elif positional_index == 2:
1268+
signed_expr = arg
1269+
else:
1270+
return None
1271+
positional_index += 1
1272+
elif name == "length":
1273+
if length_expr is not None:
1274+
return None
1275+
length_expr = arg
1276+
elif name == "byteorder":
1277+
if byteorder_expr is not None:
1278+
return None
1279+
byteorder_expr = arg
1280+
elif name == "signed":
1281+
if signed_expr is not None:
1282+
return None
1283+
signed_expr = arg
1284+
else:
1285+
return None
1286+
if length_expr is None or byteorder_expr is None:
1287+
return None
1288+
1289+
signed_is_bool = True
1290+
if signed_expr is not None:
1291+
signed_is_bool = is_bool_rprimitive(builder.node_type(signed_expr))
1292+
if not (
1293+
is_int_rprimitive(builder.node_type(length_expr))
1294+
and is_str_rprimitive(builder.node_type(byteorder_expr))
1295+
and signed_is_bool
1296+
):
1297+
return None
1298+
1299+
self_arg = builder.accept(callee.expr)
1300+
length_arg = builder.accept(length_expr)
1301+
if signed_expr is None:
1302+
signed_arg = builder.false()
1303+
else:
1304+
signed_arg = builder.accept(signed_expr)
1305+
if isinstance(byteorder_expr, StrExpr):
1306+
if byteorder_expr.value == "little":
1307+
return builder.call_c(
1308+
int_to_little_endian_op, [self_arg, length_arg, signed_arg], expr.line
1309+
)
1310+
elif byteorder_expr.value == "big":
1311+
return builder.call_c(
1312+
int_to_big_endian_op, [self_arg, length_arg, signed_arg], expr.line
1313+
)
1314+
# Fallback to generic primitive op
1315+
byteorder_arg = builder.accept(byteorder_expr)
1316+
return builder.call_c(
1317+
int_to_bytes_op, [self_arg, length_arg, byteorder_arg, signed_arg], expr.line
1318+
)
1319+
1320+
12451321
def translate_getitem_with_bounds_check(
12461322
builder: IRBuilder,
12471323
base_expr: Expression,

mypyc/lib-rt/CPy.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ CPyTagged CPyTagged_BitwiseLongOp_(CPyTagged a, CPyTagged b, char op);
149149
CPyTagged CPyTagged_Rshift_(CPyTagged left, CPyTagged right);
150150
CPyTagged CPyTagged_Lshift_(CPyTagged left, CPyTagged right);
151151
CPyTagged CPyTagged_BitLength(CPyTagged self);
152+
PyObject *CPyTagged_ToBytes(CPyTagged self, Py_ssize_t length, PyObject *byteorder, int signed_flag);
153+
PyObject *CPyTagged_ToBigEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag);
154+
PyObject *CPyTagged_ToLittleEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag);
152155

153156
PyObject *CPyTagged_Str(CPyTagged n);
154157
CPyTagged CPyTagged_FromFloat(double f);

mypyc/lib-rt/int_ops.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,68 @@ double CPyTagged_TrueDivide(CPyTagged x, CPyTagged y) {
597597
return 1.0;
598598
}
599599

600+
static PyObject *CPyLong_ToBytes(PyObject *v, Py_ssize_t length, int little_endian, int signed_flag) {
601+
// This is a wrapper for PyLong_AsByteArray and PyBytes_FromStringAndSize
602+
PyObject *result = PyBytes_FromStringAndSize(NULL, length);
603+
if (!result) {
604+
return NULL;
605+
}
606+
unsigned char *bytes = (unsigned char *)PyBytes_AS_STRING(result);
607+
#if PY_VERSION_HEX >= 0x030D0000 // 3.13.0
608+
int res = _PyLong_AsByteArray((PyLongObject *)v, bytes, length, little_endian, signed_flag, 1);
609+
#else
610+
int res = _PyLong_AsByteArray((PyLongObject *)v, bytes, length, little_endian, signed_flag);
611+
#endif
612+
if (res < 0) {
613+
Py_DECREF(result);
614+
return NULL;
615+
}
616+
return result;
617+
}
618+
619+
// int.to_bytes(length, byteorder, signed=False)
620+
PyObject *CPyTagged_ToBytes(CPyTagged self, Py_ssize_t length, PyObject *byteorder, int signed_flag) {
621+
PyObject *pyint = CPyTagged_AsObject(self);
622+
if (!PyUnicode_Check(byteorder)) {
623+
Py_DECREF(pyint);
624+
PyErr_SetString(PyExc_TypeError, "byteorder must be str");
625+
return NULL;
626+
}
627+
const char *order = PyUnicode_AsUTF8(byteorder);
628+
if (!order) {
629+
Py_DECREF(pyint);
630+
return NULL;
631+
}
632+
int little_endian;
633+
if (strcmp(order, "big") == 0) {
634+
little_endian = 0;
635+
} else if (strcmp(order, "little") == 0) {
636+
little_endian = 1;
637+
} else {
638+
PyErr_SetString(PyExc_ValueError, "byteorder must be either 'little' or 'big'");
639+
return NULL;
640+
}
641+
PyObject *result = CPyLong_ToBytes(pyint, length, little_endian, signed_flag);
642+
Py_DECREF(pyint);
643+
return result;
644+
}
645+
646+
// int.to_bytes(length, byteorder="little", signed=False)
647+
PyObject *CPyTagged_ToLittleEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag) {
648+
PyObject *pyint = CPyTagged_AsObject(self);
649+
PyObject *result = CPyLong_ToBytes(pyint, length, 1, signed_flag);
650+
Py_DECREF(pyint);
651+
return result;
652+
}
653+
654+
// int.to_bytes(length, "big", signed=False)
655+
PyObject *CPyTagged_ToBigEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag) {
656+
PyObject *pyint = CPyTagged_AsObject(self);
657+
PyObject *result = CPyLong_ToBytes(pyint, length, 0, signed_flag);
658+
Py_DECREF(pyint);
659+
return result;
660+
}
661+
600662
// int.bit_length()
601663
CPyTagged CPyTagged_BitLength(CPyTagged self) {
602664
// Handle zero

mypyc/primitives/int_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
RType,
2222
bit_rprimitive,
2323
bool_rprimitive,
24+
bytes_rprimitive,
2425
c_pyssize_t_rprimitive,
2526
float_rprimitive,
2627
int16_rprimitive,
@@ -313,6 +314,34 @@ def int_unary_op(name: str, c_function_name: str) -> PrimitiveDescription:
313314
error_kind=ERR_NEVER,
314315
)
315316

317+
# specialized custom_op cases for int.to_bytes
318+
319+
# int.to_bytes(length, "big")
320+
# int.to_bytes(length, "big", signed=...)
321+
int_to_big_endian_op = custom_op(
322+
arg_types=[int_rprimitive, c_pyssize_t_rprimitive, bool_rprimitive],
323+
return_type=bytes_rprimitive,
324+
c_function_name="CPyTagged_ToBigEndianBytes",
325+
error_kind=ERR_MAGIC,
326+
)
327+
328+
# int.to_bytes(length, "little")
329+
# int.to_bytes(length, "little", signed=...)
330+
int_to_little_endian_op = custom_op(
331+
arg_types=[int_rprimitive, c_pyssize_t_rprimitive, bool_rprimitive],
332+
return_type=bytes_rprimitive,
333+
c_function_name="CPyTagged_ToLittleEndianBytes",
334+
error_kind=ERR_MAGIC,
335+
)
336+
337+
# int.to_bytes(length, byteorder, signed=...)
338+
int_to_bytes_op = custom_op(
339+
arg_types=[int_rprimitive, c_pyssize_t_rprimitive, str_rprimitive, bool_rprimitive],
340+
return_type=bytes_rprimitive,
341+
c_function_name="CPyTagged_ToBytes",
342+
error_kind=ERR_MAGIC,
343+
)
344+
316345
# int.bit_length()
317346
method_op(
318347
name="bit_length",

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __lt__(self, n: int) -> bool: pass
8787
def __gt__(self, n: int) -> bool: pass
8888
def __le__(self, n: int) -> bool: pass
8989
def __ge__(self, n: int) -> bool: pass
90+
def to_bytes(self, length: int, order: str, *, signed: bool = False) -> bytes: pass
9091
def bit_length(self) -> int: pass
9192

9293
class str:

mypyc/test-data/irbuild-int.test

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,35 @@ L0:
211211
x = r0
212212
return x
213213

214+
[case testIntToBytes]
215+
def f(x: int) -> bytes:
216+
return x.to_bytes(2, "big")
217+
def g(x: int) -> bytes:
218+
return x.to_bytes(4, "little", signed=True)
219+
def h(x: int, byteorder: str) -> bytes:
220+
return x.to_bytes(8, byteorder)
221+
222+
[out]
223+
def f(x):
224+
x :: int
225+
r0 :: bytes
226+
L0:
227+
r0 = CPyTagged_ToBigEndianBytes(x, 2, 0)
228+
return r0
229+
def g(x):
230+
x :: int
231+
r0 :: bytes
232+
L0:
233+
r0 = CPyTagged_ToLittleEndianBytes(x, 4, 1)
234+
return r0
235+
def h(x, byteorder):
236+
x :: int
237+
byteorder :: str
238+
r0 :: bytes
239+
L0:
240+
r0 = CPyTagged_ToBytes(x, 8, byteorder, 0)
241+
return r0
242+
214243
[case testIntBitLength]
215244
def f(x: int) -> int:
216245
return x.bit_length()

mypyc/test-data/run-integers.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,27 @@ class subc(int):
573573
class int:
574574
pass
575575

576+
[case testIntToBytes]
577+
from testutil import assertRaises
578+
def to_bytes(n: int, length: int, byteorder: str, signed: bool = False) -> bytes:
579+
return n.to_bytes(length, byteorder, signed=signed)
580+
def test_to_bytes() -> None:
581+
assert to_bytes(255, 2, "big") == b'\x00\xff', to_bytes(255, 2, "big")
582+
assert to_bytes(255, 2, "little") == b'\xff\x00', to_bytes(255, 2, "little")
583+
assert to_bytes(-1, 2, "big", True) == b'\xff\xff', to_bytes(-1, 2, "big", True)
584+
assert to_bytes(0, 1, "big") == b'\x00', to_bytes(0, 1, "big")
585+
# test with a value that does not fit in 64 bits
586+
assert to_bytes(10**30, 16, "big") == b'\x00\x00\x00\x0c\x9f,\x9c\xd0Ft\xed\xea@\x00\x00\x00', to_bytes(10**30, 16, "big")
587+
# unsigned, too large for 1 byte
588+
with assertRaises(OverflowError):
589+
to_bytes(256, 1, "big")
590+
# signed, too small for 1 byte
591+
with assertRaises(OverflowError):
592+
to_bytes(-129, 1, "big", True)
593+
# signed, too large for 1 byte
594+
with assertRaises(OverflowError):
595+
to_bytes(128, 1, "big", True)
596+
576597
[case testBitLength]
577598
def bit_length(n: int) -> int:
578599
return n.bit_length()

0 commit comments

Comments
 (0)