Skip to content

Commit c0b89be

Browse files
authored
[mypyc] Implement bytes.endswith (#20447)
Rounding out #20387 and implementing `bytes.endswith`. Simple benchmark shows a ~6.4x improvement. Tested with the following benchmark code: ``` import time def bench(suffix: bytes, a: list[bytes], n: int) -> int: i = 0 for x in range(n): for b in a: if b.endswith(suffix): i += 1 return i a = [b"foo", b"barasdfsf", b"foobar", b"ab", b"asrtert", b"sertyeryt"] n = 5 * 1000 * 1000 suffix = b"foo" bench(suffix, a, n) t0 = time.time() bench(suffix, a, n) td = time.time() - t0 print(f"{td}s") ``` Output: ``` $ python bench.py 0.9002199172973633s $ python -c "import bench" 0.13828086853027344s ```
1 parent d87132f commit c0b89be

File tree

8 files changed

+82
-1
lines changed

8 files changed

+82
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
790790
CPyTagged CPyBytes_Ord(PyObject *obj);
791791
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
792792
int CPyBytes_Startswith(PyObject *self, PyObject *subobj);
793-
793+
int CPyBytes_Endswith(PyObject *self, PyObject *subobj);
794794
int CPyBytes_Compare(PyObject *left, PyObject *right);
795795

796796

mypyc/lib-rt/bytes_ops.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,36 @@ int CPyBytes_Startswith(PyObject *self, PyObject *subobj) {
183183
}
184184
return ret;
185185
}
186+
187+
int CPyBytes_Endswith(PyObject *self, PyObject *subobj) {
188+
if (PyBytes_CheckExact(self) && PyBytes_CheckExact(subobj)) {
189+
if (self == subobj) {
190+
return 1;
191+
}
192+
193+
Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj);
194+
if (subobj_len == 0) {
195+
return 1;
196+
}
197+
198+
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
199+
if (subobj_len > self_len) {
200+
return 0;
201+
}
202+
203+
const char *self_buf = PyBytes_AS_STRING(self);
204+
const char *subobj_buf = PyBytes_AS_STRING(subobj);
205+
206+
return memcmp(self_buf + (self_len - subobj_len), subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0;
207+
}
208+
PyObject *result = PyObject_CallMethodOneArg(self, mypyc_interned_str.endswith, subobj);
209+
if (result == NULL) {
210+
return 2;
211+
}
212+
int ret = PyObject_IsTrue(result);
213+
Py_DECREF(result);
214+
if (ret < 0) {
215+
return 2;
216+
}
217+
return ret;
218+
}

mypyc/lib-rt/static_data.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ intern_strings(void) {
5252
INTERN_STRING(close_, "close");
5353
INTERN_STRING(copy, "copy");
5454
INTERN_STRING(dispatch_cache, "dispatch_cache");
55+
INTERN_STRING(endswith, "endswith");
5556
INTERN_STRING(get_type_hints, "get_type_hints");
5657
INTERN_STRING(keys, "keys");
5758
INTERN_STRING(items, "items");

mypyc/lib-rt/static_data.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ typedef struct mypyc_interned_str_struct {
4444
PyObject *close_;
4545
PyObject *copy;
4646
PyObject *dispatch_cache;
47+
PyObject *endswith;
4748
PyObject *get_type_hints;
4849
PyObject *keys;
4950
PyObject *items;

mypyc/primitives/bytes_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@
133133
error_kind=ERR_MAGIC,
134134
)
135135

136+
# bytes.endswith(bytes)
137+
method_op(
138+
name="endswith",
139+
arg_types=[bytes_rprimitive, bytes_rprimitive],
140+
return_type=c_int_rprimitive,
141+
c_function_name="CPyBytes_Endswith",
142+
truncated_type=bool_rprimitive,
143+
error_kind=ERR_MAGIC,
144+
)
145+
136146
# Join bytes objects and return a new bytes.
137147
# The first argument is the total number of the following bytes.
138148
bytes_build_op = custom_op(

mypyc/test-data/fixtures/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def join(self, x: Iterable[object]) -> bytes: ...
180180
def decode(self, encoding: str=..., errors: str=...) -> str: ...
181181
def translate(self, t: bytes | bytearray) -> bytes: ...
182182
def startswith(self, t: bytes | bytearray) -> bool: ...
183+
def endswith(self, t: bytes | bytearray) -> bool: ...
183184
def __iter__(self) -> Iterator[int]: ...
184185

185186
class bytearray:
@@ -197,6 +198,7 @@ def __getitem__(self, i: int) -> int: ...
197198
def __getitem__(self, i: slice) -> bytearray: ...
198199
def decode(self, x: str = ..., y: str = ...) -> str: ...
199200
def startswith(self, t: bytes) -> bool: ...
201+
def endswith(self, t: bytes) -> bool: ...
200202

201203
class bool(int):
202204
def __init__(self, o: object = ...) -> None: ...

mypyc/test-data/irbuild-bytes.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,19 @@ L0:
266266
r1 = truncate r0: i32 to builtins.bool
267267
return r1
268268

269+
[case testBytesEndsWith]
270+
def f(a: bytes, b: bytes) -> bool:
271+
return a.endswith(b)
272+
[out]
273+
def f(a, b):
274+
a, b :: bytes
275+
r0 :: i32
276+
r1 :: bool
277+
L0:
278+
r0 = CPyBytes_Endswith(a, b)
279+
r1 = truncate r0: i32 to builtins.bool
280+
return r1
281+
269282
[case testBytesVsBytearray]
270283
def bytes_func(b: bytes) -> None: pass
271284
def bytearray_func(ba: bytearray) -> None: pass

mypyc/test-data/run-bytes.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,27 @@ def test_startswith() -> None:
221221
assert test2.startswith(b'some')
222222
assert not test2.startswith(b'other')
223223

224+
def test_endswith() -> None:
225+
# Test default behavior
226+
test = b'some string'
227+
assert test.endswith(b'string')
228+
assert test.endswith(b'some string')
229+
assert not test.endswith(b'other')
230+
assert not test.endswith(b'some string but longer')
231+
232+
# Test empty cases
233+
assert test.endswith(b'')
234+
assert b''.endswith(b'')
235+
assert not b''.endswith(test)
236+
237+
# Test bytearray to verify slow paths
238+
assert test.endswith(bytearray(b'string'))
239+
assert not test.endswith(bytearray(b'other'))
240+
241+
test2 = bytearray(b'some string')
242+
assert test2.endswith(b'string')
243+
assert not test2.endswith(b'other')
244+
224245
[case testBytesSlicing]
225246
def test_bytes_slicing() -> None:
226247
b = b'abcdefg'

0 commit comments

Comments
 (0)