diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 935852bcf30e5..32daf32520802 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -783,6 +783,7 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count); +PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 8ecf9337c28b8..6c138ad90db0d 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -171,3 +171,52 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) { } return PySequence_Repeat(bytes, temp_count); } + +PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) { + // Fast path: exact bytes object with exact bytes table + if (PyBytes_CheckExact(bytes) && PyBytes_CheckExact(table)) { + Py_ssize_t table_len = PyBytes_GET_SIZE(table); + if (table_len != 256) { + PyErr_SetString(PyExc_ValueError, + "translation table must be 256 characters long"); + return NULL; + } + + Py_ssize_t len = PyBytes_GET_SIZE(bytes); + const char *input = PyBytes_AS_STRING(bytes); + const char *trans_table = PyBytes_AS_STRING(table); + + PyObject *result = PyBytes_FromStringAndSize(NULL, len); + if (result == NULL) { + return NULL; + } + + char *output = PyBytes_AS_STRING(result); + bool changed = false; + + // Without a loop unrolling hint performance can be worse than CPython + CPY_UNROLL_LOOP(4) + for (Py_ssize_t i = len; --i >= 0;) { + char c = *input++; + if ((*output++ = trans_table[(unsigned char)c]) != c) + changed = true; + } + + // If nothing changed, discard result and return the original object + if (!changed) { + Py_DECREF(result); + Py_INCREF(bytes); + return bytes; + } + + return result; + } + + // Fallback to Python method call for non-exact types or non-standard tables + _Py_IDENTIFIER(translate); + PyObject *name = _PyUnicode_FromId(&PyId_translate); + if (name == NULL) { + return NULL; + } + return PyObject_CallMethodOneArg(bytes, name, table); +} diff --git a/mypyc/lib-rt/mypyc_util.h b/mypyc/lib-rt/mypyc_util.h index 4168d3c53ee28..50a806c91a8a6 100644 --- a/mypyc/lib-rt/mypyc_util.h +++ b/mypyc/lib-rt/mypyc_util.h @@ -48,6 +48,19 @@ #endif // Py_GIL_DISABLED +// Helper macro for stringification in _Pragma +#define CPY_STRINGIFY(x) #x + +#if defined(__clang__) + #define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x)) + #define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(unroll n) +#elif defined(__GNUC__) && __GNUC__ >= 8 + #define CPY_UNROLL_LOOP_IMPL(x) _Pragma(CPY_STRINGIFY(x)) + #define CPY_UNROLL_LOOP(n) CPY_UNROLL_LOOP_IMPL(GCC unroll n) +#else + #define CPY_UNROLL_LOOP(n) +#endif + // INCREF and DECREF that assert the pointer is not NULL. // asserts are disabled in release builds so there shouldn't be a perf hit. // I'm honestly kind of surprised that this isn't done by default. diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 858134b204137..982c41dc25b10 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -128,6 +128,15 @@ error_kind=ERR_MAGIC, ) +# bytes.translate(table) +method_op( + name="translate", + arg_types=[bytes_rprimitive, object_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_Translate", + error_kind=ERR_MAGIC, +) + # Join bytes objects and return a new bytes. # The first argument is the total number of the following bytes. bytes_build_op = custom_op( diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 5033100223a3d..d9202707124b1 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -178,6 +178,7 @@ def __getitem__(self, i: int) -> int: ... def __getitem__(self, i: slice) -> bytes: ... def join(self, x: Iterable[object]) -> bytes: ... def decode(self, encoding: str=..., errors: str=...) -> str: ... + def translate(self, t: bytes) -> bytes: ... def __iter__(self) -> Iterator[int]: ... class bytearray: diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index d3bb7d52bf3f3..9473944f44fef 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -238,3 +238,13 @@ def i_times_b(s, n): L0: r0 = CPyBytes_Multiply(s, n) return r0 + +[case testBytesTranslate] +def f(b: bytes, table: bytes) -> bytes: + return b.translate(table) +[out] +def f(b, table): + b, table, r0 :: bytes +L0: + r0 = CPyBytes_Translate(b, table) + return r0 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 1498925f23f5b..9a319b636772f 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -168,6 +168,38 @@ def test_multiply() -> None: result = b * two assert type(result) == bytes +def test_translate() -> None: + # Identity translation table (fast path - exact bytes) + identity = bytes(range(256)) + b = b'hello world' + bytes() + assert b.translate(identity) == b'hello world' + + # ROT13-like translation for lowercase letters + table = bytearray(range(256)) + for i in range(ord('a'), ord('z') + 1): + table[i] = ord('a') + (i - ord('a') + 13) % 26 + table_bytes = bytes(table) + assert b'hello'.translate(table_bytes) == b'uryyb' + assert (b'abc' + bytes()).translate(table_bytes) == b'nop' + + # Test with special characters + assert b'\x00\x01\xff'.translate(identity) == b'\x00\x01\xff' + + # Test with bytearray table (slow path - fallback to Python method) + bytearray_table = bytearray(range(256)) + assert b'hello'.translate(bytearray_table) == b'hello' + # Modify bytearray table to uppercase + for i in range(ord('a'), ord('z') + 1): + bytearray_table[i] = ord('A') + (i - ord('a')) + assert b'hello world'.translate(bytearray_table) == b'HELLO WORLD' + assert (b'test' + bytes()).translate(bytearray_table) == b'TEST' + + # Test error on wrong table size + with assertRaises(ValueError, "translation table must be 256 characters long"): + b'test'.translate(b'short') + with assertRaises(ValueError, "translation table must be 256 characters long"): + b'test'.translate(bytes(100)) + [case testBytesSlicing] def test_bytes_slicing() -> None: b = b'abcdefg'