Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
CPyTagged CPyBytes_Ord(PyObject *obj);
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);


int CPyBytes_Compare(PyObject *left, PyObject *right);
Expand Down
49 changes: 49 additions & 0 deletions mypyc/lib-rt/bytes_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,52 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
}
return PySequence_Repeat(bytes, temp_count);
}

PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) {
// Fast path: exact bytes object with exact bytes table
if (PyBytes_CheckExact(bytes) && PyBytes_CheckExact(table)) {
Py_ssize_t table_len = PyBytes_GET_SIZE(table);
if (table_len != 256) {
PyErr_SetString(PyExc_ValueError,
"translation table must be 256 characters long");
return NULL;
}

Py_ssize_t len = PyBytes_GET_SIZE(bytes);
const char *input = PyBytes_AS_STRING(bytes);
const char *trans_table = PyBytes_AS_STRING(table);

PyObject *result = PyBytes_FromStringAndSize(NULL, len);
if (result == NULL) {
return NULL;
}

char *output = PyBytes_AS_STRING(result);
bool changed = false;

// Without a loop unrolling hint performance can be worse than CPython
CPY_UNROLL_LOOP(4)
for (Py_ssize_t i = len; --i >= 0;) {
char c = *input++;
if ((*output++ = trans_table[(unsigned char)c]) != c)
changed = true;
}

// If nothing changed, discard result and return the original object
if (!changed) {
Py_DECREF(result);
Py_INCREF(bytes);
return bytes;
}

return result;
}

// Fallback to Python method call for non-exact types or non-standard tables
_Py_IDENTIFIER(translate);
PyObject *name = _PyUnicode_FromId(&PyId_translate);
if (name == NULL) {
return NULL;
}
return PyObject_CallMethodOneArg(bytes, name, table);
}
13 changes: 13 additions & 0 deletions mypyc/lib-rt/mypyc_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@

#endif // Py_GIL_DISABLED

// Helper macro for stringification in _Pragma
#define CPY_STRINGIFY(x) #x

#if defined(__clang__)
#define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x))
#define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(unroll n)
#elif defined(__GNUC__) && __GNUC__ >= 8
#define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x))
#define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(GCC unroll n)
#else
#define CPY_UNROLL_LOOP(n)
#endif

// INCREF and DECREF that assert the pointer is not NULL.
// asserts are disabled in release builds so there shouldn't be a perf hit.
// I'm honestly kind of surprised that this isn't done by default.
Expand Down
9 changes: 9 additions & 0 deletions mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@
error_kind=ERR_MAGIC,
)

# bytes.translate(table)
method_op(
name="translate",
arg_types=[bytes_rprimitive, object_rprimitive],
return_type=bytes_rprimitive,
c_function_name="CPyBytes_Translate",
error_kind=ERR_MAGIC,
)

# Join bytes objects and return a new bytes.
# The first argument is the total number of the following bytes.
bytes_build_op = custom_op(
Expand Down
1 change: 1 addition & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __getitem__(self, i: int) -> int: ...
def __getitem__(self, i: slice) -> bytes: ...
def join(self, x: Iterable[object]) -> bytes: ...
def decode(self, encoding: str=..., errors: str=...) -> str: ...
def translate(self, t: bytes) -> bytes: ...
def __iter__(self) -> Iterator[int]: ...

class bytearray:
Expand Down
10 changes: 10 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,13 @@ def i_times_b(s, n):
L0:
r0 = CPyBytes_Multiply(s, n)
return r0

[case testBytesTranslate]
def f(b: bytes, table: bytes) -> bytes:
return b.translate(table)
[out]
def f(b, table):
b, table, r0 :: bytes
L0:
r0 = CPyBytes_Translate(b, table)
return r0
32 changes: 32 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,38 @@ def test_multiply() -> None:
result = b * two
assert type(result) == bytes

def test_translate() -> None:
# Identity translation table (fast path - exact bytes)
identity = bytes(range(256))
b = b'hello world' + bytes()
assert b.translate(identity) == b'hello world'

# ROT13-like translation for lowercase letters
table = bytearray(range(256))
for i in range(ord('a'), ord('z') + 1):
table[i] = ord('a') + (i - ord('a') + 13) % 26
table_bytes = bytes(table)
assert b'hello'.translate(table_bytes) == b'uryyb'
assert (b'abc' + bytes()).translate(table_bytes) == b'nop'

# Test with special characters
assert b'\x00\x01\xff'.translate(identity) == b'\x00\x01\xff'

# Test with bytearray table (slow path - fallback to Python method)
bytearray_table = bytearray(range(256))
assert b'hello'.translate(bytearray_table) == b'hello'
# Modify bytearray table to uppercase
for i in range(ord('a'), ord('z') + 1):
bytearray_table[i] = ord('A') + (i - ord('a'))
assert b'hello world'.translate(bytearray_table) == b'HELLO WORLD'
assert (b'test' + bytes()).translate(bytearray_table) == b'TEST'

# Test error on wrong table size
with assertRaises(ValueError, "translation table must be 256 characters long"):
b'test'.translate(b'short')
with assertRaises(ValueError, "translation table must be 256 characters long"):
b'test'.translate(bytes(100))

[case testBytesSlicing]
def test_bytes_slicing() -> None:
b = b'abcdefg'
Expand Down