Skip to content

Commit 8071a81

Browse files
committed
perf: modernize varint_pack/varint_unpack with int.to_bytes/int.from_bytes
Replace the manual string-formatting hex conversion in varint_unpack() and the byte-by-byte bytearray loop in varint_pack() with Python 3 builtins int.from_bytes() and int.to_bytes(). varint_unpack used '%02x' formatting per byte, str.join, then int(..., 16) to parse back — O(n) string allocations. int.from_bytes is a single C-level call. varint_pack used a while loop appending individual bytes to a bytearray, then reversing. int.to_bytes computes the result in one C call. Also fixes the Cython path in cython_marshal.pyx which had the same slow pattern with a TODO comment to optimize. Adapted from PR #689 (varint_unpack) with new varint_pack implementation. varint_pack medium: 643 -> 90 ns/call (7.1x faster) varint_pack large: 1109 -> 96 ns/call (11.6x faster) varint_unpack medium: 1086 -> 115 ns/call (9.4x faster) varint_unpack large: 1940 -> 146 ns/call (13.3x faster)
1 parent 4ac0b86 commit 8071a81

2 files changed

Lines changed: 45 additions & 62 deletions

File tree

cassandra/cython_marshal.pyx

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,5 @@ cdef varint_unpack(Buffer *term):
5555
"""Unpack a variable-sized integer"""
5656
return varint_unpack_py3(to_bytes(term))
5757

58-
# TODO: Optimize these two functions
5958
cdef varint_unpack_py3(bytes term):
60-
val = int(''.join(["%02x" % i for i in term]), 16)
61-
if (term[0] & 128) != 0:
62-
shift = len(term) * 8 # * Note below
63-
val -= 1 << shift
64-
return val
65-
66-
# * Note *
67-
# '1 << (len(term) * 8)' Cython tries to do native
68-
# integer shifts, which overflows. We need this to
69-
# emulate Python shifting, which will expand the long
70-
# to accommodate
59+
return int.from_bytes(term, byteorder='big', signed=True)

cassandra/marshal.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,59 +21,48 @@ def _make_packer(format_string):
2121
unpack = lambda s: packer.unpack(s)[0]
2222
return pack, unpack
2323

24-
int64_pack, int64_unpack = _make_packer('>q')
25-
int32_pack, int32_unpack = _make_packer('>i')
26-
int16_pack, int16_unpack = _make_packer('>h')
27-
int8_pack, int8_unpack = _make_packer('>b')
28-
uint64_pack, uint64_unpack = _make_packer('>Q')
29-
uint32_pack, uint32_unpack = _make_packer('>I')
30-
uint32_le_pack, uint32_le_unpack = _make_packer('<I')
31-
uint16_pack, uint16_unpack = _make_packer('>H')
32-
uint8_pack, uint8_unpack = _make_packer('>B')
33-
float_pack, float_unpack = _make_packer('>f')
34-
double_pack, double_unpack = _make_packer('>d')
24+
25+
int64_pack, int64_unpack = _make_packer(">q")
26+
int32_pack, int32_unpack = _make_packer(">i")
27+
int16_pack, int16_unpack = _make_packer(">h")
28+
int8_pack, int8_unpack = _make_packer(">b")
29+
uint64_pack, uint64_unpack = _make_packer(">Q")
30+
uint32_pack, uint32_unpack = _make_packer(">I")
31+
uint32_le_pack, uint32_le_unpack = _make_packer("<I")
32+
uint16_pack, uint16_unpack = _make_packer(">H")
33+
uint8_pack, uint8_unpack = _make_packer(">B")
34+
float_pack, float_unpack = _make_packer(">f")
35+
double_pack, double_unpack = _make_packer(">d")
3536

3637
# in protocol version 3 and higher, the stream ID is two bytes
37-
v3_header_struct = struct.Struct('>BBhB')
38+
v3_header_struct = struct.Struct(">BBhB")
3839
v3_header_pack = v3_header_struct.pack
3940
v3_header_unpack = v3_header_struct.unpack
4041

4142

4243
def varint_unpack(term):
43-
val = int(''.join("%02x" % i for i in term), 16)
44-
if (term[0] & 128) != 0:
45-
len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
46-
val -= 1 << (len_term * 8)
47-
return val
44+
return int.from_bytes(term, byteorder="big", signed=True)
4845

4946

5047
def bit_length(n):
5148
return int.bit_length(n)
5249

5350

