diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index c79923f69e691..935852bcf30e5 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -770,6 +770,7 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors); Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start); Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end); CPyTagged CPyStr_Ord(PyObject *obj); +PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count); // Bytes operations @@ -781,6 +782,7 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index); PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); +PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 6ff34b021a9a3..8ecf9337c28b8 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -162,3 +162,12 @@ CPyTagged CPyBytes_Ord(PyObject *obj) { PyErr_SetString(PyExc_TypeError, "ord() expects a character"); return CPY_INT_TAG; } + +PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) { + Py_ssize_t temp_count = CPyTagged_AsSsize_t(count); + if (temp_count == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return NULL; + } + return PySequence_Repeat(bytes, temp_count); +} diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 721a2bbb10b98..f91ace78a301d 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -621,3 +621,12 @@ CPyTagged CPyStr_Ord(PyObject *obj) { PyExc_TypeError, "ord() expected a character, but a string of length %zd found", s); return CPY_INT_TAG; } + +PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count) { + Py_ssize_t temp_count = CPyTagged_AsSsize_t(count); + if (temp_count == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return NULL; + } + return PySequence_Repeat(str, temp_count); +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index c88e89d1a2bad..858134b204137 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -82,6 +82,25 @@ steals=[True, False], ) +# bytes * int +binary_op( + name="*", + arg_types=[bytes_rprimitive, int_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_Multiply", + error_kind=ERR_MAGIC, +) + +# int * bytes +binary_op( + name="*", + arg_types=[int_rprimitive, bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_Multiply", + error_kind=ERR_MAGIC, + ordering=[1, 0], +) + # bytes[begin:end] bytes_slice_op = custom_op( arg_types=[bytes_rprimitive, int_rprimitive, int_rprimitive], diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index d39f1f872763e..ceaf1cfe5dd2f 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -79,6 +79,25 @@ steals=[True, False], ) +# str * int +binary_op( + name="*", + arg_types=[str_rprimitive, int_rprimitive], + return_type=str_rprimitive, + c_function_name="CPyStr_Multiply", + error_kind=ERR_MAGIC, +) + +# int * str +binary_op( + name="*", + arg_types=[int_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name="CPyStr_Multiply", + error_kind=ERR_MAGIC, + ordering=[1, 0], +) + # str1 == str2 (very common operation, so we provide our own) str_eq = custom_primitive_op( name="str_eq", diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 8cfefe03ae22c..d3bb7d52bf3f3 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -217,3 +217,24 @@ L2: L3: keep_alive y return r2 + +[case testBytesMultiply] +def b_times_i(s: bytes, n: int) -> bytes: + return s * n +def i_times_b(s: bytes, n: int) -> bytes: + return n * s +[out] +def b_times_i(s, n): + s :: bytes + n :: int + r0 :: bytes +L0: + r0 = CPyBytes_Multiply(s, n) + return r0 +def i_times_b(s, n): + s :: bytes + n :: int + r0 :: bytes +L0: + r0 = CPyBytes_Multiply(s, n) + return r0 diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 056f120c7bac0..881ddc3656abc 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -771,3 +771,24 @@ L0: r0 = 'literal' r1 = 'literal' return 1 + +[case testStrMultiply] +def s_times_i(s: str, n: int) -> str: + return s * n +def i_times_s(s: str, n: int) -> str: + return n * s +[out] +def s_times_i(s, n): + s :: str + n :: int + r0 :: str +L0: + r0 = CPyStr_Multiply(s, n) + return r0 +def i_times_s(s, n): + s :: str + n :: int + r0 :: str +L0: + r0 = CPyStr_Multiply(s, n) + return r0 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index df5cb209b9025..1498925f23f5b 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -134,6 +134,40 @@ def test_ord_bytesarray() -> None: with assertRaises(TypeError): ord(bytearray(b'')) +def test_multiply() -> None: + # Use bytes() and int() to avoid constant folding + b = b'ab' + bytes() + zero = int() + one = 1 + zero + three = 3 + zero + neg_one = -1 + zero + + assert b * zero == b'' + assert b * one == b'ab' + assert b * three == b'ababab' + assert b * neg_one == b'' + assert zero * b == b'' + assert one * b == b'ab' + assert three * b == b'ababab' + assert neg_one * b == b'' + + # Test with empty bytes + empty = bytes() + five = 5 + zero + assert empty * five == b'' + assert five * empty == b'' + + # Test with single byte + single = b'\xff' + bytes() + four = 4 + zero + assert single * four == b'\xff\xff\xff\xff' + assert four * single == b'\xff\xff\xff\xff' + + # Test type preservation + two = 2 + zero + result = b * two + assert type(result) == bytes + [case testBytesSlicing] def test_bytes_slicing() -> None: b = b'abcdefg' diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 6a62db6ee3ee0..0ae67ed7f1c3f 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -362,6 +362,40 @@ def test_str_min_max() -> None: assert max(x, y) == 'bbb' assert max(x, z) == 'aaa' +def test_multiply() -> None: + # Use str() and int() to avoid constant folding + s = 'ab' + str() + zero = int() + one = 1 + zero + three = 3 + zero + neg_one = -1 + zero + + assert s * zero == '' + assert s * one == 'ab' + assert s * three == 'ababab' + assert s * neg_one == '' + assert zero * s == '' + assert one * s == 'ab' + assert three * s == 'ababab' + assert neg_one * s == '' + + # Test with empty string + empty = str() + five = 5 + zero + assert empty * five == '' + assert five * empty == '' + + # Test with single character + single = 'x' + str() + four = 4 + zero + assert single * four == 'xxxx' + assert four * single == 'xxxx' + + # Test type preservation + two = 2 + zero + result = s * two + assert type(result) == str + [case testStringFormattingCStyle] from typing import Tuple