1616import math
1717import random
1818import threading
19- import time
2019
20+ import pytest
2121from inference_endpoint .load_generator .sample import SampleEventHandler
2222from inference_endpoint .load_generator .scheduler import (
2323 ConcurrencyScheduler ,
@@ -43,8 +43,8 @@ def test_without_replacement_sample_order():
4343 ), "Order should be different in each pass of the dataset"
4444
4545
46- def test_with_replacement_sample_order ():
47- ordering = WithReplacementSampleOrder (12345 , 100 , rng = random .Random (42 ))
46+ def test_with_replacement_sample_order (random_seed ):
47+ ordering = WithReplacementSampleOrder (12345 , 100 , rng = random .Random (random_seed ))
4848 indices = list (iter (ordering ))
4949
5050 # With Python random.Random(42), the order can be deterministic
@@ -86,126 +86,115 @@ def test_max_throughput_scheduler(max_throughput_runtime_settings):
8686 ], "Order does not match expected deterministic order"
8787
8888
89- def test_concurrency_scheduler (concurrency_runtime_settings , clean_sample_event_hooks ):
90- """Test ConcurrencyScheduler properly gates issuance by completions.
91- Tests that concurrency is properly limited and queries are gated by completion events.
92- """
93- target_concurrency = concurrency_runtime_settings .load_pattern .target_concurrency
89+ @pytest .mark .parametrize ("target_concurrency" , [1 , 2 , 100 , 1000 ], indirect = True )
90+ def test_concurrency_scheduler (concurrency_runtime_settings , target_concurrency ):
91+ """Test ConcurrencyScheduler properly gates issuance by completions."""
9492 total_samples = concurrency_runtime_settings .n_samples_to_issue
9593
9694 scheduler = ConcurrencyScheduler (
9795 concurrency_runtime_settings , WithReplacementSampleOrder
9896 )
9997
100- # Track events with synchronization primitives instead of timing
101- issue_lock = threading .Lock ()
102- complete_lock = threading .Lock ()
103- issue_events = [] # List of issued query indices in order
104- complete_events = [] # List of completed query indices in order
105-
106- # Use events for deterministic synchronization
107- issue_gates = [threading .Event () for _ in range (total_samples )]
108- completion_gates = [threading .Event () for _ in range (total_samples )]
109-
110- # Track concurrency level
111- concurrency_lock = threading .Lock ()
98+ # State tracking
99+ state_lock = threading .RLock ()
100+ issued_count = 0
101+ completed_count = 0
112102 current_inflight = 0
113103 max_inflight = 0
114104
115- def simulate_completions ():
116- """Simulate query completions with event-based synchronization."""
117- nonlocal current_inflight , max_inflight
118-
119- for i in range (total_samples ):
120- # Wait for this query to be issued before completing it
121- issue_gates [i ].wait (timeout = 5.0 )
105+ # Synchronization: signal when queries can complete and when they're done
106+ can_complete = [threading .Event () for _ in range (total_samples )]
107+ completed = [threading .Event () for _ in range (total_samples )]
108+ # Signal when each query is issued
109+ issued = [threading .Event () for _ in range (total_samples )]
122110
123- # Simulate small variable processing time
124- time .sleep (0.001 * (1 + i % 3 )) # 1-3ms pattern
111+ def completion_worker ():
112+ """Waits for signals to complete queries."""
113+ nonlocal completed_count , current_inflight
125114
126- with complete_lock :
127- complete_events . append ( i )
115+ for position in range ( total_samples ) :
116+ can_complete [ position ]. wait ( )
128117
129- # Decrease inflight count
130- with concurrency_lock :
118+ with state_lock :
119+ completed_count += 1
131120 current_inflight -= 1
121+ assert current_inflight >= 0 , "Inflight count went negative"
132122
133- # Signal completion to scheduler
134123 scheduler ._release_slot ()
135- completion_gates [ i ].set ()
124+ completed [ position ].set ()
136125
137- completion_thread = threading .Thread (target = simulate_completions , daemon = True )
138- completion_thread .start ()
126+ threading .Thread (target = completion_worker , daemon = True ).start ()
139127
140- try :
141- # Issue queries through scheduler
142- for query_idx , _ in enumerate (scheduler ):
143- with issue_lock :
144- issue_events .append (query_idx )
128+ def issue_worker ():
129+ """Issues queries through scheduler."""
130+ nonlocal issued_count , current_inflight , max_inflight
145131
146- # Track peak concurrency
147- with concurrency_lock :
132+ for position , _ in enumerate (scheduler ):
133+ with state_lock :
134+ issued_count += 1
148135 current_inflight += 1
149136 max_inflight = max (max_inflight , current_inflight )
137+ assert (
138+ current_inflight <= target_concurrency
139+ ), f"Concurrency { current_inflight } exceeded limit { target_concurrency } "
140+ issued [position ].set ()
141+
142+ issue_thread = threading .Thread (target = issue_worker , daemon = True )
143+ issue_thread .start ()
144+
145+ try :
146+ # Phase 1: First target_concurrency queries issue immediately
147+ for position in range (target_concurrency ):
148+ issued [position ].wait ()
149+
150+ with state_lock :
151+ assert issued_count == target_concurrency
152+ assert completed_count == 0
153+ assert current_inflight == target_concurrency
154+
155+ # Phase 2: Verify scheduler blocks when at capacity, unblocks on completion
156+ for position in range (target_concurrency , total_samples ):
157+ position_to_complete = position - target_concurrency
158+
159+ # Verify next query hasn't issued yet (scheduler is blocking)
160+ assert not issued [
161+ position
162+ ].is_set (), f"Query { position } issued before slot was freed"
163+
164+ # Free a slot
165+ can_complete [position_to_complete ].set ()
166+ completed [position_to_complete ].wait ()
167+
168+ # Verify next query now issues
169+ issued [position ].wait ()
170+
171+ with state_lock :
172+ assert current_inflight == target_concurrency
173+
174+ # Phase 3: Complete remaining queries and cleanup
175+ for position in range (target_concurrency , total_samples ):
176+ can_complete [position ].set ()
177+ completed [position ].wait ()
178+
179+ issue_thread .join ()
150180
151- # Signal that this query has been issued
152- issue_gates [query_idx ].set ()
153-
154- # Wait for all completions to finish
155- for i in range (total_samples ):
156- assert completion_gates [i ].wait (
157- timeout = 5.0
158- ), f"Query { i } completion timed out"
159-
160- # === Deterministic Verification ===
161-
162- # Validation: All queries were issued in sequential order
163- assert (
164- len (issue_events ) == total_samples
165- ), f"Expected { total_samples } issues, got { len (issue_events )} "
166- for i , query_idx in enumerate (issue_events ):
167- assert (
168- query_idx == i
169- ), f"Issue order violated: position { i } has query { query_idx } "
170-
171- # Validation: All queries completed
172- assert (
173- len (complete_events ) == total_samples
174- ), f"Expected { total_samples } completions, got { len (complete_events )} "
175-
176- # Validation: Peak concurrency actually reached target
177- assert (
178- max_inflight == target_concurrency
179- ), f"Max concurrent ({ max_inflight } ) never reached target ({ target_concurrency } )"
180-
181- # Validation: gating behavior
182- # For query i where i >= target_concurrency, query (i - target_concurrency)
183- # must have completed before query i could issue.
184- #
185- # We can verify this by checking that when we issued query i,
186- # query (i - target_concurrency) had already been issued AND the scheduler
187- # had received its completion event.
188- #
189- # Since the scheduler blocks until a slot is free, and slots are freed by
190- # completions, if query i issued, then at least (i - target_concurrency + 1)
191- # completions must have occurred (to free up a slot).
192- for i in range (target_concurrency , total_samples ):
193- expected_completed_query = i - target_concurrency
194- assert (
195- expected_completed_query in complete_events
196- ), f"Query { i } issued but query { expected_completed_query } not in completions yet"
181+ # Final validation
182+ with state_lock :
183+ assert issued_count == total_samples
184+ assert completed_count == total_samples
185+ assert current_inflight == 0
186+ assert max_inflight == target_concurrency
197187
198188 finally :
199- # Ensure proper cleanup
200- completion_thread .join (timeout = 5.0 )
201189 SampleEventHandler .clear_hooks ()
202190
203191
204- def test_poisson_scheduler_distribution (poisson_runtime_settings ):
192+ @pytest .mark .parametrize ("target_qps" , [50.0 , 100.0 , 500.0 , 1000.0 ], indirect = True )
193+ def test_poisson_scheduler_distribution (poisson_runtime_settings , target_qps ):
205194 """Test PoissonDistributionScheduler produces exponentially distributed inter-arrival times.
206195
207- For a Poisson process with rate λ (1000 QPS), inter-arrival times must follow
208- exponential distribution with mean = 1/λ = 1ms .
196+ For a Poisson process with rate λ (target QPS), inter-arrival times must follow
197+ exponential distribution with mean = 1/λ.
209198
210199 Three-tier validation:
211200 1. Mean with 99.9% confidence interval
@@ -217,7 +206,7 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
217206 )
218207
219208 # Test configuration
220- TARGET_QPS = poisson_runtime_settings . metric_target . target
209+ TARGET_QPS = target_qps
221210 expected_mean_s = 1.0 / TARGET_QPS
222211
223212 # Collect delays from scheduler (in seconds) for statistical analysis
@@ -235,7 +224,7 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
235224 cv = sample_std / sample_mean
236225
237226 # Test 1: Mean with statistical confidence interval (99.9% CI)
238- # For exponential: std(X̄) = σ /√n = μ /√n
227+ # For exponential: std(X̄) = sigma /√n = mu /√n
239228 z_critical = 3.29 # 99.9% two-tailed
240229 margin_of_error = z_critical * (sample_std / math .sqrt (n ))
241230 assert abs (sample_mean - expected_mean_s ) < margin_of_error , (
@@ -264,5 +253,5 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
264253 ALPHA = 0.0001
265254 assert p_value > ALPHA , (
266255 f"KS test rejected exponential distribution: "
267- f"p-value={ p_value :.4f} < α ={ ALPHA } (D={ ks_statistic :.4f} )"
256+ f"p-value={ p_value :.4f} < alpha ={ ALPHA } (D={ ks_statistic :.4f} )"
268257 )
0 commit comments