Skip to content

Commit 683acd7

Browse files
committed
TEST: multi query batching test for memory interface from review (#845)
1 parent feea199 commit 683acd7

1 file changed

Lines changed: 139 additions & 186 deletions

File tree

tests/unit/memory/memory_interface/test_batching_scale.py

Lines changed: 139 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99

1010
import hashlib
1111
import uuid
12-
from unittest.mock import MagicMock, patch
13-
14-
from sqlalchemy import Column, Integer
15-
from sqlalchemy.orm import declarative_base
12+
from unittest.mock import patch
1613

1714
from pyrit.memory import MemoryInterface
18-
from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS, _batched_in_condition
15+
from pyrit.memory.memory_interface import _SQLITE_MAX_BIND_VARS
1916
from pyrit.models import MessagePiece, Score
2017

2118

@@ -51,157 +48,6 @@ def _create_score(message_piece_id: str) -> Score:
5148
)
5249

5350

54-
class TestBatchedInCondition:
55-
"""Tests for the _batched_in_condition helper function."""
56-
57-
def test_batched_in_condition_small_list(self):
58-
"""Test that small lists generate a simple IN condition."""
59-
Base = declarative_base()
60-
61-
class TestModel(Base):
62-
__tablename__ = "test"
63-
id = Column(Integer, primary_key=True)
64-
65-
values = list(range(10))
66-
condition = _batched_in_condition(TestModel.id, values)
67-
68-
# Should be a simple IN clause, not an OR
69-
assert "IN" in str(condition)
70-
assert "OR" not in str(condition)
71-
72-
def test_batched_in_condition_exact_batch_size(self):
73-
"""Test with exactly _SQLITE_MAX_BIND_VARS values."""
74-
Base = declarative_base()
75-
76-
class TestModel(Base):
77-
__tablename__ = "test"
78-
id = Column(Integer, primary_key=True)
79-
80-
values = list(range(_SQLITE_MAX_BIND_VARS))
81-
condition = _batched_in_condition(TestModel.id, values)
82-
83-
# Should still be a simple IN clause at the limit
84-
assert "IN" in str(condition)
85-
# May or may not have OR depending on implementation at boundary
86-
87-
def test_batched_in_condition_over_batch_size(self):
88-
"""Test with values exceeding _SQLITE_MAX_BIND_VARS."""
89-
Base = declarative_base()
90-
91-
class TestModel(Base):
92-
__tablename__ = "test"
93-
id = Column(Integer, primary_key=True)
94-
95-
values = list(range(_SQLITE_MAX_BIND_VARS + 100))
96-
condition = _batched_in_condition(TestModel.id, values)
97-
98-
# Should generate OR of multiple IN clauses
99-
condition_str = str(condition)
100-
assert "OR" in condition_str
101-
assert "IN" in condition_str
102-
103-
def test_batched_in_condition_double_batch_size(self):
104-
"""Test with double the batch size."""
105-
Base = declarative_base()
106-
107-
class TestModel(Base):
108-
__tablename__ = "test"
109-
id = Column(Integer, primary_key=True)
110-
111-
values = list(range(_SQLITE_MAX_BIND_VARS * 2))
112-
condition = _batched_in_condition(TestModel.id, values)
113-
114-
# Should generate multiple batches
115-
condition_str = str(condition)
116-
assert "OR" in condition_str
117-
# Should have at least 2 IN clauses
118-
assert condition_str.count("IN") >= 2
119-
120-
def test_batched_in_condition_three_batches(self):
121-
"""Test with enough values to require three batches."""
122-
Base = declarative_base()
123-
124-
class TestModel(Base):
125-
__tablename__ = "test"
126-
id = Column(Integer, primary_key=True)
127-
128-
values = list(range(_SQLITE_MAX_BIND_VARS * 2 + 100))
129-
condition = _batched_in_condition(TestModel.id, values)
130-
131-
condition_str = str(condition)
132-
assert "OR" in condition_str
133-
# Should have at least 3 IN clauses
134-
assert condition_str.count("IN") >= 3
135-
136-
def test_batched_in_condition_empty_list(self):
137-
"""Test with an empty list."""
138-
Base = declarative_base()
139-
140-
class TestModel(Base):
141-
__tablename__ = "test"
142-
id = Column(Integer, primary_key=True)
143-
144-
values = []
145-
condition = _batched_in_condition(TestModel.id, values)
146-
147-
# Empty list should still generate valid SQL
148-
condition_str = str(condition)
149-
assert "IN" in condition_str
150-
151-
def test_batched_in_condition_multiple_columns(self):
152-
"""Test combining multiple batched conditions with AND logic."""
153-
from sqlalchemy import String, and_
154-
155-
Base = declarative_base()
156-
157-
class TestModel(Base):
158-
__tablename__ = "test"
159-
id = Column(Integer, primary_key=True)
160-
name = Column(String)
161-
email = Column(String)
162-
163-
# Create multiple large value lists for different columns
164-
num_values = (_SQLITE_MAX_BIND_VARS * 2) + 100
165-
id_values = list(range(num_values))
166-
name_values = [f"name_{i}" for i in range(num_values)]
167-
email_values = [f"email_{i}@test.com" for i in range(num_values)]
168-
169-
# Create batched conditions for each column
170-
id_condition = _batched_in_condition(TestModel.id, id_values)
171-
name_condition = _batched_in_condition(TestModel.name, name_values)
172-
email_condition = _batched_in_condition(TestModel.email, email_values)
173-
174-
# Combine with AND (simulating real query behavior)
175-
combined_condition = and_(id_condition, name_condition, email_condition)
176-
combined_str = str(combined_condition)
177-
178-
# Verify all three columns are present in the query
179-
assert "id" in combined_str.lower()
180-
assert "name" in combined_str.lower()
181-
assert "email" in combined_str.lower()
182-
183-
# Verify OR clauses are present (batching is active)
184-
assert combined_str.count("OR") >= 3 # At least one OR per batched column
185-
186-
# Verify AND combines the conditions
187-
assert "AND" in combined_str
188-
189-
# Verify `id` count matches expected batches
190-
expected_id_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS
191-
actual_id_batches = combined_str.count("id IN")
192-
assert actual_id_batches == expected_id_batches
193-
194-
# Verify `name` count matches expected batches
195-
expected_name_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS
196-
actual_name_batches = combined_str.count("name IN")
197-
assert actual_name_batches == expected_name_batches
198-
199-
# Verify `email` count matches expected batches
200-
expected_email_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS
201-
actual_email_batches = combined_str.count("email IN")
202-
assert actual_email_batches == expected_email_batches
203-
204-
20551
class TestBatchingScale:
20652
"""Tests for batching when querying with many IDs."""
20753

