Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/ir/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ def get_header(self) -> str:

BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c")
STRING_WRITER_EXTRA_OPS: Final = SourceDep("stringwriter_extra_ops.c")
BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c")
STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c")
2 changes: 2 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,12 @@ def __hash__(self) -> int:
"librt.internal.WriteBuffer",
"librt.internal.ReadBuffer",
"librt.strings.BytesWriter",
"librt.strings.StringWriter",
]
}

bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"]
string_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.StringWriter"]


def is_native_rprimitive(rtype: RType) -> bool:
Expand Down
20 changes: 20 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
object_rprimitive,
set_rprimitive,
str_rprimitive,
string_writer_rprimitive,
uint8_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
Expand Down Expand Up @@ -124,6 +125,9 @@
bytes_writer_get_item_unsafe_op,
bytes_writer_range_check_op,
bytes_writer_set_item_unsafe_op,
string_writer_adjust_index_op,
string_writer_get_item_unsafe_op,
string_writer_range_check_op,
)
from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op
from mypyc.primitives.misc_ops import isinstance_bool
Expand Down Expand Up @@ -1447,6 +1451,22 @@ def translate_bytes_writer_set_item(
return builder.none()


@specialize_dunder("__getitem__", string_writer_rprimitive)
def translate_string_writer_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
) -> Value | None:
"""Optimized StringWriter.__getitem__ implementation with bounds checking."""
return translate_getitem_with_bounds_check(
builder,
base_expr,
args,
ctx_expr,
string_writer_adjust_index_op,
string_writer_range_check_op,
string_writer_get_item_unsafe_op,
)