5451
def varint_pack(big):
55-
pos = True
5652
if big == 0:
57-
return b'\x00'
53+
return b"\x00"
5854
if big < 0:
59-
bytelength = bit_length(abs(big) - 1) // 8 + 1
60-
big = (1 << bytelength * 8) + big
61-
pos = False
62-
revbytes = bytearray()
63-
while big > 0:
64-
revbytes.append(big & 0xff)
65-
big >>= 8
66-
if pos and revbytes[-1] & 0x80:
67-
revbytes.append(0)
68-
revbytes.reverse()
69-
return bytes(revbytes)
55+
byte_length = (-big - 1).bit_length() // 8 + 1
56+
else:
57+
byte_length = (big.bit_length() + 8) // 8
58+
return big.to_bytes(byte_length, byteorder="big", signed=True)
7059

7160

72-
point_be = struct.Struct('>dd')
73-
point_le = struct.Struct('<dd')
61+
point_be = struct.Struct(">dd")
62+
point_le = struct.Struct("<dd")
7463

75-
circle_be = struct.Struct('>ddd')
76-
circle_le = struct.Struct('<ddd')
64+
circle_be = struct.Struct(">ddd")
65+
circle_le = struct.Struct("<ddd")
7766

7867

7968
def encode_zig_zag(n):
@@ -93,19 +82,20 @@ def vints_unpack(term): # noqa
9382
if (first_byte & 128) == 0:
9483
val = first_byte
9584
else:
96-
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
97-
val = first_byte & (0xff >> num_extra_bytes)
85+
num_extra_bytes = 8 - (~first_byte & 0xFF).bit_length()
86+
val = first_byte & (0xFF >> num_extra_bytes)
9887
end = n + num_extra_bytes
9988
while n < end:
10089
n += 1
10190
val <<= 8
102-
val |= term[n] & 0xff
91+
val |= term[n] & 0xFF
10392

10493
n += 1
10594
values.append(decode_zig_zag(val))
10695

10796
return tuple(values)
10897

98+
10999
def vints_pack(values):
110100
revbytes = bytearray()
111101
values = [int(v) for v in values[::-1]]
@@ -120,39 +110,43 @@ def vints_pack(values):
120110
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
121111
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
122112
reserved_bits = num_extra_bytes + 1
123-
while num_bits > (8-(reserved_bits)):
113+
while num_bits > (8 - (reserved_bits)):
124114
num_extra_bytes += 1
125115
num_bits -= 8
126116
reserved_bits = min(num_extra_bytes + 1, 8)
127-
revbytes.append(v & 0xff)
117+
revbytes.append(v & 0xFF)
128118
v >>= 8
129119

130120
if num_extra_bytes > 8:
131-
raise ValueError('Value %d is too big and cannot be encoded as vint' % value)
121+
raise ValueError(
122+
"Value %d is too big and cannot be encoded as vint" % value
123+
)
132124

133125
# We can now store the last bits in the first byte
134126
n = 8 - num_extra_bytes
135-
v |= (0xff >> n << n)
127+
v |= 0xFF >> n << n
136128
revbytes.append(abs(v))
137129

138130
revbytes.reverse()
139131
return bytes(revbytes)
140132

133+
141134
def uvint_unpack(bytes):
142135
first_byte = bytes[0]
143136

144137
if (first_byte & 128) == 0:
145-
return (first_byte,1)
138+
return (first_byte, 1)
146139

147-
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
148-
rv = first_byte & (0xff >> num_extra_bytes)
149-
for idx in range(1,num_extra_bytes + 1):
140+
num_extra_bytes = 8 - (~first_byte & 0xFF).bit_length()
141+
rv = first_byte & (0xFF >> num_extra_bytes)
142+
for idx in range(1, num_extra_bytes + 1):
150143
new_byte = bytes[idx]
151144
rv <<= 8
152-
rv |= new_byte & 0xff
145+
rv |= new_byte & 0xFF
153146

154147
return (rv, num_extra_bytes + 1)
155148

149+
156150
def uvint_pack(val):
157151
rv = bytearray()
158152
if val < 128:
@@ -165,19 +159,19 @@ def uvint_pack(val):
165159
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
166160
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
167161
reserved_bits = num_extra_bytes + 1
168-
while num_bits > (8-(reserved_bits)):
162+
while num_bits > (8 - (reserved_bits)):
169163
num_extra_bytes += 1
170164
num_bits -= 8
171165
reserved_bits = min(num_extra_bytes + 1, 8)
172-
rv.append(v & 0xff)
166+
rv.append(v & 0xFF)
173167
v >>= 8
174168

175169
if num_extra_bytes > 8:
176-
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)
170+
raise ValueError("Value %d is too big and cannot be encoded as vint" % val)
177171

178172
# We can now store the last bits in the first byte
179173
n = 8 - num_extra_bytes
180-
v |= (0xff >> n << n)
174+
v |= 0xFF >> n << n
181175
rv.append(abs(v))
182176

183177
rv.reverse()

0 commit comments

Comments
 (0)