@@ -370,9 +216,7 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
370216
"""Test batching with multiple parameters exceeding batch limit simultaneously."""
371217
# Create enough pieces to exceed batch limit with unique values
372218
num_pieces = _SQLITE_MAX_BIND_VARS + 200
373-
pieces = [
374-
_create_message_piece(original_value=f"original_value_{i}") for i in range(num_pieces)
375-
]
219+
pieces = [_create_message_piece(original_value=f"original_value_{i}") for i in range(num_pieces)]
376220

377221
# Add to memory
378222
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
@@ -396,8 +240,7 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
396240

397241
# Should return all pieces that match ALL conditions (intersection)
398242
assert len(results) == num_pieces, (
399-
f"Expected {num_pieces} results when filtering with multiple large parameters, "
400-
f"got {len(results)}"
243+
f"Expected {num_pieces} results when filtering with multiple large parameters, got {len(results)}"
401244
)
402245

403246
# Verify all returned pieces match all filter criteria
@@ -410,12 +253,10 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
410253
assert result_sha256 == set(all_sha256), "Returned SHA256 hashes don't match filter"
411254

412255
def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_instance: MemoryInterface):
413-
"""Test that batching generates correct queries when multiple params exceed limit."""
256+
"""Test that batching executes multiple separate queries and merges results correctly."""
414257
# Create pieces exceeding batch limit
415258
num_pieces = _SQLITE_MAX_BIND_VARS + 100
416-
pieces = [
417-
_create_message_piece(original_value=f"value_{i}") for i in range(num_pieces)
418-
]
259+
pieces = [_create_message_piece(original_value=f"value_{i}") for i in range(num_pieces)]
419260
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
420261