@specialize_dunder("__getitem__", bytes_rprimitive)
def translate_bytes_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
Expand Down
11 changes: 9 additions & 2 deletions mypyc/lib-rt/librt_strings.c
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ check_string_writer(PyObject *data) {
static char string_writer_switch_kind(StringWriterObject *self, int32_t value);

static char
StringWriter_write_internal(StringWriterObject *self, PyObject *value) {
StringWriter_write_internal(PyObject *obj, PyObject *value) {
StringWriterObject *self = (StringWriterObject *)obj;
Py_ssize_t str_len = PyUnicode_GET_LENGTH(value);
if (str_len == 0) {
return CPY_NONE;
Expand Down Expand Up @@ -675,7 +676,7 @@ StringWriter_write(PyObject *self, PyObject *const *args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "value must be a str object");
return NULL;
}
if (unlikely(StringWriter_write_internal((StringWriterObject *)self, value) == CPY_NONE_ERROR)) {
if (unlikely(StringWriter_write_internal(self, value) == CPY_NONE_ERROR)) {
return NULL;
}
Py_INCREF(Py_None);
Expand Down Expand Up @@ -877,6 +878,12 @@ librt_strings_module_exec(PyObject *m)
(void *)_grow_buffer,
(void *)BytesWriter_type_internal,
(void *)BytesWriter_truncate_internal,
(void *)StringWriter_internal,
(void *)StringWriter_getvalue_internal,
(void *)string_append_slow_path,
(void *)StringWriter_type_internal,
(void *)StringWriter_write_internal,
(void *)grow_string_buffer,
};
PyObject *c_api_object = PyCapsule_New((void *)librt_strings_api, "librt.strings._C_API", NULL);
if (PyModule_Add(m, "_C_API", c_api_object) < 0) {
Expand Down
14 changes: 12 additions & 2 deletions mypyc/lib-rt/librt_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import_librt_strings(void)
// API version -- more recent versions must maintain backward compatibility, i.e.
// we can add new features but not remove or change existing features (unless
// ABI version is changed, but see the comment above).
#define LIBRT_STRINGS_API_VERSION 2
#define LIBRT_STRINGS_API_VERSION 4

// Number of functions in the capsule API. If you add a new function, also increase
// LIBRT_STRINGS_API_VERSION.
#define LIBRT_STRINGS_API_LEN 8
#define LIBRT_STRINGS_API_LEN 14

static void *LibRTStrings_API[LIBRT_STRINGS_API_LEN];

Expand Down Expand Up @@ -58,6 +58,12 @@ typedef struct {
#define LibRTStrings_ByteWriter_grow_buffer_internal (*(bool (*)(BytesWriterObject *obj, Py_ssize_t size)) LibRTStrings_API[5])
#define LibRTStrings_BytesWriter_type_internal (*(PyTypeObject* (*)(void)) LibRTStrings_API[6])
#define LibRTStrings_BytesWriter_truncate_internal (*(char (*)(PyObject *self, int64_t size)) LibRTStrings_API[7])
#define LibRTStrings_StringWriter_internal (*(PyObject* (*)(void)) LibRTStrings_API[8])
#define LibRTStrings_StringWriter_getvalue_internal (*(PyObject* (*)(PyObject *source)) LibRTStrings_API[9])
#define LibRTStrings_string_append_slow_path (*(char (*)(StringWriterObject *obj, int32_t value)) LibRTStrings_API[10])
#define LibRTStrings_StringWriter_type_internal (*(PyTypeObject* (*)(void)) LibRTStrings_API[11])
#define LibRTStrings_StringWriter_write_internal (*(char (*)(PyObject *source, PyObject *value)) LibRTStrings_API[12])
#define LibRTStrings_grow_string_buffer (*(bool (*)(StringWriterObject *obj, Py_ssize_t n)) LibRTStrings_API[13])

static int
import_librt_strings(void)
Expand Down Expand Up @@ -96,6 +102,10 @@ static inline bool CPyBytesWriter_Check(PyObject *obj) {
return Py_TYPE(obj) == LibRTStrings_BytesWriter_type_internal();
}

static inline bool CPyStringWriter_Check(PyObject *obj) {
return Py_TYPE(obj) == LibRTStrings_StringWriter_type_internal();
}

#endif // MYPYC_EXPERIMENTAL

#endif // LIBRT_STRINGS_H
11 changes: 11 additions & 0 deletions mypyc/lib-rt/stringwriter_extra_ops.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Primitives related to librt.strings.StringWriter that get linked statically
// with compiled modules, instead of being called via a capsule.

#include "stringwriter_extra_ops.h"

#ifdef MYPYC_EXPERIMENTAL

// All StringWriter operations are currently implemented as inline functions
// in stringwriter_extra_ops.h, or use the exported capsule API directly.

#endif // MYPYC_EXPERIMENTAL
77 changes: 77 additions & 0 deletions mypyc/lib-rt/stringwriter_extra_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef STRINGWRITER_EXTRA_OPS_H
#define STRINGWRITER_EXTRA_OPS_H

#ifdef MYPYC_EXPERIMENTAL

#include <stdint.h>
#include <Python.h>

#include "librt_strings.h"

static inline CPyTagged
CPyStringWriter_Len(PyObject *obj) {
return (CPyTagged)((StringWriterObject *)obj)->len << 1;
}

static inline bool
CPyStringWriter_EnsureSize(StringWriterObject *data, Py_ssize_t n) {
if (likely(data->capacity - data->len >= n)) {
return true;
} else {
return LibRTStrings_grow_string_buffer(data, n);
}
}

static inline char
CPyStringWriter_Append(PyObject *obj, int32_t value) {
StringWriterObject *self = (StringWriterObject *)obj;
char kind = self->kind;

// Fast path: kind 1 (ASCII/Latin-1) with character < 256
if (kind == 1 && (uint32_t)value < 256) {
// Store length in local variable to enable additional optimizations
Py_ssize_t len = self->len;
if (!CPyStringWriter_EnsureSize(self, 1))
return CPY_NONE_ERROR;
self->buf[len] = (char)value;
self->len = len + 1;
return CPY_NONE;
}

// Slow path: handles kind switching and other cases
return LibRTStrings_string_append_slow_path(self, value);
}

// If index is negative, convert to non-negative index (no range checking)
static inline int64_t CPyStringWriter_AdjustIndex(PyObject *obj, int64_t index) {
if (index < 0) {
return index + ((StringWriterObject *)obj)->len;
}
return index;
}

static inline bool CPyStringWriter_RangeCheck(PyObject *obj, int64_t index) {
return index >= 0 && index < ((StringWriterObject *)obj)->len;
}

static inline int32_t CPyStringWriter_GetItem(PyObject *obj, int64_t index) {
StringWriterObject *self = (StringWriterObject *)obj;
char kind = self->kind;
char *buf = self->buf;

if (kind == 1) {
return (uint8_t)buf[index];
} else if (kind == 2) {
uint16_t val;
memcpy(&val, buf + index * 2, 2);
return (int32_t)val;
} else {
uint32_t val;
memcpy(&val, buf + index * 4, 4);
return (int32_t)val;
}
}

#endif // MYPYC_EXPERIMENTAL

#endif
89 changes: 88 additions & 1 deletion mypyc/primitives/librt_strings_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS
from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
from mypyc.ir.rtypes import (
bool_rprimitive,
bytearray_rprimitive,
bytes_rprimitive,
bytes_writer_rprimitive,
int32_rprimitive,
int64_rprimitive,
none_rprimitive,
short_int_rprimitive,
str_rprimitive,
string_writer_rprimitive,
uint8_rprimitive,
void_rtype,
)
Expand Down Expand Up @@ -126,3 +129,87 @@
experimental=True,
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
)

# StringWriter operations
function_op(
name="librt.strings.StringWriter",
arg_types=[],
return_type=string_writer_rprimitive,
c_function_name="LibRTStrings_StringWriter_internal",
error_kind=ERR_MAGIC,
experimental=True,
dependencies=[LIBRT_STRINGS],
)

method_op(
name="getvalue",
arg_types=[string_writer_rprimitive],
return_type=str_rprimitive,
c_function_name="LibRTStrings_StringWriter_getvalue_internal",
error_kind=ERR_MAGIC,
experimental=True,
dependencies=[LIBRT_STRINGS],
)

method_op(
name="write",
arg_types=[string_writer_rprimitive, str_rprimitive],
return_type=none_rprimitive,
c_function_name="LibRTStrings_StringWriter_write_internal",
error_kind=ERR_MAGIC,
experimental=True,
dependencies=[LIBRT_STRINGS],
)

method_op(
name="append",
arg_types=[string_writer_rprimitive, int32_rprimitive],
return_type=none_rprimitive,
c_function_name="CPyStringWriter_Append",
error_kind=ERR_MAGIC,
experimental=True,
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
)

function_op(
name="builtins.len",
arg_types=[string_writer_rprimitive],
return_type=short_int_rprimitive,
c_function_name="CPyStringWriter_Len",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
)

# StringWriter index adjustment - convert negative index to positive
string_writer_adjust_index_op = custom_primitive_op(
name="string_writer_adjust_index",
arg_types=[string_writer_rprimitive, int64_rprimitive],
return_type=int64_rprimitive,
c_function_name="CPyStringWriter_AdjustIndex",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
)

# StringWriter range check - check if index is in valid range
string_writer_range_check_op = custom_primitive_op(
name="string_writer_range_check",
arg_types=[string_writer_rprimitive, int64_rprimitive],
return_type=bool_rprimitive,
c_function_name="CPyStringWriter_RangeCheck",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
)

# StringWriter.__getitem__() - get character at index (no bounds checking)
string_writer_get_item_unsafe_op = custom_primitive_op(
name="string_writer_get_item",
arg_types=[string_writer_rprimitive, int64_rprimitive],
return_type=int32_rprimitive,
c_function_name="CPyStringWriter_GetItem",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
)
52 changes: 52 additions & 0 deletions mypyc/test-data/irbuild-librt-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,55 @@ L1:
L2:
CPyBytesWriter_SetItem(b, r0, x)
return 1

[case testLibrtStrings_StringWriter_experimental_64bit]
from librt.strings import StringWriter
from mypy_extensions import i32, i64

def string_writer_basics() -> str:
s = StringWriter()
s.append(65)
s.write('foo')
return s.getvalue()
def string_writer_len(s: StringWriter) -> i64:
return len(s)
def string_writer_get_item(s: StringWriter, i: i64) -> i32:
return s[i]
[out]
def string_writer_basics():
r0, s :: librt.strings.StringWriter
r1 :: None
r2 :: str
r3 :: None
r4 :: str
L0:
r0 = LibRTStrings_StringWriter_internal()
s = r0
r1 = CPyStringWriter_Append(s, 65)
r2 = 'foo'
r3 = LibRTStrings_StringWriter_write_internal(s, r2)
r4 = LibRTStrings_StringWriter_getvalue_internal(s)
return r4
def string_writer_len(s):
s :: librt.strings.StringWriter
r0 :: short_int
r1 :: i64
L0:
r0 = CPyStringWriter_Len(s)
r1 = r0 >> 1
return r1
def string_writer_get_item(s, i):
s :: librt.strings.StringWriter
i, r0 :: i64
r1, r2 :: bool
r3 :: i32
L0:
r0 = CPyStringWriter_AdjustIndex(s, i)
r1 = CPyStringWriter_RangeCheck(s, r0)
if r1 goto L2 else goto L1 :: bool
L1:
r2 = raise IndexError('index out of range')
unreachable
L2:
r3 = CPyStringWriter_GetItem(s, r0)
return r3
Loading