99
1010import hashlib
1111import uuid
12- from unittest .mock import MagicMock , patch
13-
14- from sqlalchemy import Column , Integer
15- from sqlalchemy .orm import declarative_base
12+ from unittest .mock import patch
1613
1714from pyrit .memory import MemoryInterface
18- from pyrit .memory .memory_interface import _SQLITE_MAX_BIND_VARS , _batched_in_condition
15+ from pyrit .memory .memory_interface import _SQLITE_MAX_BIND_VARS
1916from pyrit .models import MessagePiece , Score
2017
2118
@@ -51,157 +48,6 @@ def _create_score(message_piece_id: str) -> Score:
5148 )
5249
5350
54- class TestBatchedInCondition :
55- """Tests for the _batched_in_condition helper function."""
56-
57- def test_batched_in_condition_small_list (self ):
58- """Test that small lists generate a simple IN condition."""
59- Base = declarative_base ()
60-
61- class TestModel (Base ):
62- __tablename__ = "test"
63- id = Column (Integer , primary_key = True )
64-
65- values = list (range (10 ))
66- condition = _batched_in_condition (TestModel .id , values )
67-
68- # Should be a simple IN clause, not an OR
69- assert "IN" in str (condition )
70- assert "OR" not in str (condition )
71-
72- def test_batched_in_condition_exact_batch_size (self ):
73- """Test with exactly _SQLITE_MAX_BIND_VARS values."""
74- Base = declarative_base ()
75-
76- class TestModel (Base ):
77- __tablename__ = "test"
78- id = Column (Integer , primary_key = True )
79-
80- values = list (range (_SQLITE_MAX_BIND_VARS ))
81- condition = _batched_in_condition (TestModel .id , values )
82-
83- # Should still be a simple IN clause at the limit
84- assert "IN" in str (condition )
85- # May or may not have OR depending on implementation at boundary
86-
87- def test_batched_in_condition_over_batch_size (self ):
88- """Test with values exceeding _SQLITE_MAX_BIND_VARS."""
89- Base = declarative_base ()
90-
91- class TestModel (Base ):
92- __tablename__ = "test"
93- id = Column (Integer , primary_key = True )
94-
95- values = list (range (_SQLITE_MAX_BIND_VARS + 100 ))
96- condition = _batched_in_condition (TestModel .id , values )
97-
98- # Should generate OR of multiple IN clauses
99- condition_str = str (condition )
100- assert "OR" in condition_str
101- assert "IN" in condition_str
102-
103- def test_batched_in_condition_double_batch_size (self ):
104- """Test with double the batch size."""
105- Base = declarative_base ()
106-
107- class TestModel (Base ):
108- __tablename__ = "test"
109- id = Column (Integer , primary_key = True )
110-
111- values = list (range (_SQLITE_MAX_BIND_VARS * 2 ))
112- condition = _batched_in_condition (TestModel .id , values )
113-
114- # Should generate multiple batches
115- condition_str = str (condition )
116- assert "OR" in condition_str
117- # Should have at least 2 IN clauses
118- assert condition_str .count ("IN" ) >= 2
119-
120- def test_batched_in_condition_three_batches (self ):
121- """Test with enough values to require three batches."""
122- Base = declarative_base ()
123-
124- class TestModel (Base ):
125- __tablename__ = "test"
126- id = Column (Integer , primary_key = True )
127-
128- values = list (range (_SQLITE_MAX_BIND_VARS * 2 + 100 ))
129- condition = _batched_in_condition (TestModel .id , values )
130-
131- condition_str = str (condition )
132- assert "OR" in condition_str
133- # Should have at least 3 IN clauses
134- assert condition_str .count ("IN" ) >= 3
135-
136- def test_batched_in_condition_empty_list (self ):
137- """Test with an empty list."""
138- Base = declarative_base ()
139-
140- class TestModel (Base ):
141- __tablename__ = "test"
142- id = Column (Integer , primary_key = True )
143-
144- values = []
145- condition = _batched_in_condition (TestModel .id , values )
146-
147- # Empty list should still generate valid SQL
148- condition_str = str (condition )
149- assert "IN" in condition_str
150-
151- def test_batched_in_condition_multiple_columns (self ):
152- """Test combining multiple batched conditions with AND logic."""
153- from sqlalchemy import String , and_
154-
155- Base = declarative_base ()
156-
157- class TestModel (Base ):
158- __tablename__ = "test"
159- id = Column (Integer , primary_key = True )
160- name = Column (String )
161- email = Column (String )
162-
163- # Create multiple large value lists for different columns
164- num_values = (_SQLITE_MAX_BIND_VARS * 2 ) + 100
165- id_values = list (range (num_values ))
166- name_values = [f"name_{ i } " for i in range (num_values )]
167- email_values = [f"email_{ i } @test.com" for i in range (num_values )]
168-
169- # Create batched conditions for each column
170- id_condition = _batched_in_condition (TestModel .id , id_values )
171- name_condition = _batched_in_condition (TestModel .name , name_values )
172- email_condition = _batched_in_condition (TestModel .email , email_values )
173-
174- # Combine with AND (simulating real query behavior)
175- combined_condition = and_ (id_condition , name_condition , email_condition )
176- combined_str = str (combined_condition )
177-
178- # Verify all three columns are present in the query
179- assert "id" in combined_str .lower ()
180- assert "name" in combined_str .lower ()
181- assert "email" in combined_str .lower ()
182-
183- # Verify OR clauses are present (batching is active)
184- assert combined_str .count ("OR" ) >= 3 # At least one OR per batched column
185-
186- # Verify AND combines the conditions
187- assert "AND" in combined_str
188-
189- # Verify `id` count matches expected batches
190- expected_id_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1 ) // _SQLITE_MAX_BIND_VARS
191- actual_id_batches = combined_str .count ("id IN" )
192- assert actual_id_batches == expected_id_batches
193-
194- # Verify `name` count matches expected batches
195- expected_name_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1 ) // _SQLITE_MAX_BIND_VARS
196- actual_name_batches = combined_str .count ("name IN" )
197- assert actual_name_batches == expected_name_batches
198-
199- # Verify `email` count matches expected batches
200- expected_email_batches = (num_values + _SQLITE_MAX_BIND_VARS - 1 ) // _SQLITE_MAX_BIND_VARS
201- actual_email_batches = combined_str .count ("email IN" )
202- assert actual_email_batches == expected_email_batches
203-
204-
20551class TestBatchingScale :
20652 """Tests for batching when querying with many IDs."""
20753
@@ -370,9 +216,7 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
370216 """Test batching with multiple parameters exceeding batch limit simultaneously."""
371217 # Create enough pieces to exceed batch limit with unique values
372218 num_pieces = _SQLITE_MAX_BIND_VARS + 200
373- pieces = [
374- _create_message_piece (original_value = f"original_value_{ i } " ) for i in range (num_pieces )
375- ]
219+ pieces = [_create_message_piece (original_value = f"original_value_{ i } " ) for i in range (num_pieces )]
376220
377221 # Add to memory
378222 sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
@@ -396,8 +240,7 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
396240
397241 # Should return all pieces that match ALL conditions (intersection)
398242 assert len (results ) == num_pieces , (
399- f"Expected { num_pieces } results when filtering with multiple large parameters, "
400- f"got { len (results )} "
243+ f"Expected { num_pieces } results when filtering with multiple large parameters, got { len (results )} "
401244 )
402245
403246 # Verify all returned pieces match all filter criteria
@@ -410,12 +253,10 @@ def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_in
410253 assert result_sha256 == set (all_sha256 ), "Returned SHA256 hashes don't match filter"
411254
412255 def test_get_message_pieces_multiple_batched_params_with_query_spy (self , sqlite_instance : MemoryInterface ):
413- """Test that batching generates correct queries when multiple params exceed limit ."""
256+ """Test that batching executes multiple separate queries and merges results correctly ."""
414257 # Create pieces exceeding batch limit
415258 num_pieces = _SQLITE_MAX_BIND_VARS + 100
416- pieces = [
417- _create_message_piece (original_value = f"value_{ i } " ) for i in range (num_pieces )
418- ]
259+ pieces = [_create_message_piece (original_value = f"value_{ i } " ) for i in range (num_pieces )]
419260 sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
420261
421262 # Get stored pieces
@@ -426,41 +267,33 @@ def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_
426267 # Mock _query_entries to track how it's called
427268 original_query = sqlite_instance ._query_entries
428269 call_count = 0
429- captured_conditions = []
430270
431271 def spy_query (* args , ** kwargs ):
432272 nonlocal call_count
433273 call_count += 1
434- if "conditions" in kwargs and kwargs ["conditions" ] is not None :
435- captured_conditions .append (str (kwargs ["conditions" ]))
436274 return original_query (* args , ** kwargs )
437275
438276 with patch .object (sqlite_instance , "_query_entries" , side_effect = spy_query ):
439- results = sqlite_instance .get_message_pieces (
440- prompt_ids = all_ids , original_values = all_original_values
441- )
277+ results = sqlite_instance .get_message_pieces (prompt_ids = all_ids , original_values = all_original_values )
442278
443279 # Should get all results despite batching
444280 assert len (results ) == num_pieces
445281
446- # Should have been called (could be 1 call with OR conditions)
447- assert call_count >= 1
448-
449- # Verify query conditions include both filters
450- if captured_conditions :
451- combined_conditions = " " .join (captured_conditions )
452- # Both column filters should be present in the query
453- assert "id" in combined_conditions .lower () or "prompt" in combined_conditions .lower ()
454- assert "original_value" in combined_conditions .lower ()
282+ # With the new batching approach, multiple separate queries should be executed
283+ # when the primary batch parameter exceeds _SQLITE_MAX_BIND_VARS
284+ # Expected: ceil(num_pieces / _SQLITE_MAX_BIND_VARS) = 2 queries
285+ expected_min_calls = (num_pieces + _SQLITE_MAX_BIND_VARS - 1 ) // _SQLITE_MAX_BIND_VARS
286+ assert call_count >= expected_min_calls , (
287+ f"Expected at least { expected_min_calls } separate queries for { num_pieces } items, "
288+ f"but only got { call_count } calls"
289+ )
455290
456291 def test_get_message_pieces_triple_large_params_preserves_intersection (self , sqlite_instance : MemoryInterface ):
457292 """Test that filtering with 3 large parameter lists returns correct intersection."""
458293 # Create a large set of pieces
459294 total_pieces = _SQLITE_MAX_BIND_VARS + 150
460295 pieces = [
461- _create_message_piece (
462- conversation_id = str (uuid .uuid4 ()), original_value = f"content_{ i } "
463- )
296+ _create_message_piece (conversation_id = str (uuid .uuid4 ()), original_value = f"content_{ i } " )
464297 for i in range (total_pieces )
465298 ]
466299 sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
@@ -487,10 +320,130 @@ def test_get_message_pieces_triple_large_params_preserves_intersection(self, sql
487320 )
488321
489322 # Should return only the intersection (subset_size items)
490- assert (
491- len (results ) == subset_size
492- ), f"Expected { subset_size } results from intersection, got { len (results )} "
323+ assert len (results ) == subset_size , f"Expected { subset_size } results from intersection, got { len (results )} "
493324
494325 # Verify all results have SHA256 in the filter list
495326 result_sha256 = {r .converted_value_sha256 for r in results }
496327 assert result_sha256 .issubset (set (filter_sha256 )), "Results contain unexpected SHA256 values"
328+
329+
330+ class TestExecuteBatchedQuery :
331+ """Tests for the _execute_batched_query helper method."""
332+
333+ def test_execute_batched_query_small_list_single_query (self , sqlite_instance : MemoryInterface ):
334+ """Test that small lists execute a single query."""
335+ # Create a small number of pieces (under batch limit)
336+ num_pieces = 10
337+ pieces = [_create_message_piece () for _ in range (num_pieces )]
338+ sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
339+
340+ # Track query calls
341+ original_query = sqlite_instance ._query_entries
342+ call_count = 0
343+
344+ def spy_query (* args , ** kwargs ):
345+ nonlocal call_count
346+ call_count += 1
347+ return original_query (* args , ** kwargs )
348+
349+ with patch .object (sqlite_instance , "_query_entries" , side_effect = spy_query ):
350+ all_ids = [piece .id for piece in pieces ]
351+ results = sqlite_instance .get_message_pieces (prompt_ids = all_ids )
352+
353+ # Should be a single query for small lists
354+ assert call_count == 1
355+ assert len (results ) == num_pieces
356+
357+ def test_execute_batched_query_large_list_multiple_queries (self , sqlite_instance : MemoryInterface ):
358+ """Test that large lists execute multiple separate queries."""
359+ # Create pieces exceeding batch limit
360+ num_pieces = _SQLITE_MAX_BIND_VARS * 3 # 3 batches needed
361+ pieces = [_create_message_piece () for _ in range (num_pieces )]
362+ sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
363+
364+ # Track query calls
365+ original_query = sqlite_instance ._query_entries
366+ call_count = 0
367+
368+ def spy_query (* args , ** kwargs ):
369+ nonlocal call_count
370+ call_count += 1
371+ return original_query (* args , ** kwargs )
372+
373+ with patch .object (sqlite_instance , "_query_entries" , side_effect = spy_query ):
374+ all_ids = [piece .id for piece in pieces ]
375+ results = sqlite_instance .get_message_pieces (prompt_ids = all_ids )
376+
377+ # Should execute 3 separate queries (one per batch)
378+ assert call_count == 3 , f"Expected 3 queries for 3 batches, got { call_count } "
379+ assert len (results ) == num_pieces
380+
381+ def test_execute_batched_query_deduplicates_results (self , sqlite_instance : MemoryInterface ):
382+ """Test that batched queries properly deduplicate results."""
383+ # Create pieces
384+ num_pieces = 50
385+ pieces = [_create_message_piece () for _ in range (num_pieces )]
386+ sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
387+
388+ # Query with the same IDs repeated (should still return unique results)
389+ all_ids = [piece .id for piece in pieces ]
390+ # Query twice with same IDs - results should still be unique
391+ results = sqlite_instance .get_message_pieces (prompt_ids = all_ids )
392+
393+ assert len (results ) == num_pieces
394+ # Verify no duplicates
395+ result_ids = [r .id for r in results ]
396+ assert len (result_ids ) == len (set (result_ids )), "Results contain duplicate entries"
397+
398+ def test_execute_batched_query_exact_batch_boundary (self , sqlite_instance : MemoryInterface ):
399+ """Test querying with exactly the batch limit (edge case)."""
400+ num_pieces = _SQLITE_MAX_BIND_VARS
401+ pieces = [_create_message_piece () for _ in range (num_pieces )]
402+ sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
403+
404+ # Track query calls
405+ original_query = sqlite_instance ._query_entries
406+ call_count = 0
407+
408+ def spy_query (* args , ** kwargs ):
409+ nonlocal call_count
410+ call_count += 1
411+ return original_query (* args , ** kwargs )
412+
413+ with patch .object (sqlite_instance , "_query_entries" , side_effect = spy_query ):
414+ all_ids = [piece .id for piece in pieces ]
415+ results = sqlite_instance .get_message_pieces (prompt_ids = all_ids )
416+
417+ # Exactly at the limit should still be a single query
418+ assert call_count == 1 , f"Expected 1 query at exact batch limit, got { call_count } "
419+ assert len (results ) == num_pieces
420+
421+ def test_batching_with_scores_exceeds_limit (self , sqlite_instance : MemoryInterface ):
422+ """Test that get_scores handles large numbers of score IDs correctly."""
423+ # Create message pieces and scores exceeding the limit
424+ num_items = _SQLITE_MAX_BIND_VARS * 2 + 50
425+ pieces = [_create_message_piece () for _ in range (num_items )]
426+ sqlite_instance .add_message_pieces_to_memory (message_pieces = pieces )
427+
428+ scores = [_create_score (str (piece .id )) for piece in pieces ]
429+ sqlite_instance .add_scores_to_memory (scores = scores )
430+
431+ # Query with all score IDs
432+ all_score_ids = [str (score .id ) for score in scores ]
433+
434+ # Track query calls
435+ original_query = sqlite_instance ._query_entries
436+ call_count = 0
437+
438+ def spy_query (* args , ** kwargs ):
439+ nonlocal call_count
440+ call_count += 1
441+ return original_query (* args , ** kwargs )
442+
443+ with patch .object (sqlite_instance , "_query_entries" , side_effect = spy_query ):
444+ results = sqlite_instance .get_scores (score_ids = all_score_ids )
445+
446+ # Should execute multiple queries
447+ expected_calls = (num_items + _SQLITE_MAX_BIND_VARS - 1 ) // _SQLITE_MAX_BIND_VARS
448+ assert call_count == expected_calls , f"Expected { expected_calls } queries, got { call_count } "
449+ assert len (results ) == num_items
0 commit comments