Skip to content

Commit 176e3ea

Browse files
committed
address comments, update concurrency test
1 parent 5074226 commit 176e3ea

3 files changed

Lines changed: 123 additions & 112 deletions

File tree

src/inference_endpoint/load_generator/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls):
381381
f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}"
382382
)
383383

384-
# Use Condition for concurrency control with explicit counter
384+
# Use threading.Condition for concurrency control with explicit counter
385385
self._condition = threading.Condition()
386386
self._inflight = 0
387387
self._target_concurrency = target_concurrency

tests/conftest.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,17 @@ def trtllm_docker_server(hf_model_name, trtllm_llama31_8b_cmd):
486486

487487

488488
@pytest.fixture
489-
def max_throughput_runtime_settings():
489+
def random_seed():
490+
"""Fixture providing the random seed for deterministic testing.
491+
492+
This allows tests to easily vary the random seed for different test scenarios
493+
while maintaining determinism by default.
494+
"""
495+
return 42
496+
497+
498+
@pytest.fixture
499+
def max_throughput_runtime_settings(random_seed):
490500
return RuntimeSettings(
491501
metrics.Throughput(100),
492502
reported_metrics=[],
@@ -495,42 +505,54 @@ def max_throughput_runtime_settings():
495505
n_samples_from_dataset=100,
496506
n_samples_to_issue=100,
497507
min_sample_count=100,
498-
rng_sched=random.Random(42),
499-
rng_sample_index=random.Random(42),
508+
rng_sched=random.Random(random_seed),
509+
rng_sample_index=random.Random(random_seed),
500510
load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT),
501511
)
502512

503513

504514
@pytest.fixture
505-
def poisson_runtime_settings():
515+
def target_qps(request):
516+
"""Target QPS for poisson scheduler tests."""
517+
return request.param if hasattr(request, "param") else 100.0
518+
519+
520+
@pytest.fixture
521+
def target_concurrency(request):
522+
"""Target concurrency for concurrency scheduler tests."""
523+
return request.param if hasattr(request, "param") else 2
524+
525+
526+
@pytest.fixture
527+
def poisson_runtime_settings(random_seed, target_qps):
506528
return RuntimeSettings(
507-
metric_target=metrics.Throughput(1000),
529+
metric_target=metrics.Throughput(target_qps),
508530
reported_metrics=[],
509531
min_duration_ms=10_000,
510532
max_duration_ms=15_000,
511533
n_samples_from_dataset=100,
512534
n_samples_to_issue=5000,
513535
min_sample_count=100,
514-
rng_sched=random.Random(42),
515-
rng_sample_index=random.Random(42),
516-
load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=100.0),
536+
rng_sched=random.Random(random_seed),
537+
rng_sample_index=random.Random(random_seed),
538+
load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=target_qps),
517539
)
518540

519541

520542
@pytest.fixture
521-
def concurrency_runtime_settings():
543+
def concurrency_runtime_settings(random_seed, target_concurrency):
522544
return RuntimeSettings(
523545
metric_target=None,
524546
reported_metrics=[],
525547
min_duration_ms=1000,
526548
max_duration_ms=10_000,
527549
n_samples_from_dataset=100,
528-
n_samples_to_issue=100,
550+
n_samples_to_issue=target_concurrency * 10,
529551
min_sample_count=100,
530-
rng_sched=random.Random(42),
531-
rng_sample_index=random.Random(42),
552+
rng_sched=random.Random(random_seed),
553+
rng_sample_index=random.Random(random_seed),
532554
load_pattern=LoadPattern(
533-
type=LoadPatternType.CONCURRENCY, target_concurrency=2
555+
type=LoadPatternType.CONCURRENCY, target_concurrency=target_concurrency
534556
),
535557
)
536558

tests/unit/load_generator/test_scheduler.py

Lines changed: 87 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import math
1717
import random
1818
import threading
19-
import time
2019

20+
import pytest
2121
from inference_endpoint.load_generator.sample import SampleEventHandler
2222
from inference_endpoint.load_generator.scheduler import (
2323
ConcurrencyScheduler,
@@ -43,8 +43,8 @@ def test_without_replacement_sample_order():
4343
), "Order should be different in each pass of the dataset"
4444

4545

46-
def test_with_replacement_sample_order():
47-
ordering = WithReplacementSampleOrder(12345, 100, rng=random.Random(42))
46+
def test_with_replacement_sample_order(random_seed):
47+
ordering = WithReplacementSampleOrder(12345, 100, rng=random.Random(random_seed))
4848
indices = list(iter(ordering))
4949

5050
# With Python random.Random(42), the order can be deterministic
@@ -86,126 +86,115 @@ def test_max_throughput_scheduler(max_throughput_runtime_settings):
8686
], "Order does not match expected deterministic order"
8787

