Skip to content

Commit 69fbc68

Browse files
committed
Optimize checks; add bytearray tests
1 parent 5dbd6a9 commit 69fbc68

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

mypyc/lib-rt/bytes_ops.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,20 +178,19 @@ int CPyBytes_Startswith(PyObject *self, PyObject *subobj) {
178178
return 1;
179179
}
180180

181-
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
182181
Py_ssize_t subobj_len = PyBytes_GET_SIZE(subobj);
182+
if (subobj_len == 0) {
183+
return 1;
184+
}
183185

186+
Py_ssize_t self_len = PyBytes_GET_SIZE(self);
184187
if (subobj_len > self_len) {
185188
return 0;
186189
}
187190

188191
const char *self_buf = PyBytes_AS_STRING(self);
189192
const char *subobj_buf = PyBytes_AS_STRING(subobj);
190193

191-
if (subobj_len == 0) {
192-
return 1;
193-
}
194-
195194
return memcmp(self_buf, subobj_buf, (size_t)subobj_len) == 0 ? 1 : 0;
196195
}
197196
_Py_IDENTIFIER(startswith);

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __add__(self, s: bytes) -> bytearray: ...
193193
def __setitem__(self, i: int, o: int) -> None: ...
194194
def __getitem__(self, i: int) -> int: ...
195195
def decode(self, x: str = ..., y: str = ...) -> str: ...
196+
def startswith(self, t: bytes) -> bool: ...
196197

197198
class bool(int):
198199
def __init__(self, o: object = ...) -> None: ...

mypyc/test-data/run-bytes.test

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ def test_startswith() -> None:
213213
assert b''.startswith(b'')
214214
assert not b''.startswith(test)
215215

216+
# Test bytearray to verify slow paths
217+
assert test.startswith(bytearray(b'some'))
218+
assert not test.startswith(bytearray(b'other'))
219+
220+
test = bytearray(b'some string')
221+
assert test.startswith(b'some')
222+
assert not test.startswith(b'other')
223+
216224
[case testBytesSlicing]
217225
def test_bytes_slicing() -> None:
218226
b = b'abcdefg'

0 commit comments

Comments
 (0)