421262
# Get stored pieces
@@ -426,41 +267,33 @@ def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_
426267
# Mock _query_entries to track how it's called
427268
original_query = sqlite_instance._query_entries
428269
call_count = 0
429-
captured_conditions = []
430270

431271
def spy_query(*args, **kwargs):
432272
nonlocal call_count
433273
call_count += 1
434-
if "conditions" in kwargs and kwargs["conditions"] is not None:
435-
captured_conditions.append(str(kwargs["conditions"]))
436274
return original_query(*args, **kwargs)
437275

438276
with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query):
439-
results = sqlite_instance.get_message_pieces(
440-
prompt_ids=all_ids, original_values=all_original_values
441-
)
277+
results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, original_values=all_original_values)
442278

443279
# Should get all results despite batching
444280
assert len(results) == num_pieces
445281

446-
# Should have been called (could be 1 call with OR conditions)
447-
assert call_count >= 1
448-
449-
# Verify query conditions include both filters
450-
if captured_conditions:
451-
combined_conditions = " ".join(captured_conditions)
452-
# Both column filters should be present in the query
453-
assert "id" in combined_conditions.lower() or "prompt" in combined_conditions.lower()
454-
assert "original_value" in combined_conditions.lower()
282+
# With the new batching approach, multiple separate queries should be executed
283+
# when the primary batch parameter exceeds _SQLITE_MAX_BIND_VARS
284+
# Expected: ceil(num_pieces / _SQLITE_MAX_BIND_VARS) = 2 queries
285+
expected_min_calls = (num_pieces + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS
286+
assert call_count >= expected_min_calls, (
287+
f"Expected at least {expected_min_calls} separate queries for {num_pieces} items, "
288+
f"but only got {call_count} calls"
289+
)
455290

456291
def test_get_message_pieces_triple_large_params_preserves_intersection(self, sqlite_instance: MemoryInterface):
457292
"""Test that filtering with 3 large parameter lists returns correct intersection."""
458293
# Create a large set of pieces
459294
total_pieces = _SQLITE_MAX_BIND_VARS + 150
460295
pieces = [
461-
_create_message_piece(
462-
conversation_id=str(uuid.uuid4()), original_value=f"content_{i}"
463-
)
296+
_create_message_piece(conversation_id=str(uuid.uuid4()), original_value=f"content_{i}")
464297
for i in range(total_pieces)
465298
]
466299
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
@@ -487,10 +320,130 @@ def test_get_message_pieces_triple_large_params_preserves_intersection(self, sql
487320
)
488321

489322
# Should return only the intersection (subset_size items)
490-
assert (
491-
len(results) == subset_size
492-
), f"Expected {subset_size} results from intersection, got {len(results)}"
323+
assert len(results) == subset_size, f"Expected {subset_size} results from intersection, got {len(results)}"
493324

494325
# Verify all results have SHA256 in the filter list
495326
result_sha256 = {r.converted_value_sha256 for r in results}
496327
assert result_sha256.issubset(set(filter_sha256)), "Results contain unexpected SHA256 values"
328+
329+
330+
class TestExecuteBatchedQuery:
331+
"""Tests for the _execute_batched_query helper method."""
332+
333+
def test_execute_batched_query_small_list_single_query(self, sqlite_instance: MemoryInterface):
334+
"""Test that small lists execute a single query."""
335+
# Create a small number of pieces (under batch limit)
336+
num_pieces = 10
337+
pieces = [_create_message_piece() for _ in range(num_pieces)]
338+
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
339+
340+
# Track query calls
341+
original_query = sqlite_instance._query_entries
342+
call_count = 0
343+
344+
def spy_query(*args, **kwargs):
345+
nonlocal call_count
346+
call_count += 1
347+
return original_query(*args, **kwargs)
348+
349+
with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query):
350+
all_ids = [piece.id for piece in pieces]
351+
results = sqlite_instance.get_message_pieces(prompt_ids=all_ids)
352+
353+
# Should be a single query for small lists
354+
assert call_count == 1
355+
assert len(results) == num_pieces
356+
357+
def test_execute_batched_query_large_list_multiple_queries(self, sqlite_instance: MemoryInterface):
358+
"""Test that large lists execute multiple separate queries."""
359+
# Create pieces exceeding batch limit
360+
num_pieces = _SQLITE_MAX_BIND_VARS * 3 # 3 batches needed
361+
pieces = [_create_message_piece() for _ in range(num_pieces)]
362+
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
363+
364+
# Track query calls
365+
original_query = sqlite_instance._query_entries
366+
call_count = 0
367+
368+
def spy_query(*args, **kwargs):
369+
nonlocal call_count
370+
call_count += 1
371+
return original_query(*args, **kwargs)
372+
373+
with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query):
374+
all_ids = [piece.id for piece in pieces]
375+
results = sqlite_instance.get_message_pieces(prompt_ids=all_ids)
376+
377+
# Should execute 3 separate queries (one per batch)
378+
assert call_count == 3, f"Expected 3 queries for 3 batches, got {call_count}"
379+
assert len(results) == num_pieces
380+
381+
def test_execute_batched_query_deduplicates_results(self, sqlite_instance: MemoryInterface):
382+
"""Test that batched queries properly deduplicate results."""
383+
# Create pieces
384+
num_pieces = 50
385+
pieces = [_create_message_piece() for _ in range(num_pieces)]
386+
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
387+
388+
# Query with the same IDs repeated (should still return unique results)
389+
all_ids = [piece.id for piece in pieces]
390+
# Query twice with same IDs - results should still be unique
391+
results = sqlite_instance.get_message_pieces(prompt_ids=all_ids)
392+
393+
assert len(results) == num_pieces
394+
# Verify no duplicates
395+
result_ids = [r.id for r in results]
396+
assert len(result_ids) == len(set(result_ids)), "Results contain duplicate entries"
397+
398+
def test_execute_batched_query_exact_batch_boundary(self, sqlite_instance: MemoryInterface):
399+
"""Test querying with exactly the batch limit (edge case)."""
400+
num_pieces = _SQLITE_MAX_BIND_VARS
401+
pieces = [_create_message_piece() for _ in range(num_pieces)]
402+
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
403+
404+
# Track query calls
405+
original_query = sqlite_instance._query_entries
406+
call_count = 0
407+
408+
def spy_query(*args, **kwargs):
409+
nonlocal call_count
410+
call_count += 1
411+
return original_query(*args, **kwargs)
412+
413+
with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query):
414+
all_ids = [piece.id for piece in pieces]
415+
results = sqlite_instance.get_message_pieces(prompt_ids=all_ids)
416+
417+
# Exactly at the limit should still be a single query
418+
assert call_count == 1, f"Expected 1 query at exact batch limit, got {call_count}"
419+
assert len(results) == num_pieces
420+
421+
def test_batching_with_scores_exceeds_limit(self, sqlite_instance: MemoryInterface):
422+
"""Test that get_scores handles large numbers of score IDs correctly."""
423+
# Create message pieces and scores exceeding the limit
424+
num_items = _SQLITE_MAX_BIND_VARS * 2 + 50
425+
pieces = [_create_message_piece() for _ in range(num_items)]
426+
sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces)
427+
428+
scores = [_create_score(str(piece.id)) for piece in pieces]
429+
sqlite_instance.add_scores_to_memory(scores=scores)
430+
431+
# Query with all score IDs
432+
all_score_ids = [str(score.id) for score in scores]
433+
434+
# Track query calls
435+
original_query = sqlite_instance._query_entries
436+
call_count = 0
437+
438+
def spy_query(*args, **kwargs):
439+
nonlocal call_count
440+
call_count += 1
441+
return original_query(*args, **kwargs)
442+
443+
with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query):
444+
results = sqlite_instance.get_scores(score_ids=all_score_ids)
445+
446+
# Should execute multiple queries
447+
expected_calls = (num_items + _SQLITE_MAX_BIND_VARS - 1) // _SQLITE_MAX_BIND_VARS
448+
assert call_count == expected_calls, f"Expected {expected_calls} queries, got {call_count}"
449+
assert len(results) == num_items

0 commit comments

Comments
 (0)