Skip to content

Commit 6439b29

Browse files
authored
[mypyc] Add primitives for librt.strings.StringWriter (#20624)
As expected, this can make some operations multiple times faster in tight loops. Inline `append` fast path and a few other simple ops, otherwise call through capsule. This mirrors the primitives for BytesWriter, but they are a bit more complicated due to having to support different string kinds.
1 parent 32897bf commit 6439b29

10 files changed

Lines changed: 309 additions & 5 deletions

File tree

mypyc/ir/deps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ def get_header(self) -> str:
5151

5252
BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
5353
BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c")
54+
STRING_WRITER_EXTRA_OPS: Final = SourceDep("stringwriter_extra_ops.c")
5455
BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c")
5556
STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c")

mypyc/ir/rtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,12 @@ def __hash__(self) -> int:
522522
"librt.internal.WriteBuffer",
523523
"librt.internal.ReadBuffer",
524524
"librt.strings.BytesWriter",
525+
"librt.strings.StringWriter",
525526
]
526527
}
527528

528529
bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"]
530+
string_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.StringWriter"]
529531

530532

531533
def is_native_rprimitive(rtype: RType) -> bool:

mypyc/irbuild/specialize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
object_rprimitive,
8181
set_rprimitive,
8282
str_rprimitive,
83+
string_writer_rprimitive,
8384
uint8_rprimitive,
8485
)
8586
from mypyc.irbuild.builder import IRBuilder
@@ -124,6 +125,9 @@
124125
bytes_writer_get_item_unsafe_op,
125126
bytes_writer_range_check_op,
126127
bytes_writer_set_item_unsafe_op,
128+
string_writer_adjust_index_op,
129+
string_writer_get_item_unsafe_op,
130+
string_writer_range_check_op,
127131
)
128132
from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op
129133
from mypyc.primitives.misc_ops import isinstance_bool
@@ -1447,6 +1451,22 @@ def translate_bytes_writer_set_item(
14471451
return builder.none()
14481452

14491453

1454+
@specialize_dunder("__getitem__", string_writer_rprimitive)
1455+
def translate_string_writer_get_item(
1456+
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
1457+
) -> Value | None:
1458+
"""Optimized StringWriter.__getitem__ implementation with bounds checking."""
1459+
return translate_getitem_with_bounds_check(
1460+
builder,
1461+
base_expr,
1462+
args,
1463+
ctx_expr,
1464+
string_writer_adjust_index_op,
1465+
string_writer_range_check_op,
1466+
string_writer_get_item_unsafe_op,
1467+
)
1468+
1469+
14501470
@specialize_dunder("__getitem__", bytes_rprimitive)
14511471
def translate_bytes_get_item(
14521472
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression

mypyc/lib-rt/librt_strings.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ check_string_writer(PyObject *data) {
619619
static char string_writer_switch_kind(StringWriterObject *self, int32_t value);
620620

621621
static char
622-
StringWriter_write_internal(StringWriterObject *self, PyObject *value) {
622+
StringWriter_write_internal(PyObject *obj, PyObject *value) {
623+
StringWriterObject *self = (StringWriterObject *)obj;
623624
Py_ssize_t str_len = PyUnicode_GET_LENGTH(value);
624625
if (str_len == 0) {
625626
return CPY_NONE;
@@ -675,7 +676,7 @@ StringWriter_write(PyObject *self, PyObject *const *args, size_t nargs) {
675676
PyErr_SetString(PyExc_TypeError, "value must be a str object");
676677
return NULL;
677678
}
678-
if (unlikely(StringWriter_write_internal((StringWriterObject *)self, value) == CPY_NONE_ERROR)) {
679+
if (unlikely(StringWriter_write_internal(self, value) == CPY_NONE_ERROR)) {
679680
return NULL;
680681
}
681682
Py_INCREF(Py_None);
@@ -877,6 +878,12 @@ librt_strings_module_exec(PyObject *m)
877878
(void *)_grow_buffer,
878879
(void *)BytesWriter_type_internal,
879880
(void *)BytesWriter_truncate_internal,
881+
(void *)StringWriter_internal,
882+
(void *)StringWriter_getvalue_internal,
883+
(void *)string_append_slow_path,
884+
(void *)StringWriter_type_internal,
885+
(void *)StringWriter_write_internal,
886+
(void *)grow_string_buffer,
880887
};
881888
PyObject *c_api_object = PyCapsule_New((void *)librt_strings_api, "librt.strings._C_API", NULL);
882889
if (PyModule_Add(m, "_C_API", c_api_object) < 0) {

mypyc/lib-rt/librt_strings.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import_librt_strings(void)
2222
// API version -- more recent versions must maintain backward compatibility, i.e.
2323
// we can add new features but not remove or change existing features (unless
2424
// ABI version is changed, but see the comment above).
25-
#define LIBRT_STRINGS_API_VERSION 2
25+
#define LIBRT_STRINGS_API_VERSION 4
2626

2727
// Number of functions in the capsule API. If you add a new function, also increase
2828
// LIBRT_STRINGS_API_VERSION.
29-
#define LIBRT_STRINGS_API_LEN 8
29+
#define LIBRT_STRINGS_API_LEN 14
3030

3131
static void *LibRTStrings_API[LIBRT_STRINGS_API_LEN];
3232

@@ -58,6 +58,12 @@ typedef struct {
5858
#define LibRTStrings_ByteWriter_grow_buffer_internal (*(bool (*)(BytesWriterObject *obj, Py_ssize_t size)) LibRTStrings_API[5])
5959
#define LibRTStrings_BytesWriter_type_internal (*(PyTypeObject* (*)(void)) LibRTStrings_API[6])
6060
#define LibRTStrings_BytesWriter_truncate_internal (*(char (*)(PyObject *self, int64_t size)) LibRTStrings_API[7])
61+
#define LibRTStrings_StringWriter_internal (*(PyObject* (*)(void)) LibRTStrings_API[8])
62+
#define LibRTStrings_StringWriter_getvalue_internal (*(PyObject* (*)(PyObject *source)) LibRTStrings_API[9])
63+
#define LibRTStrings_string_append_slow_path (*(char (*)(StringWriterObject *obj, int32_t value)) LibRTStrings_API[10])
64+
#define LibRTStrings_StringWriter_type_internal (*(PyTypeObject* (*)(void)) LibRTStrings_API[11])
65+
#define LibRTStrings_StringWriter_write_internal (*(char (*)(PyObject *source, PyObject *value)) LibRTStrings_API[12])
66+
#define LibRTStrings_grow_string_buffer (*(bool (*)(StringWriterObject *obj, Py_ssize_t n)) LibRTStrings_API[13])
6167

6268
static int
6369
import_librt_strings(void)
@@ -96,6 +102,10 @@ static inline bool CPyBytesWriter_Check(PyObject *obj) {
96102
return Py_TYPE(obj) == LibRTStrings_BytesWriter_type_internal();
97103
}
98104

105+
static inline bool CPyStringWriter_Check(PyObject *obj) {
106+
return Py_TYPE(obj) == LibRTStrings_StringWriter_type_internal();
107+
}
108+
99109
#endif // MYPYC_EXPERIMENTAL
100110

101111
#endif // LIBRT_STRINGS_H
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Primitives related to librt.strings.StringWriter that get linked statically
2+
// with compiled modules, instead of being called via a capsule.
3+
4+
#include "stringwriter_extra_ops.h"
5+
6+
#ifdef MYPYC_EXPERIMENTAL
7+
8+
// All StringWriter operations are currently implemented as inline functions
9+
// in stringwriter_extra_ops.h, or use the exported capsule API directly.
10+
11+
#endif // MYPYC_EXPERIMENTAL
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#ifndef STRINGWRITER_EXTRA_OPS_H
2+
#define STRINGWRITER_EXTRA_OPS_H
3+
4+
#ifdef MYPYC_EXPERIMENTAL
5+
6+
#include <stdint.h>
7+
#include <Python.h>
8+
9+
#include "librt_strings.h"
10+
11+
static inline CPyTagged
12+
CPyStringWriter_Len(PyObject *obj) {
13+
return (CPyTagged)((StringWriterObject *)obj)->len << 1;
14+
}
15+
16+
static inline bool
17+
CPyStringWriter_EnsureSize(StringWriterObject *data, Py_ssize_t n) {
18+
if (likely(data->capacity - data->len >= n)) {
19+
return true;
20+
} else {
21+
return LibRTStrings_grow_string_buffer(data, n);
22+
}
23+
}
24+
25+
static inline char
26+
CPyStringWriter_Append(PyObject *obj, int32_t value) {
27+
StringWriterObject *self = (StringWriterObject *)obj;
28+
char kind = self->kind;
29+
30+
// Fast path: kind 1 (ASCII/Latin-1) with character < 256
31+
if (kind == 1 && (uint32_t)value < 256) {
32+
// Store length in local variable to enable additional optimizations
33+
Py_ssize_t len = self->len;
34+
if (!CPyStringWriter_EnsureSize(self, 1))
35+
return CPY_NONE_ERROR;
36+
self->buf[len] = (char)value;
37+
self->len = len + 1;
38+
return CPY_NONE;
39+
}
40+
41+
// Slow path: handles kind switching and other cases
42+
return LibRTStrings_string_append_slow_path(self, value);
43+
}
44+
45+
// If index is negative, convert to non-negative index (no range checking)
46+
static inline int64_t CPyStringWriter_AdjustIndex(PyObject *obj, int64_t index) {
47+
if (index < 0) {
48+
return index + ((StringWriterObject *)obj)->len;
49+
}
50+
return index;
51+
}
52+
53+
static inline bool CPyStringWriter_RangeCheck(PyObject *obj, int64_t index) {
54+
return index >= 0 && index < ((StringWriterObject *)obj)->len;
55+
}
56+
57+
static inline int32_t CPyStringWriter_GetItem(PyObject *obj, int64_t index) {
58+
StringWriterObject *self = (StringWriterObject *)obj;
59+
char kind = self->kind;
60+
char *buf = self->buf;
61+
62+
if (kind == 1) {
63+
return (uint8_t)buf[index];
64+
} else if (kind == 2) {
65+
uint16_t val;
66+
memcpy(&val, buf + index * 2, 2);
67+
return (int32_t)val;
68+
} else {
69+
uint32_t val;
70+
memcpy(&val, buf + index * 4, 4);
71+
return (int32_t)val;
72+
}
73+
}
74+
75+
#endif // MYPYC_EXPERIMENTAL
76+
77+
#endif

mypyc/primitives/librt_strings_ops.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS
1+
from mypyc.ir.deps import BYTES_WRITER_EXTRA_OPS, LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS
22
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
33
from mypyc.ir.rtypes import (
44
bool_rprimitive,
55
bytearray_rprimitive,
66
bytes_rprimitive,
77
bytes_writer_rprimitive,
8+
int32_rprimitive,
89
int64_rprimitive,
910
none_rprimitive,
1011
short_int_rprimitive,
12+
str_rprimitive,
13+
string_writer_rprimitive,
1114
uint8_rprimitive,
1215
void_rtype,
1316
)
@@ -126,3 +129,87 @@
126129
experimental=True,
127130
dependencies=[LIBRT_STRINGS, BYTES_WRITER_EXTRA_OPS],
128131
)
132+
133+
# StringWriter operations
134+
function_op(
135+
name="librt.strings.StringWriter",
136+
arg_types=[],
137+
return_type=string_writer_rprimitive,
138+
c_function_name="LibRTStrings_StringWriter_internal",
139+
error_kind=ERR_MAGIC,
140+
experimental=True,
141+
dependencies=[LIBRT_STRINGS],
142+
)
143+
144+
method_op(
145+
name="getvalue",
146+
arg_types=[string_writer_rprimitive],
147+
return_type=str_rprimitive,
148+
c_function_name="LibRTStrings_StringWriter_getvalue_internal",
149+
error_kind=ERR_MAGIC,
150+
experimental=True,
151+
dependencies=[LIBRT_STRINGS],
152+
)
153+
154+
method_op(
155+
name="write",
156+
arg_types=[string_writer_rprimitive, str_rprimitive],
157+
return_type=none_rprimitive,
158+
c_function_name="LibRTStrings_StringWriter_write_internal",
159+
error_kind=ERR_MAGIC,
160+
experimental=True,
161+
dependencies=[LIBRT_STRINGS],
162+
)
163+
164+
method_op(
165+
name="append",
166+
arg_types=[string_writer_rprimitive, int32_rprimitive],
167+
return_type=none_rprimitive,
168+
c_function_name="CPyStringWriter_Append",
169+
error_kind=ERR_MAGIC,
170+
experimental=True,
171+
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
172+
)
173+
174+
function_op(
175+
name="builtins.len",
176+
arg_types=[string_writer_rprimitive],
177+
return_type=short_int_rprimitive,
178+
c_function_name="CPyStringWriter_Len",
179+
error_kind=ERR_NEVER,
180+
experimental=True,
181+
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
182+
)
183+
184+
# StringWriter index adjustment - convert negative index to positive
185+
string_writer_adjust_index_op = custom_primitive_op(
186+
name="string_writer_adjust_index",
187+
arg_types=[string_writer_rprimitive, int64_rprimitive],
188+
return_type=int64_rprimitive,
189+
c_function_name="CPyStringWriter_AdjustIndex",
190+
error_kind=ERR_NEVER,
191+
experimental=True,
192+
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
193+
)
194+
195+
# StringWriter range check - check if index is in valid range
196+
string_writer_range_check_op = custom_primitive_op(
197+
name="string_writer_range_check",
198+
arg_types=[string_writer_rprimitive, int64_rprimitive],
199+
return_type=bool_rprimitive,
200+
c_function_name="CPyStringWriter_RangeCheck",
201+
error_kind=ERR_NEVER,
202+
experimental=True,
203+
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
204+
)
205+
206+
# StringWriter.__getitem__() - get character at index (no bounds checking)
207+
string_writer_get_item_unsafe_op = custom_primitive_op(
208+
name="string_writer_get_item",
209+
arg_types=[string_writer_rprimitive, int64_rprimitive],
210+
return_type=int32_rprimitive,
211+
c_function_name="CPyStringWriter_GetItem",
212+
error_kind=ERR_NEVER,
213+
experimental=True,
214+
dependencies=[LIBRT_STRINGS, STRING_WRITER_EXTRA_OPS],
215+
)

mypyc/test-data/irbuild-librt-strings.test

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,55 @@ L1:
9393
L2:
9494
CPyBytesWriter_SetItem(b, r0, x)
9595
return 1
96+
97+
[case testLibrtStrings_StringWriter_experimental_64bit]
98+
from librt.strings import StringWriter
99+
from mypy_extensions import i32, i64
100+
101+
def string_writer_basics() -> str:
102+
s = StringWriter()
103+
s.append(65)
104+
s.write('foo')
105+
return s.getvalue()
106+
def string_writer_len(s: StringWriter) -> i64:
107+
return len(s)
108+
def string_writer_get_item(s: StringWriter, i: i64) -> i32:
109+
return s[i]
110+
[out]
111+
def string_writer_basics():
112+
r0, s :: librt.strings.StringWriter
113+
r1 :: None
114+
r2 :: str
115+
r3 :: None
116+
r4 :: str
117+
L0:
118+
r0 = LibRTStrings_StringWriter_internal()
119+
s = r0
120+
r1 = CPyStringWriter_Append(s, 65)
121+
r2 = 'foo'
122+
r3 = LibRTStrings_StringWriter_write_internal(s, r2)
123+
r4 = LibRTStrings_StringWriter_getvalue_internal(s)
124+
return r4
125+
def string_writer_len(s):
126+
s :: librt.strings.StringWriter
127+
r0 :: short_int
128+
r1 :: i64
129+
L0:
130+
r0 = CPyStringWriter_Len(s)
131+
r1 = r0 >> 1
132+
return r1
133+
def string_writer_get_item(s, i):
134+
s :: librt.strings.StringWriter
135+
i, r0 :: i64
136+
r1, r2 :: bool
137+
r3 :: i32
138+
L0:
139+
r0 = CPyStringWriter_AdjustIndex(s, i)
140+
r1 = CPyStringWriter_RangeCheck(s, r0)
141+
if r1 goto L2 else goto L1 :: bool
142+
L1:
143+
r2 = raise IndexError('index out of range')
144+
unreachable
145+
L2:
146+
r3 = CPyStringWriter_GetItem(s, r0)
147+
return r3

0 commit comments

Comments
 (0)