8888

89-
def test_concurrency_scheduler(concurrency_runtime_settings, clean_sample_event_hooks):
90-
"""Test ConcurrencyScheduler properly gates issuance by completions.
91-
Tests that concurrency is properly limited and queries are gated by completion events.
92-
"""
93-
target_concurrency = concurrency_runtime_settings.load_pattern.target_concurrency
89+
@pytest.mark.parametrize("target_concurrency", [1, 2, 100, 1000], indirect=True)
90+
def test_concurrency_scheduler(concurrency_runtime_settings, target_concurrency):
91+
"""Test ConcurrencyScheduler properly gates issuance by completions."""
9492
total_samples = concurrency_runtime_settings.n_samples_to_issue
9593

9694
scheduler = ConcurrencyScheduler(
9795
concurrency_runtime_settings, WithReplacementSampleOrder
9896
)
9997

100-
# Track events with synchronization primitives instead of timing
101-
issue_lock = threading.Lock()
102-
complete_lock = threading.Lock()
103-
issue_events = [] # List of issued query indices in order
104-
complete_events = [] # List of completed query indices in order
105-
106-
# Use events for deterministic synchronization
107-
issue_gates = [threading.Event() for _ in range(total_samples)]
108-
completion_gates = [threading.Event() for _ in range(total_samples)]
109-
110-
# Track concurrency level
111-
concurrency_lock = threading.Lock()
98+
# State tracking
99+
state_lock = threading.RLock()
100+
issued_count = 0
101+
completed_count = 0
112102
current_inflight = 0
113103
max_inflight = 0
114104

115-
def simulate_completions():
116-
"""Simulate query completions with event-based synchronization."""
117-
nonlocal current_inflight, max_inflight
118-
119-
for i in range(total_samples):
120-
# Wait for this query to be issued before completing it
121-
issue_gates[i].wait(timeout=5.0)
105+
# Synchronization: signal when queries can complete and when they're done
106+
can_complete = [threading.Event() for _ in range(total_samples)]
107+
completed = [threading.Event() for _ in range(total_samples)]
108+
# Signal when each query is issued
109+
issued = [threading.Event() for _ in range(total_samples)]
122110

123-
# Simulate small variable processing time
124-
time.sleep(0.001 * (1 + i % 3)) # 1-3ms pattern
111+
def completion_worker():
112+
"""Waits for signals to complete queries."""
113+
nonlocal completed_count, current_inflight
125114

126-
with complete_lock:
127-
complete_events.append(i)
115+
for position in range(total_samples):
116+
can_complete[position].wait()
128117

129-
# Decrease inflight count
130-
with concurrency_lock:
118+
with state_lock:
119+
completed_count += 1
131120
current_inflight -= 1
121+
assert current_inflight >= 0, "Inflight count went negative"
132122

133-
# Signal completion to scheduler
134123
scheduler._release_slot()
135-
completion_gates[i].set()
124+
completed[position].set()
136125

137-
completion_thread = threading.Thread(target=simulate_completions, daemon=True)
138-
completion_thread.start()
126+
threading.Thread(target=completion_worker, daemon=True).start()
139127

140-
try:
141-
# Issue queries through scheduler
142-
for query_idx, _ in enumerate(scheduler):
143-
with issue_lock:
144-
issue_events.append(query_idx)
128+
def issue_worker():
129+
"""Issues queries through scheduler."""
130+
nonlocal issued_count, current_inflight, max_inflight
145131

146-
# Track peak concurrency
147-
with concurrency_lock:
132+
for position, _ in enumerate(scheduler):
133+
with state_lock:
134+
issued_count += 1
148135
current_inflight += 1
149136
max_inflight = max(max_inflight, current_inflight)
137+
assert (
138+
current_inflight <= target_concurrency
139+
), f"Concurrency {current_inflight} exceeded limit {target_concurrency}"
140+
issued[position].set()
141+
142+
issue_thread = threading.Thread(target=issue_worker, daemon=True)
143+
issue_thread.start()
144+
145+
try:
146+
# Phase 1: First target_concurrency queries issue immediately
147+
for position in range(target_concurrency):
148+
issued[position].wait()
149+
150+
with state_lock:
151+
assert issued_count == target_concurrency
152+
assert completed_count == 0
153+
assert current_inflight == target_concurrency
154+
155+
# Phase 2: Verify scheduler blocks when at capacity, unblocks on completion
156+
for position in range(target_concurrency, total_samples):
157+
position_to_complete = position - target_concurrency
158+
159+
# Verify next query hasn't issued yet (scheduler is blocking)
160+
assert not issued[
161+
position
162+
].is_set(), f"Query {position} issued before slot was freed"
163+
164+
# Free a slot
165+
can_complete[position_to_complete].set()
166+
completed[position_to_complete].wait()
167+
168+
# Verify next query now issues
169+
issued[position].wait()
170+
171+
with state_lock:
172+
assert current_inflight == target_concurrency
173+
174+
# Phase 3: Complete remaining queries and cleanup
175+
for position in range(target_concurrency, total_samples):
176+
can_complete[position].set()
177+
completed[position].wait()
178+
179+
issue_thread.join()
150180

