Skip to content

Commit 2fca994

Browse files
committed
[mypyc] Add primitives for bytes and str multiply
1 parent 0c2bf7a commit 2fca994

7 files changed

Lines changed: 100 additions & 0 deletions

File tree

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
770770
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
771771
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);
772772
CPyTagged CPyStr_Ord(PyObject *obj);
773+
PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count);
773774

774775

775776
// Bytes operations
@@ -781,6 +782,7 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
781782
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
782783
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
783784
CPyTagged CPyBytes_Ord(PyObject *obj);
785+
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
784786

785787

786788
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,12 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
162162
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
163163
return CPY_INT_TAG;
164164
}
165+
166+
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
167+
Py_ssize_t temp_count = CPyTagged_AsSsize_t(count);
168+
if (temp_count == -1 && PyErr_Occurred()) {
169+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
170+
return NULL;
171+
}
172+
return PySequence_Repeat(bytes, temp_count);
173+
}

mypyc/lib-rt/str_ops.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,12 @@ CPyTagged CPyStr_Ord(PyObject *obj) {
621621
PyExc_TypeError, "ord() expected a character, but a string of length %zd found", s);
622622
return CPY_INT_TAG;
623623
}
624+
625+
PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count) {
626+
Py_ssize_t temp_count = CPyTagged_AsSsize_t(count);
627+
if (temp_count == -1 && PyErr_Occurred()) {
628+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
629+
return NULL;
630+
}
631+
return PySequence_Repeat(str, temp_count);
632+
}

mypyc/primitives/bytes_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@
8282
steals=[True, False],
8383
)
8484

85+
# bytes * int
86+
binary_op(
87+
name="*",
88+
arg_types=[bytes_rprimitive, int_rprimitive],
89+
return_type=bytes_rprimitive,
90+
c_function_name="CPyBytes_Multiply",
91+
error_kind=ERR_MAGIC,
92+
)
93+
94+
# int * bytes
95+
binary_op(
96+
name="*",
97+
arg_types=[int_rprimitive, bytes_rprimitive],
98+
return_type=bytes_rprimitive,
99+
c_function_name="CPyBytes_Multiply",
100+
error_kind=ERR_MAGIC,
101+
ordering=[1, 0],
102+
)
103+
85104
# bytes[begin:end]
86105
bytes_slice_op = custom_op(
87106
arg_types=[bytes_rprimitive, int_rprimitive, int_rprimitive],

mypyc/primitives/str_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@
7979
steals=[True, False],
8080
)
8181

82+
# str * int
83+
binary_op(
84+
name="*",
85+
arg_types=[str_rprimitive, int_rprimitive],
86+
return_type=str_rprimitive,
87+
c_function_name="CPyStr_Multiply",
88+
error_kind=ERR_MAGIC,
89+
)
90+
91+
# int * str
92+
binary_op(
93+
name="*",
94+
arg_types=[int_rprimitive, str_rprimitive],
95+
return_type=str_rprimitive,
96+
c_function_name="CPyStr_Multiply",
97+
error_kind=ERR_MAGIC,
98+
ordering=[1, 0],
99+
)
100+
82101
# str1 == str2 (very common operation, so we provide our own)
83102
str_eq = custom_primitive_op(
84103
name="str_eq",

mypyc/test-data/irbuild-bytes.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,24 @@ L2:
217217
L3:
218218
keep_alive y
219219
return r2
220+
221+
[case testBytesMultiply]
222+
def b_times_i(s: bytes, n: int) -> bytes:
223+
return s * n
224+
def i_times_b(s: bytes, n: int) -> bytes:
225+
return n * s
226+
[out]
227+
def b_times_i(s, n):
228+
s :: bytes
229+
n :: int
230+
r0 :: bytes
231+
L0:
232+
r0 = CPyBytes_Multiply(s, n)
233+
return r0
234+
def i_times_b(s, n):
235+
s :: bytes
236+
n :: int
237+
r0 :: bytes
238+
L0:
239+
r0 = CPyBytes_Multiply(s, n)
240+
return r0

mypyc/test-data/irbuild-str.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,24 @@ L0:
771771
r0 = 'literal'
772772
r1 = 'literal'
773773
return 1
774+
775+
[case testStrMultiply]
776+
def s_times_i(s: str, n: int) -> str:
777+
return s * n
778+
def i_times_s(s: str, n: int) -> str:
779+
return n * s
780+
[out]
781+
def s_times_i(s, n):
782+
s :: str
783+
n :: int
784+
r0 :: str
785+
L0:
786+
r0 = CPyStr_Multiply(s, n)
787+
return r0
788+
def i_times_s(s, n):
789+
s :: str
790+
n :: int
791+
r0 :: str
792+
L0:
793+
r0 = CPyStr_Multiply(s, n)
794+
return r0

0 commit comments

Comments
 (0)