55import time
66import traceback
77from collections .abc import Iterable
8+ from multiprocessing .queues import Queue
89
910import numpy as np
1011
1112from ... import config
13+ from ...models import ConcurrencySlotTimeoutError
1214from ..clients import api
1315
1416NUM_PER_BATCH = config .NUM_PER_BATCH
@@ -28,16 +30,18 @@ def __init__(
2830 self ,
2931 db : api .VectorDB ,
3032 test_data : list [list [float ]],
31- k : int = 100 ,
33+ k : int = config . K_DEFAULT ,
3234 filters : dict | None = None ,
3335 concurrencies : Iterable [int ] = config .NUM_CONCURRENCY ,
34- duration : int = 30 ,
36+ duration : int = config .CONCURRENCY_DURATION ,
37+ concurrency_timeout : int = config .CONCURRENCY_TIMEOUT ,
3538 ):
3639 self .db = db
3740 self .k = k
3841 self .filters = filters
3942 self .concurrencies = concurrencies
4043 self .duration = duration
44+ self .concurrency_timeout = concurrency_timeout
4145
4246 self .test_data = test_data
4347 log .debug (f"test dataset columns: { len (test_data )} " )
@@ -114,9 +118,7 @@ def _run_all_concurrencies_mem_efficient(self):
114118 log .info (f"Start search { self .duration } s in concurrency { conc } , filters: { self .filters } " )
115119 future_iter = [executor .submit (self .search , self .test_data , q , cond ) for i in range (conc )]
116120 # Sync all processes
117- while q .qsize () < conc :
118- sleep_t = conc if conc < 10 else 10
119- time .sleep (sleep_t )
121+ self ._wait_for_queue_fill (q , size = conc )
120122
121123 with cond :
122124 cond .notify_all ()
@@ -160,6 +162,15 @@ def _run_all_concurrencies_mem_efficient(self):
160162 conc_latency_avg_list ,
161163 )
162164
165+ def _wait_for_queue_fill (self , q : Queue , size : int ):
166+ wait_t = 0
167+ while q .qsize () < size :
168+ sleep_t = size if size < 10 else 10
169+ wait_t += sleep_t
170+ if wait_t > self .concurrency_timeout > 0 :
171+ raise ConcurrencySlotTimeoutError
172+ time .sleep (sleep_t )
173+
163174 def run (self ) -> float :
164175 """
165176 Returns:
0 commit comments