151-
# Signal that this query has been issued
152-
issue_gates[query_idx].set()
153-
154-
# Wait for all completions to finish
155-
for i in range(total_samples):
156-
assert completion_gates[i].wait(
157-
timeout=5.0
158-
), f"Query {i} completion timed out"
159-
160-
# === Deterministic Verification ===
161-
162-
# Validation: All queries were issued in sequential order
163-
assert (
164-
len(issue_events) == total_samples
165-
), f"Expected {total_samples} issues, got {len(issue_events)}"
166-
for i, query_idx in enumerate(issue_events):
167-
assert (
168-
query_idx == i
169-
), f"Issue order violated: position {i} has query {query_idx}"
170-
171-
# Validation: All queries completed
172-
assert (
173-
len(complete_events) == total_samples
174-
), f"Expected {total_samples} completions, got {len(complete_events)}"
175-
176-
# Validation: Peak concurrency actually reached target
177-
assert (
178-
max_inflight == target_concurrency
179-
), f"Max concurrent ({max_inflight}) never reached target ({target_concurrency})"
180-
181-
# Validation: gating behavior
182-
# For query i where i >= target_concurrency, query (i - target_concurrency)
183-
# must have completed before query i could issue.
184-
#
185-
# We can verify this by checking that when we issued query i,
186-
# query (i - target_concurrency) had already been issued AND the scheduler
187-
# had received its completion event.
188-
#
189-
# Since the scheduler blocks until a slot is free, and slots are freed by
190-
# completions, if query i issued, then at least (i - target_concurrency + 1)
191-
# completions must have occurred (to free up a slot).
192-
for i in range(target_concurrency, total_samples):
193-
expected_completed_query = i - target_concurrency
194-
assert (
195-
expected_completed_query in complete_events
196-
), f"Query {i} issued but query {expected_completed_query} not in completions yet"
181+
# Final validation
182+
with state_lock:
183+
assert issued_count == total_samples
184+
assert completed_count == total_samples
185+
assert current_inflight == 0
186+
assert max_inflight == target_concurrency
197187

198188
finally:
199-
# Ensure proper cleanup
200-
completion_thread.join(timeout=5.0)
201189
SampleEventHandler.clear_hooks()
202190

203191

204-
def test_poisson_scheduler_distribution(poisson_runtime_settings):
192+
@pytest.mark.parametrize("target_qps", [50.0, 100.0, 500.0, 1000.0], indirect=True)
193+
def test_poisson_scheduler_distribution(poisson_runtime_settings, target_qps):
205194
"""Test PoissonDistributionScheduler produces exponentially distributed inter-arrival times.
206195
207-
For a Poisson process with rate λ (1000 QPS), inter-arrival times must follow
208-
exponential distribution with mean = 1/λ = 1ms.
196+
For a Poisson process with rate λ (target QPS), inter-arrival times must follow
197+
exponential distribution with mean = 1/λ.
209198
210199
Three-tier validation:
211200
1. Mean with 99.9% confidence interval
@@ -217,7 +206,7 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
217206
)
218207

219208
# Test configuration
220-
TARGET_QPS = poisson_runtime_settings.metric_target.target
209+
TARGET_QPS = target_qps
221210
expected_mean_s = 1.0 / TARGET_QPS
222211

223212
# Collect delays from scheduler (in seconds) for statistical analysis
@@ -235,7 +224,7 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
235224
cv = sample_std / sample_mean
236225

237226
# Test 1: Mean with statistical confidence interval (99.9% CI)
238-
# For exponential: std(X̄) = σ/√n = μ/√n
227+
# For exponential: std(X̄) = sigma/√n = mu/√n
239228
z_critical = 3.29 # 99.9% two-tailed
240229
margin_of_error = z_critical * (sample_std / math.sqrt(n))
241230
assert abs(sample_mean - expected_mean_s) < margin_of_error, (
@@ -264,5 +253,5 @@ def test_poisson_scheduler_distribution(poisson_runtime_settings):
264253
ALPHA = 0.0001
265254
assert p_value > ALPHA, (
266255
f"KS test rejected exponential distribution: "
267-
f"p-value={p_value:.4f} < α={ALPHA} (D={ks_statistic:.4f})"
256+
f"p-value={p_value:.4f} < alpha={ALPHA} (D={ks_statistic:.4f})"
268257
)

0 commit comments

Comments
 (0)