Skip to content

Commit 20e7ed6

Browse files
committed
Add unit tests for column_encryption_policy optimization
Cover both the pure-Python path (ResultMessage.recv_results_rows) and the Cython fast path (ListParser, NumpyParser) with and without column encryption policy.
1 parent dacb79c commit 20e7ed6

1 file changed

Lines changed: 252 additions & 2 deletions

File tree

tests/unit/test_protocol.py

Lines changed: 252 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import io
1516
import unittest
1617

1718
from unittest.mock import Mock
1819

1920
from cassandra import ProtocolVersion, UnsupportedOperation
21+
from cassandra.cqltypes import Int32Type, UTF8Type
2022
from cassandra.protocol import (
2123
PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation,
2224
_PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG,
2325
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
24-
BatchMessage
26+
BatchMessage,
27+
ResultMessage, RESULT_KIND_ROWS
2528
)
2629
from cassandra.query import BatchType
27-
from cassandra.marshal import uint32_unpack
30+
from cassandra.marshal import uint32_unpack, int32_pack
2831
from cassandra.cluster import ContinuousPagingOptions
2932
import pytest
3033

34+
from cassandra.policies import ColDesc
3135

3236
class MessageTest(unittest.TestCase):
3337

@@ -189,3 +193,249 @@ def test_batch_message_with_keyspace(self):
189193
(b'\x00\x03',),
190194
(b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',))
191195
)
196+
197+
class ResultTest(unittest.TestCase):
198+
"""
199+
Tests to verify the optimization of column_encryption_policy checks
200+
in recv_results_rows. The optimization checks if the policy exists once
201+
per result message, avoiding the redundant 'column_encryption_policy and ...'
202+
check for every value.
203+
"""
204+
205+
def _create_mock_result_metadata(self):
206+
"""Create mock result metadata for testing"""
207+
return [
208+
('keyspace1', 'table1', 'col1', Int32Type),
209+
('keyspace1', 'table1', 'col2', UTF8Type),
210+
]
211+
212+
def _create_mock_result_message(self):
213+
"""Create a mock result message with data"""
214+
msg = ResultMessage(kind=RESULT_KIND_ROWS)
215+
msg.column_metadata = self._create_mock_result_metadata()
216+
msg.recv_results_metadata = Mock()
217+
msg.recv_row = Mock(side_effect=[
218+
[int32_pack(42), b'hello'],
219+
[int32_pack(100), b'world'],
220+
])
221+
return msg
222+
223+
def _create_mock_stream(self):
224+
"""Create a mock stream for reading rows"""
225+
# Pack rowcount (2 rows)
226+
data = int32_pack(2)
227+
return io.BytesIO(data)
228+
229+
def test_decode_without_encryption_policy(self):
230+
"""
231+
Test that decoding works correctly without column encryption policy.
232+
This should use the optimized simple path.
233+
"""
234+
msg = self._create_mock_result_message()
235+
f = self._create_mock_stream()
236+
237+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)
238+
239+
# Verify results
240+
self.assertEqual(len(msg.parsed_rows), 2)
241+
self.assertEqual(msg.parsed_rows[0][0], 42)
242+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
243+
self.assertEqual(msg.parsed_rows[1][0], 100)
244+
self.assertEqual(msg.parsed_rows[1][1], 'world')
245+
246+
def test_decode_with_encryption_policy_no_encrypted_columns(self):
247+
"""
248+
Test that decoding works with encryption policy when no columns are encrypted.
249+
"""
250+
msg = self._create_mock_result_message()
251+
f = self._create_mock_stream()
252+
253+
# Create mock encryption policy that has no encrypted columns
254+
mock_policy = Mock()
255+
mock_policy.contains_column = Mock(return_value=False)
256+
257+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
258+
259+
# Verify results
260+
self.assertEqual(len(msg.parsed_rows), 2)
261+
self.assertEqual(msg.parsed_rows[0][0], 42)
262+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
263+
264+
# Verify contains_column was called for each value (but policy existence check happens once)
265+
# Should be called 4 times (2 rows × 2 columns)
266+
self.assertEqual(mock_policy.contains_column.call_count, 4)
267+
268+
def test_decode_with_encryption_policy_with_encrypted_column(self):
269+
"""
270+
Test that decoding works with encryption policy when one column is encrypted.
271+
"""
272+
msg = self._create_mock_result_message()
273+
f = self._create_mock_stream()
274+
275+
# Create mock encryption policy where first column is encrypted
276+
mock_policy = Mock()
277+
def contains_column_side_effect(col_desc):
278+
return col_desc.col == 'col1'
279+
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
280+
mock_policy.column_type = Mock(return_value=Int32Type)
281+
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)
282+
283+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
284+
285+
# Verify results
286+
self.assertEqual(len(msg.parsed_rows), 2)
287+
self.assertEqual(msg.parsed_rows[0][0], 42)
288+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
289+
290+
# Verify contains_column was called for each value (but policy existence check happens once)
291+
# Should be called 4 times (2 rows × 2 columns)
292+
self.assertEqual(mock_policy.contains_column.call_count, 4)
293+
294+
# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
295+
self.assertEqual(mock_policy.decrypt.call_count, 2)
296+
297+
def test_optimization_efficiency(self):
298+
"""
299+
Verify that the optimization checks policy existence once per result message.
300+
The key optimization is checking 'if column_encryption_policy:' once,
301+
rather than 'column_encryption_policy and ...' for every value.
302+
"""
303+
msg = self._create_mock_result_message()
304+
305+
# Create more rows to make the check pattern clear
306+
msg.recv_row = Mock(side_effect=[
307+
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
308+
])
309+
310+
# Create mock stream with 100 rows
311+
f = io.BytesIO(int32_pack(100))
312+
313+
mock_policy = Mock()
314+
mock_policy.contains_column = Mock(return_value=False)
315+
316+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
317+
318+
# With optimization: policy existence checked once, contains_column called per value
319+
# = 100 rows * 2 columns = 200 calls to contains_column
320+
# The key is we avoid checking 'column_encryption_policy and ...' 200 times
321+
self.assertEqual(mock_policy.contains_column.call_count, 200,
322+
"contains_column should be called for each value when policy exists")
323+
324+
325+
class CythonParserTest(unittest.TestCase):
326+
"""
327+
Tests for the Cython fast-path parsers (ListParser, TupleRowParser)
328+
to verify the column_encryption_policy optimization in obj_parser.pyx.
329+
"""
330+
331+
def _build_binary_rows(self, rows):
332+
"""
333+
Build a binary buffer containing encoded rows.
334+
335+
Each row is a list of (size, raw_bytes) pairs.
336+
Prepends a 4-byte big-endian row count.
337+
"""
338+
import struct
339+
data = struct.pack('>i', len(rows))
340+
for row in rows:
341+
for raw in row:
342+
if raw is None:
343+
data += struct.pack('>i', -1) # NULL
344+
else:
345+
data += struct.pack('>i', len(raw)) + raw
346+
return data
347+
348+
def _make_parse_desc(self, column_encryption_policy=None):
349+
from cassandra.parsing import ParseDesc
350+
from cassandra.deserializers import make_deserializers
351+
from cassandra.policies import ColDesc
352+
353+
colnames = ['col1', 'col2']
354+
coltypes = [Int32Type, UTF8Type]
355+
coldescs = [ColDesc('ks', 'tbl', 'col1'), ColDesc('ks', 'tbl', 'col2')]
356+
deserializers = make_deserializers(coltypes)
357+
return ParseDesc(colnames, coltypes, column_encryption_policy,
358+
coldescs, deserializers, ProtocolVersion.V4)
359+
360+
def _int32_bytes(self, val):
361+
import struct
362+
return struct.pack('>i', val)
363+
364+
def test_list_parser_without_encryption(self):
365+
"""ListParser decodes rows correctly without encryption policy."""
366+
from cassandra.bytesio import BytesIOReader
367+
from cassandra.obj_parser import ListParser
368+
369+
desc = self._make_parse_desc(column_encryption_policy=None)
370+
data = self._build_binary_rows([
371+
[self._int32_bytes(42), b'hello'],
372+
[self._int32_bytes(100), b'world'],
373+
])
374+
reader = BytesIOReader(data)
375+
result = ListParser().parse_rows(reader, desc)
376+
377+
self.assertEqual(len(result), 2)
378+
self.assertEqual(result[0], (42, 'hello'))
379+
self.assertEqual(result[1], (100, 'world'))
380+
381+
def test_list_parser_with_encryption_no_encrypted_cols(self):
382+
"""ListParser decodes rows correctly when policy exists but no columns are encrypted."""
383+
from cassandra.bytesio import BytesIOReader
384+
from cassandra.obj_parser import ListParser
385+
386+
mock_policy = Mock()
387+
mock_policy.contains_column = Mock(return_value=False)
388+
389+
desc = self._make_parse_desc(column_encryption_policy=mock_policy)
390+
data = self._build_binary_rows([
391+
[self._int32_bytes(42), b'hello'],
392+
])
393+
reader = BytesIOReader(data)
394+
result = ListParser().parse_rows(reader, desc)
395+
396+
self.assertEqual(len(result), 1)
397+
self.assertEqual(result[0], (42, 'hello'))
398+
# 1 row * 2 columns = 2 calls
399+
self.assertEqual(mock_policy.contains_column.call_count, 2)
400+
401+
def test_list_parser_with_encrypted_column(self):
402+
"""ListParser decodes rows with an encrypted column (mock decrypt is identity)."""
403+
from cassandra.bytesio import BytesIOReader
404+
from cassandra.obj_parser import ListParser
405+
from cassandra.deserializers import find_deserializer
406+
407+
mock_policy = Mock()
408+
mock_policy.contains_column = Mock(
409+
side_effect=lambda cd: cd.col == 'col1')
410+
mock_policy.column_type = Mock(return_value=Int32Type)
411+
# decrypt returns the raw bytes unchanged (identity)
412+
mock_policy.decrypt = Mock(side_effect=lambda cd, val: val)
413+
414+
desc = self._make_parse_desc(column_encryption_policy=mock_policy)
415+
data = self._build_binary_rows([
416+
[self._int32_bytes(7), b'test'],
417+
])
418+
reader = BytesIOReader(data)
419+
result = ListParser().parse_rows(reader, desc)
420+
421+
self.assertEqual(len(result), 1)
422+
self.assertEqual(result[0], (7, 'test'))
423+
self.assertEqual(mock_policy.decrypt.call_count, 1)
424+
self.assertEqual(mock_policy.column_type.call_count, 1)
425+
426+
def test_numpy_parser_rejects_encryption(self):
427+
"""NumpyParser raises NotImplementedError when column_encryption_policy is set."""
428+
try:
429+
from cassandra.numpy_parser import NumpyParser
430+
except ImportError:
431+
self.skipTest("NumPy or numpy_parser not available")
432+
433+
from cassandra.bytesio import BytesIOReader
434+
435+
mock_policy = Mock()
436+
desc = self._make_parse_desc(column_encryption_policy=mock_policy)
437+
data = self._build_binary_rows([[self._int32_bytes(1), b'x']])
438+
reader = BytesIOReader(data)
439+
440+
with self.assertRaises(NotImplementedError):
441+
NumpyParser().parse_rows(reader, desc)

0 commit comments

Comments
 (0)