diff --git a/.gitignore b/.gitignore index 6681801b..30372720 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,7 @@ examples/03_BenchmarkComparison/vllm_venv/ .cursor_artifacts/ .cursor/ docs/superpowers/ +.claude/agent-memory/ + +# User-specific local dev configs; do not commit +CLAUDE.local.md diff --git a/AGENTS.md b/AGENTS.md index 6fec5395..25f3b849 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -351,6 +351,8 @@ Known failure modes when AI tools generate code for this project. Reference thes - **Importing removed or renamed modules**: After refactors, AI (working from stale context) may import old module paths. Always verify imports resolve to actual files. - **Over-documenting**: AI generates verbose docstrings, inline comments explaining obvious code, and type annotations on trivial variables. This project prefers minimal comments — only where the _why_ isn't obvious from the code. - **Adding backwards-compatibility shims**: If something was renamed or removed, AI may add re-exports, aliases, or deprecation wrappers. In this project, just delete the old thing and update all call sites. +- **Empty except blocks**: Every `except` block must contain either a comment explaining why the exception is ignored, or a logging statement. Bare `except: pass` without explanation is disallowed. AI often generates empty handlers — always add the reason. +- **No lazy imports**: All imports must be at the top of the file. Imports inside functions, methods, or conditional blocks (other than `TYPE_CHECKING`) are disallowed. The only exceptions are: (1) circular import avoidance with a documenting comment, (2) optional dependencies with a top-level try/except that sets the import to `None`, (3) security sandboxing code that intentionally restricts imports. ### Dependency & Environment diff --git a/docs/ENDPOINT_CLIENT.md b/docs/ENDPOINT_CLIENT.md index bbb4272d..c417dccf 100644 --- a/docs/ENDPOINT_CLIENT.md +++ b/docs/ENDPOINT_CLIENT.md @@ -261,8 +261,6 @@ class HTTPClientConfig(BaseModel): # -1 = auto: min(max(8, numa_domain_size), 24) num_workers: int = -1 - record_worker_events: bool = False - event_logs_dir: Path | None = None log_level: str = "INFO" # When True, all SSE chunks emitted via IPC (high main-thread overhead). diff --git a/docs/async_utils/services/metrics_aggregator/DESIGN.md b/docs/async_utils/services/metrics_aggregator/DESIGN.md deleted file mode 100644 index 7ff6d0d3..00000000 --- a/docs/async_utils/services/metrics_aggregator/DESIGN.md +++ /dev/null @@ -1,389 +0,0 @@ -# Metrics Aggregator Service — Design Document - -## Overview - -The metrics aggregator is a ZMQ subscriber service that consumes `EventRecord` messages -from the pub/sub event bus, computes per-sample metrics in real time, and pushes them -to a `MetricEmitter` backend (currently JSONL; future: Prometheus PushGateway). - -It runs as an independent subprocess with its own event loop, connected to the same -ZMQ PUB socket as the EventLoggerService. - -``` - ZMQ PUB (ipc://) - │ - ┌──────────────┼──────────────┐ - ▼ ▼ ▼ - EventLogger MetricsAggregator (future subscribers) - (JSONL/SQL) (real-time metrics) -``` - -## Module Layout - -``` -metrics_aggregator/ -├── __init__.py -├── __main__.py # CLI entry point -├── aggregator.py # MetricsAggregatorService (ZmqEventRecordSubscriber) -├── emitter.py # MetricEmitter ABC, JsonlMetricEmitter -├── metrics_table.py # SampleRow, MetricsTable -└── token_metrics.py # TokenizePool (thread-pool tokenizer) -``` - -## Subscribed Events - -### Session Events - -| Event | Effect | -| --------------------------------------------- | ---------------------------------------------------- | -| `SessionEventType.STARTED` | Records session start timestamp | -| `SessionEventType.START_PERFORMANCE_TRACKING` | Sets `is_tracking = True` | -| `SessionEventType.STOP_PERFORMANCE_TRACKING` | Sets `is_tracking = False` | -| `SessionEventType.ENDED` | Flushes emitter, closes subscriber, signals shutdown | - -### Sample Events - -| Event | Stored Field | Metric Emitted | Formula | -| ------------------ | --------------------------------------------------- | ------------------------------------- | -------------------------------------------------------- | -| `ISSUED` | `issued_ns` | `isl` | `len(token_ids)` or `token_count(text)` via `PromptData` | -| `RECV_FIRST` | `recv_first_ns`, `last_recv_ns`, `first_chunk_text` | `ttft_ns` | `recv_first_ns - issued_ns` | -| `RECV_NON_FIRST` | `last_recv_ns` (updated) | `chunk_delta_ns` | `timestamp - last_recv_ns` | -| `CLIENT_SEND` | `client_send_ns` | — | — | -| `CLIENT_RESP_DONE` | `client_resp_done_ns` | `request_duration_ns` | `client_resp_done_ns - client_send_ns` | -| `COMPLETE` | `complete_ns` | `sample_latency_ns`, `osl`, `tpot_ns` | see below | - -Ignored sample events: `TRANSPORT_SENT`, `TRANSPORT_RECV` (infrastructure-level, not -relevant for user-facing metrics). - -## Performance Tracking Window - -The `is_tracking` flag gates which samples are tracked: - -``` - STARTED ENDED - │ │ - ▼ ▼ -────┬─────────┬─────────────────────────────┬──────────────────┬── - │ │ ◄── samples issued here │ │ - │ START_PERF_TRACKING STOP_PERF_TRACKING │ - │ │ are tracked │ │ - │ │ │ │ -``` - -- A sample is tracked **if and only if** its `ISSUED` event arrives while `is_tracking` is `True`. -- Once tracked, a sample continues to receive events and emit metrics regardless of - later `STOP_PERFORMANCE_TRACKING` events. Only new ISSUEs are blocked. -- This allows warmup queries (before START) and cooldown queries (after STOP) to be - excluded from reported metrics while still draining in-flight samples cleanly. - -## Data Model: SampleRow - -Each tracked sample gets a `SampleRow` — a `msgspec.Struct` with `gc=False` that -stores raw `int | None` nanosecond timestamps and accumulated text: - -``` -SampleRow -├── sample_uuid: str -├── issued_ns: int | None ← set on ISSUED -├── complete_ns: int | None ← set on COMPLETE -├── recv_first_ns: int | None ← set on RECV_FIRST -├── last_recv_ns: int | None ← set on RECV_FIRST, updated on each RECV_NON_FIRST -├── client_send_ns: int | None ← set on CLIENT_SEND -├── client_resp_done_ns: int | None ← set on CLIENT_RESP_DONE -├── prompt_text: str | None ← from ISSUED event data (for ISL tokenization) -├── first_chunk_text: str | None ← from RECV_FIRST event data (for TPOT denominator) -├── first_chunk_tokens: int | None ← token count of first_chunk_text, resolved after async tokenization -└── output_chunks: list[str] ← accumulated from RECV_FIRST/RECV_NON_FIRST data -``` - -Metric formulas are simple methods on the row: - -```python -def ttft_ns(self) -> int | None: # recv_first_ns - issued_ns -def sample_latency_ns(self) -> int | None: # complete_ns - issued_ns -def request_duration_ns(self) -> int | None: # client_resp_done_ns - client_send_ns -def output_text(self) -> str: # "".join(output_chunks) -``` - -Rows are created on ISSUED and removed on COMPLETE. - -### Design Rationale: Why Not a Declarative Field System - -An earlier iteration used `_MetricField` structs with `delta_start_field_prio` lists -to declaratively describe which field pairs produce which metrics. This was abandoned -because: - -1. The formulas are few and fixed — a declarative DSL adds indirection without flexibility. -2. String-based field lookups at runtime obscure the actual data flow. -3. The metric emission logic was coupled into the data storage layer (`set_field` both - stored a timestamp and emitted a metric), making it hard to test or reason about. -4. Special cases (`mutable` flag for `recv_non_first`, `msgspec.UNSET` sentinels) - added complexity for what is fundamentally `int | None`. - -The current design keeps data storage (SampleRow) separate from metric emission -(aggregator event handlers). Each handler is 5-15 lines, reads top-to-bottom, and -is independently testable. - -## Metrics Computed - -### Timing Metrics (emitted immediately on triggering event) - -| Metric | Emitted On | Formula | Notes | -| --------------------- | ---------------- | -------------------------------------- | ---------------------------------------------------------------------------------------------------------- | -| `ttft_ns` | RECV_FIRST | `recv_first_ns - issued_ns` | Time to first token. Streaming only. | -| `sample_latency_ns` | COMPLETE | `complete_ns - issued_ns` | End-to-end latency from issue to completion. | -| `request_duration_ns` | CLIENT_RESP_DONE | `client_resp_done_ns - client_send_ns` | HTTP-level request time (inside worker process). | -| `chunk_delta_ns` | RECV_NON_FIRST | `timestamp - last_recv_ns` | Inter-token arrival time. `last_recv_ns` starts at `recv_first_ns` and advances with each non-first chunk. | - -### Token Metrics (require tokenization, may be async) - -| Metric | Emitted On | Formula | Notes | -| --------- | -------------------- | ------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `isl` | ISSUED | `len(token_ids)` or `token_count(text)` | Input sequence length. ISSUED event carries `PromptData` with either `token_ids` (SGLang, emitted synchronously) or `text` (OpenAI, tokenized async). | -| `osl` | COMPLETE (awaited) | `token_count(output_text)` | Output sequence length. Output text is from accumulated chunks (streaming) or COMPLETE data (non-streaming). | -| `tpot_ns` | COMPLETE (after OSL) | `(complete_ns - recv_first_ns) / (osl - first_chunk_tokens)` | Time per output token after the first chunk. The first chunk may contain multiple tokens, so `first_chunk_text` is tokenized separately for the denominator. Only emitted for streaming responses where `osl - first_chunk_tokens > 0`. | - -## Event Dispatch Flow - -``` -process(records: list[EventRecord]) -│ -├── for each record: -│ ├── Session events → update is_tracking / session state -│ │ -│ └── Sample events (if sample_uuid non-empty): -│ ├── ISSUED -│ │ ├── if not is_tracking: skip -│ │ ├── create SampleRow in MetricsTable -│ │ ├── store issued_ns -│ │ ├── store prompt_text from record.data (if str) -│ │ └── schedule ISL tokenization (async, fire-and-forget) -│ │ -│ ├── RECV_FIRST -│ │ ├── lookup row (skip if not tracked) -│ │ ├── store recv_first_ns, last_recv_ns -│ │ ├── emit ttft_ns -│ │ └── append record.data to output_chunks -│ │ -│ ├── RECV_NON_FIRST -│ │ ├── lookup row (skip if not tracked) -│ │ ├── emit chunk_delta_ns (from last_recv_ns) -│ │ ├── update last_recv_ns -│ │ └── append record.data to output_chunks -│ │ -│ ├── CLIENT_SEND -│ │ └── store client_send_ns -│ │ -│ ├── CLIENT_RESP_DONE -│ │ ├── store client_resp_done_ns -│ │ └── emit request_duration_ns -│ │ -│ └── COMPLETE -│ ├── store complete_ns -│ ├── emit sample_latency_ns -│ ├── await OSL tokenization → emit osl -│ ├── if streaming and osl > first_chunk_tokens → emit tpot_ns -│ └── remove row from MetricsTable -│ -└── if ENDED seen: flush emitter, close subscriber, signal shutdown -``` - -## MetricEmitter - -The `MetricEmitter` ABC defines: - -```python -class MetricEmitter(ABC): - def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None: ... - def flush(self) -> None: ... - def close(self) -> None: ... -``` - -### JsonlMetricEmitter (current implementation) - -Writes one JSON line per metric: - -```json -{"sample_uuid":"a1b2c3...","metric_name":"ttft_ns","value":1500,"timestamp_ns":98765432100} -{"sample_uuid":"a1b2c3...","metric_name":"sample_latency_ns","value":4000,"timestamp_ns":98765436100} -``` - -Uses `msgspec.json.Encoder` for serialization. Supports a configurable `flush_interval` -(flush to disk every N records). `timestamp_ns` is the wall-clock time when the metric -was emitted (not the event timestamp). - -### Future: PrometheusEmitter - -Would push to Prometheus PushGateway. The `emit()` signature supports this — -`metric_name` maps to a Prometheus metric, `sample_uuid` becomes a label, -`value` is the observation. Histograms/summaries can be built by accumulating -values per metric name. - -## TokenizePool - -Thread-pool wrapper around HuggingFace `AutoTokenizer` for ISL/OSL/TPOT computation. - -### Architecture - -``` - TokenizePool - ┌─────────────┐ - │ ThreadPool │ - token_count("text")──► Executor │ - (blocking) │ ┌───────┐ │ - │ │Thread1│──► thread-local AutoTokenizer - │ │Thread2│──► thread-local AutoTokenizer - │ │ ... │ │ - │ └───────┘ │ - └─────────────┘ -``` - -### Thread-Safety Analysis - -- **`ThreadPoolExecutor.submit()`** is internally synchronized — safe to call from - any thread. -- **Thread-local tokenizer instances** (`threading.local()`) mean zero shared mutable - state during tokenization. Each worker thread lazily loads its own - `AutoTokenizer.from_pretrained()` on first use. -- **HuggingFace tokenizers** (Rust backend via `tokenizers` crate) release the GIL - during the core tokenization work, so multiple threads actually run in parallel. -- **Blocking vs async**: `tokenize()` and `token_count()` block the calling thread - on `future.result()`. In async context, use `token_count_async()` which wraps the - call in `loop.run_in_executor(None, ...)` to avoid blocking the event loop. - -### Why `run_in_executor` for async? - -The `token_count_async` method uses a double-hop: `event loop executor → TokenizePool executor`. -This seems redundant but is necessary because: - -1. The aggregator's `process()` runs as an async task on the event loop. -2. Calling `pool.token_count()` directly would block the loop (the `future.result()` - inside `token_count()` is a synchronous wait). -3. `run_in_executor` offloads the blocking call to a thread, freeing the loop to - continue processing events. - -The inner `ThreadPoolExecutor` in `TokenizePool` still provides the thread-local -tokenizer isolation. The outer executor just prevents the blocking wait from starving -the event loop. - -## ISL Tracking: How the Prompt Gets to the Aggregator - -### Current Design - -The `ISSUED` event's `data` field carries a `PromptData` struct with either: - -- `text: str` — raw prompt string (OpenAI path), tokenized async by the aggregator. -- `token_ids: tuple[int, ...]` — pre-tokenized token IDs (SGLang/Harmonize path), - ISL is `len(token_ids)` with no tokenization needed. - -`EventRecord.data` is typed as `TextModelOutput | PromptData | ErrorData | None`. - -### Where to Publish - -The ISSUED event is published in the load generator when `issue_sample()` is called. -At that point, `sample.data` contains the post-transform dataset row. The publisher -extracts the prompt: - -```python -# In the load generator, when issuing a sample: -if "input_tokens" in sample.data: - prompt_data = PromptData(token_ids=tuple(sample.data["input_tokens"])) -elif "prompt" in sample.data: - prompt_data = PromptData(text=sample.data["prompt"]) -else: - prompt_data = None - -publisher.publish(EventRecord( - event_type=SampleEventType.ISSUED, - sample_uuid=sample.uuid, - data=prompt_data, -)) -``` - -### Adapter Considerations - -The prompt data available at ISSUED time is **post-transform** — dataset transforms -have already been applied by this point. This matters because: - -| Adapter | Transform Pipeline | `sample.data` at ISSUED | `PromptData` | -| ----------------------- | --------------------------------------------- | ----------------------------------- | ------------------------------------------- | -| OpenAI / OpenAI-Msgspec | `ColumnFilter → AddStaticColumns` | `{"prompt": "...", "model": "..."}` | `PromptData(text=prompt)` | -| SGLang | `Harmonize → ColumnFilter → AddStaticColumns` | `{"input_tokens": [int, ...]}` | `PromptData(token_ids=tuple(input_tokens))` | - -## Lifecycle - -### Startup - -```python -python -m inference_endpoint.async_utils.services.metrics_aggregator \ - --metrics-dir /tmp/metrics \ - --socket-dir /path/to/socket_dir \ - --socket-name ev_pub_ \ - --tokenizer gpt2 \ - --tokenizer-workers 2 -``` - -1. Create `TokenizePool` (if `--tokenizer` provided) -2. Create `JsonlMetricEmitter` writing to `/metrics.jsonl` -3. Create `MetricsAggregatorService` connected to the ZMQ PUB socket -4. `aggregator.start()` adds the ZMQ socket reader to the event loop -5. `await shutdown_event.wait()` blocks until ENDED is received - -### Shutdown - -On `SessionEventType.ENDED`: - -1. `_finalize()` flushes the emitter -2. `close()` closes the emitter file and removes the ZMQ socket reader -3. `shutdown_event.set()` unblocks the main coroutine -4. `TokenizePool.close()` shuts down worker threads (in `finally` block) - -### Graceful Drain - -Events received in the same batch as ENDED are processed (the `_shutdown_received` -flag is checked at the top of the loop, so events before ENDED in the batch still -get handled). Events in subsequent batches are dropped. - -In-flight samples that never receive COMPLETE will be abandoned (their rows stay in -the table but are never emitted). This is expected — if the session ends, those -samples didn't complete. - -## Output Format - -### JSONL Example (streaming sample) - -```json -{"sample_uuid":"a1b2c3d4","metric_name":"isl","value":42,"timestamp_ns":100000000} -{"sample_uuid":"a1b2c3d4","metric_name":"ttft_ns","value":1500000,"timestamp_ns":100001500} -{"sample_uuid":"a1b2c3d4","metric_name":"chunk_delta_ns","value":500000,"timestamp_ns":100002000} -{"sample_uuid":"a1b2c3d4","metric_name":"chunk_delta_ns","value":600000,"timestamp_ns":100002600} -{"sample_uuid":"a1b2c3d4","metric_name":"request_duration_ns","value":3800000,"timestamp_ns":100003800} -{"sample_uuid":"a1b2c3d4","metric_name":"sample_latency_ns","value":4000000,"timestamp_ns":100004000} -{"sample_uuid":"a1b2c3d4","metric_name":"osl","value":28,"timestamp_ns":100004001} -{"sample_uuid":"a1b2c3d4","metric_name":"tpot_ns","value":92592.6,"timestamp_ns":100004001} -``` - -### JSONL Example (non-streaming sample) - -```json -{"sample_uuid":"e5f6a7b8","metric_name":"isl","value":15,"timestamp_ns":200000000} -{"sample_uuid":"e5f6a7b8","metric_name":"request_duration_ns","value":2500000,"timestamp_ns":200002500} -{"sample_uuid":"e5f6a7b8","metric_name":"sample_latency_ns","value":3000000,"timestamp_ns":200003000} -{"sample_uuid":"e5f6a7b8","metric_name":"osl","value":50,"timestamp_ns":200003001} -``` - -Note: no `ttft_ns`, `chunk_delta_ns`, or `tpot_ns` for non-streaming — these require -`RECV_FIRST` which only occurs in streaming mode. - -## Not Yet Wired - -The EventRecord pub/sub infrastructure is ready, but actual `publish(EventRecord(...))` -calls for sample events have not been connected in the load generator or worker -processes. What needs to happen: - -1. **Load generator** (`load_generator.py` / `session.py`): Publish `ISSUED` with - prompt text, `START/STOP_PERFORMANCE_TRACKING`, `STARTED`, `ENDED`. -2. **Worker** (`worker.py`): Publish `CLIENT_SEND`, `CLIENT_RESP_DONE`, - `RECV_FIRST`, `RECV_NON_FIRST`, `COMPLETE` with response data. -3. **Session orchestrator**: Spawn the metrics aggregator subprocess alongside - the event logger subprocess, passing the same ZMQ socket address. diff --git a/docs/async_utils/transport/zmq/ready_check_design.md b/docs/async_utils/transport/zmq/ready_check_design.md new file mode 100644 index 00000000..15e97968 --- /dev/null +++ b/docs/async_utils/transport/zmq/ready_check_design.md @@ -0,0 +1,114 @@ +# ReadyCheck Design + +## Problem + +Subprocess startup is asynchronous. The main process spawns workers or service +subprocesses, but cannot use them until they have completed initialization +(bound sockets, subscribed to topics, loaded resources). Without synchronization, +the main process may send messages that are dropped because the subprocess isn't +listening yet. + +## Solution + +A generic PUSH/PULL readiness protocol that works for any subprocess type: + +``` +Main Process Subprocess (worker or service) +┌───────────────────┐ ┌───────────────────────────┐ +│ ReadyCheckReceiver│ │ │ +│ (PULL, bind) │ │ 1. Initialize transports │ +│ │ │ 2. Subscribe / connect │ +│ await wait(N) │◄─── READY ───│ 3. send_ready_signal() │ +│ blocks until N │ (PUSH) │ (ctx, path, id) │ +│ signals arrive │ │ 4. Start processing │ +└───────────────────┘ └───────────────────────────┘ +``` + +## Why PUSH/PULL + +PUB/SUB has a "slow joiner" problem — the subscriber may miss messages +published before it connects. PUSH/PULL guarantees delivery: if the PULL +socket is bound before the PUSH connects, no messages are lost. + +Multiple PUSH sockets can connect to a single PULL socket (ZMQ fan-in). +This means one receiver socket handles readiness from all subprocesses. + +## Components + +### ReadyCheckReceiver (host side) + +- Binds a ZMQ PULL socket on an IPC path +- `wait(timeout)` blocks until `count` signals arrive +- Returns list of identities in arrival order +- Closes socket after all signals received (or on timeout) +- Timeout is a total deadline, not per-message + +### `send_ready_signal()` (subprocess side) + +- Free async function: `send_ready_signal(zmq_context, path, identity)` +- Uses the subprocess's **existing** ZMQ context — no new context created +- Opens one PUSH socket, sends one msgpack-encoded int, closes the socket +- Bounded LINGER (5s) to avoid hanging if receiver is gone + +## Usage Patterns + +### Workers (PUSH/PULL primary transport) + +The `_ZmqWorkerConnector` calls `send_ready_signal()` with the worker's +existing ZMQ context after connecting its request/response transports: + +```python +requests = _create_receiver(loop, request_path, zmq_context, ...) +responses = _create_sender(loop, response_path, zmq_context, ...) + +await send_ready_signal(zmq_context, self.readiness_path, worker_id) + +yield requests, responses +``` + +The `ZmqWorkerPoolTransport` creates a `ReadyCheckReceiver` and delegates +`wait_for_workers_ready()` to it. + +### Services (PUB/SUB primary transport) + +Services (EventLoggerService, MetricsAggregatorService) accept +`--readiness-path` and `--readiness-id` CLI arguments. After calling +`service.start()`, they signal readiness using the same ZMQ context: + +```python +service.start() + +if args.readiness_path: + await send_ready_signal(zmq_ctx, args.readiness_path, args.readiness_id) + +await shutdown_event.wait() +``` + +### ServiceLauncher + +```python +launcher = ServiceLauncher(zmq_context) +procs = await launcher.launch([ + ServiceConfig(module="...event_logger", args=["--socket-dir", d, ...]), + ServiceConfig(module="...metrics_aggregator", args=["--socket-dir", d, ...]), +], timeout=30.0) + +# ... run benchmark, publish ENDED ... + +ServiceLauncher.wait_for_exit(procs, timeout=60.0) +``` + +The launcher: + +1. Creates a `ReadyCheckReceiver` bound to a unique IPC path +2. Spawns each service as `python -m ... --readiness-path --readiness-id ` +3. Awaits all readiness signals (total deadline timeout) +4. Returns subprocess handles for later `wait_for_exit()` +5. On failure, checks for subprocess crashes and kills remaining processes + +## Ordering Guarantee + +The ready signal is sent **after** the subprocess has completed its +initialization (transport connect, topic subscribe, reader registration). +This guarantees that when the main process's `wait()` returns, all +subprocesses are ready to process messages. diff --git a/docs/load_generator/DESIGN.md b/docs/load_generator/DESIGN.md index 27182d60..c4df583b 100644 --- a/docs/load_generator/DESIGN.md +++ b/docs/load_generator/DESIGN.md @@ -1,152 +1,1077 @@ -# Load Generator — Design Spec +# Async Load Generator Design -> Central orchestrator for a benchmark run: controls what samples to issue, when to issue them via pluggable schedulers, and routes completion events to the metrics recorder. +## Overview + +The load generator is the central scheduling component that controls _when_ and _how_ +samples are issued to inference endpoints during benchmarking. It is fully async with a +single-thread, single-event-loop-per-process constraint. + +## File Structure + +``` +src/inference_endpoint/load_generator/ +├── __init__.py # Public exports +├── session.py # BenchmarkSession, SessionResult +├── strategy.py # LoadStrategy protocol, TimedIssueStrategy, +│ # BurstStrategy, ConcurrencyStrategy, +│ # create_load_strategy() +├── sample_order.py # SampleOrder, WithoutReplacement, WithReplacement +└── delay.py # poisson_delay_fn, uniform_delay_fn +``` + +## Architecture + +A `BenchmarkSession` runs one or more **phases** sequentially. Each phase has its own +`RuntimeSettings`, `Dataset`, and `LoadStrategy`. Phases are categorized as either +**tracked** (produces a performance metrics report) or **untracked** (performance is not evaluated). + +Multiple performance phases allow testing different configurations (QPS targets, +concurrency levels, datasets) against the same server instance within a single session, +each producing an independent report. + +``` +BenchmarkSession.run(phases) + | + +-- STARTED + +-- [warmup] strategy.execute() → NO drain (keep in-flight saturated) + +-- [perf phase 1] START_PERFORMANCE_TRACKING → strategy.execute() → drain → STOP_PERFORMANCE_TRACKING + +-- [warmup] strategy.execute() → NO drain (keep in-flight saturated) + +-- [perf phase 2] START_PERFORMANCE_TRACKING → strategy.execute() → drain → STOP_PERFORMANCE_TRACKING + +-- [accuracy x N] strategy.execute() → drain (uuid maps collected) + +-- ENDED + | + +-- return SessionResult { perf_results: [PhaseResult, ...], accuracy_results: [...] } +``` + +Each performance phase is bracketed by `START_PERFORMANCE_TRACKING` / +`STOP_PERFORMANCE_TRACKING` events, which the `MetricsAggregator` uses to +scope its tracked counters and duration. At the end of each perf phase, +metrics are snapshotted from the KVStoreReader and a `Report` is built. + +> **TODO:** The current `MetricsAggregator` does not support per-phase scoping. +> It maintains a single set of counters and series across all tracking windows. +> To support multiple perf phases with independent reports, the aggregator will +> need either: (a) a `RESET_METRICS` event that clears counters/series between +> phases, or (b) per-phase metric namespacing (e.g., prefix keys with phase name), +> or (c) the report builder computes deltas by snapshotting before and after each +> phase. This will be addressed in a future change to the `MetricsAggregator`. +> Option (b) is the most-likely planned change as it is the most robust. + +Saturation phases exist to bring the endpoint to steady-state before a +performance measurement. In-flight requests are **not drained** at the end +of a warmup phase — the next phase starts immediately with concurrency +already at the target level. Common uses: + +- Fill KV caches so perf phase measures warm inference, not cold start +- Ramp concurrency to target level before measuring at that level +- Warm connection pools and OS TCP buffers -**Component specs:** [async_utils](../async_utils/DESIGN.md) · [commands](../commands/DESIGN.md) · [config](../config/DESIGN.md) · [core](../core/DESIGN.md) · [dataset_manager](../dataset_manager/DESIGN.md) · [endpoint_client](../endpoint_client/DESIGN.md) · [evaluation](../evaluation/DESIGN.md) · **load_generator** · [metrics](../metrics/DESIGN.md) · [openai](../openai/DESIGN.md) · [plugins](../plugins/DESIGN.md) · [profiling](../profiling/DESIGN.md) · [sglang](../sglang/DESIGN.md) · [testing](../testing/DESIGN.md) · [utils](../utils/DESIGN.md) +### Load Strategies + +Three load patterns, three implementations — each uses the optimal async primitive +for its scheduling semantics, validated by benchmarking: + +| LoadPatternType | Strategy | Mechanism | Best At | +| ----------------- | --------------------- | ---------------------------- | ------------------- | +| POISSON | `TimedIssueStrategy` | `loop.call_at` (default) | ≤50k QPS | +| POISSON (precise) | `TimedIssueStrategy` | `run_in_executor(busy_wait)` | Sub-100μs precision | +| MAX_THROUGHPUT | `BurstStrategy` | `loop.call_soon` | Max fire rate | +| CONCURRENCY | `ConcurrencyStrategy` | `asyncio.Semaphore` | Fixed concurrency | + +**Default for Poisson is `loop.call_at`:** Sub-millisecond timing precision (600–700μs) +with zero GIL contention and low response latency (0.6–1.4ms). No thread pool overhead. +Degrades above 100k+ QPS where the callback queue saturates. + +`run_in_executor(busy_wait)` is available as an opt-in for workloads requiring sub-100μs +timing precision. It achieves 65–92μs but introduces GIL contention that adds 6ms +response latency at low QPS (<1k). At mid-range QPS (5k–50k), latency is comparable. + +### Optional: Separate Timer Process + +For workloads requiring both precise timing AND minimal response latency (e.g., edge +inference with tight TPOT budgets), the timer can run in a dedicated process: + +``` +Timer Process (dedicated): + - Owns a tight busy-wait loop, no GIL contention + - Sends (sample_index: int) via ZMQ PUSH at precise times + +Main Process: + - Receives indices via ZMQ PULL + - Loads data, builds Query, issues via HTTPEndpointClient + - Runs receiver coroutine — event loop is never blocked +``` + +This eliminates the GIL contention that causes `run_in_executor` to add approximately +6ms response latency at low QPS. However, it adds ZMQ IPC latency (10–50μs) to timing +precision. + +**Not suitable for ConcurrencyStrategy**: the timer process has no visibility into +completion events, so it cannot gate on in-flight count. Concurrency mode always runs +in-process. --- -## Overview +## Components + +### BenchmarkSession + +**File:** `src/inference_endpoint/load_generator/session.py` -`load_generator/` is the central orchestrator for a benchmark run. It controls **what** to send -(dataset samples), **when** to send them (load pattern), and **how** to observe the results -(event hooks feeding the metrics recorder). +Async orchestrator. Runs phases sequentially on the shared event loop. -## Responsibilities +```python +class PhaseType(str, Enum): + """Phase types control tracking and reporting behavior.""" + PERFORMANCE = "performance" # Tracked, produces a report + ACCURACY = "accuracy" # Untracked, for eval scoring + WARMUP = "warmup" # Untracked, ramp up concurrency before perf phase + + +@dataclass(frozen=True, slots=True) +class PhaseConfig: + """Configuration for a single benchmark phase.""" + name: str + runtime_settings: RuntimeSettings + dataset: Dataset + phase_type: PhaseType = PhaseType.PERFORMANCE + + +class BenchmarkSession: + def __init__( + self, + issuer: SampleIssuer, + event_publisher: EventPublisher, + loop: asyncio.AbstractEventLoop, + on_sample_complete: Callable[[QueryResult | StreamChunk], None] | None = None, + session_id: str | None = None, + ): ... + + async def run(self, phases: list[PhaseConfig]) -> SessionResult: ... + def stop(self) -> None: ... +``` + +**`run(phases)`** lifecycle: -- Manage the full benchmark session lifecycle (start → run → drain → report) -- Implement timing strategies: max throughput, Poisson, fixed concurrency -- Emit structured events for every sample state transition -- Coordinate graceful shutdown with in-flight drain +1. Publish `SessionEventType.STARTED` +2. Start receiver coroutine (`_receive_responses`) +3. For each phase: + a. Create `SampleOrder` and `LoadStrategy` from phase settings + b. Set `self._current_dataset` to phase dataset + c. **WARMUP**: execute strategy, **do not drain** in-flight. No tracking + events, no report. Purpose: bring endpoint to steady-state concurrency + (e.g., fill KV caches, warm up connection pools). The next phase starts + immediately with concurrency already at the target level. + d. **PERFORMANCE**: publish `START_PERFORMANCE_TRACKING`, execute strategy, + drain in-flight, publish `STOP_PERFORMANCE_TRACKING`. Snapshot metrics + from KVStoreReader → build `PhaseResult`. + e. **ACCURACY**: execute strategy, drain in-flight. No tracking events. + UUID map collected for eval scoring. +4. Publish `SessionEventType.ENDED` +5. Return `SessionResult` (contains `PhaseResult` per perf phase + accuracy maps) -## Component Map +**Saturation phases** are particularly important for concurrency-based benchmarks. +A common pattern: +```python +phases = [ + # Ramp up to target concurrency, fill endpoint caches + PhaseConfig("warmup", warmup_settings, dataset, PhaseType.WARMUP), + # Measured performance run + PhaseConfig("perf", perf_settings, dataset, PhaseType.PERFORMANCE), + # Accuracy eval (uses same warmed endpoint) + PhaseConfig("accuracy", acc_settings, acc_dataset, PhaseType.ACCURACY), +] ``` -BenchmarkSession ← top-level owner; runs on background thread - └── SchedulerBasedLoadGenerator ← iterates (sample_index, delay_ns) pairs - ├── Scheduler ← determines timing - │ ├── MaxThroughputScheduler (offline: all at t=0) - │ ├── PoissonDistributionScheduler (online: exp inter-arrival) - │ └── ConcurrencyScheduler (online: fixed in-flight count) - └── SampleIssuer (ABC) ← sends the query; implemented by endpoint_client/ + +Or multiple performance sweeps with warmup between each: + +```python +phases = [ + PhaseConfig("saturate_c32", sat_32, dataset, PhaseType.WARMUP), + PhaseConfig("perf_c32", perf_32, dataset, PhaseType.PERFORMANCE), + PhaseConfig("saturate_c64", sat_64, dataset, PhaseType.WARMUP), + PhaseConfig("perf_c64", perf_64, dataset, PhaseType.PERFORMANCE), + PhaseConfig("accuracy", acc_settings, acc_dataset, PhaseType.ACCURACY), +] ``` -## Public Interface +### PhaseIssuer -### `BenchmarkSession` +**File:** `src/inference_endpoint/load_generator/session.py` (internal to session) + +Per-phase state holder that wraps the issue logic. Created fresh for each phase, +holds the phase-scoped `uuid_to_index` map and inflight counter. Passed to +strategies as a callable (`phase_issuer.issue`). + +Using an object instead of a closure makes per-phase state explicit, testable +independently, and avoids the awkward tuple return pattern. ```python -@classmethod -def start( - cls, - runtime_settings: RuntimeSettings, - dataset: Dataset, - sample_issuer: SampleIssuer, - scheduler: Scheduler, - *args, - accuracy_datasets: list[Dataset] | None = None, - load_generator_cls: type[LoadGenerator] = SchedulerBasedLoadGenerator, - name: str | None = None, - max_shutdown_timeout_s: float | None = None, - report_dir: os.PathLike | None = None, - tokenizer_override: AutoTokenizer | None = None, - dump_events_log: bool = False, -) -> "BenchmarkSession" +class PhaseIssuer: + """Wraps sample issuance for a single benchmark phase.""" -def wait_for_test_end(self, timeout: float | None = None) -> bool -def stop(self) -> None + __slots__ = ("_dataset", "_issuer", "_publisher", "_stop_check", + "uuid_to_index", "inflight", "issued_count") + + def __init__( + self, + dataset: Dataset, + issuer: SampleIssuer, + publisher: EventRecordPublisher, + stop_check: Callable[[], bool], + ): + self._dataset = dataset + self._issuer = issuer + self._publisher = publisher + self._stop_check = stop_check + self.uuid_to_index: dict[str, int] = {} + self.inflight: int = 0 + self.issued_count: int = 0 + + def issue(self, sample_index: int) -> str | None: + """Load data, build Query, publish ISSUED, send to endpoint. + + Returns query_id on success, None if session is stopping. + """ + if self._stop_check(): + return None + query_id = uuid.uuid4().hex + data = self._dataset.load_sample(sample_index) + query = Query(id=query_id, data=data) + self.uuid_to_index[query_id] = sample_index + ts = time.monotonic_ns() + self._publisher.publish(EventRecord( + event_type=SampleEventType.ISSUED, + timestamp_ns=ts, + sample_uuid=query_id, + data=PromptData(text=data.get("prompt")), + )) + self._issuer.issue(query) + self.inflight += 1 + self.issued_count += 1 + return query_id ``` -`start()` spawns the run thread immediately. `wait_for_test_end()` blocks the caller until the -session finishes or the timeout expires. `stop()` signals early termination. +The strategy calls `phase_issuer.issue(idx)`. After the phase completes, +the session reads `phase_issuer.uuid_to_index` and `phase_issuer.issued_count` +to build the `PhaseResult`. + +**UUID generation before Query construction** avoids the old `Sample` catch-22. +`Query` is a frozen `msgspec.Struct` — all fields set at construction, no mutation. -### `SampleIssuer` (abstract base class — implemented externally) +**`_receive_responses()`** — concurrent coroutine, purely async: ```python -def start() -> None -def issue(sample: Sample) -> None -def shutdown() -> None +async def _receive_responses(self): + while not self._done: + resp = await self._issuer.recv() + if resp is None: + # Transport closed — trigger stop so strategy and drain don't hang. + self._stop_requested = True + self._drain_event.set() + if self._strategy_task and not self._strategy_task.done(): + self._strategy_task.cancel() + break + self._handle_response(resp) ``` -`SampleIssuer` is an `ABC`, not a structural protocol. `start()` and `shutdown()` have default -no-op implementations; subclasses must implement `issue()`. `issue()` must be non-blocking; -responses are delivered asynchronously via `SampleEventHandler`. +Uses `recv()` exclusively — no `poll()` spin. The ZMQ fd is registered with +the event loop, so `recv()` wakes exactly when a response is available with +zero CPU overhead. Each `recv()` call yields to the event loop, ensuring +strategy coroutines (call_at callbacks, semaphore waiters) are never starved. + +For `ConcurrencyStrategy`, `_handle_response` calls `strategy.on_query_complete()` +which releases the semaphore. Since `recv()` returns as soon as the fd is readable +and `eager_task_factory` executes the woken semaphore waiter synchronously, there +is no added latency compared to a poll-based approach. + +**`_handle_response(resp)`**: + +- `QueryResult`: publish COMPLETE event, decrement `_inflight`, call `on_sample_complete`, + call `strategy.on_query_complete(query_id)` if strategy supports it +- `StreamChunk(first)`: publish RECV_FIRST event +- `StreamChunk(non-first)`: publish RECV_NON_FIRST event + +**Timestamp fidelity:** -### `Scheduler` (base class) +- ISSUED: `monotonic_ns()` taken immediately before `issuer.issue()`. The ZMQ push is + sync and non-blocking, so this honestly represents when the query entered the transport. + Note: with batched publishing, `publisher.publish()` buffers the ISSUED EventRecord + in memory — the actual ZMQ send is deferred until the batch threshold is reached or + `flush()` is called. The timestamp itself is still accurate (captured before buffering), + but the EventRecord reaches subscribers with batching latency. +- COMPLETE: `QueryResult.completed_at` is set via `force_setattr(monotonic_ns())` in + `__post_init__`, regenerated on deserialization. Both ISSUED and COMPLETE timestamps + share the same ZMQ transit bias. TTFT (`RECV_FIRST - ISSUED`) is still sensitive + to this overhead since it spans the full ZMQ round-trip. TPOT avoids cross-process + clock skew by computing time deltas between consecutive chunks within the same process. + +### LoadStrategy (Protocol) + +**File:** `src/inference_endpoint/load_generator/strategy.py` ```python -def __iter__(self) -> Iterator[tuple[int, int]] -# yields (sample_index, delay_ns) +class LoadStrategy(Protocol): + async def execute( + self, + phase_issuer: PhaseIssuer, + ) -> int: + """Drive sample issuance. Returns count of samples issued. + + Call phase_issuer.issue(sample_index) for each sample. + Returns None when session is stopping (max_duration, stop(), or + all samples exhausted). + """ + ... + + def on_query_complete(self, query_id: str) -> None: + """Called by session on each QueryResult. Default: no-op.""" + ... ``` -Subclasses register themselves via `__init_subclass__(load_pattern=LoadPatternType.X)` and are -looked up at construction time. +The strategy calls `phase_issuer.issue(idx)` which handles data loading, Query +construction, event publishing, and the actual send. The strategy only controls +_when_ and _which index_ to issue. Stop checking is internal to `PhaseIssuer.issue()` +— it returns `None` when the session should stop. + +`on_query_complete` is the hook for `ConcurrencyStrategy` — other strategies ignore it. + +### TimedIssueStrategy + +Handles `LoadPatternType.POISSON`. Default uses `loop.call_at`; opt-in +`run_in_executor(busy_wait)` available for sub-100μs precision requirements. + +```python +class TimedIssueStrategy(LoadStrategy): + def __init__( + self, + delay_fn: Callable[[], int], + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + use_executor: bool = False, + ): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + if self.use_executor: + return await self._execute_executor(phase_issuer) + else: + return await self._execute_call_at(phase_issuer) +``` + +**call_at mode** (default): + +```python +async def _execute_call_at(self, phase_issuer): + done = asyncio.Event() + start_time = self._loop.time() + cumulative_s = 0.0 + + def schedule_next(): + nonlocal cumulative_s + idx = next(self.sample_order, None) + if idx is None: + done.set() + return + cumulative_s += self.delay_fn() / 1e9 + self._loop.call_at(start_time + cumulative_s, fire, idx) + + def fire(idx): + if phase_issuer.issue(idx) is None: + done.set() + return + schedule_next() + + schedule_next() + await done.wait() + return phase_issuer.issued_count +``` + +**Executor mode** (opt-in, `use_executor=True`): + +```python +async def _execute_executor(self, phase_issuer): + start = monotonic_ns() + cumulative = 0 + for idx in self.sample_order: + cumulative += self.delay_fn() + target = start + cumulative + now = monotonic_ns() + if target > now: + await self._loop.run_in_executor(None, _busy_wait_until, target) + if phase_issuer.issue(idx) is None: + break + return phase_issuer.issued_count +``` + +### BurstStrategy + +Handles `LoadPatternType.MAX_THROUGHPUT`. Issues all samples as fast as possible +using `loop.call_soon` to schedule each issue as an event loop callback. This +avoids starving the receiver — between each callback, the loop processes I/O +events (including ZMQ recv fd readiness). + +```python +class BurstStrategy(LoadStrategy): + def __init__(self, sample_order: Iterator[int], loop: asyncio.AbstractEventLoop): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + done = asyncio.Event() + + def issue_next(): + idx = next(self.sample_order, None) + if idx is None or phase_issuer.issue(idx) is None: + done.set() + return + self._loop.call_soon(issue_next) + + self._loop.call_soon(issue_next) + await done.wait() + return phase_issuer.issued_count +``` + +Each `call_soon` yields to the event loop between issues, preventing receiver +starvation. Benchmark data shows `loop.call_at` (with zero delay, equivalent +to `call_soon`) achieves 104k QPS — the highest throughput of all strategies. + +### ConcurrencyStrategy + +Handles `LoadPatternType.CONCURRENCY`. Semaphore-gated by completions. + +```python +class ConcurrencyStrategy(LoadStrategy): + def __init__(self, target_concurrency: int, sample_order: Iterator[int]): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + for idx in self.sample_order: + await self._sem.acquire() + if phase_issuer.issue(idx) is None: + self._sem.release() + break + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + self._sem.release() +``` + +### SampleIssuer (Protocol) + +```python +class SampleIssuer(Protocol): + def issue(self, query: Query) -> None: ... + async def recv(self) -> QueryResult | StreamChunk | None: ... + def shutdown(self) -> None: ... +``` + +`issue()` is sync (ZMQ push). `recv()` is async blocking wait. +This matches `HTTPEndpointClient`'s existing interface. + +### SampleOrder (unchanged) + +`SampleOrder` is an infinite iterator yielding dataset indices. Implementations: + +- `WithoutReplacementSampleOrder` — shuffle, exhaust, reshuffle +- `WithReplacementSampleOrder` — uniform random + +Termination is controlled by `BenchmarkSession._make_stop_check()`, not the iterator. + +### SessionResult + +```python +@dataclass(frozen=True) +class PhaseResult: + """Result of a single benchmark phase.""" + name: str + phase_type: PhaseType + uuid_to_index: dict[str, int] + issued_count: int + start_time_ns: int + end_time_ns: int + + +@dataclass(frozen=True) +class SessionResult: + """Combined results from all phases in a session.""" + session_id: str + phase_results: list[PhaseResult] + start_time_ns: int + end_time_ns: int + + @property + def perf_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.PERFORMANCE] + + @property + def accuracy_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.ACCURACY] +``` + +--- ## Data Flow +### Happy Path: Issue → Response → Event + +```mermaid +sequenceDiagram + participant S as LoadStrategy + participant B as BenchmarkSession + participant D as Dataset + participant I as SampleIssuer + participant W as Worker Process + participant E as EventPublisher + participant M as MetricsAggregator + + S->>B: issue_fn(sample_index) + B->>D: load_sample(index) + D-->>B: sample_data + Note over B: Build Query(id=uuid4().hex, data=load_sample(idx)) + B->>E: publish(ISSUED, uuid, timestamp_ns) + E->>M: ZMQ PUB (EventRecord) + B->>I: issue(query) + I->>W: ZMQ PUSH (Query) + W->>W: HTTP request → endpoint + + alt Non-streaming + W-->>I: ZMQ PUSH (QueryResult) + I-->>B: recv() + B->>E: publish(COMPLETE, uuid, completed_at) + else Streaming + W-->>I: ZMQ PUSH (StreamChunk first_chunk) + I-->>B: recv() + B->>E: publish(RECV_FIRST, uuid, timestamp_ns) + loop For each subsequent chunk + W-->>I: ZMQ PUSH (StreamChunk) + I-->>B: recv() + B->>E: publish(RECV_NON_FIRST, uuid, timestamp_ns) + end + W-->>I: ZMQ PUSH (QueryResult — final accumulated output) + I-->>B: recv() + B->>E: publish(COMPLETE, uuid, completed_at) + end + + E->>M: ZMQ PUB (EventRecord) + B->>S: on_query_complete(uuid) + Note over S: ConcurrencyStrategy: sem.release() ``` -BenchmarkSession._run_test() - │ - ├─ for (index, delay_ns) in SchedulerBasedLoadGenerator: - │ busy_wait(delay_ns) - │ sample = load_sample_data(index) - │ SampleIssuer.issue(sample) → async, fire-and-forget - │ - └─ wait_for_drain() ← blocks until all in-flight complete - │ - └─ SampleEventHandler routes completions: - FIRST_CHUNK → recorder.record_event(SampleEvent.FIRST_CHUNK) - COMPLETE → recorder.record_event(SampleEvent.COMPLETE) + +### Multi-Phase Session Lifecycle + +```mermaid +sequenceDiagram + participant C as Caller (execute.py) + participant B as BenchmarkSession + participant E as EventPublisher + participant M as MetricsAggregator + participant K as KVStoreReader + + C->>B: run(phases) + B->>E: STARTED + + Note over B: === Saturation Phase === + B->>B: execute strategy (untracked, no drain) + + Note over B: === Perf Phase 1 (e.g. QPS=1000) === + B->>E: START_PERFORMANCE_TRACKING + B->>B: execute strategy + B->>B: drain in-flight + B->>E: STOP_PERFORMANCE_TRACKING + B->>K: snapshot metrics → Report + Note over B: PhaseResult("perf_qps1k", issued_count=N) + + Note over B: === Perf Phase 2 (e.g. QPS=5000) === + B->>E: START_PERFORMANCE_TRACKING + B->>B: execute strategy + B->>B: drain in-flight + B->>E: STOP_PERFORMANCE_TRACKING + B->>K: snapshot metrics → Report + Note over B: PhaseResult("perf_qps5k", issued_count=N) + + Note over B: === Accuracy Phase === + B->>B: execute strategy (untracked) + B->>B: drain in-flight + Note over B: PhaseResult("accuracy", uuid_map) + + B->>E: ENDED + B-->>C: SessionResult +``` + +### Separate Timer Process Data Flow + +```mermaid +sequenceDiagram + participant T as Timer Process + participant B as BenchmarkSession + participant D as Dataset + participant I as SampleIssuer + participant W as Worker Process + + Note over T: Busy-wait loop (no GIL contention) + T->>B: ZMQ PUSH (sample_index) + B->>D: load_sample(index) + D-->>B: sample_data + Note over B: Build Query, publish ISSUED + B->>I: issue(query) + I->>W: ZMQ PUSH + W-->>I: ZMQ PUSH (QueryResult) + I-->>B: recv() + Note over B: publish COMPLETE +``` + +--- + +## Event Loop Topology + +### Standard (single process) + +```mermaid +graph TD + subgraph "Main Process — LoopManager.default_loop (uvloop)" + A["BenchmarkSession.run()"] + B["LoadStrategy.execute()"] + C["_receive_responses() task"] + D["EventPublisher (ZMQ PUB)"] + E["HTTPEndpointClient (shared loop)"] + + A --> B + A --> C + B -->|"issue_fn → issuer.issue()"| E + C -->|"recv()"| E + B --> D + C --> D + end + + subgraph "Worker Process 1" + W1["HTTP → endpoint"] + end + subgraph "Worker Process N" + WN["HTTP → endpoint"] + end + subgraph "MetricsAggregator (subprocess)" + MA["ZmqEventRecordSubscriber"] + KB["KVStore (mmap)"] + MA --> KB + end + + E -->|"ZMQ IPC"| W1 + E -->|"ZMQ IPC"| WN + W1 -->|"ZMQ IPC"| E + WN -->|"ZMQ IPC"| E + D -->|"ZMQ PUB"| MA ``` -## Design Decisions +### With Separate Timer Process + +```mermaid +graph TD + subgraph "Timer Process" + T["Busy-wait loop + ZMQ PUSH"] + end + + subgraph "Main Process — LoopManager.default_loop" + R["ZMQ PULL receiver"] + A["BenchmarkSession"] + C["_receive_responses() task"] + D["EventPublisher"] + E["HTTPEndpointClient"] -**Busy-wait for timing precision** + R -->|"sample_index"| A + A -->|"issue()"| E + C -->|"recv()"| E + A --> D + C --> D + end + + subgraph "Worker Processes" + W["HTTP → endpoint"] + end + subgraph "MetricsAggregator" + MA["Subscriber → KVStore"] + end + + T -->|"ZMQ IPC"| R + E -->|"ZMQ IPC"| W + W -->|"ZMQ IPC"| E + D -->|"ZMQ PUB"| MA +``` + +--- + +## Load Pattern Mapping + +```python +def create_load_strategy( + runtime_settings: RuntimeSettings, + loop: asyncio.AbstractEventLoop, + sample_order: SampleOrder | None = None, + use_executor: bool = False, +) -> LoadStrategy: + lp = runtime_settings.load_pattern + + match lp.type: + case LoadPatternType.MAX_THROUGHPUT: + return BurstStrategy(sample_order, loop) + + case LoadPatternType.POISSON: + delay_fn = make_delay_fn(lp, runtime_settings.rng_sched) + return TimedIssueStrategy(delay_fn, sample_order, loop, + use_executor=use_executor) + + case LoadPatternType.CONCURRENCY: + return ConcurrencyStrategy(lp.target_concurrency, sample_order) +``` + +--- -`SchedulerBasedLoadGenerator` uses a busy-wait loop (`while time.monotonic_ns() < target_ns`) for -inter-sample delays rather than `asyncio.sleep()` or `time.sleep()`. This achieves sub-millisecond -timing accuracy at high QPS without introducing event-loop latency. The trade-off is elevated CPU -usage on the scheduling thread during the run. +## Benchmark Data Summary -**Thread-based session, not async** +Measured with MaxThroughputServer + real HTTPEndpointClient: -`BenchmarkSession._run_test()` runs on a `threading.Thread`, not a coroutine. The scheduler loop -is blocking by design — it must not yield to the event loop, which could introduce scheduling jitter. -The async event loop is owned by `HTTPEndpointClient`, not the load generator. +### Poisson Mode — Strategy Comparison -**`SampleEventHandler` singleton with registered hooks** +| QPS | `run_in_executor` precision | `loop.call_at` precision | `asyncio.sleep` precision | +| ------ | --------------------------- | ------------------------ | ------------------------- | +| 100 | 84 μs | 1,772 μs | 2,008 μs | +| 1,000 | 65 μs | 679 μs | 734 μs | +| 10,000 | 67 μs | 739 μs | 658 μs | +| 50,000 | 85 μs | 586 μs | 291 μs | +| 100k | 126 μs | 1,043 μs | 65 μs | -All sample-level events (FIRST_CHUNK, COMPLETE, etc.) route through a single global -`_SampleEventHandler`. Hooks are registered before the run starts and remain constant for its -duration. This eliminates per-sample dispatch overhead at runtime. +Response latency at 100 QPS: `run_in_executor` = 6.2ms, `loop.call_at` = 1.4ms. +The GIL contention from the executor busy-wait thread penalizes low-QPS latency. -**`ConcurrencyScheduler` coordination via `threading.Condition`** +### Concurrency Mode -The concurrency scheduler blocks issuance when in-flight count reaches the target, then wakes -via a Condition notified by the COMPLETE hook. This provides back-pressure without polling. +| Strategy | QPS | Latency (mean) | +| --------- | ------ | -------------- | +| Semaphore | 80,631 | 0.73 ms | +| Callback | 77,488 | 0.81 ms | -## Event Types +### Max Throughput -| Event | Enum type | Meaning | -| --------------------------- | -------------- | --------------------------------------- | -| `TEST_STARTED` | `SessionEvent` | Run begins | -| `STOP_PERFORMANCE_TRACKING` | `SessionEvent` | Performance issuance phase has ended | -| `LOADGEN_STOP` | `SessionEvent` | Load generator finished issuing samples | -| `TEST_ENDED` | `SessionEvent` | Run complete | -| `LOADGEN_ISSUE_CALLED` | `SessionEvent` | `issue()` called | -| `LOADGEN_DATA_LOAD` | `SessionEvent` | Sample payload loaded from dataset | -| `HTTP_REQUEST_ISSUED` | `SampleEvent` | Request sent to endpoint | -| `HTTP_RESPONSE_COMPLETED` | `SampleEvent` | Endpoint HTTP response fully received | -| `FIRST_CHUNK` | `SampleEvent` | First SSE chunk received | -| `NON_FIRST_CHUNK` | `SampleEvent` | Subsequent SSE chunk | -| `COMPLETE` | `SampleEvent` | Final result received | +| Strategy | QPS | Latency (mean) | +| -------------- | ------- | -------------- | +| `loop.call_at` | 104,039 | 1.47 ms | +| `run_in_exec` | 78,261 | 8.28 ms | + +--- ## Integration Points -| Dependency | Role | -| ---------------------------- | ---------------------------------------------------------- | -| `core/types.py` | `Query`, `QueryResult`, `StreamChunk` | -| `endpoint_client/` | Implements `SampleIssuer` | -| `metrics/recorder.py` | Receives all events via `SampleEventHandler` | -| `config/runtime_settings.py` | `RuntimeSettings` drives duration, sample count, RNG seeds | -| `dataset_manager/` | Provides `Dataset` for sample data | +### HTTPEndpointClient + +Pass `loop` to share the event loop. The client already supports this via the +`_owns_loop` flag. Two changes required: + +**Initialization deadlock:** `__init__` calls `run_coroutine_threadsafe().result()` +which deadlocks when the calling thread IS the event loop thread. Fix: add an +async classmethod factory: + +```python +@classmethod +async def create(cls, config: HTTPClientConfig, loop: asyncio.AbstractEventLoop) -> HTTPEndpointClient: + client = cls.__new__(cls) + client._setup_sync_fields(config, loop) + await client._initialize() + return client +``` + +**Shutdown deadlock:** Same pattern — `shutdown()` calls `run_coroutine_threadsafe().result()`. +Fix: expose `async shutdown_async()` as a public method. When `_owns_loop is False`, +`shutdown()` should raise if called from the event loop thread, directing callers +to use `await shutdown_async()`. + +### EventPublisher / MetricsAggregator + +Session publishes `EventRecord` instances via `ZmqEventRecordPublisher`. The publisher +uses non-blocking ZMQ send with fd-based writer fallback — safe to call from sync +callbacks (like `call_at` fire functions). + +> **Key fix:** `Report.from_kv_reader` currently reads counter keys (`n_samples_issued`, +> `duration_ns`) that don't match the `MetricCounterKey` enum written by the aggregator +> (`total_samples_issued`, `tracked_duration_ns`). Must update `from_kv_reader` to use +> the actual key names. Performance reports should use `tracked_*` counters. A +> `test_started_at` counter must be added to the aggregator (set on `SessionEventType.STARTED`). + +### HttpClientSampleIssuer Migration + +The current issuer takes `Sample` and constructs `Query` internally. In the new design, +`PhaseIssuer` constructs the `Query`, so the issuer just forwards it: + +```python +class HttpClientSampleIssuer: + def __init__(self, http_client: HTTPEndpointClient): + self.http_client = http_client + + def issue(self, query: Query) -> None: + self.http_client.issue(query) + + async def recv(self) -> QueryResult | StreamChunk | None: + return await self.http_client.recv() + + def shutdown(self) -> None: + pass # HTTPEndpointClient shutdown called separately +``` + +Removed from current issuer: `_handle_responses` coroutine, `SampleEventHandler` +routing, `run_coroutine_threadsafe` cross-loop dispatch. The session's +`_receive_responses` replaces all of this. + +### Query.id Format + +`Query.default_factory` uses `str(uuid.uuid4())` (36 chars with hyphens). +The design uses `uuid.uuid4().hex` (32 chars, no hyphens). Standardize on +`.hex` — shorter strings, no parsing overhead. Update `Query.default_factory` +to match. + +### Timestamp Fidelity + +- **ISSUED**: `monotonic_ns()` taken in `PhaseIssuer.issue()` immediately before + `issuer.issue(query)`. ZMQ push is sync/non-blocking — timestamp is honest. +- **COMPLETE**: `QueryResult.completed_at` set in `__post_init__` on deserialization + in the main process. Measures main-process receipt time, not worker-side completion. +- **TTFT**: `RECV_FIRST - ISSUED` includes full round-trip ZMQ overhead (outbound to + worker + return to main). This adds 20-100μs of systematic bias. Acceptable for + most benchmarks; document as a known measurement overhead. +- **Latency (COMPLETE - ISSUED)**: Both timestamps taken on the main process side. + ZMQ transit bias is symmetric and cancels. This is the most accurate measurement. + +### Stale Completions After Saturation + +After a warmup phase (no drain), in-flight responses arrive during the perf +phase. The receiver must distinguish stale vs current-phase completions: + +```python +def _handle_response(self, resp: QueryResult) -> None: + query_id = resp.id + # Always publish the event (aggregator tracks all samples) + self._publisher.publish(EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=resp.completed_at, + sample_uuid=query_id, + )) + # Only route to current phase strategy if this is a current-phase query + if query_id in self._current_phase_issuer.uuid_to_index: + self._current_phase_issuer.inflight -= 1 + if self._current_strategy: + self._current_strategy.on_query_complete(query_id) + # Stale completions: event published but strategy/inflight not affected +``` + +Same guard applies to `StreamChunk` with `is_complete=True` — check +`uuid_to_index` membership before decrementing inflight. Non-final +StreamChunks don't affect inflight and can be published unconditionally. + +### Sync Per-Sample Work in Callbacks + +All three strategies call `PhaseIssuer.issue()` synchronously — from `call_at` +callbacks (Poisson), `call_soon` callbacks (Burst), or inline after `sem.acquire()` +(Concurrency). Each `issue()` call performs: `dataset.load_sample()`, `uuid4().hex`, +`Query` construction, `EventRecord` publish (ZMQ NOBLOCK), and `issuer.issue()` +(ZMQ NOBLOCK). The ZMQ sends are confirmed non-blocking with internal buffering. + +The dominant cost is `dataset.load_sample()`. **Requirement:** datasets must be +pre-loaded into memory before the benchmark starts. If `load_sample()` performs +disk I/O, it blocks the event loop and degrades both timing precision and response +processing. For lazy-loading or disk-backed datasets, either pre-materialize +during setup or use executor mode. + +At 100k+ QPS with `BurstStrategy`, the `call_soon` callback queue depth can +delay `recv()` wakeups. Benchmarking shows this is acceptable (104k QPS with +1.47ms mean response latency), but the recv latency is bounded by the queue +depth rather than being strictly real-time. + +--- + +## CLI / Logging / TUI Integration + +### CLI Integration + +The CLI entry point (`commands/benchmark/execute.py`) orchestrates setup, execution, +and finalization as three sync phases: + +```python +def run_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> None: + ctx = setup_benchmark(config, test_mode) # sync: datasets, tokenizer, config + bench = run_benchmark_async(ctx) # async: returns BenchmarkResult + finalize_benchmark(ctx, bench) # sync: scoring, report, JSON output + +def run_benchmark_async(ctx: BenchmarkContext) -> BenchmarkResult: + loop = LoopManager().default_loop + return loop.run_until_complete(_run_benchmark_async(ctx, loop)) +``` + +`_run_benchmark_async` sets up the ZMQ context, event publisher, service subprocesses +(metrics_aggregator and event_logger), HTTP client, and session — all inside a +`ManagedZMQContext.scoped()` block. The HTTP config is constructed locally via +`config.settings.client.with_updates(...)`. + +```python +async def _run_benchmark_async(ctx, loop) -> BenchmarkResult: + collector = ResponseCollector(collect_responses=ctx.collect_responses, pbar=pbar) + + with ManagedZMQContext.scoped(io_threads=2) as zmq_ctx: + publisher = EventPublisherService(zmq_ctx) + launcher = ServiceLauncher(zmq_ctx) + await launcher.launch([ + ServiceConfig(module="...metrics_aggregator", args=aggregator_args), + ServiceConfig(module="...event_logger", args=event_logger_args), + ], timeout=30.0) + + http_config = config.settings.client.with_updates(...) + http_client = await HTTPEndpointClient.create(http_config, loop) + session = BenchmarkSession( + issuer=HttpClientSampleIssuer(http_client), + event_publisher=publisher, loop=loop, + on_sample_complete=collector.on_complete_hook, + ) + phases = _build_phases(ctx) + loop.add_signal_handler(signal.SIGINT, session.stop) + try: + result = await session.run(phases) + finally: + loop.remove_signal_handler(signal.SIGINT) + await http_client.shutdown_async() + publisher.close() + await asyncio.to_thread(launcher.wait_for_exit, None) + report = Report.from_kv_reader(kv_reader) # after aggregator exits + + return BenchmarkResult(session=result, collector=collector, report=report, ...) +``` + +`BenchmarkResult` is a dataclass bundling `SessionResult`, `ResponseCollector`, +`Report`, and tmpfs paths. `finalize_benchmark(ctx, bench)` unpacks it for +accuracy scoring, report display, and results JSON output. + +### Logging + +Standard Python `logging` is used throughout. Key log points: + +- Phase transitions: `logger.info("Starting phase: %s (%s)", name, phase_type)` +- Sample counts: `logger.info("Phase %s complete: %d samples issued", name, count)` +- Errors: `logger.error("Failed to issue query %s: %s", query_id, error)` +- Shutdown: `logger.info("Benchmark session cancelled")` + +Log level is configurable via `RuntimeSettings` / CLI `--log-level`. + +### Progress Reporting (tqdm) + +`ResponseCollector.on_complete_hook` drives the progress bar: + +```python +class ResponseCollector: + def __init__(self, collect_responses: bool = False, pbar: tqdm | None = None): + self.collect_responses = collect_responses + self.responses: dict[str, str] = {} + self.errors: list[str] = [] + self.count = 0 + self.pbar = pbar + + def on_complete_hook(self, result: QueryResult) -> None: + self.count += 1 + if result.error: + self.errors.append(f"Sample {result.id}: {result.error}") + if self.pbar: + self.pbar.set_postfix(refresh=True, errors=len(self.errors)) + elif self.collect_responses: + self.responses[result.id] = result.get_response_output_string() + if self.pbar: + self.pbar.update(1) +``` + +The session calls `on_sample_complete(result)` from `_handle_response`, which +fires from the `_receive_responses` coroutine on the event loop. + +### Accuracy Phase Response Collection + +After `session.run()` returns, accuracy responses are partitioned using +`PhaseResult.uuid_to_index`: + +```python +for phase_result in result.accuracy_results: + phase_responses = { + uid: collector.responses[uid] + for uid in phase_result.uuid_to_index + if uid in collector.responses + } + score = scorer.score(phase_responses, phase_result.uuid_to_index, acc_dataset) +``` + +### Future: TUI Integration + +The planned TUI architecture moves the benchmark engine (HTTPClient + load generator) +to a child process, with the TUI as the foreground process reading periodic reports. + +``` +TUI Process (foreground): + - Renders live dashboard (throughput, latency, progress) + - Reads BasicKVStoreReader for real-time metrics from /dev/shm + - Receives SessionResult via IPC on completion + +Benchmark Process (child): + - Runs BenchmarkSession on its own event loop + - Writes metrics via EventPublisher -> MetricsAggregator -> KVStore + - Returns SessionResult to parent via IPC (pickle over pipe / ZMQ) +``` + +This architecture is enabled by the current design's clean separation: + +- **KVStore** is already cross-process readable (mmap on /dev/shm) +- **BenchmarkSession** has no UI dependencies — it takes callbacks +- **SessionResult** is a frozen dataclass, trivially serializable +- The `on_sample_complete` callback would not be used in TUI mode (no + cross-process callback). Instead, the TUI polls KVStoreReader for + `tracked_samples_completed` to update the progress display. + +The TUI process can also read the `Report` from the KVStore at any time for +live intermediate reports (current QPS, latency distribution so far), not +just the final report. + +> **Constraint:** The benchmark child process must be a **non-daemon** OS process +> (e.g., `subprocess.Popen` or `multiprocessing.Process(daemon=False)`). +> `HTTPEndpointClient` spawns worker processes via `WorkerManager`, and those +> workers are `daemon=True`. Python prohibits daemon processes from spawning +> children — if the benchmark process is itself a daemon, worker creation fails. + +> **TODO:** Signal forwarding: SIGINT from the terminal goes to the foreground +> process group (TUI). The TUI must forward a stop signal to the benchmark +> child process (e.g., via `process.terminate()` or a ZMQ control channel). +> Design the stop protocol during TUI implementation. + +--- + +## Multi-Perf Sweep Example + +Concurrency sweep against same endpoint: + +```python +phases = [ + PhaseConfig("sat_c16", sat_settings(16), ds, PhaseType.WARMUP), + PhaseConfig("perf_c16", perf_settings(16), ds, PhaseType.PERFORMANCE), + PhaseConfig("sat_c32", sat_settings(32), ds, PhaseType.WARMUP), + PhaseConfig("perf_c32", perf_settings(32), ds, PhaseType.PERFORMANCE), + PhaseConfig("sat_c64", sat_settings(64), ds, PhaseType.WARMUP), + PhaseConfig("perf_c64", perf_settings(64), ds, PhaseType.PERFORMANCE), + PhaseConfig("accuracy", acc_settings, acc_ds, PhaseType.ACCURACY), +] +result = await session.run(phases) + +for pr in result.perf_results: + print(f"{pr.name}: {pr.report.qps():.0f} QPS") +``` + +--- + +## Rejected Alternatives + +| Alternative | Rejected Because | +| --------------------------------------- | ------------------------------------------------------------------------------------------ | +| Unified strategy for all patterns | Benchmark data shows each pattern benefits from different async primitives | +| `asyncio.Semaphore` for all concurrency | Correct for CONCURRENCY mode, but overhead hurts MAX_THROUGHPUT | +| `run_in_executor` for all timing | GIL contention causes 6ms latency at low QPS | +| `asyncio.sleep` for all timing | 700μs precision at mid-range QPS, `run_in_executor` is 10x better | +| Direct busy-wait on event loop | Starves receiver — 26ms response latency vs 0.6ms | +| Callback-based concurrency | Semaphore is simpler and benchmarks slightly better with real ZMQ client | +| Shared `Scheduler` base class | Concurrency has no delay concept; forcing it conflates distinct semantics | +| Separate `BenchmarkOrchestrator` | Phase sequencing is simple enough to live in `BenchmarkSession.run()` | +| poll()-based receiver spin | Starves event loop during response bursts; pure recv() is fd-driven with zero CPU overhead | diff --git a/docs/metrics/report_design.md b/docs/metrics/report_design.md new file mode 100644 index 00000000..7f663c1a --- /dev/null +++ b/docs/metrics/report_design.md @@ -0,0 +1,106 @@ +# Report Design + +## Overview + +The report module provides benchmark result summarization, display, and +serialization. It reads from the KVStore (via `BasicKVStoreReader`) and +produces a `Report` with rollup statistics, percentiles, and histograms. + +## Architecture + +``` +BasicKVStoreReader.snapshot() + │ + ▼ + build_report(reader) + │ + ├── counters → n_issued, n_completed, n_failed, duration_ns + │ + └── for each series metric: + SeriesStats.values → compute_summary() → dict + │ + ▼ + Report + ├── .display(fn) → human-readable output + ├── .to_json(path) → JSON serialization + ├── .qps → computed from n_completed / duration + └── .tps → computed from osl total / duration +``` + +## Design Principles + +**No SQL, no UUID tracking, no deduplication.** + +The old `MetricsReporter` queried SQLite via duckdb and built `RollupQueryTable` +objects with UUID-indexed rows, repeat counts, and numpy sorted arrays. None of +this complexity is needed when the input is a `list[float]` from the KVStore. + +The entire rollup is a single function: `compute_summary(values) → dict`. +It calls numpy for percentiles and histograms. No classes, no state. + +**Reports are reproducible from the event log.** + +The KVStore is lossy aggregation — it stores per-metric series, not per-sample +provenance. The authoritative record of what happened during a run is the event +log written by the `EventLoggerService`. Every number in a `Report` can be +recomputed by replaying the event log through the same aggregator logic: if a +production report shows a TTFT spike, the event log is the ground truth a user +can mine to attribute the spike to specific samples or time windows. + +New metrics must preserve this property: the aggregator may only derive values +from event fields, never from out-of-band state. If a metric cannot be rebuilt +from the event log alone, it does not belong in the KVStore. + +## Components + +### `compute_summary(values, percentiles, n_histogram_buckets) → dict` + +Takes a `list[float]`, returns a dict with: + +- `total`, `min`, `max`, `avg`, `std_dev`, `median` +- `percentiles`: dict of `{str(p): float}` for each requested percentile +- `histogram`: `{"buckets": [(lo, hi), ...], "counts": [int, ...]}` + +Empty input returns zeros with empty histogram/percentiles. + +### `Report` (frozen dataclass) + +Fields: + +- `version`, `git_sha`, `test_started_at` +- `n_samples_issued`, `n_samples_completed`, `n_samples_failed` +- `duration_ns` +- `ttft`, `tpot`, `latency`, `output_sequence_lengths` — each a summary dict + +Properties: + +- `qps`: `n_samples_completed / (duration_ns / 1e9)`, or None +- `tps`: `osl_total / (duration_ns / 1e9)`, or None + +Methods: + +- `display(fn, summary_only, newline)` — human-readable output with histograms +- `to_json(save_to)` — JSON serialization with QPS/TPS included + +### `build_report(reader) → Report` + +Reads a `BasicKVStoreReader` snapshot and constructs a `Report`. Works +identically for live metrics (mid-test) and final reports (post-drain) — +the caller decides when to call it. + +Counter keys read: `n_samples_issued`, `n_samples_completed`, +`n_samples_failed`, `duration_ns`, `test_started_at`. + +Series keys summarized: `ttft_ns`, `tpot_ns`, `sample_latency_ns`, `osl`. + +## What Was Removed + +From the old `metrics/reporter.py`: + +- `MetricsReporter` — SQLite/duckdb query engine (replaced by KVStore) +- `RollupQueryTable` — UUID-indexed rollup table (replaced by `compute_summary`) +- `MetricRow` — per-row accessor (not needed) +- `TPOTReportingMode` — niche enum (can be re-added if needed) +- `SampleUUIDNotFoundError` — UUIDs not relevant in KVStore +- `output_sequence_from_data` — SQL event data parser (not needed) +- `dump_to_json` / `dump_all_to_csv` — event log export (handled by EventLoggerService) diff --git a/examples/04_GPTOSS120B_Example/gptoss_120b_example.yaml b/examples/04_GPTOSS120B_Example/gptoss_120b_example.yaml index fb5edd33..0dce3a99 100644 --- a/examples/04_GPTOSS120B_Example/gptoss_120b_example.yaml +++ b/examples/04_GPTOSS120B_Example/gptoss_120b_example.yaml @@ -28,7 +28,6 @@ settings: client: num_workers: 4 - record_worker_events: false endpoint_config: endpoints: diff --git a/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml b/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml index 761d5c3b..5a2d2050 100644 --- a/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml +++ b/examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml @@ -50,7 +50,6 @@ settings: client: num_workers: 8 - record_worker_events: false endpoint_config: endpoints: diff --git a/scripts/regenerate_templates.py b/scripts/regenerate_templates.py index 30d0bb0d..5e407768 100644 --- a/scripts/regenerate_templates.py +++ b/scripts/regenerate_templates.py @@ -93,7 +93,7 @@ } PLACEHOLDER_MODEL = "" -PLACEHOLDER_ENDPOINT = "" +PLACEHOLDER_ENDPOINT = "http://localhost:8000" # --------------------------------------------------------------------------- diff --git a/src/inference_endpoint/async_utils/event_publisher.py b/src/inference_endpoint/async_utils/event_publisher.py index 600edcc0..98e25eae 100644 --- a/src/inference_endpoint/async_utils/event_publisher.py +++ b/src/inference_endpoint/async_utils/event_publisher.py @@ -19,42 +19,43 @@ from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher -from inference_endpoint.utils import SingletonMixin -class EventPublisherService(SingletonMixin, ZmqEventRecordPublisher): - """Singleton publisher for publishing event records.""" +class EventPublisherService(ZmqEventRecordPublisher): + """Publisher for publishing event records over ZMQ PUB socket. + + Wraps ZmqEventRecordPublisher with LoopManager integration and + auto-generated socket names. + """ def __init__( self, managed_zmq_context: ManagedZMQContext, extra_eager: bool = False, isolated_event_loop: bool = False, + send_threshold: int = 1000, ): """Creates a new EventPublisherService. - By default, the publisher will be run on the main thread's event loop (i.e. the default loop). - Args: - managed_zmq_context (ManagedZMQContext): The managed ZMQ context to use for the publisher. - extra_eager (bool): If True, the publisher will be a blocking call and calls to .publish() - will block until the message has been successfully sent. In most cases, this should not - be turned on, but it is useful for testing, or specifically in the use case where - EventRecords are being used as a synchronization mechanism (i.e. sending a specific - EventRecord as a STOP signal to subscribers to ensure the ordering of cleanup.) - isolated_event_loop (bool): If True, the publisher will be run in a separate event loop. + managed_zmq_context: The managed ZMQ context to use. + extra_eager: If True, publish() blocks until the message is sent. + Useful for testing or when EventRecords are used as a + synchronization mechanism (e.g., ENDED as a stop signal). + isolated_event_loop: If True, runs on a separate event loop thread. + send_threshold: Minimum number of buffered records before an + automatic flush is triggered. See ZmqEventRecordPublisher. """ - if getattr(self, "_initialized", False): - return - self._initialized = True - - # Set up event loop settings if extra_eager: loop = None elif isolated_event_loop: loop = LoopManager().create_loop("ev_pub") else: loop = LoopManager().default_loop + self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}" super().__init__( - f"ev_pub_{uuid.uuid4().hex[:8]}", managed_zmq_context, loop=loop + self.socket_name, + managed_zmq_context, + loop=loop, + send_threshold=send_threshold, ) diff --git a/src/inference_endpoint/async_utils/services/event_logger/__main__.py b/src/inference_endpoint/async_utils/services/event_logger/__main__.py index ec8d0965..a7842b74 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/__main__.py +++ b/src/inference_endpoint/async_utils/services/event_logger/__main__.py @@ -30,10 +30,12 @@ from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber +from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal from inference_endpoint.core.record import ( EventRecord, SessionEventType, ) +from inference_endpoint.utils.logging import setup_logging from .file_writer import JSONLWriter from .writer import RecordWriter @@ -156,13 +158,26 @@ async def main() -> None: else " Install sqlalchemy for SQL support: pip install inference-endpoint[sql]" ), ) + parser.add_argument( + "--readiness-path", + type=str, + default=None, + help="ZMQ socket path to signal readiness (optional)", + ) + parser.add_argument( + "--readiness-id", + type=int, + default=0, + help="Identity to send in the readiness signal", + ) args = parser.parse_args() + setup_logging(level="INFO") writer_classes = tuple(_WRITER_REGISTRY[name] for name in args.writers) shutdown_event = asyncio.Event() loop = LoopManager().default_loop with ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx: - logger = EventLoggerService( + service = EventLoggerService( args.log_dir, args.socket_name, zmq_ctx, @@ -172,7 +187,11 @@ async def main() -> None: shutdown_event=shutdown_event, ) - loop.call_soon(logger.start) + service.start() + + if args.readiness_path: + await send_ready_signal(zmq_ctx, args.readiness_path, args.readiness_id) + await shutdown_event.wait() diff --git a/src/inference_endpoint/async_utils/services/launcher.py b/src/inference_endpoint/async_utils/services/launcher.py new file mode 100644 index 00000000..9e53c19c --- /dev/null +++ b/src/inference_endpoint/async_utils/services/launcher.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service subprocess launcher with ready-check synchronization. + +Launches service subprocesses (EventLoggerService, MetricsAggregatorService) +via ``python -m`` and waits for each to signal readiness over ZMQ before +returning. Uses the same ReadyCheckReceiver/send_ready_signal primitives as +the worker pool transport. +""" + +from __future__ import annotations + +import logging +import subprocess +import sys +import time +import uuid +from dataclasses import dataclass, field + +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.ready_check import ReadyCheckReceiver + +logger = logging.getLogger(__name__) + + +@dataclass +class ServiceConfig: + """Configuration for a service subprocess to launch.""" + + module: str + """Python module path (e.g., 'inference_endpoint.async_utils.services.event_logger').""" + + args: list[str] = field(default_factory=list) + """Additional CLI arguments for the service.""" + + +class ServiceLauncher: + """Launches service subprocesses and waits for ready signals. + + Usage:: + + launcher = ServiceLauncher(zmq_context) + await launcher.launch([ + ServiceConfig( + module="inference_endpoint.async_utils.services.event_logger", + args=["--log-dir", "/tmp/logs", "--socket-dir", socket_dir, + "--socket-name", "events"], + ), + ], timeout=30.0) + + # ... run benchmark ... + + launcher.wait_for_exit(timeout=60.0) + """ + + def __init__(self, zmq_context: ManagedZMQContext) -> None: + self._zmq_ctx = zmq_context + self._procs: list[subprocess.Popen] = [] + + @property + def procs(self) -> list[subprocess.Popen]: + return self._procs + + async def launch( + self, + services: list[ServiceConfig], + timeout: float | None = 30.0, + ) -> None: + """Spawn service subprocesses and wait for all to signal readiness. + + Each service receives ``--readiness-path`` and ``--readiness-id`` CLI + arguments. After initialization, the service sends a ready signal via + ``send_ready_signal()`` using the same socket_dir as the launcher. + + Launched processes are stored in ``self.procs`` for later use by + ``wait_for_exit()`` and ``kill_all()``. + + Args: + services: List of ServiceConfig describing each service to launch. + timeout: Maximum total seconds to wait for all services to become ready. + + Raises: + TimeoutError: If services don't signal readiness within timeout. + """ + if not services: + return + + readiness_path = f"svc_ready_{uuid.uuid4().hex[:8]}" + receiver = ReadyCheckReceiver( + readiness_path, self._zmq_ctx, count=len(services) + ) + + try: + for i, svc in enumerate(services): + cmd = [ + sys.executable, + "-m", + svc.module, + *svc.args, + "--readiness-path", + readiness_path, + "--readiness-id", + str(i), + ] + logger.info("Launching service: %s (id=%d)", svc.module, i) + proc = subprocess.Popen(cmd) + self._procs.append(proc) + + await receiver.wait(timeout=timeout) + logger.info("All %d services ready", len(services)) + + except Exception as e: + # Collect all crashed subprocesses for a complete error message + crashed = [ + (proc.pid, exit_code) + for proc in self._procs + if (exit_code := proc.poll()) is not None and exit_code != 0 + ] + + self.kill_all() + receiver.close() + + if crashed: + details = ", ".join( + f"pid={pid} exit={exit_code}" for pid, exit_code in crashed + ) + raise RuntimeError( + f"{len(crashed)} service(s) crashed during startup: {details}" + ) from e + + # If for some reason the reason for the exception is not a crashed subprocess, + # re-raise the exception. + raise + + def kill_all(self) -> None: + """Kill all managed subprocesses.""" + for proc in self._procs: + if proc.poll() is None: + proc.kill() + + def wait_for_exit(self, timeout: float | None = 60.0) -> None: + """Wait for all service subprocesses to exit. + + Services self-terminate on SessionEventType.ENDED. This method + blocks until all have exited or the total timeout is reached. + + If the timeout is reached, the process will be killed without + waiting for proper cleanup (it is assumed that the process is + hanging). + + Args: + timeout: Maximum total seconds to wait across all processes. + If None, waits indefinitely. + """ + deadline = None if timeout is None else time.monotonic() + timeout + for proc in self._procs: + remaining = ( + None if deadline is None else max(0, deadline - time.monotonic()) + ) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + logger.warning( + "Service pid=%d did not exit within timeout, killing", proc.pid + ) + proc.kill() + proc.wait(timeout=5) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py index 72a0149e..50a3163d 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py @@ -22,9 +22,11 @@ from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal +from inference_endpoint.utils.logging import setup_logging from .aggregator import MetricsAggregatorService -from .emitter import JsonlMetricEmitter +from .kv_store import BasicKVStore from .token_metrics import TokenizePool @@ -32,12 +34,6 @@ async def main() -> None: parser = argparse.ArgumentParser( description="Metrics aggregator service - subscribes to EventRecords and computes real-time metrics" ) - parser.add_argument( - "--metrics-dir", - type=Path, - required=True, - help="Directory for metrics output (JSONL file)", - ) parser.add_argument( "--socket-dir", type=str, @@ -50,6 +46,12 @@ async def main() -> None: required=True, help="Socket name within socket-dir", ) + parser.add_argument( + "--metrics-dir", + type=str, + required=True, + help="Directory for mmap-backed metric files (created by the parent process)", + ) parser.add_argument( "--tokenizer", type=str, @@ -68,11 +70,22 @@ async def main() -> None: default=False, help="Enable streaming metrics (TTFT, chunk_delta, TPOT). Off by default.", ) + parser.add_argument( + "--readiness-path", + type=str, + default=None, + help="ZMQ socket path to signal readiness (optional)", + ) + parser.add_argument( + "--readiness-id", + type=int, + default=0, + help="Identity to send in the readiness signal", + ) args = parser.parse_args() + setup_logging(level="INFO") - args.metrics_dir.mkdir(parents=True, exist_ok=True) - metrics_file = args.metrics_dir / "metrics" - + metrics_dir = Path(args.metrics_dir) shutdown_event = asyncio.Event() loop = LoopManager().default_loop @@ -89,19 +102,26 @@ async def main() -> None: pool_cm as pool, ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx, ): - emitter = JsonlMetricEmitter(metrics_file, flush_interval=100) - aggregator = MetricsAggregatorService( - args.socket_name, - zmq_ctx, - loop, - topics=None, - emitter=emitter, - tokenize_pool=pool, - streaming=args.streaming, - shutdown_event=shutdown_event, - ) - loop.call_soon(aggregator.start) - await shutdown_event.wait() + kv_store = BasicKVStore(metrics_dir) + try: + aggregator = MetricsAggregatorService( + args.socket_name, + zmq_ctx, + loop, + topics=None, + kv_store=kv_store, + tokenize_pool=pool, + streaming=args.streaming, + shutdown_event=shutdown_event, + ) + aggregator.start() + + if args.readiness_path: + await send_ready_signal(zmq_ctx, args.readiness_path, args.readiness_id) + + await shutdown_event.wait() + finally: + kv_store.close() if __name__ == "__main__": diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py index 6b9e4f2a..c4640bbc 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py @@ -18,37 +18,65 @@ from __future__ import annotations import asyncio +import logging +from enum import Enum from inference_endpoint.async_utils.transport.zmq.pubsub import ( ZmqEventRecordSubscriber, ) from inference_endpoint.core.record import ( + ErrorEventType, EventRecord, SampleEventType, SessionEventType, ) -from .emitter import MetricEmitter +from .kv_store import KVStore from .metrics_table import ( ChunkDeltaTrigger, IslTrigger, MetricsTable, OslTrigger, - RequestDurationTrigger, + SampleField, SampleLatencyTrigger, TpotTrigger, TtftTrigger, ) from .token_metrics import TokenizePool +logger = logging.getLogger(__name__) + + +class MetricCounterKey(str, Enum): + """Counter metric keys tracked by the aggregator. + + Total counters include all samples (warmup + tracked). + Tracked counters only include samples within performance tracking windows. + """ + + TOTAL_SAMPLES_ISSUED = "total_samples_issued" + TOTAL_SAMPLES_COMPLETED = "total_samples_completed" + TOTAL_SAMPLES_FAILED = "total_samples_failed" + TRACKED_SAMPLES_ISSUED = "tracked_samples_issued" + TRACKED_SAMPLES_COMPLETED = "tracked_samples_completed" + TRACKED_DURATION_NS = "tracked_duration_ns" + # Total wall-clock duration since session start. Updated on every event as + # max(current, event_timestamp - session_start) to be defensive against + # non-monotonic timestamps. + # + # An alternative design was considered: store session_start_ns once and + # compute duration as (now - start) on read. This is infeasible because + # time.monotonic_ns() has inconsistent epoch per process — a reader in + # another process would get a meaningless value. + TOTAL_DURATION_NS = "total_duration_ns" + + _TRACKED_SAMPLE_EVENTS = frozenset( { SampleEventType.ISSUED, SampleEventType.COMPLETE, SampleEventType.RECV_FIRST, SampleEventType.RECV_NON_FIRST, - SampleEventType.CLIENT_SEND, - SampleEventType.CLIENT_RESP_DONE, } ) @@ -57,56 +85,77 @@ class MetricsAggregatorService(ZmqEventRecordSubscriber): """Subscribes to EventRecords and computes per-sample metrics in real time. The aggregator is a thin event router. All state management, trigger - dispatch, and row lifecycle are handled by MetricsTable. + dispatch, and row lifecycle are handled by MetricsTable. The KVStore + is shared between the table (for series metrics via triggers) and the + aggregator (for counter metrics like n_issued, n_completed, etc.). """ def __init__( self, *args, - emitter: MetricEmitter, + kv_store: KVStore, tokenize_pool: TokenizePool | None = None, streaming: bool = False, shutdown_event: asyncio.Event | None = None, **kwargs, ): super().__init__(*args, **kwargs) - self._emitter = emitter + self._kv_store = kv_store + self._tokenize_pool = tokenize_pool self._shutdown_event = shutdown_event self._shutdown_received = False - self._table = MetricsTable() - self._register_triggers( - self._table, emitter, tokenize_pool, self.loop, streaming - ) + for key in MetricCounterKey: + kv_store.create_key(key.value, "counter") + + self._total_issued = 0 + self._total_completed = 0 + self._total_failed = 0 + self._tracked_issued = 0 + self._tracked_completed = 0 + self._session_start_ns: int | None = None + self._total_duration_ns: int = 0 + self._total_processed = 0 + self._last_log_count = 0 - @staticmethod - def _register_triggers( - table: MetricsTable, - emitter: MetricEmitter, - tokenize_pool: TokenizePool | None, - loop: asyncio.AbstractEventLoop | None, - streaming: bool, - ) -> None: + self._table = MetricsTable(kv_store) + self._register_triggers(streaming) + + def _register_triggers(self, streaming: bool) -> None: """Register metric triggers on the table. Streaming-only triggers (TTFT, chunk_delta, TPOT) are only registered when ``streaming=True``. """ + table = self._table + store = self._kv_store + pool = self._tokenize_pool + loop = self.loop + # Always registered - table.add_trigger("issued_ns", IslTrigger(emitter, tokenize_pool, loop)) - table.add_trigger("client_resp_done_ns", RequestDurationTrigger(emitter)) - table.add_trigger("complete_ns", SampleLatencyTrigger(emitter)) - table.add_trigger("complete_ns", OslTrigger(emitter, tokenize_pool, loop)) + table.add_trigger(SampleField.ISSUED_NS, IslTrigger(store, pool, loop)) + table.add_trigger(SampleField.COMPLETE_NS, SampleLatencyTrigger(store)) + table.add_trigger(SampleField.COMPLETE_NS, OslTrigger(store, pool, loop)) # Streaming-only if streaming: - table.add_trigger("recv_first_ns", TtftTrigger(emitter)) - table.add_trigger("last_recv_ns", ChunkDeltaTrigger(emitter)) - table.add_trigger("complete_ns", TpotTrigger(emitter, tokenize_pool, loop)) + table.add_trigger(SampleField.RECV_FIRST_NS, TtftTrigger(store)) + table.add_trigger(SampleField.LAST_RECV_NS, ChunkDeltaTrigger(store)) + table.add_trigger(SampleField.COMPLETE_NS, TpotTrigger(store, pool, loop)) async def process(self, records: list[EventRecord]) -> None: saw_shutdown = False table = self._table + store = self._kv_store + + self._total_processed += len(records) + if self._total_processed - self._last_log_count >= 10000: + logger.debug( + "Aggregator processed %d records (%d in this batch)", + self._total_processed, + len(records), + ) + self._last_log_count = self._total_processed for record in records: if self._shutdown_received: @@ -114,13 +163,41 @@ async def process(self, records: list[EventRecord]) -> None: ev = record.event_type + # Update total_duration_ns on every event + if self._session_start_ns is not None: + elapsed = record.timestamp_ns - self._session_start_ns + if elapsed > self._total_duration_ns: + self._total_duration_ns = elapsed + store.update( + MetricCounterKey.TOTAL_DURATION_NS.value, + self._total_duration_ns, + ) + # --- Session events --- if isinstance(ev, SessionEventType): if ev == SessionEventType.ENDED: + logger.info("ENDED event received, shutting down aggregator") self._shutdown_received = True saw_shutdown = True else: + if ev == SessionEventType.STARTED: + self._session_start_ns = record.timestamp_ns table.handle_session_event(record) + if ev == SessionEventType.STOP_PERFORMANCE_TRACKING: + store.update( + MetricCounterKey.TRACKED_DURATION_NS.value, + table.total_tracked_duration_ns, + ) + logger.debug("Session event: %s", ev) + continue + + # --- Error events --- + if isinstance(ev, ErrorEventType): + self._total_failed += 1 + store.update( + MetricCounterKey.TOTAL_SAMPLES_FAILED.value, self._total_failed + ) + logger.debug("Error event: %s", record) continue # --- Sample events --- @@ -135,24 +212,52 @@ async def process(self, records: list[EventRecord]) -> None: ts = record.timestamp_ns if ev == SampleEventType.ISSUED: - table.set_field(uuid, "issued_ns", ts, record) + table.set_field(uuid, SampleField.ISSUED_NS, ts, record) + self._total_issued += 1 + store.update( + MetricCounterKey.TOTAL_SAMPLES_ISSUED.value, self._total_issued + ) + if table.get_row(uuid) is not None: + self._tracked_issued += 1 + store.update( + MetricCounterKey.TRACKED_SAMPLES_ISSUED.value, + self._tracked_issued, + ) elif ev == SampleEventType.RECV_FIRST: - table.set_field(uuid, "recv_first_ns", ts, record) - table.set_field(uuid, "last_recv_ns", ts, record) + table.set_field(uuid, SampleField.RECV_FIRST_NS, ts, record) + table.set_field(uuid, SampleField.LAST_RECV_NS, ts, record) elif ev == SampleEventType.RECV_NON_FIRST: - table.set_field(uuid, "last_recv_ns", ts, record) - elif ev == SampleEventType.CLIENT_SEND: - table.set_field(uuid, "client_send_ns", ts, record) - elif ev == SampleEventType.CLIENT_RESP_DONE: - table.set_field(uuid, "client_resp_done_ns", ts, record) + table.set_field(uuid, SampleField.LAST_RECV_NS, ts, record) elif ev == SampleEventType.COMPLETE: - table.set_field(uuid, "complete_ns", ts, record) + # Check if tracked before set_field (which removes the row) + is_tracked = table.get_row(uuid) is not None + table.set_field(uuid, SampleField.COMPLETE_NS, ts, record) + self._total_completed += 1 + store.update( + MetricCounterKey.TOTAL_SAMPLES_COMPLETED.value, + self._total_completed, + ) + if is_tracked: + self._tracked_completed += 1 + store.update( + MetricCounterKey.TRACKED_SAMPLES_COMPLETED.value, + self._tracked_completed, + ) if saw_shutdown: + logger.info("Draining %d async tasks...", len(table._in_flight_tasks)) await table.drain_tasks() + logger.info("Async tasks drained") + store.update( + MetricCounterKey.TRACKED_DURATION_NS.value, + table.total_tracked_duration_ns, + ) self._finalize() def _finalize(self) -> None: + logger.info( + "Aggregator finalized: %d total records processed", self._total_processed + ) self.close() if self._shutdown_event is not None: self._shutdown_event.set() @@ -160,5 +265,5 @@ def _finalize(self) -> None: self.loop.stop() def close(self) -> None: - self._emitter.close() + self._kv_store.close() super().close() diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/emitter.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/emitter.py deleted file mode 100644 index 5f06ccaf..00000000 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/emitter.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Metric emitters for the metrics aggregator service.""" - -from __future__ import annotations - -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import BinaryIO - -import msgspec - - -class MetricEmitter(ABC): - """Base class for metric emitters.""" - - @abstractmethod - def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None: - """Emit a metric value for a sample.""" - raise NotImplementedError - - @abstractmethod - def flush(self) -> None: - """Flush any buffered metrics to the underlying store.""" - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Flush and release resources.""" - raise NotImplementedError - - -class _MetricRecord(msgspec.Struct, gc=False): # type: ignore[call-arg] - sample_uuid: str - metric_name: str - value: int | float - timestamp_ns: int - - -class JsonlMetricEmitter(MetricEmitter): - """Writes metrics as JSONL lines to a file. - - Each line is a JSON object: {"sample_uuid": ..., "metric_name": ..., "value": ..., "timestamp_ns": ...} - """ - - def __init__(self, file_path: Path, flush_interval: int = 100) -> None: - self._file_path = file_path.with_suffix(".jsonl") - self._file: BinaryIO | None = self._file_path.open("wb") - self._encoder = msgspec.json.Encoder() - self._flush_interval = flush_interval - self._n_since_flush = 0 - - def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None: - if self._file is None: - return - record = _MetricRecord( - sample_uuid=sample_uuid, - metric_name=metric_name, - value=value, - timestamp_ns=time.monotonic_ns(), - ) - self._file.write(self._encoder.encode(record) + b"\n") - self._n_since_flush += 1 - if self._n_since_flush >= self._flush_interval: - self.flush() - - def flush(self) -> None: - if self._file is not None: - self._file.flush() - self._n_since_flush = 0 - - def close(self) -> None: - if self._file is not None: - try: - self.flush() - self._file.close() - except (OSError, FileNotFoundError): - # File may already be closed or I/O error on close (e.g. disk full). - pass - finally: - self._file = None diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py new file mode 100644 index 00000000..fea99811 --- /dev/null +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Filesystem type detection for mmap ordering decisions. + +On tmpfs (/dev/shm), msync() is a no-op because there is no backing store. +On a real on-disk filesystem, msync() flushes dirty pages to the shared page +cache, which provides write ordering for cross-process mmap readers. + +On ARM (weak memory model), we need msync() to act as an ordering mechanism +between the value write and the count update in _SeriesItem.append(). This +only works on a real filesystem — not tmpfs. Detecting the filesystem type +lets us: + - Skip the useless msync() syscall on tmpfs (any architecture) + - Warn if ARM code is running on tmpfs (msync won't provide ordering) +""" + +from __future__ import annotations + +import ctypes +import ctypes.util +import logging +import platform +from pathlib import Path + +logger = logging.getLogger(__name__) + +_TMPFS_MAGIC = 0x01021994 +"""Special tmpfs filesystem header value.""" + + +def _is_tmpfs_via_statfs(path: str) -> bool | None: + """Check filesystem type via libc statfs(2). Returns None if unavailable.""" + try: + lib_name = ctypes.util.find_library("c") + if lib_name is None: + return None + libc = ctypes.CDLL(lib_name, use_errno=True) + + # Allocate a large buffer to account for differently sized statfs + # structs across architectures. f_type is always the first field + # (__SWORD_TYPE / long) at offset 0 on all Linux archs. + buf = ctypes.create_string_buffer(256) + if libc.statfs(path.encode(), buf) != 0: + return None + # f_type is a native-endian long at offset 0 + f_type = ctypes.c_long.from_buffer(buf, 0).value + return f_type == _TMPFS_MAGIC + except (OSError, AttributeError, ValueError): + return None + + +def _is_tmpfs_via_proc_mounts(path: str) -> bool | None: + """Check filesystem type via /proc/mounts. Returns None if unavailable.""" + try: + resolved = str(Path(path).resolve()) + best_match = "" + best_fstype = "" + with open("/proc/mounts") as f: + for line in f: + parts = line.split() + if len(parts) < 3: + continue + mount_point, fstype = parts[1], parts[2] + if resolved.startswith(mount_point) and len(mount_point) > len( + best_match + ): + best_match = mount_point + best_fstype = fstype + if not best_match: + return None + return best_fstype == "tmpfs" + except OSError: + return None + + +def is_tmpfs(path: str | Path) -> bool: + """Check if a path resides on a tmpfs filesystem. + + Tries statfs(2) via ctypes first, falls back to /proc/mounts. + Returns False if detection fails (safe default — will call msync). + """ + path_str = str(path) + + result = _is_tmpfs_via_statfs(path_str) + if result is not None: + return result + + result = _is_tmpfs_via_proc_mounts(path_str) + if result is not None: + return result + + logger.warning( + "Could not determine filesystem type for %s " + "(statfs and /proc/mounts both unavailable). " + "Assuming non-tmpfs (msync will be called on every series append).", + path_str, + ) + return False + + +def needs_msync(path: str | Path) -> bool: + """Determine if msync() is needed for mmap write ordering at this path. + + Returns True if msync should be called between value write and count + update in series append. This is needed on ARM when the backing store + is a real filesystem (not tmpfs). + + On x86-64 (TSO), store ordering is guaranteed by hardware — msync is + never needed regardless of filesystem type. + + On ARM with tmpfs, msync is a no-op and won't help — log a warning + since the caller should use an on-disk directory for correct ordering. + """ + if platform.machine() == "x86_64": + return False + + on_tmpfs = is_tmpfs(path) + if on_tmpfs: + logger.warning( + "ARM platform with tmpfs-backed metrics at %s. " + "Python does not support memory fences. " + "Use an on-disk metrics directory for correct cross-process reads.", + path, + ) + return False + + return True diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py new file mode 100644 index 00000000..9f846234 --- /dev/null +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py @@ -0,0 +1,511 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Key-value store for metrics with per-key /dev/shm backing files. + +Each key in the store maps to a KVItem backed by an individual mmap'd file. +Two item types are supported: + +- **counter**: A single float64 value (e.g., error_count, n_in_flight). + File layout: [value: 8B float64] + +- **series**: An append-only list of float64 values with a length header + (e.g., ttft_ns, sample_latency_ns). Rollup stats are computed lazily on read. + File layout: [count: 8B uint64] [v0: 8B float64] [v1: 8B float64] ... + +Write protocol (single writer): + Counter: overwrite the 8-byte value. + Series: write float64 at HEADER + count*8, then update count. + On x86-64, aligned 8-byte stores are atomic (TSO), so readers always + see a consistent state. + +Read protocol (any process): + Counter: read 8 bytes. + Series: read count, then values[:count]. Rollup computed lazily with + incremental progress tracking (_last_rollup_idx). +""" + +from __future__ import annotations + +import logging +import math +import mmap +import os +import shutil +import struct +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal + +from .fs_check import needs_msync + +# --------------------------------------------------------------------------- +# Series rollup stats (computed on read) +# --------------------------------------------------------------------------- + +logger = logging.getLogger(__name__) + +_HEADER_BYTES = 8 # uint64 count for series +_VALUE_BYTES = 8 # 8 bytes per value (uint64 or float64) +_DEFAULT_CAPACITY = 128 * 1024 # pre-allocate for 128k values (~1 MB) +_DEFAULT_FILE_MODE = 0o600 # rw------- + +# Struct format: endian prefix + per-dtype value character +_ENDIAN = "<" +_STRUCT_CHAR: dict[type, str] = { + int: "Q", # unsigned 64-bit integer + float: "d", # 64-bit IEEE 754 float +} + + +class SeriesStats: + """Lazily-computed statistics over a series of values. + + Rollup stats (count, total, min, max, sum_sq) are computed on read, + not on write. ``_last_rollup_idx`` caches progress so subsequent + reads only process newly appended values. + + When ``dtype=int`` (default), accumulators use Python int for arbitrary + precision with uint64 values. When ``dtype=float``, accumulators + use float (for float64 series). + """ + + __slots__ = ( + "count", + "total", + "min_val", + "max_val", + "sum_sq", + "values", + "_last_rollup_idx", + ) + + def __init__(self, values: list | None = None, dtype: type = int) -> None: + self.values: list = values if values is not None else [] + self.count: int = 0 + zero = dtype() + self.total: int | float = zero + self.min_val: int | float = math.inf + self.max_val: int | float = -math.inf + self.sum_sq: int | float = zero + self._last_rollup_idx: int = 0 + if self.values: + self._update_rollup() + + def _update_rollup(self) -> None: + """Incrementally update rollup stats from _last_rollup_idx onward.""" + for v in self.values[self._last_rollup_idx :]: + self.total += v + self.sum_sq += v * v + if v < self.min_val: + self.min_val = v + if v > self.max_val: + self.max_val = v + self.count = len(self.values) + self._last_rollup_idx = self.count + + +# --------------------------------------------------------------------------- +# KVStore ABC +# --------------------------------------------------------------------------- + + +class KVStore(ABC): + """Abstract key-value store for metrics. + + Keys are created with a type (counter or series). Values are updated + via update() and read via get() or snapshot(). Implementations may + back keys with /dev/shm files, Prometheus, or in-memory dicts. + """ + + @abstractmethod + def create_key( + self, + key: str, + key_type: Literal["series", "counter"], + dtype: type = int, + ) -> None: + """Register a new key in the store. + + Args: + key: Key name. + key_type: "counter" (single uint64) or "series" (append-only). + dtype: Value type for series keys (int or float). + Ignored for counters (always int/uint64). + """ + raise NotImplementedError + + @abstractmethod + def update(self, key: str, value: int | float) -> None: + """Update a key. For counters, sets the value. For series, appends.""" + raise NotImplementedError + + @abstractmethod + def get(self, key: str) -> int | SeriesStats: + """Read the current value of a key.""" + raise NotImplementedError + + @abstractmethod + def snapshot(self) -> dict[str, int | SeriesStats]: + """Return a dict of all keys and their current values.""" + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Release resources.""" + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# KVItem implementations (per-key mmap files) +# --------------------------------------------------------------------------- + + +class _CounterItem: + """Single uint64 value backed by an 8-byte mmap file.""" + + _FMT = f"{_ENDIAN}{_STRUCT_CHAR[int]}" + __slots__ = ("_mm", "_path", "_closed") + + def __init__(self, path: Path) -> None: + self._path = path + self._closed = False + fd = os.open(str(path), os.O_CREAT | os.O_RDWR, _DEFAULT_FILE_MODE) + try: + os.ftruncate(fd, _VALUE_BYTES) + self._mm = mmap.mmap(fd, _VALUE_BYTES) + finally: + os.close(fd) + struct.pack_into(_CounterItem._FMT, self._mm, 0, 0) + + def set(self, value: int) -> None: + if not self._closed: + struct.pack_into(_CounterItem._FMT, self._mm, 0, value) + + def get(self) -> int: + return struct.unpack_from(_CounterItem._FMT, self._mm, 0)[0] + + def close(self) -> None: + if not self._closed: + self._closed = True + self._mm.close() + + +class _CounterReader: + """Reader for a counter item.""" + + _FMT = _CounterItem._FMT + __slots__ = ("_fd", "_mm", "_path") + + def __init__(self, path: Path) -> None: + self._path = path + self._fd: int | None = None + self._mm: mmap.mmap | None = None + if path.exists(): + self._open() + + def _open(self) -> None: + fd = os.open(str(self._path), os.O_RDONLY) + try: + self._mm = mmap.mmap(fd, _VALUE_BYTES, prot=mmap.PROT_READ) + self._fd = fd + except Exception: + os.close(fd) + raise + + def get(self) -> int: + if self._mm is None: + if self._path.exists(): + self._open() + if self._mm is None: + return 0 + return struct.unpack_from(_CounterReader._FMT, self._mm, 0)[0] + + def close(self) -> None: + if self._mm is not None: + self._mm.close() + self._mm = None + if self._fd is not None: + os.close(self._fd) + self._fd = None + + +class _SeriesItem: + """Append-only series backed by an mmap file. + + Default dtype is int (uint64 storage, suitable for nanosecond timestamps). + Pass dtype=float for floating-point series. + """ + + __slots__ = ( + "_mm", + "_capacity", + "_count", + "_path", + "_closed", + "_dtype", + "_char", + "_fmt", + "_needs_msync", + ) + + def __init__( + self, + path: Path, + capacity: int = _DEFAULT_CAPACITY, + dtype: type = int, + ) -> None: + self._path = path + self._capacity = capacity + self._count = 0 + self._closed = False + self._dtype = dtype + self._char = _STRUCT_CHAR[dtype] + self._fmt = f"{_ENDIAN}{self._char}" + self._needs_msync = needs_msync(path.parent) + total = _HEADER_BYTES + capacity * _VALUE_BYTES + fd = os.open(str(path), os.O_CREAT | os.O_RDWR, _DEFAULT_FILE_MODE) + try: + os.ftruncate(fd, total) + self._mm = mmap.mmap(fd, total) + finally: + os.close(fd) + struct.pack_into(" None: + if self._closed: + logger.warning("append() called on closed series: %s", self._path) + return + if not isinstance(value, self._dtype): + raise TypeError( + f"Expected {self._dtype.__name__}, got {type(value).__name__}" + ) + if self._count >= self._capacity: + self._grow() + offset = _HEADER_BYTES + self._count * _VALUE_BYTES + struct.pack_into(self._fmt, self._mm, offset, value) + # Cross-process ordering note: msync between value write and count + # update is only needed for concurrent readers. In the current + # architecture, the reader (Report builder) runs after the writer + # process exits, so process exit flushes all dirty pages and + # ordering is guaranteed by the kernel. msync is skipped entirely. + # If concurrent reading is ever needed, re-enable via needs_msync(): + # if self._needs_msync: + # self._mm.flush() + # This has shown to be a considerable bottleneck on ARM systems - this will require a more + # sophisticated redesign for concurrent read/write and live metrics. + self._count += 1 + struct.pack_into(" SeriesStats: + """Read all values from the mmap and return as SeriesStats.""" + if self._count == 0: + return SeriesStats(dtype=self._dtype) + raw = self._mm[_HEADER_BYTES : _HEADER_BYTES + self._count * _VALUE_BYTES] + values = list(struct.unpack(f"{_ENDIAN}{self._count}{self._char}", raw)) + return SeriesStats(values, dtype=self._dtype) + + def close(self) -> None: + if not self._closed: + self._closed = True + self._mm.close() + + def _grow(self) -> None: + # Concurrency safety: readers in other processes hold their own mmap of + # this file. ftruncate() extends the file and zero-fills the new region; + # the reader's existing mmap remains valid (the kernel keeps the mapping + # alive independently). The reader detects the size change via fstat() + # and remaps. Between ftruncate and the next append(), the new region + # contains zeros, but readers are safe because they only read up to the + # count header value, which hasn't been updated yet. + old_mm = self._mm + new_capacity = self._capacity * 2 + total = _HEADER_BYTES + new_capacity * _VALUE_BYTES + fd = os.open(str(self._path), os.O_RDWR) + try: + os.ftruncate(fd, total) + self._mm = mmap.mmap(fd, total) + self._capacity = new_capacity + except Exception: + self._mm = old_mm + raise + finally: + os.close(fd) + old_mm.close() + + +class _SeriesReader: + """Reader for a series item with lazy rollup.""" + + __slots__ = ("_fd", "_mm", "_path", "_stats", "_char") + + def __init__(self, path: Path, dtype: type = int) -> None: + self._path = path + self._char = _STRUCT_CHAR[dtype] + self._stats = SeriesStats(dtype=dtype) + self._fd: int | None = None + self._mm: mmap.mmap | None = None + if path.exists(): + self._open() + + def _open(self) -> None: + fd = os.open(str(self._path), os.O_RDONLY) + try: + size = os.fstat(fd).st_size + if size > 0: + self._mm = mmap.mmap(fd, 0, prot=mmap.PROT_READ) + self._fd = fd + else: + os.close(fd) + except Exception: + os.close(fd) + raise + + def get(self) -> SeriesStats: + if self._mm is None: + if self._path.exists(): + self._open() + if self._mm is None: + return self._stats + + # Re-map if file grew + file_size = os.fstat(self._fd).st_size # type: ignore[arg-type] + if file_size > self._mm.size(): + self._mm.close() + self._mm = mmap.mmap(self._fd, 0, prot=mmap.PROT_READ) # type: ignore[arg-type] + + count = struct.unpack_from(" old_count: + start_offset = _HEADER_BYTES + old_count * _VALUE_BYTES + n_new = count - old_count + raw = self._mm[start_offset : start_offset + n_new * _VALUE_BYTES] + new_vals = list(struct.unpack(f"{_ENDIAN}{n_new}{self._char}", raw)) + self._stats.values.extend(new_vals) + self._stats._update_rollup() + + return self._stats + + def close(self) -> None: + if self._mm is not None: + self._mm.close() + self._mm = None + if self._fd is not None: + os.close(self._fd) + self._fd = None + + +# --------------------------------------------------------------------------- +# BasicKVStore (mmap-backed) +# --------------------------------------------------------------------------- + + +class BasicKVStore(KVStore): + """KVStore backed by per-key mmap files on /dev/shm (or any directory). + + Each key gets its own file: counters are 8 bytes, series are append-only + with a count header. Suitable for single-writer, multi-reader access. + """ + + def __init__(self, store_dir: Path) -> None: + self._dir = store_dir + self._dir.mkdir(parents=True, exist_ok=True) + self._items: dict[str, _CounterItem | _SeriesItem] = {} + + def create_key( + self, + key: str, + key_type: Literal["series", "counter"], + dtype: type = int, + ) -> None: + if key in self._items: + return + path = self._dir / f"{key}.kv" + if key_type == "counter": + self._items[key] = _CounterItem(path) + elif key_type == "series": + self._items[key] = _SeriesItem(path, dtype=dtype) + else: + raise ValueError(f"Unknown key type: {key_type}") + + def update(self, key: str, value: int | float) -> None: + item = self._items.get(key) + if item is None: + raise KeyError(f"Key not created: {key}") + if isinstance(item, _CounterItem): + item.set(int(value)) + else: + item.append(value) + + def get(self, key: str) -> int | SeriesStats: + item = self._items.get(key) + if item is None: + raise KeyError(f"Key not created: {key}") + return item.get() + + def snapshot(self) -> dict[str, int | SeriesStats]: + return {key: item.get() for key, item in self._items.items()} + + def close(self) -> None: + for item in self._items.values(): + item.close() + + def unlink(self) -> None: + """Close all items and remove the store directory.""" + self.close() + shutil.rmtree(self._dir, ignore_errors=True) + + +class BasicKVStoreReader: + """Read-only view of a BasicKVStore from another process. + + Lazily opens files and reads values. Each call to get() or snapshot() + picks up new values appended by the writer. + """ + + def __init__(self, store_dir: Path) -> None: + self._dir = store_dir + self._readers: dict[str, _CounterReader | _SeriesReader] = {} + + def register_key( + self, + key: str, + key_type: Literal["series", "counter"], + dtype: type = int, + ) -> None: + """Register a key to read. Call before get()/snapshot().""" + if key in self._readers: + return + path = self._dir / f"{key}.kv" + if key_type == "counter": + self._readers[key] = _CounterReader(path) + elif key_type == "series": + self._readers[key] = _SeriesReader(path, dtype=dtype) + + def get(self, key: str) -> int | SeriesStats: + reader = self._readers.get(key) + if reader is None: + raise KeyError(f"Key not registered: {key}") + return reader.get() + + def snapshot(self) -> dict[str, int | SeriesStats]: + return {key: reader.get() for key, reader in self._readers.items()} + + def close(self) -> None: + for reader in self._readers.values(): + reader.close() diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py index 19fa82d6..a66c1e8d 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py @@ -21,6 +21,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Any import msgspec @@ -28,8 +29,8 @@ from inference_endpoint.core.types import PromptData, TextModelOutput if TYPE_CHECKING: - from inference_endpoint.async_utils.services.metrics_aggregator.emitter import ( - MetricEmitter, + from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + KVStore, ) from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( TokenizePool, @@ -39,6 +40,31 @@ logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# SampleField enum +# --------------------------------------------------------------------------- + + +class SampleField(str, Enum): + """SampleRow field names that triggers can be registered on.""" + + ISSUED_NS = "issued_ns" + RECV_FIRST_NS = "recv_first_ns" + LAST_RECV_NS = "last_recv_ns" + COMPLETE_NS = "complete_ns" + + +class MetricSeriesKey(str, Enum): + """Series metric keys written by triggers to the KV store.""" + + ISL = "isl" + OSL = "osl" + SAMPLE_LATENCY_NS = "sample_latency_ns" + TTFT_NS = "ttft_ns" + CHUNK_DELTA_NS = "chunk_delta_ns" + TPOT_NS = "tpot_ns" + + # --------------------------------------------------------------------------- # SampleRow # --------------------------------------------------------------------------- @@ -58,8 +84,6 @@ class SampleRow(msgspec.Struct, gc=False): # type: ignore[call-arg] issued_ns: int | None = None recv_first_ns: int | None = None last_recv_ns: int | None = None - client_send_ns: int | None = None - client_resp_done_ns: int | None = None complete_ns: int | None = None @@ -86,20 +110,33 @@ def duration_ns(self) -> int: # --------------------------------------------------------------------------- -# EmitTrigger +# EmitTrigger base classes # --------------------------------------------------------------------------- class EmitTrigger(ABC): """A metric computation that fires when a SampleRow field is set. - Runtime deps (emitter, pool, loop) are bound at construction. - ``fire()`` receives only event-specific context. + Each trigger has a ``metric_name`` and a ``kv_store`` reference. + When ``fire()`` computes a value, it writes directly to + ``self.kv_store.update(self.metric_name, value)``. """ - def __init__(self, metric_name: str, requires: tuple[str, ...] = ()): - self.metric_name = metric_name + def __init__( + self, + metric_name: str, + kv_store: KVStore, + requires: tuple[str, ...] = (), + dtype: type = int, + ): + # Resolve enum to its value string so KVStore filenames match + # what the reader expects (e.g. "ttft_ns" not "MetricSeriesKey.TTFT_NS"). + self.metric_name = ( + metric_name.value if isinstance(metric_name, Enum) else metric_name + ) + self.kv_store = kv_store self.requires = requires + self.dtype = dtype @abstractmethod def fire( @@ -112,81 +149,123 @@ def fire( raise NotImplementedError() -# --------------------------------------------------------------------------- -# Timing triggers (sync) -# --------------------------------------------------------------------------- - +class TimeDeltaTrigger(EmitTrigger): + """Sync trigger: emits ev_rec.timestamp_ns - pre_change[delta_start_fieldname]. -class TtftTrigger(EmitTrigger): - """TTFT = recv_first_ns (new, from ev_rec) - issued_ns.""" + The emitted metric is a time delta: the firing event marks the end of the + delta, and ``delta_start_fieldname`` names the SampleField whose timestamp + marks the start. Skips silently if the start field is None (the delta has + not yet opened for this sample). + """ - def __init__(self, emitter: MetricEmitter): - super().__init__("ttft_ns", requires=("issued_ns",)) - self._emitter = emitter + def __init__(self, metric_name: str, kv_store: KVStore, delta_start_fieldname: str): + super().__init__(metric_name, kv_store, requires=(delta_start_fieldname,)) + self._delta_start_fieldname = delta_start_fieldname def fire(self, ev_rec, row, pre_change): - issued_ns = pre_change.get("issued_ns") - if issued_ns is not None: - self._emitter.emit( - row.sample_uuid, "ttft_ns", ev_rec.timestamp_ns - issued_ns - ) + baseline = pre_change.get(self._delta_start_fieldname) + if baseline is not None: + self.kv_store.update(self.metric_name, ev_rec.timestamp_ns - baseline) return None -class ChunkDeltaTrigger(EmitTrigger): - """chunk_delta_ns = new timestamp - previous last_recv_ns. +class AsyncTokenTrigger(EmitTrigger): + """Base for triggers that need async tokenization. - Skips when pre-change last_recv_ns is None (first recv via RECV_FIRST). + Subclasses implement ``_extract_text()`` to pull the text to tokenize + from the event record. If text is returned, an async task is created + to tokenize and emit. Subclasses can override ``_compute_value()`` to + transform the token count before storing. """ - def __init__(self, emitter: MetricEmitter): - super().__init__("chunk_delta_ns", requires=("last_recv_ns",)) - self._emitter = emitter + def __init__( + self, + metric_name: str, + kv_store: KVStore, + tokenize_pool: TokenizePool | None, + loop: asyncio.AbstractEventLoop | None, + requires: tuple[str, ...] = (), + dtype: type = int, + ): + super().__init__(metric_name, kv_store, requires=requires, dtype=dtype) + self._pool = tokenize_pool + self._loop = loop + + @abstractmethod + def _extract_text( + self, ev_rec: EventRecord, row: SampleRow, pre_change: dict[str, Any] + ) -> str | None: + """Return the text to tokenize, or None to skip.""" + raise NotImplementedError() + + def _compute_value( + self, token_count: int, ev_rec: EventRecord, pre_change: dict[str, Any] + ) -> int | float | None: + """Transform token count into the metric value. Default: count as-is.""" + return token_count def fire(self, ev_rec, row, pre_change): - prev = pre_change.get("last_recv_ns") - if prev is None: + if self._pool is None or self._loop is None: return None - self._emitter.emit( - row.sample_uuid, "chunk_delta_ns", ev_rec.timestamp_ns - prev + text = self._extract_text(ev_rec, row, pre_change) + if not text: + return None + + pool, loop = self._pool, self._loop + store, name = self.kv_store, self.metric_name + uuid = row.sample_uuid + + async def _tokenize_and_emit() -> None: + try: + count = await pool.token_count_async(text, loop) + value = self._compute_value(count, ev_rec, pre_change) + if value is not None: + store.update(name, value) + except Exception: + logger.exception("%s tokenization failed for %s", name, uuid) + + return loop.create_task(_tokenize_and_emit()) + + +# --------------------------------------------------------------------------- +# Timing triggers (sync) +# --------------------------------------------------------------------------- + + +class TtftTrigger(TimeDeltaTrigger): + """TTFT = recv_first_ns (new) - issued_ns.""" + + def __init__(self, kv_store: KVStore): + super().__init__( + MetricSeriesKey.TTFT_NS, + kv_store, + delta_start_fieldname=SampleField.ISSUED_NS, ) - return None -class RequestDurationTrigger(EmitTrigger): - """request_duration_ns = client_resp_done_ns (new) - client_send_ns.""" +class ChunkDeltaTrigger(TimeDeltaTrigger): + """chunk_delta_ns = new timestamp - previous last_recv_ns. - def __init__(self, emitter: MetricEmitter): - super().__init__("request_duration_ns", requires=("client_send_ns",)) - self._emitter = emitter + Skips when pre-change last_recv_ns is None (first recv via RECV_FIRST). + """ - def fire(self, ev_rec, row, pre_change): - client_send = pre_change.get("client_send_ns") - if client_send is not None: - self._emitter.emit( - row.sample_uuid, - "request_duration_ns", - ev_rec.timestamp_ns - client_send, - ) - return None + def __init__(self, kv_store: KVStore): + super().__init__( + MetricSeriesKey.CHUNK_DELTA_NS, + kv_store, + delta_start_fieldname=SampleField.LAST_RECV_NS, + ) -class SampleLatencyTrigger(EmitTrigger): +class SampleLatencyTrigger(TimeDeltaTrigger): """sample_latency_ns = complete_ns (new) - issued_ns.""" - def __init__(self, emitter: MetricEmitter): - super().__init__("sample_latency_ns", requires=("issued_ns",)) - self._emitter = emitter - - def fire(self, ev_rec, row, pre_change): - issued_ns = pre_change.get("issued_ns") - if issued_ns is not None: - self._emitter.emit( - row.sample_uuid, - "sample_latency_ns", - ev_rec.timestamp_ns - issued_ns, - ) - return None + def __init__(self, kv_store: KVStore): + super().__init__( + MetricSeriesKey.SAMPLE_LATENCY_NS, + kv_store, + delta_start_fieldname=SampleField.ISSUED_NS, + ) # --------------------------------------------------------------------------- @@ -194,88 +273,54 @@ def fire(self, ev_rec, row, pre_change): # --------------------------------------------------------------------------- -class IslTrigger(EmitTrigger): +class IslTrigger(AsyncTokenTrigger): """ISL from PromptData: len(token_ids) sync, or token_count(text) async.""" def __init__( self, - emitter: MetricEmitter, + kv_store: KVStore, tokenize_pool: TokenizePool | None, loop: asyncio.AbstractEventLoop | None, ): - super().__init__("isl", requires=()) - self._emitter = emitter - self._pool = tokenize_pool - self._loop = loop + super().__init__(MetricSeriesKey.ISL, kv_store, tokenize_pool, loop) def fire(self, ev_rec, row, pre_change): - if not isinstance(ev_rec.data, PromptData): - return None - if ev_rec.data.token_ids is not None: - self._emitter.emit(row.sample_uuid, "isl", len(ev_rec.data.token_ids)) + # Sync fast path: any backend that pre-populates token_ids (e.g. SGLang). + if isinstance(ev_rec.data, PromptData) and ev_rec.data.token_ids is not None: + self.kv_store.update(self.metric_name, len(ev_rec.data.token_ids)) return None - if ( - ev_rec.data.text is not None - and self._pool is not None - and self._loop is not None - ): - text = ev_rec.data.text - uuid = row.sample_uuid - pool, loop, emitter = self._pool, self._loop, self._emitter - - async def _compute() -> None: - try: - count = await pool.token_count_async(text, loop) - emitter.emit(uuid, "isl", count) - except Exception: - logger.exception("ISL tokenization failed for %s", uuid) - - return loop.create_task(_compute()) + # Async path: tokenize raw text — used when token_ids are unavailable + # (e.g. OpenAI-compatible endpoints). Handled by the base class. + return super().fire(ev_rec, row, pre_change) + + def _extract_text(self, ev_rec, row, pre_change): + if isinstance(ev_rec.data, PromptData) and ev_rec.data.text is not None: + return ev_rec.data.text return None -class OslTrigger(EmitTrigger): +class OslTrigger(AsyncTokenTrigger): """OSL = token_count(full output text) from COMPLETE event data.""" def __init__( self, - emitter: MetricEmitter, + kv_store: KVStore, tokenize_pool: TokenizePool | None, loop: asyncio.AbstractEventLoop | None, ): - super().__init__("osl", requires=()) - self._emitter = emitter - self._pool = tokenize_pool - self._loop = loop - - def fire(self, ev_rec, row, pre_change): - if self._pool is None or self._loop is None: - return None - if not isinstance(ev_rec.data, TextModelOutput): - return None - output_text = str(ev_rec.data) - if not output_text: - return None + super().__init__(MetricSeriesKey.OSL, kv_store, tokenize_pool, loop) - uuid = row.sample_uuid - pool, loop, emitter = self._pool, self._loop, self._emitter - - async def _compute() -> None: - try: - osl = await pool.token_count_async(output_text, loop) - emitter.emit(uuid, "osl", osl) - except Exception: - logger.exception("OSL tokenization failed for %s", uuid) - - return loop.create_task(_compute()) + def _extract_text(self, ev_rec, row, pre_change): + if isinstance(ev_rec.data, TextModelOutput): + text = str(ev_rec.data) + return text if text else None + return None -class TpotTrigger(EmitTrigger): +class TpotTrigger(AsyncTokenTrigger): """TPOT = (complete_ns - recv_first_ns) / token_count(text_after_first_chunk). - Only registered when streaming mode is enabled. Computes the TPOT denominator - directly from TextModelOutput.text_after_first_chunk() at COMPLETE time, - avoiding any dependency on RECV_FIRST tokenization state. + Only registered when streaming mode is enabled. # NOTE(agents): This trigger tokenizes text_after_first_chunk independently # from OslTrigger, which tokenizes the full output. This means the output is @@ -289,41 +334,31 @@ class TpotTrigger(EmitTrigger): def __init__( self, - emitter: MetricEmitter, + kv_store: KVStore, tokenize_pool: TokenizePool | None, loop: asyncio.AbstractEventLoop | None, ): - super().__init__("tpot_ns", requires=("recv_first_ns",)) - self._emitter = emitter - self._pool = tokenize_pool - self._loop = loop + super().__init__( + MetricSeriesKey.TPOT_NS, + kv_store, + tokenize_pool, + loop, + requires=(SampleField.RECV_FIRST_NS,), + dtype=float, + ) - def fire(self, ev_rec, row, pre_change): - if self._pool is None or self._loop is None: - return None - recv_first_ns = pre_change.get("recv_first_ns") - if recv_first_ns is None: + def _extract_text(self, ev_rec, row, pre_change): + if pre_change.get(SampleField.RECV_FIRST_NS) is None: return None - if not isinstance(ev_rec.data, TextModelOutput): - return None - after_first = ev_rec.data.text_after_first_chunk() - if not after_first: - return None - - uuid = row.sample_uuid - complete_ns = ev_rec.timestamp_ns - pool, loop, emitter = self._pool, self._loop, self._emitter - - async def _compute() -> None: - try: - tokens_after_first = await pool.token_count_async(after_first, loop) - if tokens_after_first > 0: - tpot = (complete_ns - recv_first_ns) / tokens_after_first - emitter.emit(uuid, "tpot_ns", tpot) - except Exception: - logger.exception("TPOT tokenization failed for %s", uuid) + if isinstance(ev_rec.data, TextModelOutput): + return ev_rec.data.text_after_first_chunk() or None + return None - return loop.create_task(_compute()) + def _compute_value(self, token_count, ev_rec, pre_change): + if token_count <= 0: + return None + recv_first_ns = pre_change[SampleField.RECV_FIRST_NS] + return (ev_rec.timestamp_ns - recv_first_ns) / token_count # --------------------------------------------------------------------------- @@ -334,6 +369,10 @@ async def _compute() -> None: class MetricsTable: """Stores in-flight sample rows, session state, and dispatches triggers. + Takes a KVStore for metric storage. When triggers are registered via + add_trigger(), the table creates the key in the store and wires the + store onto the trigger. + Row lifecycle is managed internally via ``set_field``: - ISSUED: creates the row if tracking is on, assigns block index. - COMPLETE: fires triggers, sets field, updates tracked block, removes row. @@ -342,7 +381,8 @@ class MetricsTable: Session state is updated via ``handle_session_event``. """ - def __init__(self) -> None: + def __init__(self, kv_store: KVStore) -> None: + self._kv_store = kv_store self._in_flight: dict[str, SampleRow] = {} self._triggers: dict[str, list[EmitTrigger]] = {} self._in_flight_tasks: set[asyncio.Task] = set() @@ -355,7 +395,12 @@ def __init__(self) -> None: # --- Trigger registration --- def add_trigger(self, field_name: str, trigger: EmitTrigger) -> None: - """Register a trigger for a SampleRow field.""" + """Register a trigger for a SampleRow field. + + Creates the trigger's metric key in the KV store as a series, + using the trigger's declared dtype. + """ + self._kv_store.create_key(trigger.metric_name, "series", dtype=trigger.dtype) self._triggers.setdefault(field_name, []).append(trigger) # --- Session event handling --- diff --git a/src/inference_endpoint/async_utils/transport/protocol.py b/src/inference_endpoint/async_utils/transport/protocol.py index 92122dcc..c865eb4e 100644 --- a/src/inference_endpoint/async_utils/transport/protocol.py +++ b/src/inference_endpoint/async_utils/transport/protocol.py @@ -278,9 +278,19 @@ def send(self, topic: bytes, payload: bytes) -> None: """ raise NotImplementedError("Subclasses must implement this method.") + def flush(self) -> None: # noqa: B027 — intentionally non-abstract + """Force-send any buffered records. + + Unbuffered implementations need no override. Buffered subclasses + (e.g., ZmqEventRecordPublisher) override this to drain their buffer. + """ + @abstractmethod def close(self) -> None: - """Close the publisher and release resources.""" + """Close the publisher and release resources. + + Implementations must flush any buffered records before closing. + """ raise NotImplementedError("Subclasses must implement this method.") diff --git a/src/inference_endpoint/async_utils/transport/zmq/context.py b/src/inference_endpoint/async_utils/transport/zmq/context.py index dac0d2cf..d291447d 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/context.py +++ b/src/inference_endpoint/async_utils/transport/zmq/context.py @@ -185,7 +185,13 @@ def bind(self, sock: zmq.Socket, path: str, scheme: str = "ipc") -> str: """ if scheme == "ipc": if self.socket_dir is None: - self._tmp_dir = tempfile.TemporaryDirectory(prefix="zmq_") + # Prefer /dev/shm for IPC sockets — overlayfs (common in + # containers for /tmp) does not support Unix sockets. + shm = Path("/dev/shm") + self._tmp_dir = tempfile.TemporaryDirectory( + prefix="zmq_", + dir=str(shm) if shm.is_dir() else None, + ) self.socket_dir = self._tmp_dir.name else: Path(self.socket_dir).mkdir(parents=True, exist_ok=True) diff --git a/src/inference_endpoint/async_utils/transport/zmq/pubsub.py b/src/inference_endpoint/async_utils/transport/zmq/pubsub.py index 015a5ceb..9463356b 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/pubsub.py +++ b/src/inference_endpoint/async_utils/transport/zmq/pubsub.py @@ -19,35 +19,59 @@ from collections import deque from urllib.parse import urlparse +import msgspec.msgpack import zmq from inference_endpoint.async_utils.transport.protocol import ( EventRecordPublisher, EventRecordSubscriber, ) -from inference_endpoint.core.record import TOPIC_FRAME_SIZE +from inference_endpoint.core.record import BATCH_TOPIC, TOPIC_FRAME_SIZE from .context import ManagedZMQContext logger = logging.getLogger(__name__) +_batch_encoder = msgspec.msgpack.Encoder() +_batch_decoder = msgspec.msgpack.Decoder(type=list[bytes]) + class ZmqEventRecordPublisher(EventRecordPublisher): + """ZMQ PUB socket publisher with batched sending. + + Records are buffered in memory and flushed as a single msgpack-encoded + batch when the buffer reaches ``send_threshold``. This reduces syscalls + from one per record to one per batch (~19x throughput, ~29% smaller). + + The ``send_threshold`` is the *minimum* number of records in the buffer + before an automatic flush is triggered. There is no maximum — records + accumulate until the threshold is reached or ``flush()``/``close()`` + is called explicitly. Callers that need immediate delivery (e.g., + session control events) should call ``flush()`` after publishing. + + Batching protocol: + - Batched messages use ``BATCH_TOPIC`` as the ZMQ routing prefix. + - The payload is ``msgpack(list[bytes])`` where each element is a + pre-encoded record payload (no per-record topic prefix). + - Subscribers unpack the list and yield payloads in insertion order. + - Per-record topics are omitted because EventRecord already contains + event_type for dispatching. + - Single-record flushes use the record's own topic (no batch overhead). + """ + def __init__( self, path: str, zmq_context: ManagedZMQContext, loop: asyncio.AbstractEventLoop | None = None, scheme: str = "ipc", + send_threshold: int = 1000, ): self._socket = zmq_context.socket(zmq.PUB) - # One of the guarantees of event records is that if it is published, - # it must be eventually received by all live subscribers. - self._socket.setsockopt(zmq.SNDHWM, 0) # Unlimited send buffer - self._socket.setsockopt( - zmq.LINGER, -1 - ) # Wait indefinitely on close() to send pending messages + # Guarantee delivery: unlimited send buffer, wait on close. + self._socket.setsockopt(zmq.SNDHWM, 0) + self._socket.setsockopt(zmq.LINGER, -1) self._socket.setsockopt(zmq.IMMEDIATE, 1) bind_address = zmq_context.bind(self._socket, path, scheme) @@ -56,85 +80,117 @@ def __init__( logger.info(f"Publisher bound to {self.bind_address}") self._fd = self._socket.getsockopt(zmq.FD) - self._buffer: deque[bytes] = deque() + self._send_threshold = send_threshold + self._batch_buffer: list[bytes] = [] + self._last_topic: bytes = b"" + self._pending: deque[bytes] = deque() self._writing = False + @property + def buffered_count(self) -> int: + """Number of records currently buffered (not yet sent).""" + return len(self._batch_buffer) + + @property + def pending_count(self) -> int: + """Number of frames queued for async write (socket was busy).""" + return len(self._pending) + def send(self, topic: bytes, payload: bytes) -> None: - """Send the message via zmq. + """Buffer a record for batched sending. + + Only the payload is buffered — topics are not stored per-record + since the EventRecord already contains event_type for dispatching. + When the buffer reaches ``send_threshold``, payloads are encoded + as a single msgpack list and sent with BATCH_TOPIC. For a single + record, a direct send with the record's own topic is used instead. + """ + self._last_topic = topic + self._batch_buffer.append(payload) + + if len(self._batch_buffer) >= self._send_threshold: + self._flush_batch() + + def flush(self) -> None: + """Force-send any buffered records, regardless of threshold. + + Uses direct per-record send when only 1 record is buffered + (avoids batch encoding overhead for single records like ENDED). + """ + if self._batch_buffer: + self._flush_batch() - Args: - topic: The topic of the message. - payload: The payload of the message. + def _flush_batch(self) -> None: + """Encode and send the buffered payloads. + + The buffer is only cleared after a successful send (or successful + enqueue into the pending queue). If ``_send_frame`` raises, the + buffer is restored so records are not lost. """ - # Combine into a single frame to avoid overhead of .send_multipart() - frame = topic + payload + buf = self._batch_buffer + + if len(buf) == 1: + # Single record: send with its own topic (no batch overhead). + # _last_topic is the topic from the most recent send() call. + frame = self._last_topic + buf[0] + else: + # Multiple records: encode payloads as msgpack list[bytes], + # prefix with BATCH_TOPIC for routing. Individual topics are + # not included — subscribers decode EventRecord.event_type + # from the payload for dispatching. + frame = BATCH_TOPIC + _batch_encoder.encode(buf) - # Attempt direct send: - if not self._buffer: + try: + self._batch_buffer = [] + self._send_frame(frame) + except Exception: + # Restore buffer so records are not lost. + self._batch_buffer = buf + raise + + def _send_frame(self, frame: bytes) -> None: + """Attempt direct send; fall back to pending queue + writer.""" + if not self._pending: mode = zmq.NOBLOCK if self.loop else 0 try: - self._socket.send( - frame, - flags=mode, - copy=False, - track=False, - ) + self._socket.send(frame, flags=mode, copy=False, track=False) return except zmq.Again: - # Socket would block; fall through to buffer and use writer. + # Socket would block; fall through to queue and async writer. pass if self.loop is None: - # This should never be reached, since in eager mode, the send_multipart will block and - # should always succeed, but just in case, this guard will raise an error raise RuntimeError( "Failed direct send, but publisher is set to eager-only mode." ) - # Add to buffer since socket is blocked. - self._buffer.append(frame) + self._pending.append(frame) if not self._writing: - # Add writer callback to asyncio loop to drain the buffer when writable. self._writing = True self.loop.add_writer(self._fd, self._on_writable) def _on_writable(self) -> None: - """Drains buffer when socket becomes writable. Used as an asyncio writer callback.""" + """Drain pending frames when socket becomes writable.""" if self.is_closed: return - - self._drain_buffer(force=False) - - if not self._buffer: + self._drain_pending(force=False) + if not self._pending: self._stop_writer() - def _drain_buffer(self, force: bool = False) -> None: - """Drains the buffer. - - Args: - force (bool): If True, will use blocking sends to drain the buffer to ensure - that when this method returns, the buffer is empty. - """ + def _drain_pending(self, force: bool = False) -> None: try: - while self._buffer: - # Do not pre-emptively pop in case of errors - frame = self._buffer[0] + while self._pending: + frame = self._pending[0] mode = 0 if force else zmq.NOBLOCK - self._socket.send( - frame, - flags=mode, - copy=False, - track=False, - ) - self._buffer.popleft() + self._socket.send(frame, flags=mode, copy=False, track=False) + self._pending.popleft() except zmq.Again: + # Socket would block; remaining items stay in queue for next writable callback. return def _stop_writer(self) -> None: - """Stops the writer callback.""" if self._writing: self._writing = False - if self.loop is not None and self._fd is not None: try: self.loop.remove_writer(self._fd) @@ -146,36 +202,47 @@ def close(self) -> None: if self.is_closed: return + # Flush buffered records before marking closed so that any + # concurrent publish() calls that arrive during flush are still + # accepted into the buffer rather than silently dropped. + if self._batch_buffer: + self._flush_batch() + self.is_closed = True if self.loop: - # Remove writer callback if present self._stop_writer() - - # Drain the buffer since we should not drop messages. - if self._buffer: - logger.warning( - "Closing publisher with pending messages. Draining buffer..." - ) - self._drain_buffer(force=True) - self._buffer.clear() # This should be a no-op, but just in case. - - # Socket is closed by ManagedZMQContext.cleanup() when the context scope exits. + if self._pending: + logger.warning("Closing publisher with pending frames. Draining...") + self._drain_pending(force=True) + self._pending.clear() # Cleanup IPC socket file. - # urlparse("ipc:///a/b/c") puts the full path in parsed.path (netloc is - # empty for non-registered URI schemes like "ipc"). parsed = urlparse(self.bind_address) if parsed.scheme == "ipc" and parsed.path: try: if os.path.exists(parsed.path): os.unlink(parsed.path) except OSError: - # IPC path already removed or unlink failed (e.g. permissions). + # IPC socket file already removed or unlink failed (e.g. permissions). pass class ZmqEventRecordSubscriber(EventRecordSubscriber): + """ZMQ SUB socket subscriber that handles both single and batched messages. + + Automatically subscribes to BATCH_TOPIC in addition to any explicit + topic subscriptions. Batched messages are unpacked into individual + records and yielded in order via ``receive()``. + + Note on topic filtering with batches: batched messages contain records + of mixed event types. Subscribers with specific topic filters will + receive ALL event types from batches, not just their filtered topics. + Per-record filtering must be done in application code (e.g., checking + ``EventRecord.event_type`` after decode). This is acceptable because + the decode cost (~0.6us/record) is negligible compared to processing. + """ + def __init__( self, path: str, @@ -185,15 +252,15 @@ def __init__( scheme: str = "ipc", ): self._socket = zmq_context.socket(zmq.SUB) - self._socket.setsockopt(zmq.RCVHWM, 0) - # Subscribe to topics if not topics: self._socket.setsockopt(zmq.SUBSCRIBE, b"") else: for topic in topics: self._socket.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8")) + # Always subscribe to batch topic so batched messages are received + self._socket.setsockopt(zmq.SUBSCRIBE, BATCH_TOPIC) connect_address = zmq_context.connect(self._socket, path, scheme) super().__init__(connect_address, loop, topics) @@ -206,26 +273,59 @@ def __init__( # Reader is added in .start(); do not add here. def receive(self) -> bytes | None: - """Receive a message from the socket""" + """Receive a single record payload. + + If a batched message was received, individual payloads are buffered + and returned one at a time in insertion order. + """ if self.is_closed: return None + # Return buffered payloads first (from a previous batch) + if self._buffer: + return self._buffer.popleft() + try: - frame = self._socket.recv(flags=zmq.NOBLOCK) + raw = self._socket.recv(flags=zmq.NOBLOCK) except zmq.Again as e: raise StopIteration from e - if len(frame) > TOPIC_FRAME_SIZE: - # Should be (padded_topic + payload). Return the payload bytes. - return frame[TOPIC_FRAME_SIZE:] + # Batch message: BATCH_TOPIC prefix + msgpack list[bytes] of payloads. + # Individual payloads do not have topic prefixes — EventRecord.event_type + # is used for dispatching instead. + if raw[:TOPIC_FRAME_SIZE] == BATCH_TOPIC: + batch_data = raw[TOPIC_FRAME_SIZE:] + try: + payloads = _batch_decoder.decode(batch_data) + except (msgspec.DecodeError, ValueError): + # Corrupt batch. On IPC this should never happen (ZMQ delivers + # complete messages atomically). Possible causes: encoder bug, + # ZMQ library bug, or memory corruption. Log enough detail to + # diagnose, but there is no recovery path — the publisher's + # buffer is already gone. + logger.error( + "Failed to decode batch message (%d bytes), dropping. " + "This indicates a bug — IPC messages should never be corrupt.", + len(batch_data), + ) + return None + + for payload in payloads: + if payload: + self._buffer.append(payload) + + if self._buffer: + return self._buffer.popleft() + return None + + # Single-record message: topic prefix + payload + if len(raw) > TOPIC_FRAME_SIZE: + return raw[TOPIC_FRAME_SIZE:] return None def close(self) -> None: - """Close the subscriber and remove the loop reader. Idempotent; safe to call multiple times. - Socket is closed by ManagedZMQContext.cleanup() when the context scope exits. - """ + """Close the subscriber. Idempotent.""" if self.is_closed: return self.is_closed = True - super().close() diff --git a/src/inference_endpoint/async_utils/transport/zmq/ready_check.py b/src/inference_endpoint/async_utils/transport/zmq/ready_check.py new file mode 100644 index 00000000..3e2c35a7 --- /dev/null +++ b/src/inference_endpoint/async_utils/transport/zmq/ready_check.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic ZMQ-based ready-check for subprocess startup synchronization. + +Uses a single PULL socket (host) with many PUSH sockets (subprocesses) for +fan-in readiness signaling. All sockets share the same IPC socket directory. + +See docs/async_utils/transport/zmq/ready_check_design.md for design rationale. +""" + +from __future__ import annotations + +import asyncio +import logging +import time + +import msgspec +import zmq + +from .context import ManagedZMQContext + +logger = logging.getLogger(__name__) + +_encoder = msgspec.msgpack.Encoder() +_decoder = msgspec.msgpack.Decoder(int) + +# Bounded linger: ZMQ waits up to this long to deliver queued messages when the +# PUSH socket is closed. An infinite linger (-1) guarantees delivery but could +# hang the subprocess forever if the receiver dies. 5s is generous for a single +# small IPC message — the receiver is already listening before subprocesses are +# spawned. If this proves too short (e.g., parent is heavily loaded during +# launch), increase or set to -1, but be aware of the hang risk. +_LINGER_MS = 5000 + + +async def send_ready_signal( + zmq_context: ManagedZMQContext, + path: str, + identity: int, +) -> None: + """Send a single ready signal over a PUSH socket. + + Opens a PUSH socket on the given context, sends the identity, and closes. + The subprocess's existing ZMQ context is reused — no new context created. + + Args: + zmq_context: The subprocess's existing ManagedZMQContext. + path: IPC socket path (relative to zmq_context.socket_dir). + identity: Integer identity to send (e.g., worker_id or service_id). + """ + sock = zmq_context.async_socket(zmq.PUSH) + sock.setsockopt(zmq.LINGER, _LINGER_MS) + zmq_context.connect(sock, path) + await sock.send(_encoder.encode(identity)) + sock.close() + logger.debug("Ready signal sent (identity=%d)", identity) + + +class ReadyCheckReceiver: + """Host side: bind a single PULL socket, await N ready signals. + + Multiple subprocesses connect PUSH sockets to this single PULL socket + (ZMQ fan-in). After all signals are received, the socket is closed. + """ + + def __init__( + self, + path: str, + zmq_context: ManagedZMQContext, + count: int, + ) -> None: + self._count = count + self._path = path + + # Bind PULL socket for receiving ready signals + self._sock = zmq_context.async_socket(zmq.PULL) + zmq_context.bind(self._sock, path) + + async def wait(self, timeout: float | None = None) -> list[int]: + """Block until ``count`` ready signals are received. + + Uses a total deadline (not per-message timeout). + + Args: + timeout: Maximum total seconds to wait. None means wait indefinitely. + + Returns: + List of identities received (in arrival order). + + Raises: + TimeoutError: If not all signals arrive within timeout. + """ + deadline = (time.monotonic() + timeout) if timeout is not None else None + identities: list[int] = [] + + try: + while len(identities) < self._count: + remaining = None + if deadline is not None: + remaining = max(0, deadline - time.monotonic()) + + try: + if remaining is None: + raw = await self._sock.recv() + else: + raw = await asyncio.wait_for( + self._sock.recv(), timeout=remaining + ) + except TimeoutError: + raise TimeoutError( + f"Ready check failed: {len(identities)}/{self._count} " + f"signals received within {timeout}s" + ) from None + + identity = _decoder.decode(raw) + identities.append(identity) + logger.debug( + "Ready signal received (identity=%d, %d/%d)", + identity, + len(identities), + self._count, + ) + except TimeoutError: + # Don't close socket on timeout — caller may retry. + raise + except BaseException: + # Clean up socket on non-retryable failures (cancellation, etc.) + self.close() + raise + + logger.debug("All %d ready signals received", self._count) + self.close() + return identities + + def close(self) -> None: + """Close the PULL socket. Idempotent.""" + if self._sock is not None and not self._sock.closed: + self._sock.close() diff --git a/src/inference_endpoint/async_utils/transport/zmq/transport.py b/src/inference_endpoint/async_utils/transport/zmq/transport.py index c5e314a7..92dd7c6b 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/transport.py +++ b/src/inference_endpoint/async_utils/transport/zmq/transport.py @@ -66,6 +66,7 @@ WorkerPoolTransport, ) from .context import ManagedZMQContext +from .ready_check import ReadyCheckReceiver, send_ready_signal logger = logging.getLogger(__name__) @@ -552,16 +553,8 @@ async def connect( loop, self.response_path, zmq_context, self.config, bind=False ) - # Signal readiness using an async socket for proper async send. - readiness_sock = zmq_context.async_socket(zmq.PUSH) - readiness_sock.setsockopt(zmq.LINGER, -1) # Wait indefinitely for delivery - zmq_context.connect(readiness_sock, self.readiness_path) - try: - encoder = msgspec.msgpack.Encoder() - await readiness_sock.send(encoder.encode(worker_id)) # Async send - readiness_sock.close() # LINGER ensures message is sent - logger.debug("Worker %d signaled readiness", worker_id) + await send_ready_signal(zmq_context, self.readiness_path, worker_id) yield requests, responses finally: @@ -627,9 +620,7 @@ def __init__( QueryResult | StreamChunk, # type: ignore[arg-type] bind=True, ) - self._readiness_receiver = _create_receiver( - loop, readiness_path, zmq_context, config, bind=True - ) + self._ready_check = ReadyCheckReceiver(readiness_path, zmq_context, num_workers) # socket_dir is now guaranteed set (bind() created it if needed). # Store resolved addresses for debugging and tests. @@ -701,30 +692,7 @@ async def wait_for_workers_ready(self, timeout: float | None = None) -> None: Raises: TimeoutError: If workers don't signal in time (only if timeout is set). """ - ready_count = 0 - while ready_count < self._num_workers: - try: - if timeout is None: - worker_id = await self._readiness_receiver.recv() - else: - worker_id = await asyncio.wait_for( - self._readiness_receiver.recv(), - timeout=timeout, - ) - if worker_id is not None: - ready_count += 1 - logger.debug( - f"Worker {worker_id} ready ({ready_count}/{self._num_workers})" - ) - except TimeoutError: - raise TimeoutError( - f"Workers failed to initialize: {ready_count}/{self._num_workers} ready" - ) from None - - logger.debug(f"All {self._num_workers} workers ready") - - # Close readiness receiver - no longer needed - self._readiness_receiver.close() + await self._ready_check.wait(timeout=timeout) def cleanup(self) -> None: """Close all transports and release resources. Idempotent.""" @@ -736,7 +704,7 @@ def cleanup(self) -> None: for sender in self._request_senders: sender.close() self._response_receiver.close() - self._readiness_receiver.close() + self._ready_check.close() # Only clean up the ZMQ context if we created it (singleton may be shared) if self._owns_context: diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index ec6bc5ad..35bd23b7 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -13,18 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Benchmark execution — phased architecture for threaded and future async runners. +"""Benchmark execution — phased architecture. Phases: - 1. setup_benchmark() — load tokenizer, dataset, scheduler (no IO) - 2. run_benchmark_threaded() — HTTP client + BenchmarkSession (threaded IO) + 1. setup_benchmark() — load tokenizer, dataset, config (no IO) + 2. run_benchmark_async() — HTTP client + async BenchmarkSession 3. finalize_benchmark() — accuracy scoring, results JSON """ from __future__ import annotations +import asyncio import json import logging +import platform +import shutil import signal import tempfile import uuid @@ -34,17 +37,35 @@ from typing import Any from urllib.parse import urljoin +import msgspec.json +from huggingface_hub import model_info from tqdm import tqdm -from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging +from inference_endpoint.async_utils.event_publisher import EventPublisherService +from inference_endpoint.async_utils.loop_manager import LoopManager +from inference_endpoint.async_utils.services.launcher import ( + ServiceConfig, + ServiceLauncher, +) +from inference_endpoint.async_utils.services.metrics_aggregator.aggregator import ( + MetricCounterKey, +) +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + BasicKVStoreReader, +) +from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + MetricSeriesKey, +) +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( APIType, BenchmarkConfig, DatasetType, + LoadPattern, + LoadPatternType, StreamingMode, - SystemDefaults, TestMode, TestType, ) @@ -61,13 +82,13 @@ InputValidationError, SetupError, ) -from inference_endpoint.load_generator import ( +from inference_endpoint.load_generator.session import ( BenchmarkSession, - SampleEvent, - SampleEventHandler, - WithoutReplacementSampleOrder, + PhaseConfig, + PhaseType, + SessionResult, ) -from inference_endpoint.load_generator.scheduler import Scheduler +from inference_endpoint.metrics.report import Report transformers_logging.set_verbosity_error() @@ -92,6 +113,7 @@ def __init__(self, collect_responses: bool = False, pbar: tqdm | None = None): self.pbar = pbar def on_complete_hook(self, result: QueryResult) -> None: + """Handle query completion (called once per query via QueryResult).""" self.count += 1 if result.error: self.errors.append(f"Sample {result.id}: {result.error}") @@ -103,6 +125,17 @@ def on_complete_hook(self, result: QueryResult) -> None: self.pbar.update(1) +@dataclass +class BenchmarkResult: + """Output of run_benchmark_async — all data needed for finalization.""" + + session: SessionResult + collector: ResponseCollector + report: Report | None + tmpfs_dir: Path + metrics_dir: Path | None = None + + @dataclass class AccuracyConfiguration: scorer: type[Scorer] @@ -124,10 +157,9 @@ class BenchmarkContext: config: BenchmarkConfig test_mode: TestMode report_dir: Path - tokenizer: AutoTokenizer | None + tokenizer_name: str | None dataloader: Dataset rt_settings: RuntimeSettings - scheduler: Scheduler total_samples: int accuracy_datasets: list[Dataset] = field(default_factory=list) eval_configs: list[AccuracyConfiguration] = field(default_factory=list) @@ -146,17 +178,40 @@ def enable_streaming(self) -> bool: return self.config.model_params.streaming == StreamingMode.ON -def _load_tokenizer(model_name: str) -> AutoTokenizer | None: - """Load HuggingFace tokenizer, warn on failure.""" +def _check_tokenizer_exists(model_name: str) -> bool: + """Check if a HuggingFace tokenizer exists for the model (API only, no download). + + Returns True if the model repo exists and has tokenizer files, False otherwise. + This function is a probe — it never loads or downloads the tokenizer itself. + Downstream consumers that need tokenization (e.g. the MetricsAggregator + subprocess for ISL/OSL/TPOT, Harmony transforms for prompt preprocessing, + and any future plugin with its own tokenization need) each load their own + instance as required. + """ try: - logger.info(f"Loading tokenizer for model: {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name) - logger.info("Tokenizer loaded successfully") - return tokenizer + info = model_info(model_name) + # Check for tokenizer files in the repo + siblings = {s.rfilename for s in (info.siblings or [])} + has_tokenizer = ( + "tokenizer_config.json" in siblings or "tokenizer.json" in siblings + ) + if has_tokenizer: + logger.info(f"Tokenizer available for model: {model_name}") + else: + logger.warning(f"Model {model_name} found but has no tokenizer files") + return has_tokenizer + except ImportError: + # huggingface_hub not installed — fall back to assuming it works + logger.info( + f"huggingface_hub not installed, assuming tokenizer exists for {model_name}" + ) + return True except Exception as e: - logger.warning(f"Failed to load tokenizer for {model_name}: {e}") - logger.warning("Continuing without tokenizer (report metrics may be limited)") - return None + logger.warning(f"Could not verify tokenizer for {model_name}: {e}") + logger.warning( + "Continuing without tokenizer (ISL/OSL/TPOT metrics will be unavailable)" + ) + return False def _load_datasets( @@ -229,22 +284,6 @@ def _load_datasets( return dataloader, accuracy_datasets, eval_configs -def _create_scheduler( - config: BenchmarkConfig, rt_settings: RuntimeSettings -) -> Scheduler: - """Create scheduler using __init_subclass__ registry.""" - load_pattern_type = config.settings.load_pattern.type - try: - scheduler_class = Scheduler.get_implementation(load_pattern_type) - scheduler = scheduler_class(rt_settings, WithoutReplacementSampleOrder) - logger.info( - f"Scheduler: {scheduler_class.__name__} (pattern: {load_pattern_type.value})" - ) - return scheduler - except KeyError as e: - raise SetupError(str(e)) from e - - def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: """Load tokenizer, dataset, create scheduler, setup report dir.""" # CPU affinity @@ -261,8 +300,9 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo report_dir.mkdir(parents=True, exist_ok=True) config.to_yaml_file(report_dir / "config.yaml") - # Tokenizer (model name validated by BenchmarkConfig._resolve_and_validate) - tokenizer = _load_tokenizer(config.model_params.name) + # Tokenizer check (light API call, no download) + model_name = config.model_params.name + tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None # Streaming logger.info( @@ -289,16 +329,13 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo f"Min Duration: {rt_settings.min_duration_ms / 1000:.1f}s, Expected samples: {total_samples}" ) - scheduler = _create_scheduler(config, rt_settings) - return BenchmarkContext( config=config, test_mode=test_mode, report_dir=report_dir, - tokenizer=tokenizer, + tokenizer_name=tokenizer_name, dataloader=dataloader, rt_settings=rt_settings, - scheduler=scheduler, total_samples=total_samples, accuracy_datasets=accuracy_datasets, eval_configs=eval_configs, @@ -306,98 +343,310 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def run_benchmark_threaded(ctx: BenchmarkContext) -> tuple[Any, ResponseCollector]: - """Run benchmark session with threaded HTTP client. Returns (report, collector).""" +def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: + """Build the phase list from BenchmarkContext.""" + phases: list[PhaseConfig] = [] + + # Performance phase + phases.append( + PhaseConfig( + "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + ) + ) + + # Accuracy phases — use eval_cfg.dataset_name as phase name so it matches + # what Scorer._load_sample_index_map() looks up in sample_idx_map.json + for eval_cfg in ctx.eval_configs: + acc_ds = eval_cfg.dataset + acc_settings = RuntimeSettings( + metric_target=ctx.rt_settings.metric_target, + reported_metrics=ctx.rt_settings.reported_metrics, + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=acc_ds.num_samples(), + n_samples_to_issue=acc_ds.num_samples() * acc_ds.repeats, + min_sample_count=acc_ds.num_samples() * acc_ds.repeats, + rng_sched=ctx.rt_settings.rng_sched, + rng_sample_index=ctx.rt_settings.rng_sample_index, + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phases.append( + PhaseConfig(eval_cfg.dataset_name, acc_settings, acc_ds, PhaseType.ACCURACY) + ) + + return phases + + +def _setup_kv_reader( + metrics_dir: Path, + streaming: bool, +) -> BasicKVStoreReader: + """Create a KVStoreReader pre-registered with all metric keys.""" + reader = BasicKVStoreReader(metrics_dir) + for counter_key in MetricCounterKey: + reader.register_key(counter_key.value, "counter") + _STREAMING_ONLY = { + MetricSeriesKey.TTFT_NS, + MetricSeriesKey.CHUNK_DELTA_NS, + MetricSeriesKey.TPOT_NS, + } + _FLOAT_SERIES = {MetricSeriesKey.TPOT_NS} + for series_key in MetricSeriesKey: + if series_key in _STREAMING_ONLY and not streaming: + continue + dtype = float if series_key in _FLOAT_SERIES else int + reader.register_key(series_key.value, "series", dtype=dtype) + return reader + + +async def _run_benchmark_async( + ctx: BenchmarkContext, + loop: asyncio.AbstractEventLoop, +) -> BenchmarkResult: + """Run async benchmark session.""" config = ctx.config + session_id = f"cli_benchmark_{uuid.uuid4().hex[:8]}" - # Setup response collector + # Progress bar + response collector pbar = tqdm( desc=f"{config.model_params.name} (Streaming: {ctx.enable_streaming})", total=ctx.total_samples, - smoothing=0, # smoothing=0 shows average instead of EMA + smoothing=0, ) collector = ResponseCollector(collect_responses=ctx.collect_responses, pbar=pbar) - SampleEventHandler.register_hook(SampleEvent.COMPLETE, collector.on_complete_hook) - # Create endpoint client - endpoints = config.endpoint_config.endpoints - logger.info(f"Connecting: {endpoints}") - try: - api_type: APIType = config.endpoint_config.api_type - http_config = config.settings.client.with_updates( - endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints], - api_type=api_type, - api_key=config.endpoint_config.api_key, - event_logs_dir=ctx.report_dir, - cpu_affinity=ctx.affinity_plan, + # ZMQ context for event publishing + service launcher + with ManagedZMQContext.scoped(io_threads=2) as zmq_ctx: + # Event publisher + publisher = EventPublisherService(zmq_ctx) + pub_socket_name = publisher.socket_name + + # Tmpfs for high-frequency writes (metrics mmap + event log). + # On ARM, metrics need an on-disk directory so msync provides + # write ordering for cross-process mmap reads. Event logs are + # append-only and don't have ordering requirements, so they + # can stay on tmpfs. + shm = Path("/dev/shm") + use_shm = shm.exists() + tmpfs_base = shm if use_shm else Path(tempfile.gettempdir()) + tmpfs_dir = tmpfs_base / f"benchmark_{session_id}" + tmpfs_dir.mkdir(parents=True, exist_ok=True) + + # On ARM, mmap write ordering requires msync on a real filesystem. + # msync is a no-op on tmpfs, so metrics must use an on-disk directory. + if use_shm and platform.machine() != "x86_64": + logger.info( + "ARM platform: using on-disk metrics directory for mmap ordering" + ) + metrics_dir = Path( + tempfile.mkdtemp(prefix=f"metrics_{session_id}_", dir=".") + ) + else: + metrics_dir = tmpfs_dir / "metrics" + metrics_dir.mkdir(parents=True, exist_ok=True) + + event_log_dir = tmpfs_dir / "events" + event_log_dir.mkdir(parents=True, exist_ok=True) + + # Launch service subprocesses + launcher = ServiceLauncher(zmq_ctx) + if zmq_ctx.socket_dir is None: + raise RuntimeError("ZMQ socket_dir must be set after publisher bind") + aggregator_args: list[str] = [ + "--socket-dir", + zmq_ctx.socket_dir, + "--socket-name", + pub_socket_name, + "--metrics-dir", + str(metrics_dir), + ] + if ctx.enable_streaming: + aggregator_args.append("--streaming") + if ctx.tokenizer_name is not None: + aggregator_args.extend(["--tokenizer", ctx.tokenizer_name]) + + # EventLoggerService writes events.jsonl to tmpfs (high-frequency writes) + event_logger_args: list[str] = [ + "--log-dir", + str(event_log_dir), + "--socket-dir", + zmq_ctx.socket_dir, + "--socket-name", + pub_socket_name, + "--writers", + "jsonl", + ] + + await launcher.launch( + [ + ServiceConfig( + module="inference_endpoint.async_utils.services.metrics_aggregator", + args=aggregator_args, + ), + ServiceConfig( + module="inference_endpoint.async_utils.services.event_logger", + args=event_logger_args, + ), + ], + timeout=30.0, ) - http_client = HTTPEndpointClient(http_config) - sample_issuer = HttpClientSampleIssuer(http_client) - except Exception as e: - raise SetupError(f"Failed to connect to endpoint: {e}") from e - # Run benchmark - logger.info("Running...") - sess = None - try: - sess = BenchmarkSession.start( - ctx.rt_settings, - ctx.dataloader, - sample_issuer, - ctx.scheduler, - name=f"cli_benchmark_{uuid.uuid4().hex[0:8]}", - report_dir=ctx.report_dir, - tokenizer_override=ctx.tokenizer, - accuracy_datasets=ctx.accuracy_datasets, - max_shutdown_timeout_s=config.timeout or SystemDefaults.DEFAULT_TIMEOUT, - dump_events_log=True, + # Create endpoint client on the shared loop + endpoints = config.endpoint_config.endpoints + logger.info(f"Connecting: {endpoints}") + http_client: HTTPEndpointClient | None = None + try: + api_type: APIType = config.endpoint_config.api_type + http_config = config.settings.client.with_updates( + endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints], + api_type=api_type, + api_key=config.endpoint_config.api_key, + event_logs_dir=ctx.report_dir, + cpu_affinity=ctx.affinity_plan, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + except Exception as e: + pbar.close() + publisher.close() + launcher.kill_all() + raise SetupError(f"Failed to connect to endpoint: {e}") from e + + # Create session + session = BenchmarkSession( + issuer=issuer, + event_publisher=publisher, + loop=loop, + on_sample_complete=collector.on_complete_hook, + session_id=session_id, ) - # Wait for test end with ability to interrupt - def _raise_keyboard_interrupt(*_: object) -> None: - raise KeyboardInterrupt + phases = _build_phases(ctx) + report: Report | None = None - old_handler = signal.signal(signal.SIGINT, _raise_keyboard_interrupt) + loop.add_signal_handler(signal.SIGINT, session.stop) try: - sess.wait_for_test_end() + result = await session.run(phases) + except Exception as e: + raise ExecutionError(f"Benchmark execution failed: {e}") from e finally: - # Always restore original handler - signal.signal(signal.SIGINT, old_handler) - - # Prefer authoritative metrics from the session report - report = getattr(sess, "report", None) - if report is None: - raise ExecutionError("Session report missing — cannot produce results") - return report, collector + loop.remove_signal_handler(signal.SIGINT) + logger.info("Cleaning up...") + try: + if http_client: + await http_client.shutdown_async() + except Exception as e: + logger.warning(f"Client cleanup error: {e}") + logger.info( + "Closing publisher (buffer=%d, pending=%d)...", + publisher.buffered_count, + publisher.pending_count, + ) + publisher.close() + logger.info("Waiting for services to finish processing...") + await asyncio.to_thread(launcher.wait_for_exit, None) + + # Build report AFTER aggregator has exited — ensures all metrics + # (TTFT, TPOT, OSL, latency) are fully written to KVStore. + try: + kv_reader = _setup_kv_reader(metrics_dir, ctx.enable_streaming) + report = Report.from_kv_reader(kv_reader) + kv_reader.close() + except Exception as e: + logger.warning(f"Failed to build report from metrics: {e}") - except KeyboardInterrupt: - logger.warning("Benchmark interrupted by user") - raise - except ExecutionError: - # Re-raise our own exceptions - raise - except Exception as e: - raise ExecutionError(f"Benchmark execution failed: {e}") from e - finally: - # Cleanup - always execute - logger.info("Cleaning up...") - try: - if sess is not None: - sess.stop() pbar.close() - sample_issuer.shutdown() - http_client.shutdown() - except Exception as e: - logger.debug(f"Cleanup error: {e}") + + # Track metrics_dir separately if it's not under tmpfs_dir (ARM on-disk case) + separate_metrics = metrics_dir if metrics_dir.parent != tmpfs_dir else None + return BenchmarkResult( + session=result, + collector=collector, + report=report, + tmpfs_dir=tmpfs_dir, + metrics_dir=separate_metrics, + ) + + +def run_benchmark_async(ctx: BenchmarkContext) -> BenchmarkResult: + """Run async benchmark. Sync entry point — drives the event loop.""" + loop = LoopManager().default_loop + return loop.run_until_complete(_run_benchmark_async(ctx, loop)) -def finalize_benchmark( +def _write_scoring_artifacts( ctx: BenchmarkContext, - report: Any, - collector: ResponseCollector, + result: SessionResult, + tmpfs_dir: Path, ) -> None: + """Write sample_idx_map.json and copy events.jsonl for Scorer consumption. + + events.jsonl is written by EventLoggerService to tmpfs during the benchmark. + We copy it to report_dir (typically on disk) during finalization. + """ + + # sample_idx_map.json — {dataset_name: {uuid: sample_index}} + sample_idx_map: dict[str, dict[str, int]] = {} + for phase_result in result.phase_results: + sample_idx_map[phase_result.name] = phase_result.uuid_to_index + + map_path = ctx.report_dir / "sample_idx_map.json" + with map_path.open("wb") as f: + f.write(msgspec.json.format(msgspec.json.encode(sample_idx_map), indent=2)) + logger.debug(f"Wrote {map_path}") + + # Copy events.jsonl from tmpfs to report_dir. + # Tmpfs cleanup is handled by run_benchmark()'s finally block. + _salvage_tmpfs(ctx.report_dir, tmpfs_dir) + + +def _salvage_tmpfs(report_dir: Path, tmpfs_dir: Path) -> None: + """Copy all salvageable artifacts from tmpfs to report_dir. + + Called during normal finalization and on interrupt/crash to preserve logs. + Safe to call multiple times (skips if already copied or tmpfs is gone). + """ + if not tmpfs_dir.exists(): + return + + # events.jsonl (from EventLoggerService) + src_events = tmpfs_dir / "events" / "events.jsonl" + if src_events.exists(): + dst_events = report_dir / "events.jsonl" + shutil.copy2(src_events, dst_events) + logger.debug(f"Copied {src_events} -> {dst_events}") + + # metrics mmap files (from MetricsAggregator KVStore) + src_metrics = tmpfs_dir / "metrics" + if src_metrics.exists(): + dst_metrics = report_dir / "metrics" + dst_metrics.mkdir(parents=True, exist_ok=True) + for f in src_metrics.iterdir(): + if f.is_file(): + shutil.copy2(f, dst_metrics / f.name) + logger.debug(f"Copied metrics from {src_metrics} -> {dst_metrics}") + + +def finalize_benchmark(ctx: BenchmarkContext, bench: BenchmarkResult) -> None: """Score accuracy, aggregate results, write JSON.""" config = ctx.config + result = bench.session + collector = bench.collector + report = bench.report + + # Display report if available (from MetricsAggregator KVStore) + if report is not None: + report.display(fn=lambda s: logger.info(s), summary_only=True) + report.to_json(save_to=ctx.report_dir / "result_summary.json") + + # Write human-readable report.txt + report_txt = ctx.report_dir / "report.txt" + with report_txt.open("w") as f: + report.display(fn=lambda s: print(s, file=f)) + logger.info(f"Report written to {report_txt}") + + # Write scoring artifacts + copy event log from tmpfs to disk + _write_scoring_artifacts(ctx, result, bench.tmpfs_dir) # Accuracy scoring accuracy_scores: dict[str, Any] = {} @@ -421,15 +670,27 @@ def finalize_benchmark( } logger.info(f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)") - # Report metrics - elapsed = report.duration_ns / 1e9 if report.duration_ns is not None else 0.0 - total_issued = report.n_samples_issued - success = total_issued - report.n_samples_failed - qps = report.qps or 0.0 - - logger.info(f"Completed in {elapsed:.1f}s") - logger.info(f"Results: {success}/{total_issued} successful") - logger.info(f"Estimated QPS: {qps:.1f}") + # Report metrics: prefer Report from KVStore, fall back to SessionResult + if report is not None and report.duration_ns is not None: + perf_elapsed = report.duration_ns / 1e9 + total_issued = report.n_samples_issued + n_errors = report.n_samples_failed + qps = report.qps() or 0.0 + else: + perf = result.perf_results[0] if result.perf_results else None + if perf: + perf_elapsed = (perf.end_time_ns - perf.start_time_ns) / 1e9 + total_issued = perf.issued_count + else: + perf_elapsed = (result.end_time_ns - result.start_time_ns) / 1e9 + total_issued = 0 + n_errors = len(collector.errors) + qps = total_issued / perf_elapsed if perf_elapsed > 0 else 0.0 + + logger.info(f"Completed in {perf_elapsed:.1f}s") + logger.info(f"Results: {max(0, total_issued - n_errors)}/{total_issued} successful") + if qps > 0: + logger.info(f"Estimated QPS: {qps:.1f}") if collector.errors: logger.warning(f"Errors: {len(collector.errors)}") @@ -448,9 +709,9 @@ def finalize_benchmark( }, "results": { "total": total_issued, - "successful": success, - "failed": report.n_samples_failed, - "elapsed_time": elapsed, + "successful": max(0, total_issued - n_errors), + "failed": n_errors, + "elapsed_time": perf_elapsed, "qps": qps, }, } @@ -477,5 +738,17 @@ def run_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> None: config.model_dump_json(indent=2, exclude_none=True), ) ctx = setup_benchmark(config, test_mode) - report, collector = run_benchmark_threaded(ctx) - finalize_benchmark(ctx, report, collector) + bench: BenchmarkResult | None = None + try: + bench = run_benchmark_async(ctx) + finalize_benchmark(ctx, bench) + except KeyboardInterrupt: + logger.warning("Benchmark interrupted by user") + finally: + if bench: + if bench.tmpfs_dir.exists(): + _salvage_tmpfs(ctx.report_dir, bench.tmpfs_dir) + shutil.rmtree(bench.tmpfs_dir, ignore_errors=True) + if bench.metrics_dir and bench.metrics_dir.exists(): + shutil.rmtree(bench.metrics_dir, ignore_errors=True) + logger.info(f"Partial results saved to {ctx.report_dir}") diff --git a/src/inference_endpoint/config/ruleset_registry.py b/src/inference_endpoint/config/ruleset_registry.py index 59cddc30..75ce5139 100644 --- a/src/inference_endpoint/config/ruleset_registry.py +++ b/src/inference_endpoint/config/ruleset_registry.py @@ -28,6 +28,8 @@ from typing import TYPE_CHECKING +from .rulesets.mlcommons.rules import CURRENT as mlcommons_current + if TYPE_CHECKING: from .ruleset_base import BenchmarkSuiteRuleset @@ -77,18 +79,10 @@ def list_rulesets() -> list[str]: # Auto-register MLCommons rulesets def _auto_register_mlcommons(): """Auto-register MLCommons rulesets.""" - try: - from .rulesets.mlcommons.rules import CURRENT as mlcommons_current - - # Register with version-specific name - register_ruleset( - f"mlperf-inference-{mlcommons_current.version}", mlcommons_current - ) - # Also register as "mlcommons-current" for convenience - register_ruleset("mlcommons-current", mlcommons_current) - except ImportError: - # MLCommons rulesets not available - pass + # Register with version-specific name + register_ruleset(f"mlperf-inference-{mlcommons_current.version}", mlcommons_current) + # Also register as "mlcommons-current" for convenience + register_ruleset("mlcommons-current", mlcommons_current) # Auto-register on import diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index e7bfe19e..6ad4224e 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -430,7 +430,10 @@ class EndpointConfig(BaseModel): endpoints: Annotated[ list[str], cyclopts.Parameter(alias="--endpoints", help="Endpoint URL(s)", negative=""), - ] = Field(min_length=1) + ] = Field( + min_length=1, + description="Endpoint URL(s). Must include scheme, e.g. 'http://host:port'.", + ) api_key: Annotated[ str | None, cyclopts.Parameter(alias="--api-key", help="API key") ] = None @@ -439,6 +442,16 @@ class EndpointConfig(BaseModel): cyclopts.Parameter(alias="--api-type", help="API type: openai or sglang"), ] = APIType.OPENAI + @field_validator("endpoints", mode="after") + @classmethod + def _validate_endpoint_scheme(cls, v: list[str]) -> list[str]: + for url in v: + if not url.startswith(("http://", "https://")): + raise ValueError( + f"Endpoint URL must include scheme (http:// or https://), got: {url!r}" + ) + return v + class BenchmarkConfig(WithUpdatesMixin, BaseModel): """Benchmark configuration — single source of truth for YAML and CLI. @@ -690,7 +703,7 @@ def create_default_config(test_type: TestType) -> BenchmarkConfig: _common = { "model_params": ModelParams(name=""), "datasets": [Dataset(path="")], - "endpoint_config": EndpointConfig(endpoints=[""]), + "endpoint_config": EndpointConfig(endpoints=["http://localhost:8000"]), } if test_type == TestType.OFFLINE: return OfflineBenchmarkConfig(**_common) diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index 31e1c05f..7b560ed7 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -17,5 +17,5 @@ settings: type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step target_concurrency: 32 # Concurrent requests endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 48f1b34f..3a8e004f 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -49,7 +49,6 @@ settings: target_concurrency: 32 # Concurrent requests client: num_workers: -1 # Worker processes (-1=auto) - record_worker_events: false # Record per-worker events log_level: INFO # Worker log level warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) max_connections: -1 # Max TCP connections (-1=unlimited) @@ -70,8 +69,8 @@ settings: min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 api_key: null # API key api_type: openai # API type: openai or sglang | options: openai, sglang report_dir: null # Report output directory diff --git a/src/inference_endpoint/config/templates/offline_template.yaml b/src/inference_endpoint/config/templates/offline_template.yaml index 6531771a..6e83d10f 100644 --- a/src/inference_endpoint/config/templates/offline_template.yaml +++ b/src/inference_endpoint/config/templates/offline_template.yaml @@ -14,5 +14,5 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 7c5f43c6..faabffde 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -49,7 +49,6 @@ settings: target_concurrency: null # Concurrent requests client: num_workers: -1 # Worker processes (-1=auto) - record_worker_events: false # Record per-worker events log_level: INFO # Worker log level warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) max_connections: -1 # Max TCP connections (-1=unlimited) @@ -70,8 +69,8 @@ settings: min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 api_key: null # API key api_type: openai # API type: openai or sglang | options: openai, sglang report_dir: null # Report output directory diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index eafac9e9..d33c1fd5 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -17,5 +17,5 @@ settings: type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step target_qps: 10.0 # Target QPS endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 6e274f8e..e9b7a673 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -49,7 +49,6 @@ settings: target_concurrency: null # Concurrent requests client: num_workers: -1 # Worker processes (-1=auto) - record_worker_events: false # Record per-worker events log_level: INFO # Worker log level warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) max_connections: -1 # Max TCP connections (-1=unlimited) @@ -70,8 +69,8 @@ settings: min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system endpoint_config: - endpoints: # Endpoint URL(s) - - '' + endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. + - http://localhost:8000 api_key: null # API key api_type: openai # API type: openai or sglang | options: openai, sglang report_dir: null # Report output directory diff --git a/src/inference_endpoint/config/utils.py b/src/inference_endpoint/config/utils.py index 25c27899..fd62d2bb 100644 --- a/src/inference_endpoint/config/utils.py +++ b/src/inference_endpoint/config/utils.py @@ -124,10 +124,8 @@ def parse_dataset_string(s: str) -> dict[str, object]: # Validate parser remap targets (CLI only — YAML validated in factory) if "parser" in result and isinstance(result["parser"], dict): - # Lazy import to avoid circular dep: schema_utils → dataset_manager → schema - from inference_endpoint.dataset_manager.transforms import ( - MakeAdapterCompatible, - ) + # Lazy import: circular dependency (config.schema → config.utils → dataset_manager → config.schema) + from inference_endpoint.dataset_manager.transforms import MakeAdapterCompatible valid = set(MakeAdapterCompatible().remap.values()) invalid = set(result["parser"].keys()) - valid diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index 33b13d1c..adac5c8f 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -32,6 +32,9 @@ equal to the length of the longest topic string. """ +BATCH_TOPIC: Final[bytes] = b"batch".ljust(TOPIC_FRAME_SIZE, b"\0") +"""Reserved topic prefix for batched messages containing multiple records.""" + class EventTypeMeta(enum.EnumMeta): """Metaclass for event kind enums classes. @@ -142,10 +145,6 @@ class SampleEventType(EventType): COMPLETE = "complete" RECV_FIRST = "recv_first" RECV_NON_FIRST = "recv_non_first" - CLIENT_SEND = "client_send" - CLIENT_RESP_DONE = "client_resp_done" - TRANSPORT_SENT = "transport_sent" - TRANSPORT_RECV = "transport_recv" class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: ignore[call-arg] diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index aa862b66..accd2ca8 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -348,19 +348,13 @@ class StreamChunk( display and accurate Time-To-First-Token (TTFT) measurements. Multiple StreamChunks with the same id collectively form the complete response. - The is_complete flag indicates the final chunk in the sequence. + The final QueryResult (sent by the worker after all chunks) signals completion. Attributes: id: Query identifier (matches the originating Query.id). response_chunk: Partial response text for this chunk (delta, not cumulative). - is_complete: True if this is the final chunk, False for intermediate chunks. metadata: Additional metadata for this chunk (timing, token info, etc.). - Example: - Streaming "Hello World" might produce: - >>> StreamChunk(id="q1", response_chunk="Hello", is_complete=False) - >>> StreamChunk(id="q1", response_chunk=" World", is_complete=True) - Note: gc=False: Safe because metadata contains only scalar key-value pairs. Do NOT store cyclic references in metadata field. @@ -368,13 +362,12 @@ class StreamChunk( omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory) are omitted if value equals default. - array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}] + array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", {}] instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size. """ id: str = "" response_chunk: str = "" - is_complete: bool = False metadata: dict[str, Any] = msgspec.field(default_factory=dict) diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 25091965..6ed1674a 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -19,6 +19,7 @@ """ import logging +from pathlib import Path from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat @@ -103,8 +104,6 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data transforms.append(MakeAdapterCompatible()) assert dataset_path is not None - from pathlib import Path - return Dataset.load_from_file( Path(dataset_path), transforms=transforms, diff --git a/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py b/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py index be2a12af..1f1dd23c 100644 --- a/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py +++ b/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py @@ -58,7 +58,7 @@ def _ensure_venv(cls, venv_path: Path) -> Path: """ if not venv_path.exists(): logger.info(f"Creating virtual environment at {venv_path}") - venv.create(venv_path, with_pip=True, clear=True) + venv.create(venv_path, with_pip=True, clear=True, symlinks=False) # Determine Python executable path based on platform if sys.platform == "win32": diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index a2e2e3be..79133796 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -17,6 +17,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from importlib import import_module from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -24,6 +25,7 @@ import pandas as pd +from ..endpoint_client.config import ADAPTER_MAP from ..openai.harmony import Harmonizer @@ -405,10 +407,6 @@ def get_transforms_for_api_type( Returns: A list of transforms required for the given API type """ - from importlib import import_module - - from inference_endpoint.endpoint_client.config import ADAPTER_MAP - adapter_path = ADAPTER_MAP.get(api_type) if not adapter_path: raise ValueError(f"Invalid or unsupported API type: {api_type}") diff --git a/src/inference_endpoint/endpoint_client/config.py b/src/inference_endpoint/endpoint_client/config.py index 599ca7c4..53f52780 100644 --- a/src/inference_endpoint/endpoint_client/config.py +++ b/src/inference_endpoint/endpoint_client/config.py @@ -77,7 +77,6 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel): ), ] = Field(-1, ge=-1) - record_worker_events: bool = Field(False, description="Record per-worker events") log_level: str = Field("INFO", description="Worker log level") # Pre-establish TCP connections during init for reuse at runtime. diff --git a/src/inference_endpoint/endpoint_client/http_client.py b/src/inference_endpoint/endpoint_client/http_client.py index 4d158d75..e273f581 100644 --- a/src/inference_endpoint/endpoint_client/http_client.py +++ b/src/inference_endpoint/endpoint_client/http_client.py @@ -68,7 +68,10 @@ def __init__( self.loop = loop assert self.loop is not None - # Initialize on event loop + # Initialize on event loop. + # NOTE: This uses run_coroutine_threadsafe().result() which DEADLOCKS + # if called from the same event loop thread. For shared-loop usage, + # use the async factory: await HTTPEndpointClient.create(config, loop) asyncio.run_coroutine_threadsafe(self._initialize(), self.loop).result() logger.info( @@ -79,6 +82,36 @@ def __init__( f"transport={self.config.transport.type if self.config.transport else 'none'}" ) + @classmethod + async def create( + cls, + config: HTTPClientConfig, + loop: asyncio.AbstractEventLoop, + ) -> "HTTPEndpointClient": + """Async factory for shared-loop usage. + + Use this instead of __init__ when the caller is already running on + the target event loop (e.g., inside run_benchmark_async). The regular + constructor uses run_coroutine_threadsafe().result() which deadlocks + when called from the same loop. + """ + self = cls.__new__(cls) + self.client_id = uuid.uuid4().hex[:8] + self.config = config + self._worker_cycle = cycle(range(config.num_workers)) + self._owns_loop = False + self._loop_name = None + self.loop = loop + await self._initialize() + logger.info( + f"EndpointClient initialized with num_workers={config.num_workers}, " + f"endpoints={config.endpoint_urls}, " + f"adapter={config.adapter.__name__ if config.adapter else 'none'}, " + f"accumulator={config.accumulator.__name__ if config.accumulator else 'none'}, " + f"transport={config.transport.type if config.transport else 'none'}" + ) + return self + async def _initialize(self) -> None: """Initialize worker manager and transports.""" self._shutdown: bool = False @@ -113,11 +146,22 @@ def drain(self) -> list[QueryResult | StreamChunk]: return list(iter(self.poll, None)) def shutdown(self) -> None: - """Gracefully shutdown client. Synchronous — blocks the caller until complete.""" - if self._shutdown: # Already shutdown, no-op + """Gracefully shutdown client. Synchronous — blocks the caller until complete. + + NOTE: This uses run_coroutine_threadsafe().result() which DEADLOCKS + if called from the same event loop thread. For shared-loop usage, + use: await client.shutdown_async() + """ + if self._shutdown: return asyncio.run_coroutine_threadsafe(self._shutdown_async(), self.loop).result() + async def shutdown_async(self) -> None: + """Async shutdown for shared-loop usage. Must be called from the event loop.""" + if self._shutdown: + return + await self._shutdown_async() + async def _shutdown_async(self) -> None: """Async shutdown internals - must be called on the event loop.""" self._shutdown = True diff --git a/src/inference_endpoint/endpoint_client/http_sample_issuer.py b/src/inference_endpoint/endpoint_client/http_sample_issuer.py index c30379ee..38f38923 100644 --- a/src/inference_endpoint/endpoint_client/http_sample_issuer.py +++ b/src/inference_endpoint/endpoint_client/http_sample_issuer.py @@ -13,90 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LoadGenerator integration for HTTPEndpointClient.""" +"""SampleIssuer implementation wrapping HTTPEndpointClient. -import asyncio -import logging +Thin adapter: delegates issue/recv/shutdown to the underlying HTTP client. +The BenchmarkSession owns the response receive loop — this class does NOT +run its own _handle_responses coroutine. +""" from inference_endpoint.core.types import Query, QueryResult, StreamChunk from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient -from inference_endpoint.load_generator import SampleIssuer -from inference_endpoint.load_generator.sample import Sample, SampleEventHandler -from inference_endpoint.profiling import profile -logger = logging.getLogger(__name__) +class HttpClientSampleIssuer: + """SampleIssuer wrapping an HTTPEndpointClient. -class HttpClientSampleIssuer(SampleIssuer): - """ - SampleIssuer interface for HTTPEndpointClient. - Routes completed responses to SampleEventHandler. + Satisfies the SampleIssuer protocol from load_generator.session. Usage: - # Create HTTP client and sample issuer - auto-initializes - client = HTTPEndpointClient(config) + client = await HTTPEndpointClient.create(config, loop) issuer = HttpClientSampleIssuer(client) - # Issue samples - issuer.issue(sample) - - # shutdown() is optional - only needed for early exit + issuer.issue(query) # sync ZMQ push + response = await issuer.recv() # async ZMQ recv + issuer.shutdown() # no-op (client shutdown called separately) """ - def __init__( - self, - http_client: HTTPEndpointClient, - ): - super().__init__() + def __init__(self, http_client: HTTPEndpointClient): self.http_client = http_client - # Start response handler task to route completed responses back to SampleEventHandler - self._response_task = asyncio.run_coroutine_threadsafe( - self._handle_responses(), self.http_client.loop - ) - - @profile - async def _handle_responses(self): - """Route completed responses to SampleEventHandler.""" - while True: - try: - # TODO(vir): consider using recv() + drain - match response := await self.http_client.recv(): - case StreamChunk(is_complete=False): - # NOTE(vir): is_complete=True should not be received, QueryResult is expected instead - SampleEventHandler.stream_chunk_complete(response) - - case QueryResult(error=err): - SampleEventHandler.query_result_complete(response) - if err is not None: - logger.error(f"Error in request {response.id}: {err}") - - case None: - # Transport closed or shutdown - break - - case _: - raise ValueError(f"Unexpected response type: {type(response)}") - - except asyncio.CancelledError: - # Handle shutdown signal - break - except Exception as e: - logger.error(f"Error in response handler: {e}", exc_info=True) - continue + def issue(self, query: Query) -> None: + """Issue query to HTTP endpoint. Non-blocking (ZMQ push).""" + self.http_client.issue(query) - @profile - def issue(self, sample: Sample): - """Issue sample to HTTP endpoint.""" - # NOTE(vir): - # If using extra headers (e.g., Authorization), pre-cache them in - # worker.py request-template via HttpRequestTemplate.cache_headers() - # to avoid per-request encoding overhead at runtime. - self.http_client.issue(Query(id=sample.uuid, data=sample.data)) + async def recv(self) -> QueryResult | StreamChunk | None: + """Wait for next response. Returns None when transport is closed.""" + return await self.http_client.recv() - def shutdown(self): - """ - Gracefully shutdown sample issuer. - Will cancel the response-handler task. - """ - self._response_task.cancel() + def shutdown(self) -> None: + """No-op. HTTPEndpointClient.shutdown() is called separately by the caller.""" + pass diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 4f3ebd11..8e0e560e 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -23,7 +23,6 @@ import signal import ssl import sys -import time import traceback from collections.abc import AsyncGenerator from typing import Any @@ -46,9 +45,6 @@ InFlightRequest, PooledConnection, ) -from inference_endpoint.load_generator.events import SampleEvent -from inference_endpoint.metrics.recorder import EventRecorder -from inference_endpoint.metrics.reporter import MetricsReporter from inference_endpoint.profiling import profile from inference_endpoint.utils.logging import setup_logging @@ -267,29 +263,8 @@ async def run(self) -> None: f"(need {threshold}). Consider closing background TCP connections." ) - # TODO(vir): - # record_worker_events has high overhead - slows down the worker 100x - # replace with fine-grained metrics, always captured/dumped per worker - # Run main processing loop - if self.http_config.record_worker_events: - pid = os.getpid() - worker_db_name = f"worker_report_{self.worker_id}_{pid}" - assert ( - self.http_config.event_logs_dir is not None - ), "event_logs_dir must be set if record_worker_events is enabled" - report_path = self.http_config.event_logs_dir / f"{worker_db_name}.csv" - - with EventRecorder(session_id=worker_db_name) as event_recorder: - await self._run_main_loop() - event_recorder.wait_for_writes(force_commit=True) - - with MetricsReporter(event_recorder.connection_name) as reporter: - logger.debug(f"About to dump report to {report_path}") - reporter.dump_all_to_csv(report_path) - logger.debug(f"Report dumped to {report_path}") - else: - await self._run_main_loop() + await self._run_main_loop() except Exception as e: logger.error(f"Error: {type(e).__name__}: {str(e)}") @@ -328,14 +303,6 @@ async def _run_main_loop(self) -> None: if query is None: break - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.ZMQ_REQUEST_RECEIVED, - time.monotonic_ns(), - sample_uuid=query.id, - assert_active=True, - ) - # Prepare and fire request req = self._prepare_request(query) if not await self._fire_request(req): @@ -439,15 +406,6 @@ async def _process_response(self, req: InFlightRequest) -> None: # Release connection back to pool if not already self._pool.release(conn) - # Record completion event - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.HTTP_RESPONSE_COMPLETED, - time.monotonic_ns(), - sample_uuid=req.query_id, - assert_active=True, - ) - # Clean up task reference current_task = asyncio.current_task() if current_task is not None: @@ -467,26 +425,11 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: if stream_chunk := accumulator.add_chunk(delta): self._responses.send(stream_chunk) - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.ZMQ_RESPONSE_SENT, - time.monotonic_ns(), - sample_uuid=query_id, - assert_active=True, - ) - # Release connection early - done with socket I/O (idempotent) self._pool.release(conn) # Send final complete back to main rank self._responses.send(accumulator.get_final_output()) - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.ZMQ_RESPONSE_SENT, - time.monotonic_ns(), - sample_uuid=query_id, - assert_active=True, - ) @profile async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: @@ -505,13 +448,6 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: # Send result back to main rank self._responses.send(result) - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.ZMQ_RESPONSE_SENT, - time.monotonic_ns(), - sample_uuid=query_id, - assert_active=True, - ) async def _handle_error(self, query_id: str, error: Exception | str) -> None: """Send error response for a query.""" @@ -532,13 +468,6 @@ async def _handle_error(self, query_id: str, error: Exception | str) -> None: error=error_data, ) self._responses.send(error_response) - if self.http_config.record_worker_events: - EventRecorder.record_event( - SampleEvent.ZMQ_RESPONSE_SENT, - time.monotonic_ns(), - sample_uuid=query_id, - assert_active=True, - ) @profile async def _iter_sse_lines( diff --git a/src/inference_endpoint/evaluation/scoring.py b/src/inference_endpoint/evaluation/scoring.py index f4e24fbc..f5145a75 100644 --- a/src/inference_endpoint/evaluation/scoring.py +++ b/src/inference_endpoint/evaluation/scoring.py @@ -36,9 +36,16 @@ except ImportError: websocket = None +try: + import evaluate as _evaluate + import nltk as _nltk +except ImportError: + _evaluate = None + _nltk = None + +from ..core.record import EventRecord, EventType, SampleEventType from ..dataset_manager.dataset import Dataset from ..dataset_manager.predefined.shopify_product_catalogue import ProductMetadata -from ..load_generator.events import SampleEvent from .extractor import Extractor, PythonCodeExtractor @@ -100,10 +107,6 @@ def __init__( self.dataset = dataset self.report_dir = Path(report_dir) self.extractor = extractor - # If the dataset was transformed with a preset, we still treat it as the original - # dataset name for the purposes of scoring - if "::" in dataset_name: - dataset_name = dataset_name.split("::")[0] self.dataset_name = dataset_name self.ground_truth_column = ( @@ -123,22 +126,30 @@ def _load_sample_index_map(self): return d[self.dataset_name] # Implicitly raises KeyError def get_outputs(self): - # TODO: Currently, the outputs are only saved in the events.jsonl file, which is quite - # large, and only saved optionally. Later, we should move to saving the outputs in a - # separate file for easier compute. + """Read COMPLETE events from events.jsonl and extract response text. + + The EventLoggerService writes EventRecord objects serialized via msgspec. + We decode them using the EventRecord decoder and extract the response + text from TextModelOutput data. + """ events_log_path = self.report_dir / "events.jsonl" if not events_log_path.exists(): raise FileNotFoundError(f"Events log file not found at {events_log_path}") - outputs = [] + decoder = msgspec.json.Decoder(type=EventRecord, dec_hook=EventType.decode_hook) + outputs: list[dict[str, str]] = [] with events_log_path.open("r") as f: for line in f: - event = msgspec.json.decode(line.strip()) - if event["event_type"] == SampleEvent.COMPLETE.value: - outputs.append(event) - df = pd.DataFrame(outputs, columns=["sample_uuid", "value"]) - df.rename(columns={"value": "output"}, inplace=True) - return df + stripped = line.strip() + if not stripped: + continue + record = decoder.decode(stripped) + if record.event_type == SampleEventType.COMPLETE: + output_text = str(record.data) if record.data is not None else "" + outputs.append( + {"sample_uuid": record.sample_uuid, "output": output_text} + ) + return pd.DataFrame(outputs) def match_sample_index(self, row: pd.Series) -> pd.Series: # Pandas Apply function to create a new 'sample_index' column @@ -226,27 +237,13 @@ class RougeScorer(Scorer, scorer_id="rouge"): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - try: - import importlib.util as _importlib_util - - if ( - _importlib_util.find_spec("evaluate") is None - or _importlib_util.find_spec("nltk") is None - or _importlib_util.find_spec("rouge_score") is None - ): - raise ImportError - - import evaluate - import nltk - - self.metric = evaluate.load("rouge") - self.nltk = nltk - - except ImportError: + if _evaluate is None or _nltk is None: raise ImportError( "nltk, evaluate, and rouge_score are required for ROUGE scoring. " "Install with: pip install nltk evaluate rouge_score" - ) from None + ) + self.metric = _evaluate.load("rouge") + self.nltk = _nltk def postprocess_text(self, texts): texts = [text.strip() for text in texts] diff --git a/src/inference_endpoint/load_generator/__init__.py b/src/inference_endpoint/load_generator/__init__.py index 94b6309f..840fa635 100644 --- a/src/inference_endpoint/load_generator/__init__.py +++ b/src/inference_endpoint/load_generator/__init__.py @@ -13,43 +13,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Load Generator for the MLPerf Inference Endpoint Benchmarking System. +"""Async load generator for the MLPerf Inference Endpoint Benchmarking System. -This module handles load pattern generation and query lifecycle management. -Status: To be implemented by the development team. +See docs/load_generator/DESIGN.md for the full design. """ -from .events import Event, SampleEvent, SessionEvent -from .load_generator import LoadGenerator, SampleIssuer, SchedulerBasedLoadGenerator -from .sample import IssuedSample, Sample, SampleEventHandler -from .scheduler import ( - ConcurrencyScheduler, - MaxThroughputScheduler, - PoissonDistributionScheduler, +from .delay import make_delay_fn, poisson_delay_fn +from .sample_order import ( SampleOrder, - Scheduler, WithoutReplacementSampleOrder, WithReplacementSampleOrder, + create_sample_order, +) +from .session import ( + BenchmarkSession, + PhaseConfig, + PhaseIssuer, + PhaseResult, + PhaseType, + SessionResult, +) +from .strategy import ( + BurstStrategy, + ConcurrencyStrategy, + LoadStrategy, + TimedIssueStrategy, + create_load_strategy, ) -from .session import BenchmarkSession __all__ = [ - "Event", - "SessionEvent", - "SampleEvent", - "Sample", - "SampleEventHandler", - "IssuedSample", - "Scheduler", - "ConcurrencyScheduler", - "MaxThroughputScheduler", - "PoissonDistributionScheduler", + # New async API + "BenchmarkSession", + "PhaseConfig", + "PhaseType", + "PhaseResult", + "SessionResult", + "PhaseIssuer", + "LoadStrategy", + "TimedIssueStrategy", + "BurstStrategy", + "ConcurrencyStrategy", + "create_load_strategy", "SampleOrder", - "WithReplacementSampleOrder", "WithoutReplacementSampleOrder", - "LoadGenerator", - "SampleIssuer", - "SchedulerBasedLoadGenerator", - "BenchmarkSession", + "WithReplacementSampleOrder", + "create_sample_order", + "make_delay_fn", + "poisson_delay_fn", ] diff --git a/src/inference_endpoint/load_generator/delay.py b/src/inference_endpoint/load_generator/delay.py new file mode 100644 index 00000000..bfa09a25 --- /dev/null +++ b/src/inference_endpoint/load_generator/delay.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inter-arrival delay functions for timed load strategies. + +Each function returns a callable that produces delay values in nanoseconds. +Used by TimedIssueStrategy for Poisson and other time-based load patterns. +""" + +from __future__ import annotations + +import random +from collections.abc import Callable + +from ..config.schema import LoadPattern, LoadPatternType + + +def poisson_delay_fn(target_qps: float, rng: random.Random) -> Callable[[], int]: + """Create a Poisson-distributed delay function. + + Returns inter-arrival delays following an exponential distribution + (Poisson process). Models realistic client behavior where requests + arrive independently at a target rate. + + How it works: + + ``expovariate(lambd)`` draws from the exponential distribution with rate + ``lambd``. Critically, the return value is in units of ``1 / lambd`` — + NOT in units of ``lambd``. So if ``lambd`` is expressed in + events-per-nanosecond, the return value is in nanoseconds. + + Step by step for target_qps = 50,000: + 1. lambd = 50,000 / 1e9 = 5e-5 events per nanosecond + 2. expovariate(5e-5) returns values with mean = 1 / 5e-5 = 20,000 ns + 3. So the average inter-arrival delay is 20,000 ns = 20 us + 4. This matches 50,000 QPS: 1 second / 20 us = 50,000 queries + + The return value is cast to int (nanoseconds). The ``max(1, ...)`` guard + prevents zero-delay at extreme QPS (> 500M), where the mean approaches + 1 ns and the exponential distribution produces sub-1 values ~63% of the + time. In practice, no system can issue > 500M QPS, so the guard is + purely defensive. + + Reference: https://docs.python.org/3/library/random.html#random.Random.expovariate + + Args: + target_qps: Target queries per second. + rng: Seeded random number generator for reproducibility. + + Returns: + Callable returning delay in nanoseconds (int, always >= 1). + """ + if target_qps <= 0: + raise ValueError(f"target_qps must be > 0, got {target_qps}") + lambd = target_qps / 1_000_000_000 # events per nanosecond + return lambda: max(1, int(rng.expovariate(lambd))) + + +def make_delay_fn(load_pattern: LoadPattern, rng: random.Random) -> Callable[[], int]: + """Create a delay function from a LoadPattern config. + + Only used by TimedIssueStrategy. MAX_THROUGHPUT uses BurstStrategy, + CONCURRENCY uses ConcurrencyStrategy — neither needs a delay function. + + Args: + load_pattern: LoadPattern config from schema.py. + rng: Seeded random number generator for reproducibility. + + Returns: + Callable returning delay in nanoseconds. + + Raises: + ValueError: If load pattern type has no delay function. + """ + if load_pattern.type == LoadPatternType.POISSON: + if load_pattern.target_qps is None or load_pattern.target_qps <= 0: + raise ValueError("Poisson load pattern requires target_qps > 0") + return poisson_delay_fn(load_pattern.target_qps, rng) + + raise ValueError( + f"No delay function for load pattern type: {load_pattern.type}. " + f"MAX_THROUGHPUT uses BurstStrategy, CONCURRENCY uses ConcurrencyStrategy." + ) diff --git a/src/inference_endpoint/load_generator/events.py b/src/inference_endpoint/load_generator/events.py deleted file mode 100644 index b3e80374..00000000 --- a/src/inference_endpoint/load_generator/events.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - - -class Event(Enum): - pass - - -class SessionEvent(Event): - TEST_STARTED = "test_started" - TEST_ENDED = "test_ended" - LOADGEN_ISSUE_CALLED = "loadgen_issue_called" - LOADGEN_STOP = "loadgen_stop" - LOADGEN_DATA_LOAD = "loadgen_data_load" - STOP_PERFORMANCE_TRACKING = "stop_performance_tracking" - ERROR = "error" - - -class SampleEvent(Event): - COMPLETE = "complete" - FIRST_CHUNK = "first_chunk_received" - NON_FIRST_CHUNK = "non_first_chunk_received" - HTTP_REQUEST_ISSUED = "http_request_issued" - HTTP_RESPONSE_COMPLETED = "http_response_completed" - ZMQ_REQUEST_RECEIVED = "zmq_request_received" - ZMQ_RESPONSE_SENT = "zmq_response_sent" diff --git a/src/inference_endpoint/load_generator/load_generator.py b/src/inference_endpoint/load_generator/load_generator.py deleted file mode 100644 index 0203d967..00000000 --- a/src/inference_endpoint/load_generator/load_generator.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from abc import ABC, abstractmethod -from typing import Any - -from ..dataset_manager.dataset import Dataset -from ..metrics.recorder import EventRecorder -from ..utils import sleep_ns -from .events import SessionEvent -from .sample import IssuedSample, Sample -from .scheduler import Scheduler - - -class SampleIssuer(ABC): - """Abstract base class for components that send samples to inference endpoints. - - SampleIssuers are responsible for the complete workflow of sending a sample - to a System Under Test (SUT): - 1. Ingest a Sample object from the Load Generator - 2. Build the appropriate request format (HTTP, gRPC, etc.) - 3. Send the request to the endpoint - 4. Handle the response asynchronously (results arrive via callbacks) - - Implementations must handle: - - Request formatting (converting Sample.data to endpoint-specific format) - - Network communication (HTTP, gRPC, WebSocket, etc.) - - Error handling (timeouts, connection errors, etc.) - - Response routing (back to metrics collector via events) - - Lifecycle: - 1. start() - Initialize connections, setup resources - 2. issue(sample) - Send samples (called repeatedly during benchmark) - 3. shutdown() - Clean up connections, release resources - - Example implementations: - - HttpClientSampleIssuer: HTTP/REST endpoints (OpenAI-compatible) - - GrpcSampleIssuer: gRPC endpoints (future) - """ - - def start(self): # noqa: B027 - """Initialize resources and establish connections. - - Called once after instantiation to set up any dependency components - like HTTP client pools, authentication, or connection pooling. - - Optional implementation - default does nothing. - - Raises: - SetupError: If initialization fails. - """ - pass - - @abstractmethod - def issue(self, sample: Sample): - """Send a sample to the SUT endpoint. - - This is the core method that sends a single sample/query to the endpoint. - It should be non-blocking and return quickly - actual response handling - happens asynchronously via the event system. - - The implementation must: - 1. Convert Sample.data to the endpoint's request format - 2. Send the request (typically async/non-blocking) - 3. Ensure response triggers appropriate events (COMPLETE, STREAM_CHUNK, etc.) - - Args: - sample: Sample object containing request data and metadata. - - Raises: - ExecutionError: If request cannot be sent. - """ - raise NotImplementedError - - def shutdown(self): # noqa: B027 - """Clean up resources and close connections. - - Called once when the issuer is no longer needed. Should gracefully - shutdown connections, flush pending requests, and release resources. - - Optional implementation - default does nothing. - """ - pass - - -class LoadGenerator(ABC): - """Abstract base class for load generation strategies. - - LoadGenerators control WHEN samples are issued to the SUT. They coordinate: - - Sample selection from the dataset (via DataLoader) - - Timing and scheduling (via Scheduler) - - Actual sample issuance (via SampleIssuer) - - Event recording for metrics - - Key responsibilities: - - Load sample data from dataset at the right time - - Apply scheduling/timing delays - - Issue samples via the SampleIssuer - - Record timing events for metrics - - LoadGenerators are iterators - each iteration issues one sample and - returns information about what was issued. - - Attributes: - sample_issuer: Component that sends samples to endpoints. - dataloader: Component that loads sample data from datasets. - """ - - def __init__( - self, - sample_issuer: SampleIssuer, - dataloader: Dataset, - name: str | None = None, - ): - """Initialize load generator with required dependencies. - - Args: - sample_issuer: SampleIssuer to send samples to endpoint. - dataloader: DataLoader to retrieve sample data from dataset. - """ - self.sample_issuer = sample_issuer - self.dataloader = dataloader - self.name = name - self.uuid_to_index_map: dict[str, int] = {} - - @abstractmethod - def __next__(self) -> IssuedSample: - """Issue the next sample according to the load generation strategy. - - This method should: - 1. Determine which sample to issue next - 2. Load the sample data from dataloader - 3. Apply any scheduling delays (blocking) - 4. Issue the sample via sample_issuer - 5. Return the sample and timestamp - - Note: This method MAY block to implement delays/scheduling. - It should only return AFTER the sample has been issued. - - Returns: - IssuedSample object containing the sample, index, and issue timestamp. - - Raises: - StopIteration: When all samples have been issued. - """ - raise NotImplementedError - - def __iter__(self): - """Return self as an iterator.""" - self.uuid_to_index_map = {} - return self - - def load_sample_data(self, sample_index: int, sample_uuid: str) -> Any: - """Load sample data from dataloader and record event. - - Helper method that loads sample data and records the data load event - for accurate timing measurements. - - Args: - sample_index: Index of sample in dataset. - sample_uuid: UUID of the sample being created. - - Returns: - Sample data loaded from dataloader (format depends on dataset). - """ - sample_data = self.dataloader.load_sample(sample_index) - EventRecorder.record_event( - SessionEvent.LOADGEN_DATA_LOAD, - time.monotonic_ns(), - sample_uuid=sample_uuid, - ) - return sample_data - - def issue_sample(self, sample: Sample) -> int: - """Issue a sample via the SampleIssuer and record timing event. - - Helper method that: - 1. Records the current timestamp - 2. Records LOADGEN_ISSUE_CALLED event for metrics - 3. Invokes sample_issuer.issue(sample) - 4. Returns the timestamp - - The timestamp is recorded BEFORE issuing to ensure accurate timing - even if the issue() call is slow or triggers immediate callbacks. - - Args: - sample: Sample to issue to the endpoint. - - Returns: - Monotonic nanosecond timestamp when issue was called. - """ - timestamp_ns = time.monotonic_ns() - - # Currently, EventRecorder will raise an Exception if the in-flight sample - # counter is negative. This happens if the SampleIssuer somehow invokes a - # SampleEvent.COMPLETE event before the record_event call for LOADGEN_ISSUE_CALLED - # goes off. - # This can be solved by just recording the issue() call right before actually - # invoking it. If this timing mechanism is a problem, we can remove the - # negative check in EventRecorder, since the order of insertions doesn't matter - # as much if the timestamps are correct. - EventRecorder.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - timestamp_ns, - sample_uuid=sample.uuid, - ) - logging.debug(f"Issuing sample {sample.uuid} at {timestamp_ns}") - self.sample_issuer.issue(sample) - return timestamp_ns - - -class SchedulerBasedLoadGenerator(LoadGenerator): - """LoadGenerator that uses a Scheduler to control sample timing. - - This is the primary LoadGenerator implementation, delegating timing decisions - to a pluggable Scheduler. It handles: - - Sample ordering (via scheduler's sample_order) - - Timing delays (via scheduler's delay_fn) - - Sample loading and issuance - - Timing measurements - - The scheduler determines: - - Which sample to issue next (sample index) - - How long to wait before issuing (delay in nanoseconds) - - This enables different load patterns (Poisson, max throughput, burst, etc.) - without changing the LoadGenerator code. - - Attributes: - scheduler: Scheduler controlling sample timing. - _iterator: Iterator over scheduler (sample_index, delay) pairs. - last_issue_timestamp_ns: Timestamp of last issued sample (for delay calculation). - """ - - def __init__( - self, - sample_issuer: SampleIssuer, - dataloader: Dataset, - scheduler: Scheduler, - ): - """Initialize scheduler-based load generator. - - Args: - sample_issuer: SampleIssuer to send samples to endpoint. - dataloader: DataLoader to retrieve sample data. - scheduler: Scheduler controlling timing and sample order. - """ - super().__init__(sample_issuer, dataloader) - - self.scheduler = scheduler - self._iterator = None - self.last_issue_timestamp_ns = 0 - self._start_time_ns: int | None = None - - def __next__(self) -> IssuedSample: - """Issue next sample according to scheduler timing. - - This method: - 1. Gets next (sample_index, delay_ns) from scheduler - 2. Loads sample data from dataloader - 3. Waits for scheduled time (busy-wait for precision) - 4. Issues sample via sample_issuer - 5. Returns IssuedSample with timing info - - The busy-wait ensures precise timing even for high QPS scenarios - where sleep() precision would be insufficient. - - Returns: - IssuedSample containing sample, index, and actual issue timestamp. - - Raises: - StopIteration: When scheduler has no more samples to issue. - """ - # Check wall-clock timeout before advancing the iterator, so we don't - # consume a (sample_index, delay) pair that will never be issued. - max_duration_ms = self.scheduler.runtime_settings.max_duration_ms - if max_duration_ms is not None and self._start_time_ns is not None: - elapsed_ns = time.monotonic_ns() - self._start_time_ns - if elapsed_ns >= max_duration_ms * 1_000_000: - logging.info( - f"max_duration_ms={max_duration_ms}ms reached after " - f"{elapsed_ns / 1e6:.1f}ms, stopping sample issuance" - ) - raise StopIteration - - # Let raised StopIteration be propagated up the stack - # Ignore mypy error complaining that self._iterator maybe None - s_idx, delay_ns = next(self._iterator) # type: ignore[call-overload] - - # Data loading is not timed for Time-to-Token metrics. It is assumed that the - # hypothetical user would have put the data into memory available for a network - # request beforehand. - sample = Sample(None) # Create sample object first to generate uuid - sample.data = self.load_sample_data(s_idx, sample.uuid) - - self.uuid_to_index_map[sample.uuid] = s_idx - - scheduled_issue_timestamp_ns = self.last_issue_timestamp_ns + delay_ns - while (now := time.monotonic_ns()) < scheduled_issue_timestamp_ns: - sleep_ns(scheduled_issue_timestamp_ns - now) - self.last_issue_timestamp_ns = self.issue_sample(sample) - return IssuedSample(sample, s_idx, self.last_issue_timestamp_ns) - - def __iter__(self): - if self._iterator is not None: - raise RuntimeError( - "SchedulerBasedLoadGenerator can only be iterated over once" - ) - self._start_time_ns = time.monotonic_ns() - self._iterator = iter(self.scheduler) - return super().__iter__() diff --git a/src/inference_endpoint/load_generator/sample.py b/src/inference_endpoint/load_generator/sample.py deleted file mode 100644 index 2b9c371a..00000000 --- a/src/inference_endpoint/load_generator/sample.py +++ /dev/null @@ -1,217 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -import uuid -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -from ..core.types import QueryResult, StreamChunk -from ..metrics.recorder import EventRecorder, record_exception -from .events import SampleEvent - -logger = logging.getLogger(__name__) - - -class Sample: - """Represents a sample/query to be sent to an inference endpoint. - - A Sample encapsulates the request data and provides a unique identifier for - tracking through the benchmark lifecycle. It enforces immutability to prevent - accidental modification during benchmarking. - - Immutability rules: - - UUID is immutable once set (on creation) - - Data can be set once from None to a value, then immutable - - This allows delayed data loading while maintaining safety - - Memory optimization: - - Uses __slots__ to reduce memory overhead - - UUID as hex string (32 chars) instead of UUID object - - Attributes: - uuid: Unique hex string identifier for this sample (32 characters). - data: Request payload (dict, typically with prompt/model/params). - Can be None initially and set once. - - Example: - >>> sample = Sample({"prompt": "Hello", "model": "gpt-4"}) - >>> sample.uuid # '8f3d2a1b...' (32 char hex) - >>> sample.data["prompt"] # 'Hello' - """ - - __slots__ = ["uuid", "data"] - - def __init__(self, data: Any): - """Initialize sample with data and generate unique ID. - - Args: - data: Request data to send to endpoint. Can be None if data - will be loaded later, but can only be set once. - """ - # 128-bit UUID might be a little overkill for our use case, we can investigate slimming down memory usage - self.uuid = uuid.uuid4().hex - self.data = data - - def __setattr__(self, name: str, value: Any): - if not hasattr(self, name) or (name == "data" and self.data is None): - object.__setattr__(self, name, value) - else: - raise AttributeError(f"Sample is immutable - cannot set attribute: {name}") - - -class _SampleEventHandler: - """Contains handlers for SampleEvents given a sample UUID. This is also to avoid needing other classes - to do their own bookkeeping for Sample objects, which can be discarded once they are issued, as long as - their UUIDs are saved. - - This class is a singleton rather than a class method mainly because it needs to hold some state (i.e. hooks) - - A user can register hooks to any event type, and will be run in the order they were registered. - A valid hook is a callable that takes a single argument, representing the response object (StreamChunk or QueryResult). - - A simple example use-case of a hook is to update a progress bar on-completion of a sample. - - NOTE: Hook lists are not thread-safe. Hooks must be registered before the benchmark - starts (single-threaded setup phase). This is a known limitation; _SampleEventHandler - is being deprecated in favor of the pub-sub EventLoggerService. - """ - - __slots__ = ["first_chunk_hooks", "non_first_chunk_hooks", "complete_hooks"] - - SINGLETON = None - _initialized = False - - def __new__(cls): - if cls.SINGLETON is None: - cls.SINGLETON = super().__new__(cls) - return cls.SINGLETON - - def __init__(self): - if _SampleEventHandler._initialized: - return - _SampleEventHandler._initialized = True - - self.first_chunk_hooks = [] - self.non_first_chunk_hooks = [] - self.complete_hooks = [] - - def register_hook( - self, - event_type: SampleEvent, - hook: Callable[[StreamChunk], None] | Callable[[QueryResult], None], - ) -> None: - if event_type == SampleEvent.FIRST_CHUNK: - self.first_chunk_hooks.append(hook) - elif event_type == SampleEvent.NON_FIRST_CHUNK: - self.non_first_chunk_hooks.append(hook) - elif event_type == SampleEvent.COMPLETE: - self.complete_hooks.append(hook) - else: - raise ValueError(f"Invalid event type: {event_type}") - - def clear_hooks(self, ev_type: SampleEvent | None = None) -> None: - if ev_type is None: - self.first_chunk_hooks.clear() - self.non_first_chunk_hooks.clear() - self.complete_hooks.clear() - elif ev_type == SampleEvent.FIRST_CHUNK: - self.first_chunk_hooks.clear() - elif ev_type == SampleEvent.NON_FIRST_CHUNK: - self.non_first_chunk_hooks.clear() - elif ev_type == SampleEvent.COMPLETE: - self.complete_hooks.clear() - - def stream_chunk_complete(self, chunk: StreamChunk) -> None: - """Handle completion of a streaming chunk. - - Called when a chunk arrives from a streaming response. Records timing - event and invokes registered hooks for first/non-first chunks. - - Args: - chunk: StreamChunk containing response data and metadata. - """ - timestamp_ns = time.monotonic_ns() - - assert isinstance(chunk, StreamChunk), f"Invalid chunk type: {type(chunk)}" - - hooks = [] - if chunk.metadata.get("first_chunk", False): - EventRecorder.record_event( - SampleEvent.FIRST_CHUNK, - timestamp_ns, - sample_uuid=chunk.id, - data=chunk.response_chunk, - ) - hooks = self.first_chunk_hooks - else: - EventRecorder.record_event( - SampleEvent.NON_FIRST_CHUNK, - timestamp_ns, - sample_uuid=chunk.id, - ) - hooks = self.non_first_chunk_hooks - - for hook in hooks: - hook(chunk) - - def query_result_complete(self, result: QueryResult) -> None: - """Handle completion of a query (success or failure). - - Called when a query finishes (with response or error). Records timing - event and invokes registered completion hooks. - - Args: - result: QueryResult containing response data or error information. - """ - timestamp_ns = time.monotonic_ns() - - assert isinstance(result, QueryResult), f"Invalid result type: {type(result)}" - - # Even if there is an error, we still record the event to count the sample as complete - if result.error is not None: - err_str = str(result.error) - logger.error(f"Error in request {result.id}: {err_str}") - - record_exception(err_str, result.id) - - EventRecorder.record_event( - SampleEvent.COMPLETE, - timestamp_ns, - sample_uuid=result.id, - data=result.response_output, - ) - - for hook in self.complete_hooks: - hook(result) - - -@dataclass -class IssuedSample: - """Contains data about a sample that has been issued to the inference endpoint. - - SampleIssuer is not allowed to know the actual sample index of the data to prevent cheating - and response caching. This class contains metadata about the sample for bookkeeping by the - LoadGenerator and BenchmarkSession. - """ - - sample: Sample - index: int - issue_timestamp_ns: int - - -SampleEventHandler = _SampleEventHandler() diff --git a/src/inference_endpoint/load_generator/sample_order.py b/src/inference_endpoint/load_generator/sample_order.py new file mode 100644 index 00000000..a08156c1 --- /dev/null +++ b/src/inference_endpoint/load_generator/sample_order.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample ordering strategies for benchmark dataset traversal. + +SampleOrder is an infinite iterator yielding dataset indices. It never raises +StopIteration — termination is controlled by BenchmarkSession._make_stop_check(). +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod + +from ..config.runtime_settings import RuntimeSettings + + +class SampleOrder(ABC): + """Abstract base class for sample ordering strategies. + + Yields dataset sample indices indefinitely. Different strategies enable + different testing scenarios (balanced coverage vs random sampling). + + Attributes: + n_samples_in_dataset: Number of unique samples available in dataset. + rng: Random number generator for reproducible randomness. + """ + + def __init__( + self, + n_samples_in_dataset: int, + rng: random.Random = random, # type: ignore[assignment] + ): + if n_samples_in_dataset <= 0: + raise ValueError( + f"n_samples_in_dataset must be > 0, got {n_samples_in_dataset}" + ) + self.n_samples_in_dataset = n_samples_in_dataset + self.rng = rng + + def __iter__(self): + return self + + def __next__(self) -> int: + return self.next_sample_index() + + @abstractmethod + def next_sample_index(self) -> int: + """Get the next sample index to issue. + + Returns: + Sample index (0 to n_samples_in_dataset-1). + """ + raise NotImplementedError + + +class WithoutReplacementSampleOrder(SampleOrder): + """Shuffle dataset, use all samples before repeating. + + Ensures balanced coverage: shuffles all dataset indices, issues them one + by one until exhausted, then reshuffles and repeats (infinite cycle). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.index_order = list(range(self.n_samples_in_dataset)) + # Force initial shuffle on first call + self._curr_idx = self.n_samples_in_dataset + 1 + + def _reset(self): + self.rng.shuffle(self.index_order) + self._curr_idx = 0 + + def next_sample_index(self) -> int: + if self._curr_idx >= len(self.index_order): + self._reset() + retval = self.index_order[self._curr_idx] + self._curr_idx += 1 + return retval + + +class WithReplacementSampleOrder(SampleOrder): + """Truly random sampling from dataset with replacement. + + Each sample is chosen uniformly at random, independent of previous choices. + """ + + def next_sample_index(self) -> int: + return self.rng.randint(0, self.n_samples_in_dataset - 1) + + +def create_sample_order(settings: RuntimeSettings) -> SampleOrder: + """Create a SampleOrder from RuntimeSettings.""" + return WithoutReplacementSampleOrder( + n_samples_in_dataset=settings.n_samples_from_dataset, + rng=settings.rng_sample_index, + ) diff --git a/src/inference_endpoint/load_generator/scheduler.py b/src/inference_endpoint/load_generator/scheduler.py deleted file mode 100644 index ae691d09..00000000 --- a/src/inference_endpoint/load_generator/scheduler.py +++ /dev/null @@ -1,420 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import threading -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterator - -from ..config.runtime_settings import RuntimeSettings -from ..config.schema import LoadPatternType -from .sample import SampleEvent, SampleEventHandler - - -class SampleOrder(ABC): - """Abstract base class for sample ordering strategies. - - SampleOrder determines which dataset sample to use next when issuing queries. - Different strategies enable different testing scenarios: - - The SampleOrder is an iterator that yields sample indices from the dataset. - It handles wrapping around when total_samples_to_issue > dataset size. - - Attributes: - total_samples_to_issue: Total number of samples to issue during benchmark. - n_samples_in_dataset: Number of unique samples available in dataset. - rng: Random number generator for reproducible randomness. - _issued_samples: Counter of samples issued so far. - """ - - def __init__( - self, total_samples_to_issue: int, n_samples_in_dataset: int, rng=random - ): - """Initialize sample ordering strategy. - - Args: - total_samples_to_issue: The total number of samples to issue. - May be larger than n_samples_in_dataset. - n_samples_in_dataset: The number of unique samples in the dataset. - rng: Random number generator (for reproducibility via seeding). - """ - self.total_samples_to_issue = total_samples_to_issue - self.n_samples_in_dataset = n_samples_in_dataset - self.rng = rng - - self._issued_samples = 0 - - def __iter__(self) -> Iterator[int]: - """Iterate over sample indices to issue. - - Yields sample indices until total_samples_to_issue is reached. - - Yields: - Sample index (0 to n_samples_in_dataset-1). - """ - while self._issued_samples < self.total_samples_to_issue: - yield self.next_sample_index() - self._issued_samples += 1 - - @abstractmethod - def next_sample_index(self) -> int: - """Get the next sample index to issue. - - Returns: - Sample index (0 to n_samples_in_dataset-1). - """ - raise NotImplementedError - - -class WithoutReplacementSampleOrder(SampleOrder): - """Sample ordering without replacement - shuffle dataset, use all samples before repeating. - - This strategy ensures balanced coverage of the dataset: - 1. Shuffles all dataset indices randomly - 2. Issues them one by one until exhausted - 3. Reshuffles and repeats if more samples needed - - Use this for: - - Fair benchmarking (all samples used equally) - - Avoiding bias from repeated samples - - Deterministic results with seed control - - Example with 3-sample dataset, 7 samples to issue: - - Shuffle: [2, 0, 1] - - Issue: 2, 0, 1 (first pass) - - Reshuffle: [1, 2, 0] - - Issue: 1, 2, 0, 1 (second pass, partial) - - Attributes: - index_order: Current shuffled order of indices. - _curr_idx: Position in current shuffle (resets after each complete pass). - """ - - def __init__(self, *args, **kwargs): - """Initialize without-replacement sample ordering. - - Args: - *args: Forwarded to SampleOrder.__init__. - **kwargs: Forwarded to SampleOrder.__init__. - """ - super().__init__(*args, **kwargs) - self.index_order = list(range(self.n_samples_in_dataset)) - self._curr_idx = ( - self.n_samples_in_dataset + 1 - ) # Ensure we start at an invalid index to force shuffle - - def _reset(self): - """Shuffle indices and reset position for next pass.""" - self.rng.shuffle(self.index_order) - self._curr_idx = 0 - - def next_sample_index(self) -> int: - """Get next sample index from current shuffle, reshuffling if needed. - - Returns: - Sample index from dataset. - """ - if self._curr_idx >= len(self.index_order): - self._reset() - retval = self.index_order[self._curr_idx] - self._curr_idx += 1 - return retval - - -class WithReplacementSampleOrder(SampleOrder): - """Sample ordering with replacement - truly random sampling from dataset. - - Each sample is chosen uniformly at random from the entire dataset, - independent of previous choices. The same sample can (and will) appear - multiple times, even consecutively. - - Use this for: - - Stress testing with realistic randomness - - Simulating unpredictable user behavior - - When dataset coverage balance is not important - - Example with 3-sample dataset, 7 samples to issue: - - Might produce: [1, 1, 0, 2, 1, 0, 0] - - Note repeated samples even without exhausting dataset - """ - - def __init__(self, *args, **kwargs): - """Initialize with-replacement sample ordering. - - Args: - *args: Forwarded to SampleOrder.__init__. - **kwargs: Forwarded to SampleOrder.__init__. - """ - super().__init__(*args, **kwargs) - - def next_sample_index(self) -> int: - """Get random sample index from dataset. - - Returns: - Random sample index (uniform distribution over dataset). - """ - return self.rng.randint(0, self.n_samples_in_dataset - 1) - - -def uniform_delay_fn( - max_delay_ns: int = 0, rng: random.Random | None = None -) -> Callable[[], float]: - """Create a uniform delay function for schedulers. - - Returns a function that generates delays uniformly distributed between - 0 and max_delay_ns. Used for max throughput (max_delay_ns=0) or uniform - load distribution. - - Args: - max_delay_ns: Maximum delay in nanoseconds. If 0, always returns 0 (no delay). - rng: Random number generator for reproducibility. - - Returns: - Function that returns delay in nanoseconds (float). - """ - rng = rng or random.Random() - - def _fn(): - if max_delay_ns == 0: - return 0 - return rng.uniform(0, max_delay_ns) - - return _fn - - -def poisson_delay_fn( - expected_queries_per_second: float, rng: random.Random | None = None -) -> Callable[[], float]: - """Create a Poisson-distributed delay function for realistic online benchmarking. - - Returns a function that generates delays following an exponential distribution - (inter-arrival times of a Poisson process). This models realistic user/client - behavior where requests arrive independently at a target rate. - - The exponential distribution has the property that: - - Mean inter-arrival time = 1 / expected_qps - - Variance = mean^2 (high variability, realistic for network traffic) - - Args: - expected_queries_per_second: Target QPS (queries per second). - rng: Random number generator for reproducibility. - - Returns: - Function that returns delay in nanoseconds (float). - """ - rng = rng or random.Random() - queries_per_ns = expected_queries_per_second / 1e9 - - def _fn(): - if queries_per_ns == 0: - return 0 - return rng.expovariate(lambd=queries_per_ns) # lambd=1/mean, where mean=latency - - return _fn - - -class Scheduler: - """Base class for query scheduling strategies that control benchmark load patterns. - - Schedulers determine: - 1. Sample ordering (which sample to use next) - 2. Timing delays (when to issue the next query) - - They combine a SampleOrder (what to issue) with a delay function (when to issue) - to produce a stream of (sample_index, delay_ns) pairs. - - Scheduler implementations auto-register via __init_subclass__ by specifying - the load_pattern parameter. This enables runtime selection of schedulers: - - scheduler_cls = Scheduler.get_implementation(LoadPatternType.POISSON) - scheduler = scheduler_cls(runtime_settings, sample_order_cls) - - Built-in schedulers: - - MaxThroughputScheduler: Issues all queries immediately (offline mode) - - PoissonDistributionScheduler: Poisson-distributed delays (online mode) - - ConcurrencyScheduler: Fixed concurrency level (online mode) - - Attributes: - _IMPL_MAP: Class-level registry mapping LoadPatternType to Scheduler classes. - runtime_settings: Runtime configuration (QPS, duration, seeds, etc.). - total_samples_to_issue: Total queries to issue during benchmark. - n_unique_samples: Number of unique samples in dataset. - sample_order: Iterator over sample indices to use. - delay_fn: Function returning delay before next query (nanoseconds). - """ - - # Registry for scheduler implementations (populated via __init_subclass__) - _IMPL_MAP: dict[LoadPatternType, type["Scheduler"]] = {} - - def __init__( - self, - runtime_settings: RuntimeSettings, - sample_order_cls: type[SampleOrder], - ): - """Initialize scheduler with runtime settings and sample ordering strategy. - - Args: - runtime_settings: Runtime configuration containing QPS, duration, seeds. - sample_order_cls: SampleOrder class to use for sample selection. - """ - self.runtime_settings = runtime_settings - - self.total_samples_to_issue = runtime_settings.total_samples_to_issue() - self.n_unique_samples = runtime_settings.n_samples_from_dataset - self.sample_order = iter( - sample_order_cls( - self.total_samples_to_issue, - self.n_unique_samples, - rng=self.runtime_settings.rng_sample_index, - ) - ) - self.delay_fn: Callable[[], int] | None = None # Subclasses must set this - - def __iter__(self): - """Iterate over (sample_index, delay_ns) pairs. - - Yields: - Tuple of (sample_index, delay_ns): - - sample_index: Index of sample to issue next - - delay_ns: Nanoseconds to wait before issuing - """ - for s_idx in self.sample_order: - yield s_idx, self.delay_fn() - - def __init_subclass__(cls, load_pattern: LoadPatternType | None = None, **kwargs): - """Auto-register scheduler implementations. - - Args: - load_pattern: LoadPatternType to bind this scheduler to - - Raises: - ValueError: If load_pattern already registered - """ - super().__init_subclass__(**kwargs) - - if load_pattern is not None: - if load_pattern in Scheduler._IMPL_MAP: - raise ValueError( - f"Cannot bind {cls.__name__} to {load_pattern} - " - f"Already bound to {Scheduler._IMPL_MAP[load_pattern].__name__}" - ) - Scheduler._IMPL_MAP[load_pattern] = cls - - @classmethod - def get_implementation(cls, load_pattern: LoadPatternType) -> type["Scheduler"]: - """Get scheduler implementation for load pattern. - - Args: - load_pattern: LoadPatternType enum - - Returns: - Scheduler subclass - - Raises: - NotImplementedError: If no implementation registered - KeyError: If load_pattern invalid - """ - if load_pattern not in cls._IMPL_MAP: - available_str = ", ".join(p.value for p in cls._IMPL_MAP.keys()) - raise KeyError( - f"No scheduler registered for '{load_pattern.value}'. " - f"Available: {available_str}" - ) - return cls._IMPL_MAP[load_pattern] - - -class MaxThroughputScheduler(Scheduler, load_pattern=LoadPatternType.MAX_THROUGHPUT): - """Offline max throughput scheduler (all queries at t=0). - - Auto-registers for LoadPatternType.MAX_THROUGHPUT. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.delay_fn = uniform_delay_fn(rng=self.runtime_settings.rng_sched) - - -class PoissonDistributionScheduler(Scheduler, load_pattern=LoadPatternType.POISSON): - """Poisson-distributed query scheduler for online benchmarking. - - Simulates realistic client-server network usage by using a Poisson process - to issue queries. The delay between each sample is sampled from an exponential - distribution, centered around the expected latency based on target QPS. - - Use this scheduler for online latency testing with sustained QPS. - - Auto-registers for LoadPatternType.POISSON. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.delay_fn = poisson_delay_fn( - expected_queries_per_second=self.runtime_settings.metric_target.target, - rng=self.runtime_settings.rng_sched, - ) - - -class ConcurrencyScheduler(Scheduler, load_pattern=LoadPatternType.CONCURRENCY): - """Concurrency-based scheduler that maintains fixed concurrent requests. - - Issues queries based on COMPLETION events rather than time delays. - Maintains target concurrency level (e.g., always 32 requests in-flight). - - Auto-registers for LoadPatternType.CONCURRENCY. - """ - - def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): - super().__init__(runtime_settings, sample_order_cls) - assert runtime_settings.load_pattern is not None - target_concurrency = runtime_settings.load_pattern.target_concurrency - if target_concurrency is None or target_concurrency <= 0: - raise ValueError( - f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}" - ) - - # Use threading.Condition for concurrency control with explicit counter - self._condition = threading.Condition() - self._inflight = 0 - self._target_concurrency = target_concurrency - - # Register completion hook - free up slot when query completes - SampleEventHandler.register_hook(SampleEvent.COMPLETE, self._release_slot) - - # Unused (required by Scheduler interface) - returns 0 delay - self.delay_fn = lambda: 0 - - def _release_slot(self, result=None): - """Release a concurrency slot and notify waiting threads. - - Args: - result: QueryResult from completed query (unused, required by hook signature) - """ - with self._condition: - self._inflight -= 1 - self._condition.notify() - - def __iter__(self): - """ - Iterate over sample indices to issue. - Yields sample indices until total_samples_to_issue is reached. - - Waits for available concurrency slot before yielding each sample index. - """ - for s_idx in self.sample_order: - with self._condition: - while self._inflight >= self._target_concurrency: - self._condition.wait() - self._inflight += 1 - yield s_idx, 0 diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index da5faf84..1c8ad992 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -13,327 +13,479 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Async benchmark session: orchestrates phases, issues samples, receives responses. + +See docs/load_generator/DESIGN.md for the full design. +""" + from __future__ import annotations +import asyncio import logging import os -import threading import time import uuid -from pathlib import Path - -import msgspec.json -from transformers import AutoTokenizer +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from typing import Protocol from ..config.runtime_settings import RuntimeSettings +from ..core.record import ( + ErrorEventType, + EventRecord, + SampleEventType, + SessionEventType, +) +from ..core.types import PromptData, Query, QueryResult, StreamChunk from ..dataset_manager.dataset import Dataset -from ..metrics.recorder import EventRecorder -from ..metrics.reporter import MetricsReporter -from ..utils.version import get_version_info -from .events import SessionEvent -from .load_generator import LoadGenerator, SampleIssuer, SchedulerBasedLoadGenerator -from .scheduler import Scheduler, WithoutReplacementSampleOrder +from .sample_order import create_sample_order +from .strategy import LoadStrategy, create_load_strategy logger = logging.getLogger(__name__) -# poll interval for checking if test-session should end -SHUTDOWN_POLL_INTERVAL_S = 10.0 +_WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" -class BenchmarkSession: +# --------------------------------------------------------------------------- +# Phase configuration +# --------------------------------------------------------------------------- + + +class PhaseType(str, Enum): + """Phase types control tracking and reporting behavior.""" + + PERFORMANCE = "performance" + ACCURACY = "accuracy" + WARMUP = "warmup" + + +@dataclass(frozen=True, slots=True) +class PhaseConfig: + """Configuration for a single benchmark phase.""" + + name: str + runtime_settings: RuntimeSettings + dataset: Dataset + phase_type: PhaseType = PhaseType.PERFORMANCE + + +# --------------------------------------------------------------------------- +# Results +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class PhaseResult: + """Result of a single benchmark phase.""" + + name: str + phase_type: PhaseType + uuid_to_index: dict[str, int] + issued_count: int + start_time_ns: int + end_time_ns: int + + +@dataclass(frozen=True) +class SessionResult: + """Combined results from all phases in a session.""" + + session_id: str + phase_results: list[PhaseResult] + start_time_ns: int + end_time_ns: int + + @property + def perf_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.PERFORMANCE] + + @property + def accuracy_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.ACCURACY] + + +# --------------------------------------------------------------------------- +# SampleIssuer protocol +# --------------------------------------------------------------------------- + + +class SampleIssuer(Protocol): + """Sends queries to an endpoint and receives responses. + + Matches HTTPEndpointClient's interface: issue (sync ZMQ push), + recv (async ZMQ recv), shutdown. + """ + + def issue(self, query: Query) -> None: ... + async def recv(self) -> QueryResult | StreamChunk | None: ... + def shutdown(self) -> None: ... + + +# --------------------------------------------------------------------------- +# EventRecordPublisher protocol +# --------------------------------------------------------------------------- + + +class EventPublisher(Protocol): + """Publishes EventRecords to the metrics pipeline.""" + + def publish(self, event_record: EventRecord) -> None: ... + def flush(self) -> None: ... + + +# --------------------------------------------------------------------------- +# PhaseIssuer +# --------------------------------------------------------------------------- + + +class PhaseIssuer: + """Per-phase state holder that wraps the issue logic. + + Created fresh for each phase. Holds the phase-scoped uuid_to_index map, + inflight counter, and issued count. Strategies call issue(sample_index) + to load data, build a Query, publish ISSUED, and send to the endpoint. + """ + + __slots__ = ( + "_dataset", + "_issuer", + "_publisher", + "_stop_check", + "uuid_to_index", + "inflight", + "issued_count", + ) + def __init__( self, - runtime_settings: RuntimeSettings, - session_id: str | None = None, + dataset: Dataset, + issuer: SampleIssuer, + publisher: EventPublisher, + stop_check: Callable[[], bool], ): - self.logger = logging.getLogger(__name__) - self.runtime_settings = runtime_settings - self.session_id = session_id if session_id else uuid.uuid4().hex + self._dataset = dataset + self._issuer = issuer + self._publisher = publisher + self._stop_check = stop_check + self.uuid_to_index: dict[str, int] = {} + self.inflight: int = 0 + self.issued_count: int = 0 + + def issue(self, sample_index: int) -> str | None: + """Load data, build Query, publish ISSUED, send to endpoint. + + Returns query_id on success, None if session is stopping. + + Note: load_sample() runs synchronously before the ISSUED timestamp. + For accurate timing, datasets MUST be pre-loaded into memory. + Disk-backed datasets will inflate timing and delay subsequent issues. + """ + if self._stop_check(): + return None + query_id = uuid.uuid4().hex + data = self._dataset.load_sample(sample_index) + query = Query(id=query_id, data=data) + self.uuid_to_index[query_id] = sample_index + ts = time.monotonic_ns() + prompt_data: PromptData + if isinstance(data, dict): + token_ids = data.get("input_tokens") or data.get("token_ids") + prompt_data = PromptData( + text=data.get("prompt"), + token_ids=tuple(token_ids) if token_ids is not None else None, + ) + else: + prompt_data = PromptData() + self._publisher.publish( + EventRecord( + event_type=SampleEventType.ISSUED, + timestamp_ns=ts, + sample_uuid=query_id, + data=prompt_data, + ) + ) + self._issuer.issue(query) + self.inflight += 1 + self.issued_count += 1 + return query_id - # EventRecorder will set this when all samples complete, helps avoid busy-waiting - self.end_event = threading.Event() - self.thread: threading.Thread | None = None - # CPython GIL provides atomic boolean writes, no need for threading.Event() - self.stop_requested = False +# --------------------------------------------------------------------------- +# BenchmarkSession +# --------------------------------------------------------------------------- - self.event_recorder = EventRecorder( - session_id=self.session_id, notify_idle=self.end_event - ) - # Will be populated after the test finishes by _run_test - self.report = None - self.sample_uuid_map: dict[str, dict[str, int]] | None = None +class BenchmarkSession: + """Async benchmark orchestrator. Single thread, single event loop. - @property - def is_running(self): - return self.thread is not None and self.thread.is_alive() + Runs phases sequentially. Each phase gets its own PhaseIssuer and + LoadStrategy. The receiver coroutine runs concurrently throughout, + processing responses and routing completions to the active strategy. + """ - def stop(self) -> None: - """Signal the session to stop early.""" - self.stop_requested = True - # wakeup _run_test if needed, short-circuit SHUTDOWN_POLL_INTERVAL_S - self.end_event.set() - - def _run_test( + def __init__( self, - perf_test_generator: LoadGenerator, - accuracy_test_generators: dict[str, LoadGenerator] | None = None, - max_shutdown_timeout_s: float | None = 300.0, - report_dir: os.PathLike | None = None, - tokenizer_override: AutoTokenizer | None = None, - dump_events_log: bool = False, + issuer: SampleIssuer, + event_publisher: EventPublisher, + loop: asyncio.AbstractEventLoop, + on_sample_complete: Callable[[QueryResult], None] | None = None, + session_id: str | None = None, ): - with self.event_recorder: - try: - EventRecorder.record_event( - SessionEvent.TEST_STARTED, - time.monotonic_ns(), - data=get_version_info(), - ) + self._issuer = issuer + self._publisher = event_publisher + self._loop = loop + self._on_sample_complete = on_sample_complete + self.session_id = session_id or uuid.uuid4().hex + + # Mutable state + self._stop_requested = False + self._done = False + self._current_phase_issuer: PhaseIssuer | None = None + self._current_strategy: LoadStrategy | None = None + self._recv_task: asyncio.Task | None = None + self._strategy_task: asyncio.Task | None = None + self._drain_event = asyncio.Event() + + def stop(self) -> None: + """Signal early termination. Safe to call from signal handler. + + Cancels the running strategy task to unblock strategies that may be + waiting on semaphores or other async primitives. Also sets the drain + event to unblock _drain_inflight if it's waiting for responses. + """ + self._stop_requested = True + self._drain_event.set() + if self._strategy_task and not self._strategy_task.done(): + self._strategy_task.cancel() + + async def run(self, phases: list[PhaseConfig]) -> SessionResult: + """Run all benchmark phases sequentially. - for _ in perf_test_generator: - # Actual issue is done during next(generator). Nothing else to do here, just pass. + Returns SessionResult with per-phase results. + """ + session_start = time.monotonic_ns() + self._publish_session_event(SessionEventType.STARTED) + + self._recv_task = asyncio.create_task(self._receive_responses()) + phase_results: list[PhaseResult] = [] + + try: + for phase in phases: + if self._stop_requested: + break + if phase.phase_type == PhaseType.WARMUP and not _WARMUP_ENABLED: + logger.info( + "Skipping warmup phase %s (set ENABLE_WARMUP=1 to enable)", + phase.name, + ) + continue + result = await self._run_phase(phase) + if result is not None: + phase_results.append(result) + finally: + self._done = True + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: pass + self._publish_session_event(SessionEventType.ENDED) - EventRecorder.record_event( - SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns() - ) - self.logger.info("All performance samples issued") + return SessionResult( + session_id=self.session_id, + phase_results=phase_results, + start_time_ns=session_start, + end_time_ns=time.monotonic_ns(), + ) + + async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: + """Run a single phase. Returns PhaseResult or None for warmup.""" + logger.info("Starting phase: %s (%s)", phase.name, phase.phase_type.value) + phase_start = time.monotonic_ns() + + # Create per-phase state + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) + phase_issuer = PhaseIssuer( + dataset=phase.dataset, + issuer=self._issuer, + publisher=self._publisher, + stop_check=self._make_stop_check(phase.runtime_settings, phase_start), + ) - if accuracy_test_generators: - for _, generator in accuracy_test_generators.items(): - for _ in generator: - # Actual issue is done during next(generator). Nothing else to do here, just pass. - pass + self._current_phase_issuer = phase_issuer + self._current_strategy = strategy + + # Performance phases get tracking events + if phase.phase_type == PhaseType.PERFORMANCE: + self._publish_session_event(SessionEventType.START_PERFORMANCE_TRACKING) + + # Execute the strategy as a task so it can be cancelled on transport close + self._strategy_task = asyncio.create_task(strategy.execute(phase_issuer)) + try: + await self._strategy_task + except asyncio.CancelledError: + logger.info("Strategy cancelled for phase %s", phase.name) + finally: + self._strategy_task = None + + # Drain in-flight (skip for warmup — keep concurrency hot) + if phase.phase_type != PhaseType.WARMUP: + await self._drain_inflight(phase_issuer) + + if phase.phase_type == PhaseType.PERFORMANCE: + self._publish_session_event(SessionEventType.STOP_PERFORMANCE_TRACKING) + + phase_end = time.monotonic_ns() + logger.info( + "Phase %s complete: %d samples issued", + phase.name, + phase_issuer.issued_count, + ) - self.logger.info("All accuracy samples issued") + # Saturation phases produce no result + if phase.phase_type == PhaseType.WARMUP: + return None + + return PhaseResult( + name=phase.name, + phase_type=phase.phase_type, + uuid_to_index=phase_issuer.uuid_to_index, + issued_count=phase_issuer.issued_count, + start_time_ns=phase_start, + end_time_ns=phase_end, + ) - self.event_recorder.should_check_idle = True - EventRecorder.record_event( - SessionEvent.LOADGEN_STOP, time.monotonic_ns() + async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None: + """Wait for all in-flight responses from this phase to complete. + + Currently, there is no timeout for the drain step. In the future, + we can possibly add a dynamic timeout based on the rate of completion + throughout the current phase.""" + if phase_issuer.inflight <= 0 or self._stop_requested: + return + logger.info("Draining %d in-flight responses...", phase_issuer.inflight) + self._drain_event.clear() + await self._drain_event.wait() + + async def _receive_responses(self) -> None: + """Receive responses from the issuer. Runs as a concurrent task.""" + while not self._done: + resp = await self._issuer.recv() + if resp is None: + # Transport closed unexpectedly — trigger stop so strategy + # and drain don't hang waiting for responses that will never arrive. + logger.warning("Issuer recv() returned None — transport closed") + self._stop_requested = True + self._drain_event.set() # Unblock _drain_inflight + # Cancel the strategy task if it's blocked (e.g., ConcurrencyStrategy + # awaiting sem.acquire() that will never be released). + if self._strategy_task and not self._strategy_task.done(): + self._strategy_task.cancel() + break + self._handle_response(resp) + + def _handle_response(self, resp: QueryResult | StreamChunk) -> None: + """Route a response to the appropriate handler. + + Transport contract for streaming: the worker sends intermediate + StreamChunk messages for timing events, then a final QueryResult + with accumulated output for completion. + """ + phase_issuer = self._current_phase_issuer + + if isinstance(resp, QueryResult): + query_id = resp.id + self._publisher.publish( + EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=resp.completed_at + if isinstance(resp.completed_at, int) + else time.monotonic_ns(), + sample_uuid=query_id, + data=resp.response_output, ) - start_time = time.monotonic() - while self.event_recorder.n_inflight_samples != 0: - if ( - max_shutdown_timeout_s is not None - and time.monotonic() - start_time > max_shutdown_timeout_s - ): - raise TimeoutError( - f"Max shutdown timeout of {max_shutdown_timeout_s}s reached" - ) - - if self.stop_requested: - self.logger.info( - f"Early stop requested (pending={self.event_recorder.n_inflight_samples}), shutting down test..." - ) - break - - self.end_event.wait(timeout=SHUTDOWN_POLL_INTERVAL_S) - if max_shutdown_timeout_s is not None: - self.logger.debug( - f"Waiting for the test to end... {self.event_recorder.n_inflight_samples} samples remaining" - ) - - except Exception as e: - logger.error(f"Error running benchmark session: {e}") - raise e - finally: - EventRecorder.record_event(SessionEvent.TEST_ENDED, time.monotonic_ns()) - - self.event_recorder.wait_for_writes() - - # Handle reporting - with MetricsReporter(self.event_recorder.connection_name) as reporter: - has_model = hasattr(self.runtime_settings, "model") - tokenizer = None - if tokenizer_override is not None: - tokenizer = tokenizer_override - if has_model: - model = getattr(self.runtime_settings, "model", None) - if tokenizer is None and model is not None: - try: - tokenizer = AutoTokenizer.from_pretrained( - model if isinstance(model, str) else model.name - ) - except Exception as e: - logger.error( - f"Error loading tokenizer for model {model}: {e}" - ) - tokenizer = None - report = reporter.create_report(tokenizer) - - # Store report on session so external callers can use it - self.report = report - - # Consolidate UUID->index mappings - perf_name = ( - perf_test_generator.name - if perf_test_generator.name - else "performance" + ) + if resp.error is not None: + self._publisher.publish( + EventRecord( + event_type=ErrorEventType.GENERIC, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=resp.error, + ) ) - sample_idx_map = { - perf_name: perf_test_generator.uuid_to_index_map, - } - if accuracy_test_generators: - for default_name, generator in accuracy_test_generators.items(): - name = generator.name if generator.name else default_name - sample_idx_map[name] = generator.uuid_to_index_map - self.sample_uuid_map = sample_idx_map - - # Save to report directory if provided - if report_dir: - Path(report_dir).mkdir(parents=True, exist_ok=True) - report.to_json(save_to=Path(report_dir) / "result_summary.json") - - # Dump runtime settings to report directory - rt_settings_data: dict[str, int | str | None] = { - "min_duration_ms": self.runtime_settings.min_duration_ms, - "max_duration_ms": self.runtime_settings.max_duration_ms, - "n_samples_from_dataset": self.runtime_settings.n_samples_from_dataset, - "n_samples_to_issue": self.runtime_settings.n_samples_to_issue, - "min_sample_count": self.runtime_settings.min_sample_count, - "total_samples_to_issue": self.runtime_settings.total_samples_to_issue(), - } - # TODO: Since RuntimeSettings stores the random.Random objects directly, there is no way - # to retrieve the seed values. The best way to do this is probably a custom random.Random - # class that stores the original seed as a read-only property, and unable to set the seed - # after initialization. - if has_model and model is not None: - rt_settings_data["model"] = ( - model if isinstance(model, str) else str(model.name) - ) - - # TODO: After Zhihan's MR is merged, grab the scheduler class and other LG init settings - # from the runtime settings object - with (Path(report_dir) / "runtime_settings.json").open("w") as f: - f.write( - msgspec.json.format( - msgspec.json.encode( - dict(sorted(rt_settings_data.items())) - ), - indent=2, - ).decode("utf-8") - ) - - # Save the UUID mapping for output verification - with (Path(report_dir) / "sample_idx_map.json").open("w") as f: - f.write( - msgspec.json.encode(self.sample_uuid_map).decode("utf-8") - ) - - if dump_events_log: - reporter.dump_to_json(Path(report_dir) / "events.jsonl") - - # Display report to console - report.display(fn=print, summary_only=True) - - # Dump report to text file if report_dir is provided - if report_dir: - report_path = Path(report_dir) / "report.txt" - with open(report_path, "w") as f: - report.display(fn=f.write, summary_only=False, newline="\n") - logger.info(f"Report saved to {report_path}") - - def wait_for_test_end(self, timeout: float | None = None) -> bool: - """ - Join the test thread and return True if the test completed, False if it timed out. + if phase_issuer is not None and query_id in phase_issuer.uuid_to_index: + phase_issuer.inflight -= 1 + if phase_issuer.inflight <= 0: + self._drain_event.set() + if self._current_strategy: + self._current_strategy.on_query_complete(query_id) + if self._on_sample_complete: + self._on_sample_complete(resp) + + elif isinstance(resp, StreamChunk): + is_first = resp.metadata.get("first_chunk", False) + event_type = ( + SampleEventType.RECV_FIRST + if is_first + else SampleEventType.RECV_NON_FIRST + ) + self._publisher.publish( + EventRecord( + event_type=event_type, + timestamp_ns=time.monotonic_ns(), + sample_uuid=resp.id, + ) + ) - Args: - timeout: The maximum time to wait for the test to complete. If None, wait indefinitely. + def _make_stop_check( + self, settings: RuntimeSettings, phase_start_ns: int + ) -> Callable[[], bool]: + """Create a stop-check closure for a phase. - Returns: - bool: True if the test thread has completed, False if it timed out. + Reads self._current_phase_issuer at call time (not creation time). + Invariant: _current_phase_issuer must not change while a phase's + strategy is executing. This is guaranteed by sequential phase execution. """ - if not self.thread: + max_duration_ns = ( + settings.max_duration_ms * 1_000_000 + if settings.max_duration_ms is not None + else 0 + ) + total_samples = settings.total_samples_to_issue() + + def check() -> bool: + if self._stop_requested: + return True + if ( + self._current_phase_issuer + and self._current_phase_issuer.issued_count >= total_samples + ): + return True + if ( + max_duration_ns > 0 + and (time.monotonic_ns() - phase_start_ns) >= max_duration_ns + ): + return True return False - self.thread.join(timeout=timeout) - return not self.thread.is_alive() - @classmethod - def start( - cls, - runtime_settings: RuntimeSettings, - dataset: Dataset, - sample_issuer: SampleIssuer, - scheduler: Scheduler, - *args, - accuracy_datasets: list[Dataset] | None = None, - load_generator_cls: type[LoadGenerator] = SchedulerBasedLoadGenerator, - name: str | None = None, - max_shutdown_timeout_s: float | None = None, - report_dir: os.PathLike | None = None, - tokenizer_override: AutoTokenizer | None = None, - dump_events_log: bool = False, - ) -> BenchmarkSession: - """Start a new BenchmarkSession in a thread. - - Args: - runtime_settings: The runtime settings to use for the session. - dataset: The dataset to use for the performance test. - sample_issuer: The sample issuer to use for the session. - scheduler: The scheduler to use for the session. - accuracy_datasets: The datasets to use for the accuracy tests. If None, no accuracy tests will be run. - load_generator_cls: The load generator class to use for the session. - name: The name of the session. - max_shutdown_timeout_s: The maximum timeout to wait for the test to complete after all samples have been issued. - If None, wait indefinitely. (Default: 300.0 seconds) - report_dir: The path to save the report to. If None, no report will be saved. - tokenizer_override: The tokenizer to use for the session. If None, a tokenizer will be automatically selected - based on the model name in the runtime settings. - dump_events_csv: Whether to dump the events to a CSV file. Only use for debugging - purposes, as the events database can get quite large. - - Returns: - The new BenchmarkSession. - """ - session = cls(runtime_settings, session_id=name) - load_generator = load_generator_cls(sample_issuer, dataset, scheduler, *args) # type: ignore[arg-type] - - # Create accuracy test generators - accuracy_test_generators = None - if accuracy_datasets: - accuracy_test_generators = {} - for ds in accuracy_datasets: - if hasattr(ds.__class__, "DATASET_ID"): - ds_name = ds.__class__.DATASET_ID - else: - ds_name = ds.__class__.__name__ - - # Create accuracy dataset specific runtime settings - acc_rt_settings = RuntimeSettings( - metric_target=runtime_settings.metric_target, - reported_metrics=runtime_settings.reported_metrics, - min_duration_ms=0, - max_duration_ms=None, - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples() * ds.repeats, - min_sample_count=ds.num_samples() * ds.repeats, - rng_sched=runtime_settings.rng_sched, - rng_sample_index=runtime_settings.rng_sample_index, - load_pattern=runtime_settings.load_pattern, - ) - acc_sched = scheduler.__class__( - acc_rt_settings, WithoutReplacementSampleOrder - ) + return check - accuracy_test_generators[ds_name] = load_generator_cls( - sample_issuer, - ds, - acc_sched, # type: ignore[arg-type] - *args, - ) + def _publish_session_event(self, event_type: SessionEventType) -> None: + """Publish a session event and flush the publisher immediately. - session.thread = threading.Thread( - target=session._run_test, - args=(load_generator,), - kwargs={ - "accuracy_test_generators": accuracy_test_generators, - "max_shutdown_timeout_s": max_shutdown_timeout_s, - "report_dir": report_dir, - "tokenizer_override": tokenizer_override, - "dump_events_log": dump_events_log, - }, + Session events are control signals (STARTED, ENDED, START/STOP + PERFORMANCE_TRACKING) that subscribers must receive promptly for + correct state transitions. Flushing ensures any buffered sample + events are sent first, followed by the session event, so ordering + is preserved and the signal is not delayed by batching. + """ + self._publisher.publish( + EventRecord(event_type=event_type, timestamp_ns=time.monotonic_ns()) ) - session.thread.start() - return session + self._publisher.flush() diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py new file mode 100644 index 00000000..dd311f10 --- /dev/null +++ b/src/inference_endpoint/load_generator/strategy.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load strategies: controls the pacing of sample issuance. + +Three implementations, each optimized for a different load pattern: +- TimedIssueStrategy: Poisson (loop.call_at or run_in_executor) +- BurstStrategy: Max throughput (loop.call_soon) +- ConcurrencyStrategy: Fixed concurrency (asyncio.Semaphore) + +See docs/load_generator/DESIGN.md for benchmark data and design rationale. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable, Iterator +from time import monotonic_ns +from typing import Protocol + +from ..config.runtime_settings import RuntimeSettings +from ..config.schema import LoadPatternType +from .delay import make_delay_fn +from .sample_order import SampleOrder, create_sample_order + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# LoadStrategy Protocol +# --------------------------------------------------------------------------- + + +class PhaseIssuerProtocol(Protocol): + """Minimal interface that strategies see for issuing samples.""" + + def issue(self, sample_index: int) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping.""" + ... + + issued_count: int + + +class LoadStrategy(Protocol): + """Controls the pacing strategy for issuing requests. + + Strategies call phase_issuer.issue(sample_index) to issue each sample. + issue() returns query_id on success, None when the session should stop. + """ + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive sample issuance. Returns count of samples issued.""" + ... + + def on_query_complete(self, query_id: str) -> None: + """Called by session on each QueryResult. Default: no-op. + + Used by ConcurrencyStrategy to release semaphore slots. + """ + ... + + +# --------------------------------------------------------------------------- +# TimedIssueStrategy (Poisson) +# --------------------------------------------------------------------------- + + +def _busy_wait_until(target_ns: int) -> None: + """Busy-wait in a thread pool thread until target timestamp.""" + while monotonic_ns() < target_ns: + pass + + +class TimedIssueStrategy: + """Schedule-driven load strategy with inter-arrival delays. + + Default mode (call_at): schedules each issue as an event loop callback + at the precise target time. Zero GIL contention, sub-ms precision. + Good for <= 50k QPS. + + Executor mode (opt-in): offloads busy-wait to thread pool for sub-100us + precision. Introduces GIL contention that adds latency at low QPS. + """ + + def __init__( + self, + delay_fn: Callable[[], int], + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + use_executor: bool = False, + ): + self._delay_fn = delay_fn + self._sample_order = sample_order + self._loop = loop + self._use_executor = use_executor + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + if self._use_executor: + return await self._execute_executor(phase_issuer) + return await self._execute_call_at(phase_issuer) + + def on_query_complete(self, query_id: str) -> None: + pass + + async def _execute_call_at(self, phase_issuer: PhaseIssuerProtocol) -> int: + done = asyncio.Event() + start_time = self._loop.time() + cumulative_s = 0.0 + + def schedule_next(): + nonlocal cumulative_s, error + try: + idx = next(self._sample_order, None) + if idx is None: + done.set() + return + cumulative_s += self._delay_fn() / 1e9 + self._loop.call_at(start_time + cumulative_s, fire, idx) + except Exception as exc: + error = exc + done.set() + + error: BaseException | None = None + + def fire(idx: int): + nonlocal error + try: + if phase_issuer.issue(idx) is None: + done.set() + return + schedule_next() + except Exception as exc: + error = exc + done.set() + + schedule_next() + await done.wait() + if error is not None: + raise error + return phase_issuer.issued_count + + async def _execute_executor(self, phase_issuer: PhaseIssuerProtocol) -> int: + start = monotonic_ns() + cumulative = 0 + for idx in self._sample_order: + cumulative += self._delay_fn() + target = start + cumulative + now = monotonic_ns() + if target > now: + await self._loop.run_in_executor(None, _busy_wait_until, target) + if phase_issuer.issue(idx) is None: + break + return phase_issuer.issued_count + + +# --------------------------------------------------------------------------- +# BurstStrategy (Max Throughput) +# --------------------------------------------------------------------------- + + +class BurstStrategy: + """Fire-as-fast-as-possible strategy using loop.call_soon. + + Each issue is scheduled as an event loop callback, yielding between + issues so the receiver coroutine can process responses. Achieves + 100k+ QPS without starving the event loop. + """ + + def __init__( + self, + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + ): + self._sample_order = sample_order + self._loop = loop + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + done = asyncio.Event() + error: BaseException | None = None + + def issue_next(): + nonlocal error + try: + idx = next(self._sample_order, None) + if idx is None or phase_issuer.issue(idx) is None: + done.set() + return + self._loop.call_soon(issue_next) + except Exception as exc: + error = exc + done.set() + + self._loop.call_soon(issue_next) + await done.wait() + if error is not None: + raise error + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + pass + + +# --------------------------------------------------------------------------- +# ConcurrencyStrategy +# --------------------------------------------------------------------------- + + +class ConcurrencyStrategy: + """Completion-driven strategy maintaining fixed concurrent requests. + + Uses asyncio.Semaphore for gating: acquire before issue, release on + completion via on_query_complete(). With eager_task_factory, the woken + waiter executes synchronously within release(), minimizing jitter. + """ + + def __init__( + self, + target_concurrency: int, + sample_order: Iterator[int], + ): + if target_concurrency <= 0: + raise ValueError( + f"target_concurrency must be > 0, got {target_concurrency}" + ) + self._target = target_concurrency + self._sem = asyncio.Semaphore(target_concurrency) + self._sample_order = sample_order + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + for idx in self._sample_order: + await self._sem.acquire() + if phase_issuer.issue(idx) is None: + self._sem.release() + break + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + self._sem.release() + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_load_strategy( + runtime_settings: RuntimeSettings, + loop: asyncio.AbstractEventLoop, + sample_order: SampleOrder | None = None, + use_executor: bool = False, +) -> LoadStrategy: + """Create a LoadStrategy from RuntimeSettings. + + Args: + runtime_settings: Runtime configuration with load_pattern. + loop: Event loop for scheduling callbacks. + sample_order: Sample ordering iterator. If None, created from settings. + use_executor: For Poisson, use run_in_executor for sub-100us precision. + + Returns: + LoadStrategy implementation for the configured load pattern. + """ + lp = runtime_settings.load_pattern + if lp is None: + raise ValueError("RuntimeSettings.load_pattern must not be None") + + if sample_order is None: + sample_order = create_sample_order(runtime_settings) + + match lp.type: + case LoadPatternType.MAX_THROUGHPUT: + return BurstStrategy(sample_order, loop) + + case LoadPatternType.POISSON: + delay_fn = make_delay_fn(lp, runtime_settings.rng_sched) + return TimedIssueStrategy( + delay_fn, sample_order, loop, use_executor=use_executor + ) + + case LoadPatternType.CONCURRENCY: + if lp.target_concurrency is None or lp.target_concurrency <= 0: + raise ValueError( + "Concurrency load pattern requires target_concurrency > 0" + ) + return ConcurrencyStrategy(lp.target_concurrency, sample_order) + + case _: + raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/src/inference_endpoint/metrics/recorder.py b/src/inference_endpoint/metrics/recorder.py deleted file mode 100644 index e26827bb..00000000 --- a/src/inference_endpoint/metrics/recorder.py +++ /dev/null @@ -1,487 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import atexit -import contextlib -import dataclasses -import logging -import multiprocessing -import queue -import shutil -import sqlite3 -import threading -import time -import uuid -from functools import partial -from pathlib import Path -from typing import Any, ClassVar - -import msgspec.json - -from ..load_generator.events import Event, SampleEvent, SessionEvent -from ..profiling import profile -from ..utils import byte_quantity_to_str - -logger = logging.getLogger(__name__) - - -@contextlib.contextmanager -def sqlite3_cursor(path: Path): - """Context manager for SQLite cursor that properly handles connection lifecycle. - - Args: - path: Path to the SQLite database file. - - Yields: - A SQLite cursor object. - """ - conn = sqlite3.connect(str(path)) - cursor = conn.cursor() - try: - yield cursor, conn - finally: - cursor.close() - conn.close() - - -@dataclasses.dataclass -class EventRow: - sample_uuid: str = dataclasses.field(metadata={"sql_type": "TEXT"}) - """UUID string identifier for the sample""" - - event_type: Event = dataclasses.field(metadata={"sql_type": "TEXT"}) - """The type of event to record""" - - timestamp_ns: int = dataclasses.field(metadata={"sql_type": "INTEGER"}) - """The timestamp of the event in nanoseconds. Note that this is a monotonic timestamp, so the value itself - is not meaningful, but the differences between timestamps are accurate.""" - - data: bytes = dataclasses.field(default=b"", metadata={"sql_type": "BLOB"}) - """The data, if any, associated with the event, encoded as JSON bytes.""" - - @staticmethod - def to_table_query() -> str: - # Dynamically construct table query based on the dataclass fields - fields = [] - for field in dataclasses.fields(EventRow): - sql_type = field.metadata.get("sql_type", "BLOB") - fields.append(f"{field.name} {sql_type}") - - field_str = ", ".join(fields) - return f"CREATE TABLE IF NOT EXISTS events ({field_str})" - - @staticmethod - def insert_query() -> str: - fields = dataclasses.fields(EventRow) - names = [field.name for field in fields] - names_str = ", ".join(names) - placeholders = ", ".join(["?"] * len(fields)) - return f"INSERT INTO events ({names_str}) VALUES ({placeholders})" - - def to_insert_params(self) -> tuple[str, str, int, bytes]: - return ( - self.sample_uuid, - self.event_type.value, - self.timestamp_ns, - self.data, - ) - - -def register_cleanup(file_path: Path): - if multiprocessing.parent_process() is not None: - return - atexit.register(partial(file_path.unlink, missing_ok=True)) - logger.debug(f"Registered at-exit cleanup for {file_path}") - - -class EventRecorderSingletonViolation(RuntimeError): - """Raised when an attempt is made to create a second EventRecorder while one is already active. - - This is to prevent: - - Multiple writer connections to the same database - - Potential writes to the wrong event database if multiple are open - """ - - pass - - -class EventRecorder: - """Records events to a shared memory database, which can be accessed across multiple processes. - - An optional session id can be provided to connect to an existing database. If the database does not exist, it will first check if /dev/shm has enough free space to - create a new database. - - This class uses a dedicated writer thread to handle all SQLite operations, making it thread-safe. - Events are queued via record_event() and processed asynchronously by the writer thread. - - Only 1 EventRecorder can be actively writing events at a time. - """ - - LIVE: EventRecorder | None = None - - _created_session_dbs: ClassVar[set[str]] = set() - - # Sentinel objects for queue control - _STOP_SENTINEL: ClassVar[object] = object() - _FORCE_COMMIT_SENTINEL: ClassVar[object] = object() - - def __init__( - self, - session_id: str | None = None, - txn_buffer_size: int = 1000, - min_memory_req_bytes: int = 512 * 1024 * 1024, - notify_idle: threading.Event | None = None, - close_timeout_s: float = 10.0, - ): - """Creates a new EventRecorder. - - Args: - session_id: Optional session id to connect to an existing database. If not provided, a new database will be created. - txn_buffer_size: The number of events to buffer before committing to the database. (Default: 1000) - min_memory_req_bytes: The minimum amount of free space (in bytes) in /dev/shm required to create a new database. (Default: 1GB) - notify_idle: Optional threading.Event. If provided, EventRecorder will set when the number of inflight samples is 0. - close_timeout_s: The timeout in seconds to wait for the writer thread to finish processing when calling close(). (Default: 10.0) - """ - if session_id is None: - session_id = uuid.uuid4().hex - - self.session_id = session_id - - if self.connection_name not in EventRecorder._created_session_dbs: - register_cleanup(self.connection_name) - EventRecorder._created_session_dbs.add(str(self.connection_name)) - - if not Path(self.connection_name).parent.exists(): - raise FileNotFoundError( - "Cannot create shm db, POSIX shm dir at /dev/shm does not exist" - ) - - if not Path(self.connection_name).exists(): - # If we're creating a new db, we require a minimum of 1GB of shared memory - logging.debug(f"Creating new events db at {self.connection_name}") - shm_stats = shutil.disk_usage("/dev/shm") - logging.debug( - f"/dev/shm usage stats: total={shm_stats.total}B, free={shm_stats.free}B" - ) - - min_memory_req_str = byte_quantity_to_str(min_memory_req_bytes) - if shm_stats.total < min_memory_req_bytes: - raise MemoryError( - f"A minimum of {min_memory_req_str} of total space in /dev/shm is required. Use --shm-size={min_memory_req_str} in `docker run` if using docker." - ) - - if shm_stats.free < min_memory_req_bytes: - free_space_str = byte_quantity_to_str(shm_stats.free) - raise MemoryError( - f"A minimum of {min_memory_req_str} of free space in /dev/shm is required, but only {free_space_str} is free. Please free up space or increase the /dev/shm size limit." - ) - - # Queue for thread-safe event recording - self.event_queue: queue.Queue = queue.Queue() - self.txn_buffer_size = txn_buffer_size - - # Writer thread management - self.writer_thread: threading.Thread | None = None - self._writer_started = False - self.close_timeout_s = close_timeout_s - - self.notify_idle = notify_idle - self.n_inflight_samples = 0 - self.should_check_idle = False - - @property - def connection_name(self) -> Path: - # To support accessing in multiple processes, we store the db in /dev/shm - # Otherwise, using mode=memory&cache=shared only works within the same process - return EventRecorder.db_path(self.session_id) - - @staticmethod - def db_path(session_id: str) -> Path: - """Helper method to figure out the path of a session's database without creating an EventRecorder instance. - - Args: - session_id: The session id. - - Returns: - The path to the session's database. - """ - return Path(f"/dev/shm/mlperf_testsession_{session_id}.db") - - def _writer_loop(self): - """Writer thread loop that processes events from the queue and commits them to the database. - - This method runs in a dedicated thread and owns the SQLite connection and cursor. - It processes events from the queue, buffering them until the buffer is full or a force commit is requested. - """ - logging.debug(f"Writer thread started for {self.connection_name}") - - with sqlite3_cursor(self.connection_name) as (cur, conn): - # Initialize the database table - cur.execute(EventRow.to_table_query()) - conn.commit() - - event_buffer = [] - - insert_query = EventRow.insert_query() - - def commit_buffer(): - """Helper to commit and clear the event buffer.""" - if event_buffer: - cur.executemany(insert_query, event_buffer) - conn.commit() - event_buffer.clear() - - while True: - try: - # Get item from queue, blocking until available - item = self.event_queue.get(timeout=1.0) - except queue.Empty: - # Timeout - continue loop to check for stop condition - continue - - # Check for sentinel values - should_commit = False - if item is EventRecorder._STOP_SENTINEL: - # Commit any remaining events before stopping - if event_buffer: - logging.debug( - f"Writer thread stopping - committing final {len(event_buffer)} transactions" - ) - should_commit = True - elif item is EventRecorder._FORCE_COMMIT_SENTINEL: - # Force commit current buffer - if event_buffer: - logging.debug( - f"Force committing {len(event_buffer)} transactions" - ) - should_commit = True - else: - # Regular event - add to buffer - event_buffer.append(item) - should_commit = len(event_buffer) >= self.txn_buffer_size - - # Commit if buffer is full - if should_commit: - logging.debug( - f"Committing {len(event_buffer)} transactions (max buffer size: {self.txn_buffer_size})" - ) - commit_buffer() - self.event_queue.task_done() - - if ( - self.should_check_idle - and self.notify_idle is not None - and self.n_inflight_samples == 0 - and self.event_queue.empty() - ): - self.notify_idle.set() - - if item is EventRecorder._STOP_SENTINEL: - break - logging.debug(f"Writer thread stopped for {self.connection_name}") - - def _start_writer_thread(self): - """Starts the writer thread if not already started.""" - if EventRecorder.LIVE is not None: - raise EventRecorderSingletonViolation( - f"EventRecorder {EventRecorder.LIVE.session_id} is already active, cannot open {self.session_id}" - ) - EventRecorder.LIVE = self - - if self._writer_started: - logging.debug("Writer thread already started") - return - - logging.debug(f"Starting writer thread for {self.connection_name}") - self.writer_thread = threading.Thread( - target=self._writer_loop, - name=f"EventRecorder-Writer-{self.session_id}", - daemon=False, - ) - self.writer_thread.start() - self._writer_started = True - - def wait_for_writes(self, force_commit: bool = True): - """Blocks until all queued events are processed. - - Args: - force_commit: Whether to force commit the current buffer immediately. (Default: True) - """ - if not self._writer_started: - return - - if force_commit: - self.event_queue.put(self._FORCE_COMMIT_SENTINEL) - self.event_queue.join() - - @profile - @classmethod - def record_event( - cls, - ev_type: Event, - timestamp_ns: int, - sample_uuid: str = "", - force_commit: bool = False, - assert_active: bool = True, - data: Any = None, - ) -> bool: - """Records an event by pushing it to the queue for the writer thread to process. - - This method is thread-safe and can be called from multiple threads simultaneously. - The actual database write happens asynchronously in the writer thread. - - Args: - ev_type (Event): The type of event to record. - timestamp_ns (int): The timestamp in nanoseconds of the event. - sample_uuid (str): The sample uuid of the event. - force_commit (bool): Whether to force commit the current buffer immediately. - assert_active (bool): Whether to raise an exception if no EventRecorder is active. - If False, this method will return False. (Default: True) - data (Any): The data to record associated with the event. Must be JSON serializable. - (Default: None) - Returns: - bool: True if the event was recorded, False otherwise. If assert_active is True, - this method will always return True or raise an exception. - """ - if EventRecorder.LIVE is None: - if assert_active: - raise EventRecorderSingletonViolation( - "No EventRecorder is active, cannot record event" - ) - return False - - rec_inst = EventRecorder.LIVE - - if not rec_inst._writer_started: - raise RuntimeError( - "Writer thread not started - Users should use `with EventRecorder(...)` to ensure the writer thread is started" - ) - - # Update inflight sample tracking - # NOTE: n_inflight_samples is not thread-safe (+=/-= from multiple threads). - # This is a known issue but EventRecorder is being deprecated in favor of - # EventLoggerService (pub-sub based). Not worth fixing here. - if ev_type == SessionEvent.LOADGEN_ISSUE_CALLED: - rec_inst.n_inflight_samples += 1 - elif ev_type == SampleEvent.COMPLETE: - rec_inst.n_inflight_samples -= 1 - - if rec_inst.n_inflight_samples < 0: - raise RuntimeError( - f"Number of inflight samples is negative: {rec_inst.n_inflight_samples}" - ) - - # Push event to queue for writer thread to process - encoded_bytes: bytes = b"" - try: - if data is not None: - encoded_bytes = msgspec.json.encode(data) - except msgspec.EncodeError as e: - rec_inst.event_queue.put( - ( - sample_uuid, - SessionEvent.ERROR.value, - time.monotonic_ns(), - msgspec.json.encode( - { - "error_type": "JSONEncodeError", - "error_message": str(e), - } - ), - ) - ) - finally: - rec_inst.event_queue.put( - (sample_uuid, ev_type.value, timestamp_ns, encoded_bytes) - ) - - # If force commit requested, send sentinel - if force_commit: - rec_inst.event_queue.put(EventRecorder._FORCE_COMMIT_SENTINEL) - return True - - def close(self): - """Closes the EventRecorder and stops the writer thread. - - This method signals the writer thread to stop, waits for it to finish processing - all queued events, and then joins the thread. - """ - if EventRecorder.LIVE is not self: - raise EventRecorderSingletonViolation( - f"EventRecorder {self.session_id} is not active, cannot close" - ) - EventRecorder.LIVE = None - - if not self._writer_started: - logging.debug("Writer thread was never started, nothing to close") - return - - logging.debug("Stopping writer thread...") - # Send stop sentinel to writer thread - self.event_queue.put(self._STOP_SENTINEL) - - # Wait for the writer thread to finish - if self.writer_thread is not None: - self.writer_thread.join(timeout=self.close_timeout_s) - if self.writer_thread.is_alive(): - n_pending = self.event_queue.qsize() - raise RuntimeError( - f"Writer thread did not stop within timeout for {self.connection_name}. {n_pending} events pending." - ) - else: - logging.debug( - f"Writer thread stopped successfully for {self.connection_name}" - ) - self._writer_started = False - self.writer_thread = None - - def __enter__(self): - """Context manager entry - starts the writer thread.""" - if not self._writer_started: - self._start_writer_thread() - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Context manager exit - stops the writer thread.""" - self.close() - - -def record_exception( - exc_value: Exception | str, - sample_uuid: str | None = None, -): - """Records an exception as an event to the current event recorder. - - This will force commit the existing event buffer immediately to ensure the error is surfaced - as soon as possible for any monitoring. - - Args: - exc_value: The exception to record, or a string error message. - sample_uuid: The sample uuid to record the error for. - """ - if EventRecorder.LIVE is None: - return - EventRecorder.record_event( - SessionEvent.ERROR, - time.monotonic_ns(), - sample_uuid=sample_uuid or "", - data={ - "error_type": exc_value.__class__.__name__, - "error_message": str(exc_value), - }, - force_commit=True, - ) diff --git a/src/inference_endpoint/metrics/report.py b/src/inference_endpoint/metrics/report.py new file mode 100644 index 00000000..e24d954b --- /dev/null +++ b/src/inference_endpoint/metrics/report.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark report: summary statistics, display, and JSON serialization.""" + +from __future__ import annotations + +import math +import os +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import msgspec.json +import numpy as np + +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + BasicKVStoreReader, + SeriesStats, +) +from inference_endpoint.utils.version import get_version_info + +from ..utils import monotime_to_datetime + +# --------------------------------------------------------------------------- +# Summary computation +# --------------------------------------------------------------------------- + +_DEFAULT_PERCENTILES = (99.9, 99, 97, 95, 90, 80, 75, 50, 25, 10, 5, 1) + + +def compute_summary( + stats: SeriesStats, + percentiles: tuple[float, ...] = _DEFAULT_PERCENTILES, + n_histogram_buckets: int = 10, +) -> dict[str, Any]: + """Compute rollup statistics from pre-computed SeriesStats. + + Scalar stats (total, min, max, avg, std_dev) are derived from the + incrementally maintained rollups in SeriesStats. Numpy is only used + for percentiles and histograms, which require the raw values. + + Returns a dict with: total, min, max, avg, std_dev, median, + percentiles (dict), and histogram (buckets + counts). + """ + if stats.count == 0: + return { + "total": 0, + "min": 0, + "max": 0, + "median": 0.0, + "avg": 0.0, + "std_dev": 0.0, + "percentiles": {str(p): 0.0 for p in percentiles}, + "histogram": {"buckets": [], "counts": []}, + } + + # Scalar stats from pre-computed rollups (no numpy needed) + avg = stats.total / stats.count + # Bessel's correction (ddof=1) for sample standard deviation + if stats.count > 1: + n = stats.count + std_dev = math.sqrt((stats.sum_sq - stats.total**2 / n) / (n - 1)) + else: + std_dev = 0.0 + + # Percentiles and histogram require raw values + # Don't force float64 — numpy preserves int for uint64 series, + # so percentile(method="lower") returns actual observed values + # in their original type. + arr = np.array(stats.values) + arr.sort() + + # Inject 50th percentile for median if not already requested + need_median = 50 not in percentiles + all_percentiles = (*percentiles, 50) if need_median else percentiles + + perc_values = np.percentile(arr, all_percentiles, method="lower") + perc_dict = { + str(p): v.item() for p, v in zip(all_percentiles, perc_values, strict=True) + } + median = perc_dict.pop("50") if need_median else perc_dict["50"] + + bounds = np.histogram_bin_edges(arr, bins=n_histogram_buckets) + counts, _ = np.histogram(arr, bins=bounds) + hist_buckets = [ + (float(bounds[i]), float(bounds[i + 1])) for i in range(len(bounds) - 1) + ] + + return { + "total": stats.total, + "min": stats.min_val, + "max": stats.max_val, + "median": median, + "avg": avg, + "std_dev": std_dev, + "percentiles": perc_dict, + "histogram": {"buckets": hist_buckets, "counts": counts.tolist()}, + } + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + + +class Report(msgspec.Struct, frozen=True): # type: ignore[call-arg] + """Summarized benchmark report.""" + + version: str + git_sha: str | None + test_started_at: int + n_samples_issued: int + n_samples_completed: int + n_samples_failed: int + duration_ns: int | None + + # Per-metric rollup dicts (output of compute_summary) + ttft: dict[str, Any] + tpot: dict[str, Any] + latency: dict[str, Any] + output_sequence_lengths: dict[str, Any] + + def qps(self) -> float | None: + if self.duration_ns is None or self.duration_ns <= 0: + return None + return self.n_samples_completed / (self.duration_ns / 1e9) + + def tps(self) -> float | None: + if self.duration_ns is None or self.duration_ns <= 0: + return None + if not self.output_sequence_lengths: + return None + total = self.output_sequence_lengths.get("total", 0) + return total / (self.duration_ns / 1e9) + + @classmethod + def from_kv_reader(cls, reader: BasicKVStoreReader) -> Report: + """Build a Report from the current KVStore state. + + Reads counters and series from the reader, computes rollup summaries + (percentiles, histograms) for each series metric, and returns a Report. + + Works identically for live metrics (mid-test) and final reports + (post-drain). The caller decides when to call. + """ + snap = reader.snapshot() + + def _counter(key: str) -> int: + val = snap.get(key) + return int(val) if isinstance(val, int) else 0 + + def _summarize(key: str) -> dict: + val = snap.get(key) + if isinstance(val, SeriesStats) and val.count > 0: + return compute_summary(val) + return {} + + version_info = get_version_info() + duration_ns = _counter("tracked_duration_ns") + + return cls( + version=str(version_info.get("version", "unknown")), + git_sha=version_info.get("git_sha"), + test_started_at=0, # TODO: add test_started_at counter to aggregator + n_samples_issued=_counter("tracked_samples_issued"), + n_samples_completed=_counter("tracked_samples_completed"), + # TODO: Add tracked_samples_failed to MetricCounterKey. + # For now, total_samples_failed is the best available. + n_samples_failed=_counter("total_samples_failed"), + duration_ns=duration_ns if duration_ns > 0 else None, + ttft=_summarize("ttft_ns"), + tpot=_summarize("tpot_ns"), + latency=_summarize("sample_latency_ns"), + output_sequence_lengths=_summarize("osl"), + ) + + def to_json(self, save_to: os.PathLike | None = None) -> bytes: + json_bytes = msgspec.json.format(msgspec.json.encode(self), indent=2) + if save_to is not None: + with Path(save_to).open("wb") as f: + f.write(json_bytes) + return json_bytes + + def display( + self, + fn: Callable[[str], None] = print, + summary_only: bool = False, + newline: str = "", + ) -> None: + fn(f"----------------- Summary -----------------{newline}") + fn(f"Version: {self.version}{newline}") + if self.git_sha: + fn(f"Git SHA: {self.git_sha}{newline}") + if self.test_started_at > 0: + approx = monotime_to_datetime(self.test_started_at) + fn(f"Test started at: {approx.strftime('%Y-%m-%d %H:%M:%S')}{newline}") + fn(f"Total samples issued: {self.n_samples_issued}{newline}") + fn(f"Total samples completed: {self.n_samples_completed}{newline}") + fn(f"Total samples failed: {self.n_samples_failed}{newline}") + if self.duration_ns is not None: + fn(f"Duration: {self.duration_ns / 1e9:.2f} seconds{newline}") + else: + fn(f"Duration: N/A{newline}") + + if (qps := self.qps()) is not None: + fn(f"QPS: {qps:.2f}{newline}") + else: + fn(f"QPS: N/A{newline}") + + if (tps := self.tps()) is not None: + fn(f"TPS: {tps:.2f}{newline}") + + if summary_only: + fn(f"----------------- End of Summary -----------------{newline}") + return + + fn(f"\n------------------- Latency Breakdowns -------------------{newline}") + + for section_name, metric_dict, unit, scale_factor in [ + ("TTFT", self.ttft, "ms", 1e-6), + ("TPOT", self.tpot, "ms", 1e-6), + ("Latency", self.latency, "ms", 1e-6), + ("Output sequence lengths", self.output_sequence_lengths, "tokens", 1.0), + ]: + if not metric_dict: + continue + fn(f"{section_name}:{newline}") + _display_metric( + metric_dict, + fn=fn, + unit=unit, + scale_factor=scale_factor, + newline=newline, + ) + fn(f"{newline}") + + fn(f"----------------- End of Report -----------------{newline}") + + +# --------------------------------------------------------------------------- +# Display helpers +# --------------------------------------------------------------------------- + + +def _display_metric( + metric_dict: dict[str, Any], + fn: Callable[[str], None], + unit: str = "", + max_bar_length: int = 30, + scale_factor: float = 1.0, + newline: str = "", +) -> None: + for name, key in [ + ("Min", "min"), + ("Max", "max"), + ("Median", "median"), + ("Avg.", "avg"), + ("Std Dev.", "std_dev"), + ]: + fn(f" {name}: {metric_dict[key] * scale_factor:.2f} {unit}{newline}") + + fn(f"\n Histogram:{newline}") + buckets = metric_dict["histogram"]["buckets"] + counts = metric_dict["histogram"]["counts"] + + if buckets: + bucket_strs = [ + f" [{lo * scale_factor:.2f}, {hi * scale_factor:.2f}" + + ("]" if i == len(buckets) - 1 else ")") + for i, (lo, hi) in enumerate(buckets) + ] + max_count = max(counts) + normalize = max_bar_length / max_count if max_count > 0 else 1 + max_label = max(len(s) for s in bucket_strs) + + for label, count in zip(bucket_strs, counts, strict=True): + bar = "#" * int(count * normalize) + fn(f" {label:>{max_label}} |{bar} {count}{newline}") + + fn(f"\n Percentiles:{newline}") + for p, val in metric_dict.get("percentiles", {}).items(): + fn(f" {p:>6}: {val * scale_factor:.2f} {unit}{newline}") diff --git a/src/inference_endpoint/metrics/reporter.py b/src/inference_endpoint/metrics/reporter.py deleted file mode 100644 index d6273496..00000000 --- a/src/inference_endpoint/metrics/reporter.py +++ /dev/null @@ -1,1368 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import csv -import dataclasses -import functools -import importlib -import logging -import numbers -import os -import sqlite3 -from collections import defaultdict -from collections.abc import Callable, Iterable -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import msgspec.json -import numpy as np - -from ..load_generator.events import SampleEvent, SessionEvent -from ..profiling import profile -from ..utils import monotime_to_datetime - -if TYPE_CHECKING: - from transformers import Tokenizer - - -class TPOTReportingMode(str, Enum): - """TPOT (Time Per Output Token) reporting mode. - - - REQUEST_WEIGHTED: Each request contributes one entry to TPOT calculation (default) - - TOKEN_WEIGHTED: Each token contributes to TPOT calculation (weighted by token count) - """ - - REQUEST_WEIGHTED = "request_weighted" - TOKEN_WEIGHTED = "token_weighted" - - -class SampleUUIDNotFoundError(Exception): - def __init__(self, uuid: str, datasource: str): - super().__init__(f"Sample UUID {uuid} not found in {datasource}") - - -@dataclasses.dataclass -class MetricRow: - sample_uuid: str - metric_type: str - metric_value: float - - -@dataclasses.dataclass(frozen=True) -class RollupQueryTable: - """Represents a table that is the result of a roll-up query. - This class lazily converts tuples to MetricRow objects on-access to reduce unnecessary overhead. - - The columns are assumed to be (sample_uuid, metric_value). If a roll-up query returns different columns, - define a subclass and override the __getitem__ method. - """ - - metric_type: str - """A string describing the metric being computed and rolled up.""" - - from_query: str - """If provided, the query that was used to generate the table.""" - - rows: list[tuple[Any, ...]] - """The rows of the table, each a tuple of values.""" - - repeats: list[int] | None = None - """If provided, this means the rows are condensed by consecutive duplicates. `repeats` - represents the number of times each row should be repeated.""" - - _sorted_vals: np.ndarray = dataclasses.field(init=False) - _by_uuid: dict[str, list[int]] = dataclasses.field(init=False) - - def __post_init__(self): - if self.repeats is not None: - if len(self.repeats) != len(self.rows): - raise IndexError( - f"Length of repeats {len(self.repeats)} does not match length of rows {len(self.rows)}" - ) - if not isinstance(self.repeats, np.ndarray): - object.__setattr__( - self, "repeats", np.array(self.repeats, dtype=np.int64) - ) - else: - object.__setattr__(self, "repeats", self.repeats.astype(np.int64)) - - # Metrics are always differences between integer nanosecond timestamps - # Some might be 32-bit and 64 bit, so we force np.int64 here - if self.repeats is None: - sorted_vals = np.array([row[1] for row in self.rows], dtype=np.int64) - sorted_vals.sort() - else: - arr = np.array( - [(self.rows[i][1], self.repeats[i]) for i in range(len(self.rows))], - dtype=np.int64, - ) - sorted_vals = arr[arr[:, 0].argsort()] - object.__setattr__(self, "_sorted_vals", sorted_vals) - - # Pre-compute a dictionary to map sample UUIDs to values - by_uuid = defaultdict(list) - for i, (s_uuid, value) in enumerate(self.rows): - if self.repeats is not None: - value = (value, self.repeats[i]) - by_uuid[s_uuid].append(value) - object.__setattr__(self, "_by_uuid", by_uuid) - - def __getitem__(self, index: int) -> MetricRow: - """Returns the MetricRow at the given index / row number in the table. - - Returns: - MetricRow: The MetricRow at the given index / row number in the table. - """ - length = len(self) - if index >= length: - raise IndexError(f"Index {index} out of range for {self.metric_type}") - - while index < 0: - index += length - - if self.repeats is None: - return MetricRow(self.rows[index][0], self.metric_type, self.rows[index][1]) - else: - passed = 0 - for i, repeat in enumerate(self.repeats): - next_row_start = passed + repeat - if index < next_row_start: - return MetricRow(self.rows[i][0], self.metric_type, self.rows[i][1]) - else: - passed = next_row_start - # This should never happen if our index validation is correct - raise IndexError(f"Index {index} out of range for {self.metric_type}") - - def __len__(self) -> int: - if self.repeats is None: - return len(self.rows) - else: - return int(sum(self.repeats)) - - def filter_uuid(self, uuid: str, only_first: bool = False) -> Any: - """Returns the values for the given sample UUID. - - Args: - uuid: The sample UUID to filter by. - only_first: Whether to only return the first value for the sample UUID. - Returns: - The values for the given sample UUID as a tuple. If only_first is True, - returns the first value directly, unless no values are found, in which - case None is returned. - """ - values = self._by_uuid[uuid] - - # Expand values if there are counts - if self.repeats is not None: - if only_first: # If we only want the first value, we don't need to expand - return values[0][0] - - expanded_values = [] - for value, count in values: - expanded_values.extend([value] * count) - values = expanded_values - - if only_first: - if len(values) == 0: - return None - return values[0] - return tuple(values) - - def __contains__(self, uuid: str) -> bool: - """Returns True if the given sample UUID is in the table.""" - return uuid in self._by_uuid - - def summarize( - self, - percentiles: Iterable[float] = (99.9, 99, 97, 95, 90, 80, 75, 50, 25, 10, 5, 1), - ) -> dict[str, float]: - if len(self._sorted_vals) == 0: - return { - "total": 0.0, - "min": 0.0, - "max": 0.0, - "median": 0.0, - "avg": 0.0, - "std_dev": 0.0, - "percentiles": {str(p): 0.0 for p in percentiles}, - "histogram": { - "buckets": [], - "counts": [], - }, - } - else: - # Note values are sorted, we can avoid using np.max and np.min - # Need to convert to default Python types since msgspec doesn't support numpy dtypes - if self.repeats is None: - values = self._sorted_vals - counts = np.ones(self._sorted_vals.shape, dtype=self._sorted_vals.dtype) - else: - values = self._sorted_vals[:, 0] - counts = self._sorted_vals[:, 1] - - total = int((values * counts).sum()) - minimum = int(values[0]) - maximum = int(values[-1]) - median = self.percentile(50) - avg = float(np.average(values, weights=counts)) - if self.repeats is None: - std_dev = float(np.std(values)) - else: - deviations_squared = (values - avg) ** 2 - std_dev = float( - np.sqrt(np.sum(deviations_squared * counts) / counts.sum()) - ) - summary = { - "total": total, - "min": minimum, - "max": maximum, - "median": median, - "avg": avg, - "std_dev": std_dev, - "percentiles": { - str(p): v for p, v in self.percentile(percentiles).items() - }, - } - - # Add histogram - buckets, counts = self.to_histogram(n_buckets=10) - summary["histogram"] = { - "buckets": buckets, - "counts": counts, - } - return summary - - def to_histogram( - self, - n_buckets: int = 20, - convert_to_native_types: bool = True, - ) -> tuple[list[tuple[float, float]], list[int]] | tuple[np.ndarray, np.ndarray]: - """Returns a histogram of the metrics values. - - The returned buckets are uniformly sized, distributed between the min and max values, with an - inclusive lower bound and exclusive upper bound. - - Args: - n_buckets: The number of buckets to create. Alternatively, any valid argument for `bins` in `np.histogram` - can be provided. - convert_to_native_types: Whether to convert the buckets and counts to native Python types. - If False, returns numpy arrays. (Default: True) - - Returns: - A tuple of lists, the first list is the buckets, the second list is the counts. - If convert_to_native_types is False, returns a numpy arrays instead. - """ - if self.repeats is None: - values = self._sorted_vals - repeats = None - else: - values = self._sorted_vals[:, 0] - repeats = self._sorted_vals[:, 1] - - # Derive bins from values - bounds = np.histogram_bin_edges(values, bins=n_buckets) - - counts, _ = np.histogram(values, bins=bounds, weights=repeats) - if not convert_to_native_types: - buckets = np.zeros((len(bounds) - 1, 2), dtype=bounds.dtype) - buckets[:, 0] = bounds[:-1] - buckets[:, 1] = bounds[1:] - return buckets, counts - - buckets = [ - (float(bounds[i]), float(bounds[i + 1])) for i in range(len(bounds) - 1) - ] - return buckets, counts.tolist() - - def percentile( - self, - percentile: float | list[float] | tuple[float, ...], - interpolate_strategy: str = "linear", - ) -> float | dict[float, float]: - """Compute the percentile(s) of the metric values. - The value returned is the value of the metric at the index marking the percentile, - not an interpolated value. - - Args: - percentile: The percentile(s) to compute. If a single value, returns a single value. If a list of values, returns a dictionary of values. - interpolate_strategy: The percentile interpolation string to use for numpy.percentile. See - https://numpy.org/doc/2.2/reference/generated/numpy.percentile.html to see what interpolation methods are available. - (Default: "linear") - - Returns: - A single value if a single percentile is provided, a dictionary of values if a list of percentiles is provided. - """ - if not isinstance(percentile, (numbers.Number | Iterable)): - raise TypeError( - f"percentile must be a number or an iterable of numbers, got {type(percentile)}" - ) - - if isinstance(percentile, Iterable): - if len(percentile) == 0: - return {} - - if not isinstance(percentile[0], numbers.Number): - raise TypeError( - f"percentile must be an iterable of numbers, got Iterable[{type(percentile[0])}]" - ) - - if self.repeats is None: - perc_values = np.percentile( - self._sorted_vals, - percentile, - overwrite_input=False, - method=interpolate_strategy, - ) - else: - values = self._sorted_vals[:, 0] - counts = self._sorted_vals[:, 1] - perc_values = np.percentile( - values, - percentile, - weights=counts, - overwrite_input=False, - method="inverted_cdf", - ) - - if isinstance(percentile, numbers.Number): - return float(perc_values) - else: - return {p: float(v) for p, v in zip(percentile, perc_values, strict=False)} - - -@dataclasses.dataclass(frozen=True) -class Report: - """Represents a summarized report of metrics""" - - version: str - git_sha: str | None - test_started_at: int - n_samples_issued: int - n_samples_completed: int - n_samples_failed: int - duration_ns: int | None - - # For the following metrics, the key is a rollup statistic (i.e. mean, median, etc.) - ttft: dict[str, float] - tpot: dict[str, float] - latency: dict[str, float] - output_sequence_lengths: dict[str, int] - tpot_reporting_mode: TPOTReportingMode = TPOTReportingMode.REQUEST_WEIGHTED - - @functools.cached_property - def qps(self) -> float | None: - """Calculates the queries (or samples) per second (QPS) based on actual throughput. - - This is the actual throughput: total completed samples divided by test duration. - If duration is 0, which shouldn't happen in practice, returns None. - - Returns: - The QPS or None if duration is 0. - """ - if not self.duration_ns: - return None - return float(self.n_samples_completed / (self.duration_ns / 1e9)) - - @functools.cached_property - def tps(self) -> float | None: - """Calculates the tokens per second based on the output sequence lengths and duration. - - Returns: - The tokens per second or None if duration is 0. - """ - if not self.duration_ns: - return None - if not self.output_sequence_lengths: - return None - return float(self.output_sequence_lengths["total"] / (self.duration_ns / 1e9)) - - def to_json(self, save_to: os.PathLike | None = None) -> str: - """Returns a JSON string representation of the report. - - Args: - save_to: If provided, saves the serialized JSON to the given path. - - Returns: - The JSON string representation of the report. - """ - d = dataclasses.asdict(self) - d["qps"] = self.qps - d["tps"] = self.tps - json_str = msgspec.json.format( - msgspec.json.encode(dict(sorted(d.items()))), indent=2 - ).decode("utf-8") - if save_to is not None: - with Path(save_to).open("w") as f: - f.write(json_str) - return json_str - - @staticmethod - def _display_metric( - metric_dict, - fn: Callable[[str], None] = print, - unit: str = "", - max_bar_length: int = 30, - scale_factor: float = 1.0, - newline: str = "", - ) -> None: - """Displays a metric dictionary in a human-readable format. - - Args: - metric_dict: The metric dictionary to display. - fn: The function to call to print a string, such as logging.info, file.write, etc. (Default: `print`) - unit: The string representing the unit of the metric - max_bar_length: The maximum length of the bar to display for the histogram - scale_factor: The factor to scale metric values by. (Default: 1.0) - newline: The newline character to append to each line. Set to "\\n" for file.write. (Default: "") - """ - for name, key in [ - ("Min", "min"), - ("Max", "max"), - ("Median", "median"), - ("Avg.", "avg"), - ("Std Dev.", "std_dev"), - ]: - fn(f" {name}: {metric_dict[key] * scale_factor:.2f} {unit}{newline}") - fn(f"\n Histogram:{newline}") - - # Display histogram - buckets = metric_dict["histogram"]["buckets"] - counts = metric_dict["histogram"]["counts"] - - if len(buckets) > 0: - bucket_strs = [] - for lower, upper in buckets: - if upper is None: - bucket_strs.append(f" {lower * scale_factor:.2f}+") - else: - bucket_strs.append( - f" [{lower * scale_factor:.2f}, {upper * scale_factor:.2f})" - ) - - normalize_factor = max_bar_length / max(counts) - max_bucket_str_len = max(len(s) for s in bucket_strs) - - for bucket_str, count in zip(bucket_strs, counts, strict=False): - bar_length = int(count * normalize_factor) - fn( - f" {bucket_str:>{max_bucket_str_len}} |{'#' * bar_length} {count}{newline}" - ) - - fn(f"\n Percentiles:{newline}") - max_percentile_str_len = max( - len(str(p)) for p in metric_dict["percentiles"].keys() - ) - for percentile, value in metric_dict["percentiles"].items(): - fn( - f" {percentile:>{max_percentile_str_len}}: {value * scale_factor:.2f} {unit}{newline}" - ) - - def display( - self, - fn: Callable[[str], None] = print, - summary_only: bool = False, - newline: str = "", - ) -> None: - """Displays the report in a human-readable format. - - Args: - fn: The function to call to print a string, such as logging.info, file.write, etc. (Default: `print`) - newline: The newline character to append to each line. Set to "\\n" for file.write. (Default: "") - """ - - fn(f"----------------- Summary -----------------{newline}") - fn(f"Version: {self.version}{newline}") - if self.git_sha: - fn(f"Git SHA: {self.git_sha}{newline}") - # Approximate absolute time of the test started at using monotime_to_datetime from utils.py - test_started_at_approx = monotime_to_datetime(self.test_started_at) - fn( - f"Test started at: (timestamp_ns):{self.test_started_at}, approx. wall-clock time: ({test_started_at_approx.strftime('%Y-%m-%d %H:%M:%S')}){newline}" - ) - fn(f"Total samples issued: {self.n_samples_issued}{newline}") - fn(f"Total samples completed: {self.n_samples_completed}{newline}") - fn(f"Total samples failed: {self.n_samples_failed}{newline}") - if self.duration_ns is not None: - fn(f"Duration: {self.duration_ns / 1e9:.2f} seconds{newline}") - else: - fn(f"Duration: N/A (no performance samples were issued){newline}") - - if self.qps is not None: - fn(f"QPS: {self.qps:.2f}{newline}") - else: - fn(f"QPS: N/A (no performance samples were issued){newline}") - - if self.tps is not None: - fn(f"TPS: {self.tps:.2f}{newline}") - - if summary_only: - fn(f"----------------- End of Summary -----------------{newline}") - return - - fn(f"\n\n------------------- Latency Breakdowns -------------------{newline}") - if len(self.latency) > 0 and self.ttft == 0: - fn( - f"WARNING: Non-streaming-based Issuer used. TTFT metrics cannot be calculated{newline}" - ) - - for section_name, metric_dict, unit, scale_factor in [ - ("TTFT", self.ttft, "ms", 1e-6), - (f"TPOT ({self.tpot_reporting_mode.value})", self.tpot, "ms", 1e-6), - ("Latency", self.latency, "ms", 1e-6), - ("Output sequence lengths", self.output_sequence_lengths, "tokens", 1.0), - ]: - if metric_dict is None or len(metric_dict) == 0: - continue - fn(f"{section_name}:{newline}") - Report._display_metric( - metric_dict, - fn=fn, - unit=unit, - scale_factor=scale_factor, - newline=newline, - ) - fn(f"\n{newline}") - - -def _output_sequence_to_str(output_sequence: str | list[str]) -> str | None: - if isinstance(output_sequence, list): - return "".join(output_sequence) - elif isinstance(output_sequence, str): - return output_sequence - else: - return None - - -def output_sequence_from_data( - data_bytes: bytes | None, - join_chunks: bool = True, -) -> tuple[str | list[str] | None, str | list[str] | None]: - """Parse the data column from a COMPLETE event and extract output and reasoning sequences. - - The data column is expected to be a JSON-encoded byte string. The decoded value can be: - - A string: treated as the output sequence directly - - A list: tagged msgspec array from TextModelOutput (array_like=True, tag=True), - formatted as ["TextModelOutput", output, reasoning] - - A dictionary with 'output' key (required) and optionally 'reasoning' key - - Both 'output' and 'reasoning' can be either strings or lists of strings - - If a list of strings, they will be joined together - - Args: - data_bytes: The raw bytes from the database 'data' column - join_chunks: Whether to join the chunks into a single string if the data values are lists of strings - (Default: True) - Returns: - A tuple of (output_sequence, reasoning_sequence), where each is a string (if join_chunks is True), - list of strings (if join_chunks is False) or None. - If the data cannot be decoded or is invalid, returns (None, None). - """ - if data_bytes is None or len(data_bytes) == 0: - return None, None - - try: - decoded_data = msgspec.json.decode(data_bytes) - except (msgspec.DecodeError, TypeError): - logging.warning("Failed to decode data bytes") - return None, None - - output, reasoning = None, None - if isinstance(decoded_data, str): - # If decoded value is a string, it's the output sequence - output = decoded_data - elif isinstance(decoded_data, list): - # Tagged msgspec array_like Struct: ["TextModelOutput", output, reasoning] - # The tag is at index 0, output at index 1, reasoning at index 2 - if len(decoded_data) < 2 or decoded_data[0] != "TextModelOutput": - logging.warning( - f"Invalid TextModelOutput tagged array data: {decoded_data}" - ) - return None, None - raw_output = decoded_data[1] - raw_reasoning = decoded_data[2] if len(decoded_data) > 2 else None - output = _output_sequence_to_str(raw_output) if join_chunks else raw_output - if output is None and raw_output is not None: - logging.warning(f"Output field has unexpected type: {type(raw_output)}") - return None, None - if raw_reasoning is not None: - reasoning = ( - _output_sequence_to_str(raw_reasoning) if join_chunks else raw_reasoning - ) - elif isinstance(decoded_data, dict): - # If decoded value is a dict, extract 'output' and optionally 'reasoning' - if "output" not in decoded_data: - logging.warning("Dictionary data missing required 'output' key") - return None, None - - # Extract output - can be string or list of strings - output = ( - _output_sequence_to_str(decoded_data["output"]) - if join_chunks - else decoded_data["output"] - ) - if output is None: - logging.warning(f"Output field has unexpected type: {type(output)}") - return None, None - - # Extract reasoning if present - can be string or list of strings - if "reasoning" in decoded_data: - reasoning = ( - _output_sequence_to_str(decoded_data["reasoning"]) - if join_chunks - else decoded_data["reasoning"] - ) - else: - logging.warning(f"Decoded data has unexpected type: {type(decoded_data)}") - return None, None - return output, reasoning - - -class MetricsReporter: - """Derives metrics from events via rollup queries. This is a *read only* client.""" - - def __init__( - self, - connection_name: os.PathLike, - client_type: str = "duckdb", - ): - """ - Creates a new MetricsReporter. - - Args: - connection_name: The path to the database to connect to. - client_type: The client type to use to connect to the database. Choices: ["duckdb", "sqlite"] (Default: "duckdb") - """ - self.connection_name = Path(connection_name) - self.client_type = client_type - self.is_closed = True - - def init_connection(self): - if not self.is_closed: - logging.debug(f"Connection already initialized at {self.connection_name}") - return - - if self.client_type == "duckdb": - logging.debug(f"Initializing duckdb connection at {self.connection_name}") - if importlib.util.find_spec("duckdb") is None: - raise ImportError("duckdb is not installed") - duckdb = importlib.import_module("duckdb") - # Install sqlite extension - self.conn = duckdb.connect() - - logging.debug("Installing sqlite extension for duckdb") - # duckdb doesn't inherit proxy variables from environment - # Try setting proxy explicitly if environment variables are not enough - proxy = os.environ.get("http_proxy") or os.environ.get("HTTP_PROXY") - if proxy: - logging.debug(f"Setting http_proxy to {proxy} for duckdb") - # Use parameterized query to safely set http_proxy - self.conn.execute("SET http_proxy=?", [proxy]) - self.conn.install_extension("sqlite") - self.conn.load_extension("sqlite") - - logging.debug( - f"Attaching {self.connection_name} to duckdb in read-only mode" - ) - self.conn.execute( - f"ATTACH '{self.connection_name}' AS sqlite_db (TYPE sqlite, READ_ONLY)" - ) - self.conn.execute("USE sqlite_db") - - self.cur_ = ( - self.conn - ) # duckdb calls execute() on connection, there is no cursor object - elif self.client_type == "sqlite": - logging.debug( - f"Initializing read-only sqlite connection at {self.connection_name}" - ) - self.conn = sqlite3.connect( - f"file:{self.connection_name}?mode=ro", uri=True - ) - self.cur_ = self.conn.cursor() - else: - raise ValueError(f"Invalid client type: {self.client_type}") - self.is_closed = False - - @functools.cached_property - def stop_performance_tracking_timestamp_ns(self) -> float: - """Returns the timestamp_ns of the STOP_PERFORMANCE_TRACKING event. - - This method is cached to prevent re-derivation. If the event is not found, - returns positive infinity, since this indicates that the performance run is probably still - running, or the test was killed before it could complete. - - Returns: - float: The timestamp_ns of STOP_PERFORMANCE_TRACKING event, or float('inf') if not found. - """ - result = self.cur_.execute(f""" - SELECT timestamp_ns - FROM events - WHERE event_type = '{SessionEvent.STOP_PERFORMANCE_TRACKING.value}' - LIMIT 1 - """).fetchone() - - if result is None: - logging.warning( - "No STOP_PERFORMANCE_TRACKING event found, performance run not yet complete" - ) - return float("inf") - return float(result[0]) - - @profile - def derive_metric(self, query: str, metric_type: str) -> RollupQueryTable: - res = self.cur_.execute(query) - logging.debug(f"Roll-up for {metric_type}. Running query: {query}") - return RollupQueryTable(metric_type, query, res.fetchall()) - - def derive_TTFT(self) -> RollupQueryTable: - stop_ts = self.stop_performance_tracking_timestamp_ns - - # Build the HAVING clause conditionally to handle infinity - if stop_ts != float("inf"): - before_stop_ts_clause = f""" - HAVING COUNT(DISTINCT event_type) = 2 - AND MAX(CASE WHEN event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' THEN timestamp_ns END) < {stop_ts} - """ - else: - before_stop_ts_clause = """ - HAVING COUNT(DISTINCT event_type) = 2 - """ - - return self.derive_metric( - f""" - SELECT - sample_uuid, - MAX(CASE WHEN event_type = '{SampleEvent.FIRST_CHUNK.value}' THEN timestamp_ns END) - - MAX(CASE WHEN event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' THEN timestamp_ns END) AS ttft - FROM events - WHERE event_type IN ('{SessionEvent.LOADGEN_ISSUE_CALLED.value}', '{SampleEvent.FIRST_CHUNK.value}') - GROUP BY sample_uuid - {before_stop_ts_clause} - """, - "ttft", - ) - - def dump_all_to_csv(self, csv_path: Path): - logging.debug(f"Dumping to CSV at {csv_path}") - with csv_path.open("w", newline="") as f: - writer = csv.writer(f) - query = """ - SELECT - sample_uuid, - timestamp_ns, - event_type - FROM events - """ - rows = self.cur_.execute(query).fetchall() - writer.writerows(rows) - logging.debug(f"Written rows {len(rows)} to {csv_path}") - - def derive_duration(self, check_malformed: bool = True) -> float | None: - """Calculates the total test duration. - - If STOP_PERFORMANCE_TRACKING event exists: - - This method will return T_(last_perf_sample) - T_(test_started) where: - - T_(test_started) is the timestamp of the TEST_STARTED event and - - T_(last_perf_sample) is the timestamp of the latest COMPLETE event present - whose sample_uuid has a corresponding LOADGEN_ISSUE_CALLED event before - the STOP_PERFORMANCE_TRACKING event. - - If for some reason, no samples were issued before the STOP_PERFORMANCE_TRACKING event, - such as in the case of running an accuracy-only test, then this method will return None. - - If STOP_PERFORMANCE_TRACKING does not exist: - - This method will return the max(timestamp_ns) - T_(test_started) where: - - T_(test_started) is the timestamp of the TEST_STARTED event and - - max(timestamp_ns) is the largest timestamp_ns in the events database. - - An error is raised if TEST_ENDED is present, but not the event associated with max(timestamp_ns) - - If `check_malformed` is False, no checks for the error-conditions above are performed. This is useful - to in cases where the latency of this method matters, and we would like to avoid executing extra queries. - In this case, the caller can periodically set check_malformed to True to perform verification in intervals. - - Args: - check_malformed: Whether to check for malformed events. (Default: True) - - Raises: - RuntimeError: If TEST_STARTED is not present or occurs more than once - RuntimeError: If TEST_ENDED exists but is not the maximum timestamp_ns - RuntimeError: If more than one TEST_ENDED event exists - - Returns: - float: The duration in nanoseconds, None if no performance samples were issued. - """ - # Validate TEST_STARTED exists exactly once - test_started_result = self.cur_.execute(f""" - SELECT COUNT(*) AS n_starts, MAX(timestamp_ns) AS start_ts - FROM events - WHERE event_type = '{SessionEvent.TEST_STARTED.value}' - """).fetchone() - - n_test_started = test_started_result[0] - test_started_ts = test_started_result[1] - - # Return None early if no TEST_STARTED event to avoid errors in duration calculations - if test_started_ts is None or n_test_started == 0: - if check_malformed: - raise RuntimeError("TEST_STARTED event not found in database") - return None - - if check_malformed and n_test_started > 1: - raise RuntimeError( - f"Multiple TEST_STARTED events found - {n_test_started} events" - ) - - # Check if STOP_PERFORMANCE_TRACKING event exists - stop_ts = self.stop_performance_tracking_timestamp_ns - - if stop_ts != float("inf"): - # Build list of sample_uuids with LOADGEN_ISSUE_CALLED before stop_ts - # Then find the max timestamp_ns of any event from those sample_uuids - max_perf_ts_result = self.cur_.execute(f""" - SELECT MAX(timestamp_ns) AS max_perf_ts - FROM events - WHERE sample_uuid IN ( - SELECT DISTINCT sample_uuid - FROM events - WHERE event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' - AND timestamp_ns < {stop_ts} - ) - AND event_type = '{SampleEvent.COMPLETE.value}' - """).fetchone() - - max_perf_ts = max_perf_ts_result[0] - if max_perf_ts is None: - # No samples were issued before stop_ts - return None - - return float(max_perf_ts - test_started_ts) - else: - # No STOP_PERFORMANCE_TRACKING, use max timestamp_ns in database - # Get max timestamp in database - max_ts_result = self.cur_.execute(""" - SELECT MAX(timestamp_ns) AS max_ts - FROM events - """).fetchone() - max_ts = max_ts_result[0] - - if check_malformed: - # Validate TEST_ENDED constraints - test_ended_result = self.cur_.execute(f""" - SELECT COUNT(*) AS n_ends, MAX(timestamp_ns) AS end_ts - FROM events - WHERE event_type = '{SessionEvent.TEST_ENDED.value}' - """).fetchone() - - n_test_ended = test_ended_result[0] - test_ended_ts = test_ended_result[1] - - if n_test_ended > 1: - raise RuntimeError( - f"Multiple TEST_ENDED events found - {n_test_ended} events" - ) - - # If TEST_ENDED exists, it must be the maximum timestamp - if n_test_ended == 1 and test_ended_ts != max_ts: - raise RuntimeError( - f"TEST_ENDED exists (timestamp_ns={test_ended_ts}) but is not the maximum timestamp in database (max={max_ts})" - ) - - if max_ts is None: - return None - - return float(max_ts - test_started_ts) - - def derive_sample_latency(self) -> RollupQueryTable: - """Calculates the end-to-end latency for each sample from issue to completion. - - Returns: - RollupQueryTable: A table containing per-sample latencies in nanoseconds. - """ - stop_ts = self.stop_performance_tracking_timestamp_ns - - # HAVING clause is different if there is a STOP_PERFORMANCE_TRACKING event - if stop_ts != float("inf"): - before_stop_ts_clause = f""" - HAVING COUNT(DISTINCT event_type) = 2 - AND MAX(CASE WHEN event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' THEN timestamp_ns END) < {stop_ts} - """ - else: - before_stop_ts_clause = """ - HAVING COUNT(DISTINCT event_type) = 2 - """ - - return self.derive_metric( - f""" - SELECT - sample_uuid, - MAX(CASE WHEN event_type = '{SampleEvent.COMPLETE.value}' THEN timestamp_ns END) - - MAX(CASE WHEN event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' THEN timestamp_ns END) AS latency - FROM events - WHERE event_type IN ('{SessionEvent.LOADGEN_ISSUE_CALLED.value}', '{SampleEvent.COMPLETE.value}') - GROUP BY sample_uuid - {before_stop_ts_clause} - """, - "sample_latency", - ) - - @profile - def get_sample_statuses(self) -> dict[int, str]: - """Returns a dictionary with the following keys: - - "total_sent" (int): The total number of samples sent - - "completed" (int): The number of samples completed - - "in_flight" (int): The number of samples in flight - """ - stop_ts = self.stop_performance_tracking_timestamp_ns - - # Build WHERE clause to filter samples issued before stop_ts - where_clause = "" - if stop_ts != float("inf"): - where_clause = f""" - WHERE sample_uuid IN ( - SELECT sample_uuid FROM events - WHERE event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' - AND timestamp_ns < {stop_ts} - ) - """ - - statuses = self.cur_.execute(f""" - SELECT - COUNT(DISTINCT CASE WHEN event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' THEN sample_uuid END) AS request_sent_count, - COUNT(DISTINCT CASE WHEN event_type = '{SampleEvent.COMPLETE.value}' THEN sample_uuid END) AS complete_count - FROM events - {where_clause} - """).fetchone() - - return { - "total_sent": statuses[0], - "completed": statuses[1], - "in_flight": statuses[0] - statuses[1], - } - - def get_error_count(self) -> int: - """Returns the number of distinct samples that encountered an error within the performance window. - - A sample with multiple ERROR events is counted only once. Only samples whose - LOADGEN_ISSUE_CALLED event occurred before the STOP_PERFORMANCE_TRACKING timestamp - are included, keeping this metric consistent with n_samples_issued and n_samples_completed. - If no STOP_PERFORMANCE_TRACKING event exists, all errored samples are counted. - - Returns: - int: The number of distinct failed sample UUIDs. - """ - stop_ts = self.stop_performance_tracking_timestamp_ns - - where_clause = "" - if stop_ts != float("inf"): - where_clause = f""" - AND sample_uuid IN ( - SELECT DISTINCT sample_uuid FROM events - WHERE event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' - AND timestamp_ns < {stop_ts} - ) - """ - - return self.cur_.execute(f""" - SELECT - COUNT(DISTINCT sample_uuid) AS error_count - FROM events - WHERE event_type = '{SessionEvent.ERROR.value}' - AND sample_uuid NOT IN ('', '') - {where_clause} - """).fetchone()[0] - - def get_sample_outputs( - self, performance_only: bool = True - ) -> list[tuple[str, bytes]]: - """Query for COMPLETE events with their data column. - - Args: - performance_only: Whether to only include samples that are in the performance window. (Default: True) - - Returns: - A list of tuples containing (sample_uuid, data_bytes) for each COMPLETE event. - Returns an empty list if no COMPLETE events are found. - """ - stop_ts = self.stop_performance_tracking_timestamp_ns - - # Build WHERE clause to filter samples issued before STOP_PERFORMANCE_TRACKING - if performance_only and stop_ts != float("inf"): - before_stop_ts_clause = f""" - AND sample_uuid IN ( - SELECT sample_uuid FROM events - WHERE event_type = '{SessionEvent.LOADGEN_ISSUE_CALLED.value}' - AND timestamp_ns < {stop_ts} - ) - """ - else: - before_stop_ts_clause = "" - - # Query for COMPLETE events with their data column - query_result = self.cur_.execute(f""" - SELECT sample_uuid, data - FROM events - WHERE event_type = '{SampleEvent.COMPLETE.value}' - {before_stop_ts_clause} - """).fetchall() - - return query_result - - @profile - def get_output_sequence_lengths( - self, tokenizer: Tokenizer - ) -> RollupQueryTable | None: - """Returns a RollupQueryTable representing per-sample output sequence lengths based on a Tokenizer. - - Reads output data from the 'data' column of COMPLETE events in the database. - - Args: - tokenizer: A Tokenizer object from HuggingFace - - Returns: - RollupQueryTable: A table containing per-sample output sequence lengths, or None if no complete events found. - """ - query_result = self.get_sample_outputs() - - rows = [] - for sample_uuid, data_bytes in query_result: - output_sequence, reasoning_sequence = output_sequence_from_data(data_bytes) - - if output_sequence is None: - continue - - # Concatenate reasoning and output if reasoning exists - if reasoning_sequence is not None: - full_sequence = f"{reasoning_sequence} {output_sequence}" - else: - full_sequence = output_sequence - - # Tokenize and calculate length - output_tokens = tokenizer.tokenize(full_sequence) - rows.append((sample_uuid, len(output_tokens))) - - if not rows: - return None - - return RollupQueryTable("output_sequence_length", None, rows) - - @profile - def derive_TPOT( - self, - tokenizer: Tokenizer, - ttft_rollup: RollupQueryTable | None = None, - sample_latency_rollup: RollupQueryTable | None = None, - condense_table: bool = True, - reporting_mode: TPOTReportingMode = TPOTReportingMode.REQUEST_WEIGHTED, - ) -> RollupQueryTable | None: - """Derives the TPOT metric from the text outputs, ttft, and sample latencies. - - Roughly, if a sample UUID `X` has a TTFT of `a`, a total latency of `b`, and an output sequence `S`, - then `X` will contribute `len(tokenize(S)) - 1` entries in the table, each with the value: - `(b - a) / len(tokenize(S) - 1)` - If the sample was completed in non-streaming mode however, then `a` is assumed to be 0, and `X` will - instead contribute `len(tokenize(S))` entries, each with the value: `b / len(tokenize(S))` - - TPOT tracks the time it takes for each token after the first to be generated (in streaming mode). Since - the client does not have direct visibility into the endpoint / server-under-test, we have to estimate it, - assuming that in an ideal scenario, each token outputed in the output text took the same amount of - time. - - Args: - tokenizer: A Tokenizer object from HuggingFace, used to calculate the number of tokens in a sequence - ttft_rollup: Precomputed TTFT RollupQueryTable. If not provided, will be derived via self.derive_TTFT() - sample_latency_rollup: Precomputed sample latency RollupQueryTable. If not provided, will be derived via self.derive_sample_latency() - condense_table: Whether to condense the table by not storing individual token times, but rather just keeping the average time per token - and number of tokens per sample UUID. This is only supported if reporting_mode is TOKEN_WEIGHTED. - If reporting_mode is REQUEST_WEIGHTED, each sample only contributes one entry to the table. (Default: True) - reporting_mode: TPOT reporting mode (REQUEST_WEIGHTED or TOKEN_WEIGHTED). (Default: REQUEST_WEIGHTED) - """ - if ttft_rollup is None: - ttft_rollup = self.derive_TTFT() - - # If no TTFT data available, TPOT cannot be calculated accurately for streaming mode - if len(ttft_rollup) == 0: - return None - - if sample_latency_rollup is None: - sample_latency_rollup = self.derive_sample_latency() - - # Query for COMPLETE events with their data column - query_result = self.get_sample_outputs() - - if not query_result: - return None - - rows = [] - if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED: - repeats = [] - else: - repeats = None - - for sample_uuid, data_bytes in query_result: - if data_bytes is None or len(data_bytes) == 0: - continue - - # Extract output from decoded data - # For TPOT calculation, we need the output to be a list of chunks (streaming mode) with at least 2 - # elements - output_sequence, reasoning_sequence = output_sequence_from_data( - data_bytes, join_chunks=False - ) - if isinstance(output_sequence, str): - output_sequence = [output_sequence] - if not isinstance(output_sequence, list): - logging.warning( - f"Output sequence for sample {sample_uuid} is not a list but {type(output_sequence)}: {output_sequence}" - ) - continue - - all_chunks = output_sequence - if isinstance(reasoning_sequence, list): - all_chunks.extend(reasoning_sequence) - - # For TPOT, we need streaming data (list of chunks with at least 2 elements) - if len(all_chunks) < 2: - continue - - # Skip samples that are not in the filtered rollups (i.e., issued after STOP_PERFORMANCE_TRACKING) - if sample_uuid not in sample_latency_rollup: - continue - - # Output can be in one of two formats depending on the issuer: - # 1. A list of all chunks (i.e. ['chunk1', 'chunk2', ...]) - # 2. A 2 item list of ['chunk1', 'chunk2chunk3...'] - # Both of these are valid as we only need to distinguish the first chunk for the purposes of TPOT calculation. - # The choice is up to the issuer implementation depending on performance considerations. - - # Join list elements to get the non-first chunk text - if len(all_chunks) > 2: - non_first_chunk = "".join(str(chunk) for chunk in all_chunks[1:]) - else: - non_first_chunk = str(all_chunks[1]) - - if len(non_first_chunk) == 0: - # Possible malformed output data where empty string is included as a non-first chunk - continue - - non_first_tokens = tokenizer.tokenize(non_first_chunk) - n_non_first_tokens = len(non_first_tokens) - - latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True) - if latency is None: - raise SampleUUIDNotFoundError(sample_uuid, "events record") - - ttft = ttft_rollup.filter_uuid(sample_uuid, only_first=True) - if ttft is None: - # Non-streaming mode for this sample - error - raise RuntimeError( - f"No TTFT found for sample {sample_uuid} in streaming mode" - ) - - avg_tpot = (latency - ttft) / n_non_first_tokens - - if condense_table: - rows.append((sample_uuid, avg_tpot)) - if reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED: - repeats.append(n_non_first_tokens) - else: - # Entries are tuples, and are such immutable. We can use list multiplication for performance - repeat_fac = ( - 1 - if reporting_mode == TPOTReportingMode.REQUEST_WEIGHTED - else n_non_first_tokens - ) - rows.extend([(sample_uuid, avg_tpot)] * repeat_fac) - - if not rows: - return None - - return RollupQueryTable("tpot", None, rows, repeats=repeats) - - def close(self): - if self.is_closed: - logging.debug("Connection is already closed, skipping close") - return - self.is_closed = True - - if self.cur_ is not self.conn: - self.cur_.close() - self.conn.close() - - def dump_to_json(self, json_path: Path): - """ - Dumps all events to a JSONL file, including decoded output data from the 'data' column. - Each line in the output file is a valid JSON object. - """ - - with json_path.open("w", encoding="utf-8", newline="") as f: - query_result = self.cur_.execute( - "SELECT sample_uuid, event_type, timestamp_ns, data FROM events" - ) - while True: - if hasattr(query_result, "fetchmany"): - rows = query_result.fetchmany(1000) - else: - rows = query_result.fetchall() - - if not rows: - break - - for sample_uuid, event_type, timestamp_ns, data_bytes in rows: - value = "" - - # For events with data, decode and extract the relevant value - if data_bytes is not None and len(data_bytes) > 0: - if event_type == SampleEvent.COMPLETE.value: - # For COMPLETE, use helper method to extract output sequence - output_seq, reasoning_seq = output_sequence_from_data( - data_bytes - ) - if output_seq is not None: - if reasoning_seq is not None: - value = f"[reasoning: {reasoning_seq}] {output_seq}" - else: - value = output_seq - elif event_type in ( - SampleEvent.FIRST_CHUNK.value, - SessionEvent.ERROR.value, - ): - # For other event types, just decode and stringify - try: - decoded_data = msgspec.json.decode(data_bytes) - value = str(decoded_data) if decoded_data else "" - except (msgspec.DecodeError, TypeError) as e: - value = f"" - - approx_datetime_str = monotime_to_datetime(timestamp_ns).isoformat() - - json_obj = { - "sample_uuid": sample_uuid, - "event_type": event_type, - "timestamp_ns": timestamp_ns, - "approx_datetime_str": approx_datetime_str, - "value": value, - } - f.write( - msgspec.json.encode(dict(sorted(json_obj.items()))).decode( - "utf-8" - ) - + "\n" - ) - - def __enter__(self): - if self.is_closed: - self.init_connection() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def _get_version_info(self) -> dict[str, str | None]: - """Extract version info from TEST_STARTED event data. - - Returns: - Dictionary with 'version' and 'git_sha' keys. - """ - query = f""" - SELECT data FROM events - WHERE event_type = '{SessionEvent.TEST_STARTED.value}' - LIMIT 1 - """ - result = self.cur_.execute(query).fetchone() - if result and result[0]: - try: - return msgspec.json.decode(result[0]) - except Exception: - pass - return {"version": "unknown", "git_sha": None} - - def get_test_started_at(self) -> int | None: - """Gets the timestamp of the TEST_STARTED event. - - Returns: - int|None: The timestamp of the TEST_STARTED event in nanoseconds, or None if not found. - """ - query = f""" - SELECT timestamp_ns FROM events - WHERE event_type = '{SessionEvent.TEST_STARTED.value}' - ORDER BY timestamp_ns ASC - LIMIT 1""" - result = self.cur_.execute(query).fetchone() - if result and result[0]: - return result[0] - return None - - def create_report( - self, - tokenizer: Tokenizer | None = None, - tpot_reporting_mode: TPOTReportingMode = TPOTReportingMode.REQUEST_WEIGHTED, - ) -> Report: - """Creates a Report object from the metrics. - - Args: - tokenizer: A Tokenizer object from HuggingFace. If provided, output sequence lengths will be calculated. - tpot_reporting_mode: TPOT reporting mode (REQUEST_WEIGHTED or TOKEN_WEIGHTED). (Default: REQUEST_WEIGHTED) - - Returns: - Report: A Report object containing the metrics. - """ - test_started_at = self.get_test_started_at() - if test_started_at is None: - raise RuntimeError("TEST_STARTED event not found in database") - - sample_statuses = self.get_sample_statuses() - ttft_rollup = self.derive_TTFT() - sample_latency_rollup = self.derive_sample_latency() - output_sequence_lengths = None - tpot_summary = None - if tokenizer is not None: - osl_rollup = self.get_output_sequence_lengths(tokenizer) - if osl_rollup is not None: - output_sequence_lengths = osl_rollup.summarize() - - # Only calculate TPOT if TTFT data is available (streaming mode) - if len(ttft_rollup) > 0: - tpot_rollup = self.derive_TPOT( - tokenizer, - ttft_rollup=ttft_rollup, - sample_latency_rollup=sample_latency_rollup, - reporting_mode=tpot_reporting_mode, - ) - if tpot_rollup is not None: - tpot_summary = tpot_rollup.summarize() - - if len(ttft_rollup) == 0: - ttft_summary = None - else: - ttft_summary = ttft_rollup.summarize() - - # Extract version information - version_info = self._get_version_info() - - return Report( - version=version_info.get("version", "unknown"), - git_sha=version_info.get("git_sha"), - test_started_at=test_started_at, - n_samples_issued=sample_statuses["total_sent"], - n_samples_completed=sample_statuses["completed"], - n_samples_failed=self.get_error_count(), - duration_ns=self.derive_duration(), - ttft=ttft_summary, - tpot=tpot_summary, - latency=sample_latency_rollup.summarize(), - output_sequence_lengths=output_sequence_lengths, - tpot_reporting_mode=tpot_reporting_mode, - ) diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 4400766c..6cb23ed8 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -57,10 +57,8 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: chunk = StreamChunk( id=self.query_id, response_chunk=content, - is_complete=False, metadata={ "first_chunk": not self.first_chunk_sent, - "final_chunk": False, }, ) self.first_chunk_sent = True diff --git a/src/inference_endpoint/profiling/line_profiler.py b/src/inference_endpoint/profiling/line_profiler.py index 79aaa144..56c2d659 100644 --- a/src/inference_endpoint/profiling/line_profiler.py +++ b/src/inference_endpoint/profiling/line_profiler.py @@ -28,6 +28,11 @@ import io import os import sys + +try: + from line_profiler import LineProfiler +except ImportError: + LineProfiler = None from collections.abc import Callable from pathlib import Path from typing import Any, Optional, TypeVar @@ -68,18 +73,15 @@ def __init__(self): self._atexit_registered = False if self.enabled: - try: - from line_profiler import LineProfiler - - self.profiler = LineProfiler() - self.profiler.enable() - atexit.register(self._safe_cleanup) - self._atexit_registered = True - except ImportError as e: + if LineProfiler is None: raise ImportError( f"line_profiler not installed but {ENV_VAR_ENABLE_LINE_PROFILER}={enable_profiler} is set. " f"Install with: pip install line_profiler" - ) from e + ) + self.profiler = LineProfiler() + self.profiler.enable() + atexit.register(self._safe_cleanup) + self._atexit_registered = True def _safe_cleanup(self): """Safe cleanup wrapper that suppresses all errors during atexit.""" diff --git a/src/inference_endpoint/profiling/pytest_profiling_plugin.py b/src/inference_endpoint/profiling/pytest_profiling_plugin.py index 71641cf9..3a268061 100644 --- a/src/inference_endpoint/profiling/pytest_profiling_plugin.py +++ b/src/inference_endpoint/profiling/pytest_profiling_plugin.py @@ -26,6 +26,7 @@ import atexit import glob import os +import shutil import sys from inference_endpoint.profiling import shutdown @@ -102,8 +103,6 @@ def _print_worker_profiles(): def _cleanup_profile_files(output_file: str): """Remove profile directory and files after displaying results.""" try: - import shutil - profile_dir = os.path.dirname(output_file) if profile_dir and os.path.exists(profile_dir): shutil.rmtree(profile_dir, ignore_errors=True) diff --git a/src/inference_endpoint/sglang/accumulator.py b/src/inference_endpoint/sglang/accumulator.py index 29579e7f..081106eb 100644 --- a/src/inference_endpoint/sglang/accumulator.py +++ b/src/inference_endpoint/sglang/accumulator.py @@ -65,7 +65,6 @@ def add_chunk(self, delta: SGLangSSEDelta) -> StreamChunk | None: chunk = StreamChunk( id=self.query_id, response_chunk=content_diff, - is_complete=False, metadata=metadata, ) self.first_chunk_sent = True diff --git a/src/inference_endpoint/testing/echo_server.py b/src/inference_endpoint/testing/echo_server.py index a957c39a..6555f2e6 100644 --- a/src/inference_endpoint/testing/echo_server.py +++ b/src/inference_endpoint/testing/echo_server.py @@ -29,6 +29,7 @@ from inference_endpoint.core.types import QueryResult, TextModelOutput from inference_endpoint.openai.openai_adapter import OpenAIAdapter from inference_endpoint.openai.openai_types_gen import CreateChatCompletionRequest +from inference_endpoint.utils.logging import setup_logging class HTTPServer: @@ -427,8 +428,6 @@ def main(): """ # - from inference_endpoint.utils.logging import setup_logging - setup_logging() parser = create_parser() args = parser.parse_args() diff --git a/src/inference_endpoint/testing/max_throughput_server.py b/src/inference_endpoint/testing/max_throughput_server.py index 63bacb38..12f3267d 100644 --- a/src/inference_endpoint/testing/max_throughput_server.py +++ b/src/inference_endpoint/testing/max_throughput_server.py @@ -30,12 +30,14 @@ import argparse import asyncio +import gc import multiprocessing import multiprocessing.sharedctypes import multiprocessing.synchronize import os import signal import socket +import sys import threading import time @@ -301,8 +303,6 @@ def _worker( global _req_counter, _resp_counter, _byte_counter _req_counter, _resp_counter, _byte_counter = counters - import gc - gc.disable() uvloop.install() @@ -329,8 +329,6 @@ def protocol_factory(): try: asyncio.run(run()) except Exception as exc: - import sys - print( f"[MaxThroughputServer] Worker {wid} failed: {exc}", file=sys.stderr, diff --git a/src/inference_endpoint/testing/variable_throughput_server.py b/src/inference_endpoint/testing/variable_throughput_server.py index fae8d1a9..b640b6b4 100644 --- a/src/inference_endpoint/testing/variable_throughput_server.py +++ b/src/inference_endpoint/testing/variable_throughput_server.py @@ -49,6 +49,7 @@ import argparse import asyncio +import gc import math import multiprocessing import multiprocessing.sharedctypes @@ -57,8 +58,10 @@ import random import signal import socket +import sys import threading import time +import warnings import httptools import uvloop @@ -446,8 +449,6 @@ def _worker( global _req_counter, _resp_counter, _byte_counter _req_counter, _resp_counter, _byte_counter = counters - import gc - gc.disable() uvloop.install() @@ -501,8 +502,6 @@ def protocol_factory(): try: asyncio.run(run()) except Exception as exc: - import sys - print( f"[VariableResponseServer] Worker {wid} failed: {exc}", file=sys.stderr, @@ -648,8 +647,6 @@ def __init__( if max_concurrency > 0: self._max_concurrency_per_worker = max(1, max_concurrency // num_workers) if max_concurrency < num_workers: - import warnings - warnings.warn( f"max_concurrency ({max_concurrency}) < num_workers ({num_workers}): " f"each worker gets 1 slot, effective total={num_workers} exceeds cap.", diff --git a/src/inference_endpoint/utils/benchmark_httpclient.py b/src/inference_endpoint/utils/benchmark_httpclient.py index 3785af21..cb0e4ecb 100644 --- a/src/inference_endpoint/utils/benchmark_httpclient.py +++ b/src/inference_endpoint/utils/benchmark_httpclient.py @@ -37,6 +37,8 @@ import time from dataclasses import dataclass +import uvloop + from inference_endpoint.core.types import Query, QueryResult from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.cpu_affinity import ( @@ -50,6 +52,19 @@ build_response, ) +try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.colors as mcolors + import matplotlib.pyplot as plt + import matplotlib.ticker as ticker +except ImportError: + matplotlib = None + mcolors = None + plt = None + ticker = None + # Suppress transformers "no framework found" warning (only tokenizers used) os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") @@ -1054,14 +1069,7 @@ def generate_sweep_plot( 4 params: MxNxK facet grid — rows=param3, columns=param4. Where N = 2 (non-streaming) or 3 (streaming, adds SSE Rate). """ - try: - import matplotlib - - matplotlib.use("Agg") - import matplotlib.colors as mcolors - import matplotlib.pyplot as plt - import matplotlib.ticker as ticker - except ImportError: + if plt is None: print("\nMatplotlib not installed. Skipping plot generation.") print(" Install with: pip install matplotlib") return @@ -1446,9 +1454,6 @@ def main() -> None: ) gc.set_threshold(70000, 10, 100) - - import uvloop - uvloop.install() server: MaxThroughputServer | None = None diff --git a/tests/conftest.py b/tests/conftest.py index 7425d35d..d69d3c75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,22 +22,17 @@ import logging import os import random -import sqlite3 import sys import uuid from pathlib import Path from typing import Any -import msgspec.json import pytest from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.core.types import TextModelOutput from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat from inference_endpoint.dataset_manager.transforms import ColumnRemap -from inference_endpoint.load_generator.events import SampleEvent, SessionEvent -from inference_endpoint.load_generator.sample import SampleEventHandler from inference_endpoint.testing.docker_server import DockerServer from inference_endpoint.testing.echo_server import EchoServer, HTTPServer @@ -241,69 +236,6 @@ def fake_outputs(sample_uuids): } -@pytest.fixture -def events_db(tmp_path, sample_uuids, fake_outputs): - """Returns a sample in-memory sqlite database for events. - This database contains events for 3 sent queries, but only 2 are completed. The 3rd query has no 'received' events. - """ - logger.info(f"Creating events database at {tmp_path}") - test_db = str(tmp_path / f"test_events_{uuid.uuid4().hex}.db") - conn = sqlite3.connect(test_db) - cur = conn.cursor() - cur.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - - # Use deterministic UUIDs for testing - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - # Define output data for COMPLETE events - events = [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - (uuid2, SampleEvent.FIRST_CHUNK.value, 10190, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10201, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10202, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10203, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10210, b""), - (uuid3, SessionEvent.ERROR.value, 10211, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10211, b""), - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - msgspec.json.encode(TextModelOutput(output=tuple(fake_outputs[uuid1]))), - ), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10214, b""), - (uuid3, SessionEvent.ERROR.value, 10216, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10217, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10219, b""), - ( - uuid2, - SampleEvent.COMPLETE.value, - 10219, - msgspec.json.encode(TextModelOutput(output=tuple(fake_outputs[uuid2]))), - ), - (uuid3, SessionEvent.ERROR.value, 10225, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ] - cur.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - events, - ) - conn.commit() - yield test_db - - cur.close() - conn.close() - Path(test_db).unlink() - logger.info(f"Events database at {test_db} deleted") - - class CharacterTokenizer: def tokenize(self, text: str) -> list[str]: return list(text) @@ -534,11 +466,3 @@ def concurrency_runtime_settings(random_seed, target_concurrency): type=LoadPatternType.CONCURRENCY, target_concurrency=target_concurrency ), ) - - -@pytest.fixture -def clean_sample_event_hooks(): - """Fixture to ensure SampleEventHandler hooks are cleared before and after each test.""" - SampleEventHandler.clear_hooks() - yield SampleEventHandler - SampleEventHandler.clear_hooks() diff --git a/tests/futures_client.py b/tests/futures_client.py index 24bcf05d..871378e1 100644 --- a/tests/futures_client.py +++ b/tests/futures_client.py @@ -69,7 +69,7 @@ async def _handle_responses(self): break # None signals transport closed - exit handler match response: - case StreamChunk(is_complete=False): + case StreamChunk(): # Intermediate stream chunk - future stays pending pass diff --git a/tests/integration/commands/test_accuracy_pipeline.py b/tests/integration/commands/test_accuracy_pipeline.py new file mode 100644 index 00000000..8574599f --- /dev/null +++ b/tests/integration/commands/test_accuracy_pipeline.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: full accuracy scoring pipeline with echo server. + +The echo server returns the user message content unchanged. We create a +dataset where some prompts match their ground_truth (correct) and some +don't (incorrect), then verify the scorer produces the expected accuracy. +""" + +import json +from pathlib import Path + +import msgspec.json +import pandas as pd +import pytest +from inference_endpoint.commands.benchmark.execute import run_benchmark +from inference_endpoint.config.schema import ( + AccuracyConfig, + BenchmarkConfig, + DatasetType, + EndpointConfig, + LoadPattern, + LoadPatternType, + ModelParams, + RuntimeConfig, + Settings, + StreamingMode, + TestMode, + TestType, +) +from inference_endpoint.config.schema import Dataset as DatasetConfig +from inference_endpoint.endpoint_client.config import HTTPClientConfig + + +def _create_accuracy_dataset(tmp_path: Path) -> Path: + """Create a CSV dataset with some correct and some incorrect ground truths. + + The echo server returns the prompt verbatim. So: + - If ground_truth == prompt → score 1.0 (correct) + - If ground_truth != prompt → score 0.0 (incorrect) + + Dataset: 5 samples, 3 correct + 2 incorrect = 60% accuracy. + """ + data = { + "prompt": [ + "alpha", # correct: echo returns "alpha", ground_truth is "alpha" + "beta", # correct + "gamma", # correct + "What is the answer?", # INCORRECT: echo returns prompt, ground_truth is "42" + "Tell me a joke", # INCORRECT: echo returns prompt, ground_truth is "knock knock" + ], + "ground_truth": [ + "alpha", + "beta", + "gamma", + "42", + "knock knock", + ], + } + df = pd.DataFrame(data) + csv_path = tmp_path / "accuracy_dataset.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +def _create_perf_dataset(tmp_path: Path) -> Path: + """Create a minimal perf dataset (CSV with prompt column).""" + data = {"prompt": ["hello"] * 3} + df = pd.DataFrame(data) + csv_path = tmp_path / "perf_dataset.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +@pytest.mark.integration +class TestAccuracyPipeline: + def test_accuracy_scoring_with_echo_server( + self, mock_http_echo_server, tmp_path, caplog + ): + """Full end-to-end: perf phase + accuracy phase + scoring. + + Expected: 3/5 correct = 60% accuracy (0.6 score). + """ + perf_path = _create_perf_dataset(tmp_path) + acc_path = _create_accuracy_dataset(tmp_path) + + report_dir = tmp_path / "report" + config = BenchmarkConfig( + type=TestType.OFFLINE, + endpoint_config=EndpointConfig(endpoints=[mock_http_echo_server.url]), + model_params=ModelParams(name="echo-server", streaming=StreamingMode.OFF), + datasets=[ + DatasetConfig( + path=str(perf_path), + type=DatasetType.PERFORMANCE, + ), + DatasetConfig( + name="echo_accuracy", + path=str(acc_path), + type=DatasetType.ACCURACY, + accuracy_config=AccuracyConfig( + eval_method="string_match", + ground_truth="ground_truth", + extractor="identity_extractor", + ), + ), + ], + settings=Settings( + runtime=RuntimeConfig(min_duration_ms=0), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + client=HTTPClientConfig( + num_workers=1, warmup_connections=0, max_connections=10 + ), + ), + report_dir=str(report_dir), + ) + + with caplog.at_level("INFO"): + run_benchmark(config, TestMode.BOTH) + + # Verify scoring artifacts were written + assert (report_dir / "sample_idx_map.json").exists() + assert (report_dir / "events.jsonl").exists() + + # Verify sample_idx_map has both phases + with (report_dir / "sample_idx_map.json").open("rb") as f: + idx_map = msgspec.json.decode(f.read()) + assert "performance" in idx_map + assert "echo_accuracy" in idx_map + assert len(idx_map["echo_accuracy"]) == 5 # 5 accuracy samples + + # Verify events.jsonl has COMPLETE events (EventRecord format: "sample.complete") + events_path = report_dir / "events.jsonl" + with events_path.open() as f: + events = [msgspec.json.decode(line.strip()) for line in f if line.strip()] + complete_events = [ + e for e in events if e.get("event_type") == "sample.complete" + ] + # Should have both perf (3) and accuracy (5) completions + assert len(complete_events) == 8 + + # Verify results.json was written with accuracy scores + results_path = report_dir / "results.json" + assert results_path.exists() + with results_path.open() as f: + results = json.load(f) + + assert "accuracy_scores" in results + assert "echo_accuracy" in results["accuracy_scores"] + score_data = results["accuracy_scores"]["echo_accuracy"] + score = score_data["score"] + + # 3 correct out of 5 = 0.6 accuracy + assert abs(score - 0.6) < 0.01, f"Expected 0.6, got {score}" + + # Verify logs mention scoring + assert "Score for echo_accuracy" in caplog.text diff --git a/tests/integration/commands/test_benchmark_command.py b/tests/integration/commands/test_benchmark_command.py index a51afe8c..f052a90b 100644 --- a/tests/integration/commands/test_benchmark_command.py +++ b/tests/integration/commands/test_benchmark_command.py @@ -16,6 +16,7 @@ """Integration tests for benchmark commands against echo server.""" import json +import os import re from pathlib import Path @@ -86,7 +87,7 @@ def test_offline_benchmark( assert "Completed in" in caplog.text assert "successful" in caplog.text assert "QPS:" in caplog.text - assert "MaxThroughputScheduler" in caplog.text + assert "Starting phase:" in caplog.text @pytest.mark.integration @pytest.mark.parametrize("streaming", [StreamingMode.OFF, StreamingMode.ON]) @@ -105,8 +106,7 @@ def test_poisson_benchmark( assert "Completed in" in caplog.text assert "successful" in caplog.text - assert "PoissonDistributionScheduler" in caplog.text - assert "50" in caplog.text + assert "Starting phase:" in caplog.text @pytest.mark.integration @pytest.mark.parametrize("streaming", [StreamingMode.OFF, StreamingMode.ON]) @@ -213,6 +213,10 @@ class TestTemplateIntegration: """Verify generated templates run end-to-end against a local server.""" @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_TOKEN"), + reason="Templates reference gated HF models; requires HF_TOKEN to fetch tokenizer", + ) @pytest.mark.parametrize("template", _GENERATED_TEMPLATES) def test_template_runs(self, mock_http_echo_server, tmp_path, caplog, template): data = _resolve_template(TEMPLATE_DIR / template, mock_http_echo_server.url) diff --git a/tests/integration/test_end_to_end_oracle.py b/tests/integration/test_end_to_end_oracle.py index c2c55feb..ecb16cd0 100644 --- a/tests/integration/test_end_to_end_oracle.py +++ b/tests/integration/test_end_to_end_oracle.py @@ -13,15 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +"""End-to-end oracle test: verify responses match expected dataset outputs. + +Uses the async BenchmarkSession to issue all samples to a mock oracle server, +then checks each response against the expected ground-truth output. +""" + +import asyncio import random -from pathlib import Path from urllib.parse import urljoin import pytest from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.core.record import EventRecord from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager import Dataset from inference_endpoint.dataset_manager.transforms import ( @@ -30,199 +36,138 @@ ) from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient -from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer -from inference_endpoint.load_generator import ( +from inference_endpoint.endpoint_client.http_sample_issuer import ( + HttpClientSampleIssuer, +) +from inference_endpoint.load_generator.session import ( BenchmarkSession, - MaxThroughputScheduler, - SampleEvent, - SampleEventHandler, - WithoutReplacementSampleOrder, + PhaseConfig, ) -class DeepSeekR1SampleIssuer(HttpClientSampleIssuer): - def __init__(self, tmp_path: Path, url: str): - self.http_config = HTTPClientConfig( - endpoint_urls=[urljoin(url, "/v1/chat/completions")], - warmup_connections=0, - ) - super().__init__( - HTTPEndpointClient( - self.http_config, - ) - ) +class _NoOpPublisher: + """Minimal EventPublisher that discards all events.""" + def publish(self, event_record: EventRecord) -> None: + pass -async def run_benchmark(server_url, dataloader, tmp_path, rt_settings): - # Step 1. Register the complete hook to store the responses from the server. - server_responses: {str: str} = {} + def flush(self) -> None: + pass - def on_complete_hook(result: QueryResult): - """Callback to store the responses from the server.""" - server_responses[result.id] = result.get_response_output_string() - SampleEventHandler.register_hook(SampleEvent.COMPLETE, on_complete_hook) +async def _run_oracle_test(url: str, dataloader: Dataset, rt_settings: RuntimeSettings): + """Run benchmark session against an oracle server and verify responses.""" + loop = asyncio.get_running_loop() + n_samples = dataloader.num_samples() - # Step 2. Create the scheduler. - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, + # Collect responses via callback + responses: dict[str, str] = {} + + def on_complete(result: QueryResult) -> None: + responses[result.id] = result.get_response_output_string() + + # Create HTTP client with warmup disabled (test server) + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) - sample_issuer = None try: - # Step 3. Create the sample issuer. - sample_issuer = DeepSeekR1SampleIssuer(tmp_path, server_url) - - # Step 4. Create the benchmark session. - sess = BenchmarkSession.start( - rt_settings, - dataloader, - sample_issuer, - scheduler, - name="pytest_run_benchmark", - max_shutdown_timeout_s=3 * 60, + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, ) - - # Step 5. Wait for the test to end. - logging.info("Waiting for the test to end...") - sess.wait_for_test_end() - # Step 6. Return the sample UUID map and the server responses. - return sess.sample_uuid_map, server_responses + phases = [PhaseConfig("performance", rt_settings, dataloader)] + result = await asyncio.wait_for(session.run(phases), timeout=60.0) finally: - # Step 7. Shutdown the sample issuer and the HTTP client. - if sample_issuer is not None: - sample_issuer.shutdown() - sample_issuer.http_client.shutdown() - + await http_client.shutdown_async() + + # Verify all samples got responses + assert result.perf_results[0].issued_count == n_samples + assert len(responses) == n_samples + + # Build expected values from dataset + expected = {} + for i in range(n_samples): + entry = dataloader.load_sample(i) + expected[i] = entry["output"] + + # Verify each response matches ground truth + uuid_to_index = result.perf_results[0].uuid_to_index + for sample_uuid, resp in responses.items(): + sample_index = uuid_to_index[sample_uuid] + assert resp == expected[sample_index], ( + f"Sample {sample_uuid} (index {sample_index}): " + f"expected {expected[sample_index][:50]!r}, got {resp[:50]!r}" + ) -""" -Test the load generator full run with a given URL. -""" + return responses -async def _run_load_generator_full_run_url( - url, dataset_path, tmp_path, clean_sample_event_hooks, hf_model_name +@pytest.mark.integration +@pytest.mark.asyncio +async def test_load_generator_full_run_mock_http_oracle_server( + mock_http_oracle_server, + ds_dataset_path, + hf_model_name, ): dummy_dataloader = Dataset.load_from_file( - dataset_path, + ds_dataset_path, transforms=[ ColumnRemap({"text_input": "prompt", "ref_output": "output"}), AddStaticColumns({"model": hf_model_name}), ], ) dummy_dataloader.load() - assert dummy_dataloader.num_samples() > 0 + n_samples = dummy_dataloader.num_samples() + assert n_samples > 0 rt_settings = RuntimeSettings( - metrics.Throughput(50), - [metrics.Throughput(50)], - min_duration_ms=1_00, - max_duration_ms=1_000, - n_samples_from_dataset=dummy_dataloader.num_samples(), - n_samples_to_issue=dummy_dataloader.num_samples(), + metrics.Throughput(5000), + [metrics.Throughput(5000)], + min_duration_ms=0, + max_duration_ms=60_000, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, + min_sample_count=1, rng_sched=random.Random(1234), rng_sample_index=random.Random(1234), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, - ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") - # Now call the benchmark - sample_uuid_map, response_cache = await run_benchmark( - url, dummy_dataloader, tmp_path, rt_settings - ) - num_responses_in_cache = len(response_cache) - assert ( - num_responses_in_cache == scheduler.total_samples_to_issue - ), "Number of samples in response cache and number of samples in dataset should be the same" - vals = {} - for i in range(dummy_dataloader.num_samples()): - entry = dummy_dataloader.load_sample(i) - vals[i] = entry["output"] - num_samples_in_dataset = len(vals) - logging.info(f"Number of samples in dataset: {num_samples_in_dataset}") - logging.info(f"Total samples to issue: {scheduler.total_samples_to_issue}") - logging.info(f"Request data: {num_responses_in_cache}") - - for sample_uuid, resp in response_cache.items(): - if resp is None: - logging.error(f"Sample {sample_uuid} has no response") - else: - sample_index = sample_uuid_map[sample_uuid].index - logging.info( - f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" - ) + await _run_oracle_test(mock_http_oracle_server.url, dummy_dataloader, rt_settings) -@pytest.mark.asyncio -async def test_load_generator_full_run_mock_http_oracle_server( - mock_http_oracle_server, - ds_dataset_path, - tmp_path, - clean_sample_event_hooks, - hf_model_name, -): +async def _run_load_generator_full_run_url(url, dataset_path, hf_model_name): dummy_dataloader = Dataset.load_from_file( - ds_dataset_path, + dataset_path, transforms=[ ColumnRemap({"text_input": "prompt", "ref_output": "output"}), AddStaticColumns({"model": hf_model_name}), ], ) dummy_dataloader.load() - assert dummy_dataloader.num_samples() > 0 + n_samples = dummy_dataloader.num_samples() + assert n_samples > 0 rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1_000, - max_duration_ms=10_000_000, - n_samples_from_dataset=dummy_dataloader.num_samples(), - n_samples_to_issue=dummy_dataloader.num_samples(), - min_sample_count=1, + metrics.Throughput(50), + [metrics.Throughput(50)], + min_duration_ms=0, + max_duration_ms=60_000, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, rng_sched=random.Random(1234), rng_sample_index=random.Random(1234), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, - ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") - - sample_uuid_map, response_cache = await run_benchmark( - mock_http_oracle_server.url, dummy_dataloader, tmp_path, rt_settings - ) - num_responses_in_cache = len(response_cache) - assert ( - num_responses_in_cache == scheduler.total_samples_to_issue - ), "Number of samples in response cache and number of samples in dataset should be the same" - vals = {} - for i in range(dummy_dataloader.num_samples()): - entry = dummy_dataloader.load_sample(i) - vals[i] = entry["output"] - num_samples_in_dataset = len(vals) - logging.info(f"Number of samples in dataset: {num_samples_in_dataset}") - logging.info(f"Total samples to issue: {scheduler.total_samples_to_issue}") - logging.info(f"Request data: {num_responses_in_cache}") - assert ( - num_samples_in_dataset == scheduler.total_samples_to_issue - ), "Number of samples in dataset and number of samples in request data should be the same" - - for sample_uuid, resp in response_cache.items(): - sample_index = sample_uuid_map["performance"][sample_uuid] - logging.info( - f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" - ) - assert ( - resp == vals[sample_index] - ), f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" + await _run_oracle_test(url, dummy_dataloader, rt_settings) @pytest.mark.asyncio @@ -232,16 +177,10 @@ async def test_load_generator_full_run_mock_http_oracle_server( async def test_load_generator_full_run_vllm_docker_server( vllm_docker_server, ds_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( - vllm_docker_server.url, - ds_dataset_path, - tmp_path, - clean_sample_event_hooks, - hf_model_name, + vllm_docker_server.url, ds_dataset_path, hf_model_name ) @@ -252,16 +191,10 @@ async def test_load_generator_full_run_vllm_docker_server( async def test_load_generator_full_run_sglang_docker_server( sglang_docker_server, ds_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( - sglang_docker_server.url, - ds_dataset_path, - tmp_path, - clean_sample_event_hooks, - hf_model_name, + sglang_docker_server.url, ds_dataset_path, hf_model_name ) @@ -272,14 +205,8 @@ async def test_load_generator_full_run_sglang_docker_server( async def test_load_generator_full_run_trtllm_docker_server( trtllm_docker_server, ds_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( - trtllm_docker_server.url, - ds_dataset_path, - tmp_path, - clean_sample_event_hooks, - hf_model_name, + trtllm_docker_server.url, ds_dataset_path, hf_model_name ) diff --git a/tests/performance/async_utils/transport/test_zmq.py b/tests/performance/async_utils/transport/test_zmq.py index 4df5cebc..f31baa51 100644 --- a/tests/performance/async_utils/transport/test_zmq.py +++ b/tests/performance/async_utils/transport/test_zmq.py @@ -77,7 +77,6 @@ def make_stream_chunk(payload_chars: int, idx: int) -> StreamChunk: return StreamChunk( id=str(idx), response_chunk="x" * payload_chars, - is_complete=False, ) diff --git a/tests/performance/test_recorder.py b/tests/performance/test_recorder.py deleted file mode 100644 index f2d9cc3b..00000000 --- a/tests/performance/test_recorder.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import random -import time -from dataclasses import dataclass, fields -from pathlib import Path -from typing import TextIO - -import pytest -from inference_endpoint.load_generator.events import SampleEvent, SessionEvent -from inference_endpoint.metrics.recorder import EventRecorder -from inference_endpoint.metrics.reporter import MetricsReporter -from inference_endpoint.profiling.line_profiler import ENV_VAR_ENABLE_LINE_PROFILER - - -def get_EventRecorder(*args, **kwargs): - # Set requirement to 128MB for testing - return EventRecorder(*args, min_memory_req_bytes=128 * 1024 * 1024, **kwargs) - - -class CharTokenizer: - def tokenize(self, text: str) -> list[str]: - return list(text) - - -class TimingLog: - def __init__(self, log_file: Path | str | None = None): - if log_file is None: - log_file = Path("/tmp/recorder_timing_log.txt") - self.log_file = Path(log_file) - self.f_obj: TextIO | None = None - - def __enter__(self): - if self.f_obj is not None: - raise ValueError("TimingLog already open") - self.f_obj = self.log_file.open("a") - return self - - def __exit__(self, exc_type, exc_value, traceback): - assert self.f_obj is not None - self.f_obj.close() - self.f_obj = None - - def log(self, key: str, duration_sec: float, variant: str = "default"): - assert self.f_obj is not None - self.f_obj.write(f"[{key}] {variant}: {duration_sec} sec.\n") - - -@pytest.fixture -def timing_log(tmp_path): - with TimingLog(tmp_path / "timing_log.txt") as log: - yield log - - -@pytest.fixture -def check_time_fn(timing_log): - def check_time( - fn, - thresh, - *args, - log_key: str, - variant: str = "default", - rel_tol=0.05, - **kwargs, - ): - start_time = time.monotonic_ns() - r = fn(*args, **kwargs) - end_time = time.monotonic_ns() - duration_sec = (end_time - start_time) / 1e9 - timing_log.log(log_key, duration_sec, variant=variant) - - upper_limit = thresh * (1 + rel_tol) - assert duration_sec <= upper_limit - return r - - yield check_time - - -@dataclass -class ReporterTimeThresholds: - write: float - ttft: float - tpot: float - sample_statuses: float - - def __post_init__(self): - if os.environ.get(ENV_VAR_ENABLE_LINE_PROFILER, "0") == "1": - profile_overhead_factor = 5 - else: - profile_overhead_factor = 1 - - for f in fields(self): - setattr(self, f.name, getattr(self, f.name) * profile_overhead_factor) - - -@pytest.mark.skip(reason="Only run manually for debugging and development purposes") -@pytest.mark.performance -@pytest.mark.xdist_group(name="serial_performance") -@pytest.mark.parametrize( - "client_type,time_thresholds", - [ - ( - "duckdb", - ReporterTimeThresholds( - write=20e-6, ttft=0.1, tpot=1.5, sample_statuses=0.15 - ), - ), - ( - "sqlite", - ReporterTimeThresholds( - write=20e-6, ttft=0.3, tpot=1.5, sample_statuses=0.3 - ), - ), - ], -) -def test_many_chunk_performance( - client_type, time_thresholds, cleanup_connections, check_time_fn -): - # Generate a very large number of events, and see if queries return within a threshold - n_samples = 10 - n_chunks = int(1e6) - - with get_EventRecorder() as rec: - conn_name = rec.connection_name - cleanup_connections["delete"].append(conn_name) - - start_time = time.monotonic_ns() - for sample_uuid in range(n_samples): - rec.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - for sample_uuid in range(n_samples): - rec.record_event( - SampleEvent.FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - for _ in range(n_chunks): - rec.record_event( - SampleEvent.NON_FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(random.randint(1, n_samples)), - ) - - for sample_uuid in range(n_samples): - rec.record_event( - SampleEvent.COMPLETE, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - data="test", - ) - rec.wait_for_writes(force_commit=True) - end_time = time.monotonic_ns() - - assert ( - end_time - start_time - ) / 1e9 <= n_chunks * time_thresholds.write # Cap at ~20 microseconds per event - - with MetricsReporter(conn_name, client_type=client_type) as reporter: - variant = f"{client_type}_{n_chunks}rows_{n_samples}samples" - assert check_time_fn( - reporter.get_sample_statuses, - time_thresholds.sample_statuses, - log_key="many_chunk_completed", - variant=variant, - ) == {"total_sent": n_samples, "completed": n_samples, "in_flight": 0} - ttft_rollup = check_time_fn( - reporter.derive_TTFT, - time_thresholds.ttft, - log_key="many_chunk_ttft", - variant=variant, - ) - check_time_fn( - reporter.derive_TPOT, - time_thresholds.tpot, - CharTokenizer(), - log_key="many_chunk_tpot", - variant=variant, - ttft_rollup=ttft_rollup, - ) - - -@pytest.mark.skip(reason="Only run manually for debugging and development purposes") -@pytest.mark.performance -@pytest.mark.xdist_group(name="serial_performance") -@pytest.mark.parametrize( - "client_type,time_thresholds", - [ - ( - "duckdb", - ReporterTimeThresholds( - write=20e-6, ttft=0.3, tpot=1.5, sample_statuses=0.15 - ), - ), - ( - "sqlite", - ReporterTimeThresholds( - write=20e-6, ttft=0.6, tpot=1.5, sample_statuses=0.3 - ), - ), - ], -) -def test_2_chunk_per_query_performance( - client_type, time_thresholds, cleanup_connections, check_time_fn -): - # Generate a very large number of events, and see if queries return within a threshold - n_events = int(1e6) - n_queries = n_events // 4 - - with get_EventRecorder() as rec: - conn_name = rec.connection_name - cleanup_connections["delete"].append(conn_name) - - start_time = time.monotonic_ns() - for sample_uuid in range(n_queries): - rec.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - order = list(range(n_queries)) - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.NON_FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.COMPLETE, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - data="test", - ) - - rec.wait_for_writes(force_commit=True) - end_time = time.monotonic_ns() - - assert ( - end_time - start_time - ) / 1e9 <= n_events * time_thresholds.write # Cap at ~20 microseconds per event - - with MetricsReporter(conn_name, client_type=client_type) as reporter: - variant = f"{client_type}_{n_events}events" - assert check_time_fn( - reporter.get_sample_statuses, - time_thresholds.sample_statuses, - log_key="2_chunk_per_query_completed", - variant=variant, - ) == {"total_sent": n_queries, "completed": n_queries, "in_flight": 0} - ttft_rollup = check_time_fn( - reporter.derive_TTFT, - time_thresholds.ttft, - log_key="2_chunk_per_query_ttft", - variant=variant, - ) - check_time_fn( - reporter.derive_TPOT, - time_thresholds.tpot, - CharTokenizer(), - log_key="2_chunk_per_query_tpot", - variant=variant, - ttft_rollup=ttft_rollup, - ) - - -@pytest.mark.performance -@pytest.mark.xdist_group(name="serial_performance") -def test_db_write_performance(cleanup_connections, check_time_fn): - with get_EventRecorder() as rec: - cleanup_connections["delete"].append(rec.connection_name) - - n_events = int(1e6) - - def bulk_write(): - for i in range(n_events): - rec.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - time.monotonic_ns(), - sample_uuid=str(i + 1), - ) - rec.wait_for_writes(force_commit=True) - - check_time_fn( - bulk_write, - n_events * 10e-6, - log_key="bulk_write", - variant=f"{n_events}events", - ) diff --git a/tests/performance/test_reporter.py b/tests/performance/test_reporter.py deleted file mode 100644 index 29dcebbd..00000000 --- a/tests/performance/test_reporter.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import random -import time - -import pytest -from inference_endpoint.load_generator.events import SampleEvent, SessionEvent -from inference_endpoint.metrics.recorder import EventRecorder -from inference_endpoint.metrics.reporter import MetricsReporter, TPOTReportingMode -from pympler import asizeof - - -def get_EventRecorder(*args, **kwargs): - # Set requirement to 128MB for testing - return EventRecorder(*args, min_memory_req_bytes=128 * 1024 * 1024, **kwargs) - - -class CharTokenizer: - def tokenize(self, text: str) -> list[str]: - return list(text) - - -def time_fn(fn, *args, **kwargs): - start_time = time.monotonic_ns() - result = fn(*args, **kwargs) - end_time = time.monotonic_ns() - return result, end_time - start_time - - -@pytest.mark.skip(reason="Only used to manually test TPOT performance") -@pytest.mark.performance -@pytest.mark.xdist_group(name="serial_performance") -def test_tpot_performance(cleanup_connections): - # Generate a very large number of events, and see if queries return within a threshold - n_events = int(1e6) - n_queries = n_events // 4 - - with get_EventRecorder() as rec: - conn_name = rec.connection_name - cleanup_connections["delete"].append(conn_name) - cleanup_connections["delete"].append(rec.outputs_path) - - for sample_uuid in range(n_queries): - rec.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - order = list(range(n_queries)) - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - output="a", - ) - - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.NON_FIRST_CHUNK, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - ) - - random.shuffle(order) - for sample_uuid in order: - rec.record_event( - SampleEvent.COMPLETE, - time.monotonic_ns(), - sample_uuid=str(sample_uuid + 1), - output=["a", "a" * 128], - ) - - rec.wait_for_writes(force_commit=True) - - with MetricsReporter(conn_name) as reporter: - # Precompute rollups to avoid recomputing them for each test - ttft_rollup = reporter.derive_TTFT() - sample_latency_rollup = reporter.derive_sample_latency() - - tpot_condensed, condensed_duration_ns = time_fn( - reporter.derive_TPOT, - CharTokenizer(), - ttft_rollup=ttft_rollup, - sample_latency_rollup=sample_latency_rollup, - condense_table=True, - reporting_mode=TPOTReportingMode.TOKEN_WEIGHTED, - ) - - tpot_full, full_duration_ns = time_fn( - reporter.derive_TPOT, - CharTokenizer(), - ttft_rollup=ttft_rollup, - sample_latency_rollup=sample_latency_rollup, - condense_table=False, - reporting_mode=TPOTReportingMode.TOKEN_WEIGHTED, - ) - - condensed_size = asizeof.asizeof(tpot_condensed) - full_size = asizeof.asizeof(tpot_full) - print(f"Condensed TPOT table size: {condensed_size} bytes") - print(f"Condensed TPOT table calculated in {condensed_duration_ns} ns") - print(f"Full TPOT table size: {full_size} bytes") - print(f"Full TPOT table calculated in {full_duration_ns} ns") - assert condensed_size <= (full_size / 5) - assert condensed_duration_ns <= full_duration_ns * 1.1 - - condensed_summary, condensed_summary_duration_ns = time_fn( - tpot_condensed.summarize, - ) - full_summary, full_summary_duration_ns = time_fn( - tpot_full.summarize, - ) - print(f"Condensed TPOT summary calculated in {condensed_summary_duration_ns} ns") - print(f"Full TPOT summary calculated in {full_summary_duration_ns} ns") - assert condensed_summary_duration_ns < (full_summary_duration_ns / 40) - - # These should definitely be the same - for k in [ - "total", - "histogram", - "min", - "max", - "avg", - ]: - assert condensed_summary[k] == full_summary[k] - - for k in [ - "median", - "std_dev", - ]: - assert math.isclose(condensed_summary[k], full_summary[k], rel_tol=0.01) - - for percentile in [99.9, 99, 95, 90, 80, 75, 50, 25, 10, 5, 1]: - assert math.isclose( - condensed_summary["percentiles"][str(percentile)], - full_summary["percentiles"][str(percentile)], - rel_tol=0.01, - ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7fa586e4..6e68f3cc 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,20 +24,13 @@ import random import string import uuid -from asyncio import Future -from concurrent.futures import ThreadPoolExecutor from pathlib import Path import zmq from inference_endpoint.core.types import ( Query, - QueryResult, - StreamChunk, - TextModelOutput, ) from inference_endpoint.dataset_manager.dataset import Dataset -from inference_endpoint.load_generator.load_generator import SampleIssuer -from inference_endpoint.load_generator.sample import SampleEventHandler def _generate_random_word( @@ -184,111 +177,3 @@ def get_test_socket_path(tmp_path: Path, test_name: str, suffix: str = "") -> st len(socket_path) <= zmq.IPC_PATH_MAX_LEN ), "socket path is too long for ZMQ IPC" return socket_path - - -class SerialSampleIssuer(SampleIssuer): - """SampleIssuer for testing. No threading, and is blocking. Whenever issue is called, - it performs the provided compute function, calling callbacks when necessary. - - The compute function should be a generator, yielding the 'chunks' of the supposed - response. - """ - - def __init__(self, compute_func=None): - if compute_func is None: - self.compute_func = lambda x: [x] - else: - self.compute_func = compute_func - - def issue(self, sample): - first = True - chunks = [] - for chunk in self.compute_func(sample.data): - chunks.append(chunk) - stream_chunk = StreamChunk( - id=sample.uuid, metadata={"first_chunk": first}, response_chunk=chunk - ) - SampleEventHandler.stream_chunk_complete(stream_chunk) - first = False - query_result = QueryResult( - id=sample.uuid, response_output=TextModelOutput(output="".join(chunks)) - ) - SampleEventHandler.query_result_complete(query_result) - - -class PooledSampleIssuer(SampleIssuer): - """SampleIssuer that has a non-blocking issue() method. Has a pool of workers which compute - the samples in parallel. - - Uses ThreadPoolExecutor to properly propagate exceptions from worker threads to the main thread. - Call check_errors() to raise any exceptions that occurred in workers and clean up completed futures. - """ - - def __init__(self, compute_func=None, n_workers: int = 4): - self.n_workers = n_workers - if compute_func is None: - self.compute_func = lambda x: [x] - else: - self.compute_func = compute_func - self.executor = ThreadPoolExecutor(max_workers=n_workers) - self.futures: list[Future[None]] = [] - - def shutdown(self, wait: bool = True): - """Shutdown the executor and wait for all tasks to complete. - - Args: - wait: Whether to wait for all tasks to complete before returning. - If False, the executor will be shutdown and the method will return immediately. - The caller is responsible for checking the futures for exceptions. (Default: True) - - Raises any exceptions that occurred in worker threads. - """ - self.executor.shutdown(wait=wait) - - if wait: - # Check all futures for exceptions - for future in self.futures: - future.result() # This will raise if the worker raised an exception - self.futures.clear() - - def handle_sample(self, sample): - first = True - chunks = [] - for chunk in self.compute_func(sample.data): - chunks.append(chunk) - stream_chunk = StreamChunk( - id=sample.uuid, metadata={"first_chunk": first}, response_chunk=chunk - ) - SampleEventHandler.stream_chunk_complete(stream_chunk) - first = False - query_result = QueryResult( - id=sample.uuid, response_output=TextModelOutput(output="".join(chunks)) - ) - SampleEventHandler.query_result_complete(query_result) - - def check_errors(self): - """Check if any worker thread has raised an exception and re-raise it. - - This checks completed futures without blocking and removes them from the list - to prevent unbounded memory growth. - """ - remaining_futures = [] - for future in self.futures: - if future.done(): - # This will raise if the worker raised an exception - future.result() - # Don't keep completed futures - else: - # Keep incomplete futures - remaining_futures.append(future) - self.futures = remaining_futures - - def issue(self, sample): - """Submit a sample to be processed by the worker pool.""" - future = self.executor.submit(self.handle_sample, sample) - self.futures.append(future) - - # Periodically clean up completed futures to prevent unbounded growth - # Check every 100 submissions to balance cleanup overhead vs memory usage - if len(self.futures) >= 100: - self.check_errors() diff --git a/tests/unit/async_utils/services/metrics_aggregator/conftest.py b/tests/unit/async_utils/services/metrics_aggregator/conftest.py new file mode 100644 index 00000000..eb80b2ba --- /dev/null +++ b/tests/unit/async_utils/services/metrics_aggregator/conftest.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test doubles and factories for metrics aggregator tests.""" + +from __future__ import annotations + +import asyncio +from typing import Literal +from unittest.mock import MagicMock + +from inference_endpoint.async_utils.services.metrics_aggregator.aggregator import ( + MetricsAggregatorService, +) +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + KVStore, + SeriesStats, +) +from inference_endpoint.core.record import ( + EventRecord, + SampleEventType, + SessionEventType, +) +from inference_endpoint.core.types import TextModelOutput + +# --------------------------------------------------------------------------- +# In-memory KVStore for tests +# --------------------------------------------------------------------------- + + +class InMemoryKVStore(KVStore): + """In-memory KVStore for unit tests. No /dev/shm files needed.""" + + def __init__(self) -> None: + self._counters: dict[str, int] = {} + self._series: dict[str, list] = {} + self._series_dtype: dict[str, type] = {} + self.closed: bool = False + + def create_key( + self, key: str, key_type: Literal["series", "counter"], dtype: type = int + ) -> None: + if key_type == "counter" and key not in self._counters: + self._counters[key] = 0 + elif key_type == "series" and key not in self._series: + self._series[key] = [] + self._series_dtype[key] = dtype + + def update(self, key: str, value: int | float) -> None: + if key in self._counters: + self._counters[key] = int(value) + elif key in self._series: + self._series[key].append(value) + else: + raise KeyError(f"Key not created: {key}") + + def get(self, key: str) -> int | SeriesStats: + if key in self._counters: + return self._counters[key] + if key in self._series: + dtype = self._series_dtype[key] + return SeriesStats(list(self._series[key]), dtype=dtype) + raise KeyError(f"Key not created: {key}") + + def snapshot(self) -> dict[str, int | SeriesStats]: + result: dict[str, int | SeriesStats] = {} + for k, v in self._counters.items(): + result[k] = v + for k, vals in self._series.items(): + dtype = self._series_dtype[k] + result[k] = SeriesStats(list(vals), dtype=dtype) + return result + + def close(self) -> None: + self.closed = True + + # --- Test helpers --- + + def get_series_values(self, key: str) -> list: + return list(self._series.get(key, [])) + + def get_counter(self, key: str) -> int: + return self._counters.get(key, 0) + + def get_all_series(self) -> dict[str, list[float]]: + """All series as {metric_name: [values]}.""" + return {k: list(v) for k, v in self._series.items()} + + +# --------------------------------------------------------------------------- +# Mock TokenizePool +# --------------------------------------------------------------------------- + + +class MockTokenizePool: + """Mock TokenizePool that splits on whitespace with artificial async delay.""" + + def __init__(self, delay: float = 0.01) -> None: + self._delay = delay + + def token_count(self, text: str) -> int: + return len(text.split()) + + async def token_count_async( + self, text: str, _loop: asyncio.AbstractEventLoop + ) -> int: + await asyncio.sleep(self._delay) + return len(text.split()) + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +# --------------------------------------------------------------------------- +# Aggregator factories +# --------------------------------------------------------------------------- + + +def mock_zmq_context() -> MagicMock: + """Create a mock ManagedZMQContext that no-ops all ZMQ operations.""" + ctx = MagicMock() + ctx.socket.return_value = MagicMock() + ctx.connect.return_value = "ipc:///mock/socket" + return ctx + + +def make_stub_aggregator( + kv_store: KVStore, + tokenize_pool=None, + streaming: bool = True, +) -> MetricsAggregatorService: + """Create a MetricsAggregatorService with ZMQ mocked out.""" + return MetricsAggregatorService( + "mock_path", + mock_zmq_context(), + MagicMock(spec=asyncio.AbstractEventLoop), + kv_store=kv_store, + tokenize_pool=tokenize_pool, + streaming=streaming, + ) + + +def make_async_stub_aggregator( + kv_store: KVStore, + tokenize_pool, + loop: asyncio.AbstractEventLoop, + streaming: bool = True, +) -> MetricsAggregatorService: + """Create a MetricsAggregatorService with a real loop and mock ZMQ.""" + return MetricsAggregatorService( + "mock_path", + mock_zmq_context(), + loop, + kv_store=kv_store, + tokenize_pool=tokenize_pool, + streaming=streaming, + ) + + +# --------------------------------------------------------------------------- +# EventRecord factories +# --------------------------------------------------------------------------- + + +def session_event(ev_type: SessionEventType, ts: int = 0) -> EventRecord: + return EventRecord(event_type=ev_type, timestamp_ns=ts) + + +def sample_event( + ev_type: SampleEventType, uuid: str, ts: int = 0, data=None +) -> EventRecord: + return EventRecord(event_type=ev_type, timestamp_ns=ts, sample_uuid=uuid, data=data) + + +def text_output(s: str) -> TextModelOutput: + return TextModelOutput(output=s) + + +def streaming_text(*chunks: str) -> TextModelOutput: + return TextModelOutput(output=tuple(chunks)) diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py index c36f2f71..98674a46 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py @@ -20,90 +20,26 @@ """ import asyncio -from unittest.mock import MagicMock import pytest -from inference_endpoint.async_utils.services.metrics_aggregator.aggregator import ( - MetricsAggregatorService, -) -from inference_endpoint.async_utils.services.metrics_aggregator.emitter import ( - MetricEmitter, -) from inference_endpoint.core.record import ( ErrorEventType, EventRecord, SampleEventType, SessionEventType, ) -from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput - - -class FakeEmitter(MetricEmitter): - def __init__(self): - self.emitted: list[tuple[str, str, int | float]] = [] - self.flushed = False - self.closed = False - - def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None: - self.emitted.append((sample_uuid, metric_name, value)) - - def flush(self) -> None: - self.flushed = True - - def close(self) -> None: - self.flush() - self.closed = True - - def get_metrics(self, sample_uuid: str) -> dict[str, int | float]: - return {name: val for uuid, name, val in self.emitted if uuid == sample_uuid} - - def get_all(self, metric_name: str) -> list[tuple[str, int | float]]: - return [(uuid, val) for uuid, name, val in self.emitted if name == metric_name] - - -def _mock_zmq_context() -> MagicMock: - """Create a mock ManagedZMQContext that no-ops all ZMQ operations.""" - ctx = MagicMock() - ctx.socket.return_value = MagicMock() - ctx.connect.return_value = "ipc:///mock/socket" - return ctx - - -def make_stub_aggregator( - emitter: MetricEmitter, tokenize_pool=None, streaming: bool = True -) -> MetricsAggregatorService: - """Create a MetricsAggregatorService with ZMQ mocked out. - - Uses a mock ManagedZMQContext so the full __init__ chain runs - (including super().__init__) without creating real ZMQ sockets. - """ - return MetricsAggregatorService( - "mock_path", - _mock_zmq_context(), - MagicMock(spec=asyncio.AbstractEventLoop), - emitter=emitter, - tokenize_pool=tokenize_pool, - streaming=streaming, - ) - - -def _session(ev_type, ts=0): - return EventRecord(event_type=ev_type, timestamp_ns=ts) - - -def _sample(ev_type, uuid, ts=0, data=None): - return EventRecord(event_type=ev_type, timestamp_ns=ts, sample_uuid=uuid, data=data) - - -def _text(s: str) -> TextModelOutput: - """Wrap a string in TextModelOutput for use as EventRecord.data.""" - return TextModelOutput(output=s) - - -def _streaming_text(*chunks: str) -> TextModelOutput: - """Wrap chunks in a streaming TextModelOutput (tuple output).""" - return TextModelOutput(output=tuple(chunks)) - +from inference_endpoint.core.types import ErrorData, PromptData + +from .conftest import ( + InMemoryKVStore, + MockTokenizePool, + make_async_stub_aggregator, + make_stub_aggregator, + sample_event, + session_event, + streaming_text, + text_output, +) # --------------------------------------------------------------------------- # Performance tracking window @@ -114,95 +50,109 @@ def _streaming_text(*chunks: str) -> TextModelOutput: class TestTrackingWindow: @pytest.mark.asyncio async def test_not_tracked_before_start(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.STARTED, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.STARTED, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), ] ) - assert agg._table.get_row("s1") is None - assert emitter.emitted == [] + assert agg._table.get_row("s1") is None, ( + "Sample issued before START_PERFORMANCE_TRACKING must not create a " + "table row — warmup samples should be excluded from the tracked set." + ) + assert ( + store.get_series_values("ttft_ns") == [] + ), "No TTFT should be recorded for samples issued before tracking begins." + assert store.get_series_values("sample_latency_ns") == [], ( + "No sample_latency should be recorded for samples issued before " + "tracking begins." + ) @pytest.mark.asyncio async def test_tracked_after_start(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), ] ) - assert agg._table.get_row("s1") is not None + assert agg._table.get_row("s1") is not None, ( + "Sample issued after START_PERFORMANCE_TRACKING must create a table " + "row so its metrics are included in the tracked set." + ) @pytest.mark.asyncio async def test_not_tracked_after_stop(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _session(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=50), - _sample(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + session_event(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=50), + sample_event(SampleEventType.ISSUED, "s1", ts=100), ] ) - assert agg._table.get_row("s1") is None + assert agg._table.get_row("s1") is None, ( + "Sample issued after STOP_PERFORMANCE_TRACKING must not create a " + "table row — the tracking window has closed." + ) @pytest.mark.asyncio async def test_inflight_sample_continues_after_stop(self): """A sample issued during tracking completes normally after STOP.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), - _session(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), - _sample(SampleEventType.RECV_FIRST, "s1", ts=300), - _sample(SampleEventType.COMPLETE, "s1", ts=500), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=300), + sample_event(SampleEventType.COMPLETE, "s1", ts=500), ] ) - metrics = emitter.get_metrics("s1") - assert metrics["ttft_ns"] == 200 - assert metrics["sample_latency_ns"] == 400 + assert 200 in store.get_series_values("ttft_ns") + assert 400 in store.get_series_values("sample_latency_ns") @pytest.mark.asyncio async def test_restart_tracking_window(self): """START -> STOP -> START creates a second tracking window.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), - _session(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), - _sample(SampleEventType.ISSUED, "s2", ts=300), # not tracked - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=400), - _sample(SampleEventType.ISSUED, "s3", ts=500), # tracked - _sample(SampleEventType.COMPLETE, "s1", ts=600), - _sample(SampleEventType.COMPLETE, "s3", ts=700), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), + sample_event(SampleEventType.ISSUED, "s2", ts=300), # not tracked + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=400), + sample_event(SampleEventType.ISSUED, "s3", ts=500), # tracked + sample_event(SampleEventType.COMPLETE, "s1", ts=600), + sample_event(SampleEventType.COMPLETE, "s3", ts=700), ] ) assert agg._table.get_row("s2") is None # never tracked - assert "sample_latency_ns" in emitter.get_metrics("s1") - assert "sample_latency_ns" in emitter.get_metrics("s3") + latencies = store.get_series_values("sample_latency_ns") + assert len(latencies) == 2 # s1 and s3 both completed @pytest.mark.asyncio async def test_tracked_block_durations(self): """Tracked blocks extend to last sample completion.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), - _session(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), - _sample(SampleEventType.COMPLETE, "s1", ts=700), - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=800), - _sample(SampleEventType.ISSUED, "s2", ts=900), - _sample(SampleEventType.COMPLETE, "s2", ts=1000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=200), + sample_event(SampleEventType.COMPLETE, "s1", ts=700), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=800), + sample_event(SampleEventType.ISSUED, "s2", ts=900), + sample_event(SampleEventType.COMPLETE, "s2", ts=1000), ] ) assert agg._table.tracked_blocks[0].duration_ns == 700 # 700 - 0 @@ -220,141 +170,79 @@ async def test_tracked_block_durations(self): class TestTimingMetrics: @pytest.mark.asyncio async def test_ttft_and_sample_latency(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2500), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2500), + sample_event(SampleEventType.COMPLETE, "s1", ts=5000), ] ) - m = emitter.get_metrics("s1") - assert m["ttft_ns"] == 1500 - assert m["sample_latency_ns"] == 4000 - - @pytest.mark.asyncio - async def test_request_duration(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process( - [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.CLIENT_SEND, "s1", ts=1100), - _sample(SampleEventType.CLIENT_RESP_DONE, "s1", ts=4100), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), - ] - ) - assert emitter.get_metrics("s1")["request_duration_ns"] == 3000 + assert 1500 in store.get_series_values("ttft_ns") + assert 4000 in store.get_series_values("sample_latency_ns") @pytest.mark.asyncio async def test_chunk_deltas(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample(SampleEventType.RECV_NON_FIRST, "s1", ts=3000), - _sample(SampleEventType.RECV_NON_FIRST, "s1", ts=4500), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event(SampleEventType.RECV_NON_FIRST, "s1", ts=3000), + sample_event(SampleEventType.RECV_NON_FIRST, "s1", ts=4500), + sample_event(SampleEventType.COMPLETE, "s1", ts=5000), ] ) - deltas = [v for _, name, v in emitter.emitted if name == "chunk_delta_ns"] - assert deltas == [1000, 1500] + assert store.get_series_values("chunk_delta_ns") == [1000, 1500] @pytest.mark.asyncio async def test_non_streaming_latency_only(self): - """Non-streaming sample emits sample_latency_ns and OSL, but no TTFT/chunk_delta/TPOT. - - Uses AsyncStubAggregator with a real loop and MockTokenizePool so that - async triggers (OslTrigger) actually execute. This ensures OSL is - emitted for non-streaming samples, and that the absence of streaming - metrics is due to *logic* (no RECV_FIRST means no TTFT/TPOT, no - RECV_NON_FIRST means no chunk_delta), not because the pool was missing. - """ - emitter = FakeEmitter() + """Non-streaming sample emits sample_latency_ns and OSL, but no TTFT/chunk_delta/TPOT.""" + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.0) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event( SampleEventType.COMPLETE, "s1", ts=3000, - data=_text("hello world"), + data=text_output("hello world"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["sample_latency_ns"] == 2000 - assert m["osl"] == 2 - assert "ttft_ns" not in m - assert "chunk_delta_ns" not in m - assert "tpot_ns" not in m - - @pytest.mark.asyncio - async def test_all_timing_metrics_full_lifecycle(self): - """Full streaming sample lifecycle emits all expected timing metrics.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process( - [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.CLIENT_SEND, "s1", ts=1050), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample(SampleEventType.RECV_NON_FIRST, "s1", ts=3000), - _sample(SampleEventType.CLIENT_RESP_DONE, "s1", ts=4000), - _sample(SampleEventType.COMPLETE, "s1", ts=4500), - ] - ) - m = emitter.get_metrics("s1") - assert m["ttft_ns"] == 1000 - assert m["sample_latency_ns"] == 3500 - assert m["request_duration_ns"] == 2950 - assert m["chunk_delta_ns"] == 1000 + assert 2000 in store.get_series_values("sample_latency_ns") + assert 2 in store.get_series_values("osl") + assert store.get_series_values("ttft_ns") == [] + assert store.get_series_values("chunk_delta_ns") == [] + assert store.get_series_values("tpot_ns") == [] @pytest.mark.asyncio async def test_chunk_delta_not_emitted_without_last_recv(self): """RECV_NON_FIRST without prior RECV_FIRST: no chunk_delta emitted.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), ] ) row = agg._table.get_row("s1") assert row is not None assert row.last_recv_ns is None # No recv events yet - @pytest.mark.asyncio - async def test_request_duration_not_emitted_without_client_send(self): - """CLIENT_RESP_DONE without CLIENT_SEND: no request_duration.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process( - [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.CLIENT_RESP_DONE, "s1", ts=4000), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), - ] - ) - assert "request_duration_ns" not in emitter.get_metrics("s1") - # --------------------------------------------------------------------------- -# ISL (token_ids path — sync, no tokenize_pool needed) +# ISL (token_ids path -- sync, no tokenize_pool needed) # --------------------------------------------------------------------------- @@ -363,12 +251,12 @@ class TestIsl: @pytest.mark.asyncio async def test_issued_with_token_ids_emits_isl_directly(self): """SGLang path: PromptData with token_ids emits ISL = len(token_ids).""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event( SampleEventType.ISSUED, "s1", ts=1000, @@ -376,19 +264,19 @@ async def test_issued_with_token_ids_emits_isl_directly(self): ), ] ) - assert ("s1", "isl", 5) in emitter.emitted + assert 5 in store.get_series_values("isl") @pytest.mark.asyncio async def test_issued_without_data_no_isl(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), ] ) - assert all(name != "isl" for _, name, _ in emitter.emitted) + assert store.get_series_values("isl") == [] # --------------------------------------------------------------------------- @@ -400,170 +288,151 @@ async def test_issued_without_data_no_isl(self): class TestEdgeCases: @pytest.mark.asyncio async def test_untracked_sample_events_ignored(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.RECV_FIRST, "unknown", ts=2000), - _sample(SampleEventType.COMPLETE, "unknown", ts=5000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.RECV_FIRST, "unknown", ts=2000), + sample_event(SampleEventType.COMPLETE, "unknown", ts=5000), ] ) - assert emitter.emitted == [] + assert store.get_series_values("ttft_ns") == [] + assert store.get_series_values("sample_latency_ns") == [] @pytest.mark.asyncio async def test_complete_removes_row(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.COMPLETE, "s1", ts=5000), ] ) assert agg._table.get_row("s1") is None assert len(agg._table) == 0 @pytest.mark.asyncio - async def test_session_ended_flushes_and_closes(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + async def test_session_ended_closes_store(self): + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.STARTED, ts=0), - _session(SessionEventType.ENDED, ts=100), + session_event(SessionEventType.STARTED, ts=0), + session_event(SessionEventType.ENDED, ts=100), ] ) - assert emitter.flushed - assert emitter.closed + assert store.closed @pytest.mark.asyncio async def test_events_after_ended_are_dropped(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=100), - _session(SessionEventType.ENDED, ts=200), - _sample(SampleEventType.RECV_FIRST, "s1", ts=300), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=100), + session_event(SessionEventType.ENDED, ts=200), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=300), ] ) - assert "ttft_ns" not in emitter.get_metrics("s1") + assert store.get_series_values("ttft_ns") == [] @pytest.mark.asyncio async def test_empty_sample_uuid_ignored(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "", ts=1000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "", ts=1000), ] ) assert len(agg._table) == 0 @pytest.mark.asyncio async def test_multiple_samples_independent(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.ISSUED, "s2", ts=1500), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample(SampleEventType.RECV_FIRST, "s2", ts=3000), - _sample(SampleEventType.COMPLETE, "s1", ts=4000), - _sample(SampleEventType.COMPLETE, "s2", ts=5000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.ISSUED, "s2", ts=1500), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event(SampleEventType.RECV_FIRST, "s2", ts=3000), + sample_event(SampleEventType.COMPLETE, "s1", ts=4000), + sample_event(SampleEventType.COMPLETE, "s2", ts=5000), ] ) - s1 = emitter.get_metrics("s1") - s2 = emitter.get_metrics("s2") - assert s1["ttft_ns"] == 1000 - assert s2["ttft_ns"] == 1500 - assert s1["sample_latency_ns"] == 3000 - assert s2["sample_latency_ns"] == 3500 - - @pytest.mark.asyncio - async def test_transport_events_ignored(self): - """TRANSPORT_SENT / TRANSPORT_RECV should not affect metrics.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process( - [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.TRANSPORT_SENT, "s1", ts=1001), - _sample(SampleEventType.TRANSPORT_RECV, "s1", ts=1002), - _sample(SampleEventType.COMPLETE, "s1", ts=5000), - ] - ) - m = emitter.get_metrics("s1") - assert m == {"sample_latency_ns": 4000} + ttfts = store.get_series_values("ttft_ns") + latencies = store.get_series_values("sample_latency_ns") + assert 1000 in ttfts + assert 1500 in ttfts + assert 3000 in latencies + assert 3500 in latencies @pytest.mark.asyncio async def test_error_events_ignored(self): """Error events should not crash the aggregator.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), EventRecord( event_type=ErrorEventType.GENERIC, timestamp_ns=500, data=ErrorData(error_type="test", error_message="boom"), ), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.COMPLETE, "s1", ts=2000), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.COMPLETE, "s1", ts=2000), ] ) - assert emitter.get_metrics("s1")["sample_latency_ns"] == 1000 + assert 1000 in store.get_series_values("sample_latency_ns") @pytest.mark.asyncio async def test_session_started_stores_timestamp(self): - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process([_session(SessionEventType.STARTED, ts=42)]) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) + await agg.process([session_event(SessionEventType.STARTED, ts=42)]) assert agg._table.session_started_ns == 42 @pytest.mark.asyncio async def test_process_multiple_batches(self): """Two sequential process() calls maintain state correctly.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) + store = InMemoryKVStore() + agg = make_stub_aggregator(store) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), ] ) assert agg._table.get_row("s1") is not None await agg.process( [ - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample(SampleEventType.COMPLETE, "s1", ts=3000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event(SampleEventType.COMPLETE, "s1", ts=3000), ] ) - m = emitter.get_metrics("s1") - assert m["ttft_ns"] == 1000 - assert m["sample_latency_ns"] == 2000 + assert 1000 in store.get_series_values("ttft_ns") + assert 2000 in store.get_series_values("sample_latency_ns") assert agg._table.get_row("s1") is None @pytest.mark.asyncio async def test_ended_in_second_batch(self): """ENDED in a later batch still triggers finalize.""" - emitter = FakeEmitter() - agg = make_stub_aggregator(emitter) - await agg.process([_session(SessionEventType.STARTED, ts=0)]) - assert not emitter.flushed - await agg.process([_session(SessionEventType.ENDED, ts=100)]) - assert emitter.flushed - assert emitter.closed + store = InMemoryKVStore() + agg = make_stub_aggregator(store) + await agg.process([session_event(SessionEventType.STARTED, ts=0)]) + assert not store.closed + await agg.process([session_event(SessionEventType.ENDED, ts=100)]) + assert store.closed # --------------------------------------------------------------------------- @@ -571,59 +440,20 @@ async def test_ended_in_second_batch(self): # --------------------------------------------------------------------------- -class MockTokenizePool: - """Mock TokenizePool that splits on whitespace with artificial async delay.""" - - def __init__(self, delay: float = 0.01): - self._delay = delay - - def token_count(self, text: str) -> int: - return len(text.split()) - - async def token_count_async( - self, text: str, _loop: asyncio.AbstractEventLoop - ) -> int: - await asyncio.sleep(self._delay) - return len(text.split()) - - def close(self) -> None: - pass - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - -def make_async_stub_aggregator( - emitter: MetricEmitter, tokenize_pool, loop, streaming: bool = True -) -> MetricsAggregatorService: - """Create a MetricsAggregatorService with a real loop and mock ZMQ.""" - return MetricsAggregatorService( - "mock_path", - _mock_zmq_context(), - loop, - emitter=emitter, - tokenize_pool=tokenize_pool, - streaming=streaming, - ) - - @pytest.mark.unit class TestAsyncTriggers: @pytest.mark.asyncio async def test_isl_text_path_async(self): """ISL with text prompt triggers async tokenization.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.01) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event( SampleEventType.ISSUED, "s1", ts=1000, @@ -633,156 +463,151 @@ async def test_isl_text_path_async(self): ) # ISL task is in-flight; drain it await agg._table.drain_tasks() - assert ("s1", "isl", 4) in emitter.emitted + assert 4 in store.get_series_values("isl") @pytest.mark.asyncio async def test_osl_emitted_on_complete(self): """OSL is emitted via async tokenization when COMPLETE carries TextModelOutput.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.01) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event( SampleEventType.COMPLETE, "s1", ts=5000, - data=_text("the quick brown fox"), + data=text_output("the quick brown fox"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["sample_latency_ns"] == 4000 - assert m["osl"] == 4 + assert 4000 in store.get_series_values("sample_latency_ns") + assert 4 in store.get_series_values("osl") @pytest.mark.asyncio async def test_tpot_emitted_for_streaming(self): """TPOT is emitted for streaming responses using text_after_first_chunk.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.0) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( SampleEventType.COMPLETE, "s1", ts=5000, # Streaming: 3 chunks, text_after_first_chunk = "world foo" - data=_streaming_text("hello", " world", " foo"), + data=streaming_text("hello", " world", " foo"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["osl"] == 3 # "hello world foo" = 3 tokens + assert 3 in store.get_series_values("osl") # "hello world foo" = 3 tokens # tpot = (5000 - 2000) / token_count("world foo") = 3000 / 2 = 1500 - assert m["tpot_ns"] == 1500.0 + assert 1500.0 in store.get_series_values("tpot_ns") @pytest.mark.asyncio async def test_tpot_skipped_when_single_chunk(self): """TPOT is not emitted when there are no tokens after the first chunk.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.0) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( SampleEventType.COMPLETE, "s1", ts=5000, # Single chunk: text_after_first_chunk = "" - data=_streaming_text("only"), + data=streaming_text("only"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["osl"] == 1 - assert "tpot_ns" not in m + assert 1 in store.get_series_values("osl") + assert store.get_series_values("tpot_ns") == [] @pytest.mark.asyncio async def test_tpot_not_emitted_without_streaming_flag(self): """TPOT trigger is not registered when streaming=False.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.0) - agg = make_async_stub_aggregator(emitter, pool, loop, streaming=False) + agg = make_async_stub_aggregator(store, pool, loop, streaming=False) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( SampleEventType.COMPLETE, "s1", ts=5000, - data=_streaming_text("hello", " world", " foo"), + data=streaming_text("hello", " world", " foo"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["sample_latency_ns"] == 4000 - assert m["osl"] == 3 - assert "tpot_ns" not in m - assert "ttft_ns" not in m - assert "chunk_delta_ns" not in m + assert 4000 in store.get_series_values("sample_latency_ns") + assert 3 in store.get_series_values("osl") + assert store.get_series_values("tpot_ns") == [] + assert store.get_series_values("ttft_ns") == [] + assert store.get_series_values("chunk_delta_ns") == [] @pytest.mark.asyncio async def test_tpot_non_streaming_output_skipped(self): """TPOT is not emitted for non-streaming (str) TextModelOutput.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.0) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample(SampleEventType.ISSUED, "s1", ts=1000), - _sample(SampleEventType.RECV_FIRST, "s1", ts=2000), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( SampleEventType.COMPLETE, "s1", ts=5000, # Non-streaming: str output, text_after_first_chunk = "" - data=_text("hello world foo"), + data=text_output("hello world foo"), ), ] ) await agg._table.drain_tasks() - m = emitter.get_metrics("s1") - assert m["osl"] == 3 - assert "tpot_ns" not in m + assert 3 in store.get_series_values("osl") + assert store.get_series_values("tpot_ns") == [] @pytest.mark.asyncio async def test_drain_tasks_awaits_in_flight(self): """drain_tasks() properly awaits all in-flight async trigger tasks.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.05) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event( SampleEventType.ISSUED, "s1", ts=1000, @@ -795,32 +620,31 @@ async def test_drain_tasks_awaits_in_flight(self): await agg._table.drain_tasks() assert len(agg._table._in_flight_tasks) == 0 - assert ("s1", "isl", 5) in emitter.emitted + assert 5 in store.get_series_values("isl") @pytest.mark.asyncio async def test_shutdown_drains_async_tasks(self): """ENDED drains in-flight async tasks before finalizing.""" - emitter = FakeEmitter() + store = InMemoryKVStore() loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0.02) - agg = make_async_stub_aggregator(emitter, pool, loop) + agg = make_async_stub_aggregator(store, pool, loop) await agg.process( [ - _session(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), - _sample( + session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=0), + sample_event( SampleEventType.ISSUED, "s1", ts=1000, data=PromptData(text="one two three"), ), - _session(SessionEventType.ENDED, ts=2000), + session_event(SessionEventType.ENDED, ts=2000), ] ) # After ENDED, drain_tasks was called, so ISL should be emitted - assert ("s1", "isl", 3) in emitter.emitted - assert emitter.flushed - assert emitter.closed + assert 3 in store.get_series_values("isl") + assert store.closed # TODO: Add tests for trigger exception handling (logger.exception paths). # Inject a MockTokenizePool that raises on token_count_async and verify: diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_e2e.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_e2e.py index fbd8c1fe..20319cec 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_e2e.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_e2e.py @@ -17,12 +17,12 @@ These tests launch an EventPublisherService, connect a MetricsAggregatorService over ZMQ IPC, publish EventRecords, and verify the aggregator computes and -emits the correct metrics. +emits the correct metrics into the KVStore. """ import asyncio -import json import time +from threading import Lock import pytest import zmq @@ -31,10 +31,6 @@ from inference_endpoint.async_utils.services.metrics_aggregator.aggregator import ( MetricsAggregatorService, ) -from inference_endpoint.async_utils.services.metrics_aggregator.emitter import ( - JsonlMetricEmitter, - MetricEmitter, -) from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.core.record import ( EventRecord, @@ -42,39 +38,39 @@ SessionEventType, ) +from .conftest import InMemoryKVStore + # --------------------------------------------------------------------------- -# Helpers +# Signaling KVStore for e2e tests # --------------------------------------------------------------------------- -class CollectingEmitter(MetricEmitter): - """Thread-safe emitter that collects metrics and signals when a target count is reached.""" +class SignalingKVStore(InMemoryKVStore): + """InMemoryKVStore that signals an asyncio.Event when a target series count is reached. + + This replaces the old CollectingEmitter.set_wait_target() pattern. Call + set_wait_target(event, count) before publishing records; the event will be + set once the total number of series values across all series keys reaches + the target count. + """ - def __init__(self): - self.emitted: list[tuple[str, str, int | float]] = [] + def __init__(self) -> None: + super().__init__() self._target_event: asyncio.Event | None = None self._target_count: int = 0 - self.flushed = False - self.closed = False + self._lock = Lock() def set_wait_target(self, event: asyncio.Event, count: int) -> None: self._target_event = event self._target_count = count - def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None: - self.emitted.append((sample_uuid, metric_name, value)) - if self._target_event is not None and len(self.emitted) >= self._target_count: - self._target_event.set() - - def flush(self) -> None: - self.flushed = True - - def close(self) -> None: - self.flush() - self.closed = True - - def get_metrics(self, sample_uuid: str) -> dict[str, int | float]: - return {name: val for uuid, name, val in self.emitted if uuid == sample_uuid} + def update(self, key: str, value: float) -> None: + super().update(key, value) + with self._lock: + if self._target_event is not None: + total = sum(len(v) for v in self._series.values()) + if total >= self._target_count: + self._target_event.set() # --------------------------------------------------------------------------- @@ -92,15 +88,12 @@ def zmq_context(): @pytest.fixture def publisher(zmq_context): - EventPublisherService._instance = None try: service = EventPublisherService(zmq_context) except zmq.ZMQError as exc: - EventPublisherService._instance = None pytest.skip(f"ZMQ IPC bind unavailable (sandboxed?): {exc}") yield service service.close() - EventPublisherService._instance = None @pytest.fixture @@ -112,8 +105,8 @@ def aggregator_loop(): @pytest.fixture -def collecting_emitter(): - return CollectingEmitter() +def signaling_store(): + return SignalingKVStore() @pytest.fixture @@ -123,7 +116,7 @@ def shutdown_event(): @pytest.fixture def aggregator( - publisher, aggregator_loop, zmq_context, collecting_emitter, shutdown_event + publisher, aggregator_loop, zmq_context, signaling_store, shutdown_event ): """MetricsAggregatorService connected to the publisher via ZMQ.""" agg = MetricsAggregatorService( @@ -131,7 +124,7 @@ def aggregator( zmq_context, aggregator_loop, topics=None, - emitter=collecting_emitter, + kv_store=signaling_store, tokenize_pool=None, streaming=True, shutdown_event=shutdown_event, @@ -145,8 +138,9 @@ def aggregator( def _publish_and_sleep(publisher, record, delay=0.05): - """Publish a record and sleep briefly to let the event loop drain.""" + """Publish a record, flush, and sleep briefly to let the event loop drain.""" publisher.publish(record) + publisher.flush() time.sleep(delay) @@ -159,12 +153,12 @@ def _publish_and_sleep(publisher, record, delay=0.05): class TestAggregatorE2E: @pytest.mark.asyncio async def test_single_sample_timing_metrics( - self, publisher, aggregator, collecting_emitter + self, publisher, aggregator, signaling_store ): """Full streaming sample lifecycle over real ZMQ pub/sub.""" done = asyncio.Event() - # Expect: ttft_ns, chunk_delta_ns, sample_latency_ns = 3 metrics - collecting_emitter.set_wait_target(done, 3) + # Expect: ttft_ns, chunk_delta_ns, sample_latency_ns = 3 series values + signaling_store.set_wait_target(done, 3) _publish_and_sleep( publisher, @@ -208,21 +202,20 @@ async def test_single_sample_timing_metrics( await asyncio.wait_for(done.wait(), timeout=_WAIT_TIMEOUT) - m = collecting_emitter.get_metrics("s1") - assert m["ttft_ns"] == 1000 - assert m["chunk_delta_ns"] == 1000 - assert m["sample_latency_ns"] == 3000 + assert 1000 in signaling_store.get_series_values("ttft_ns") + assert 1000 in signaling_store.get_series_values("chunk_delta_ns") + assert 3000 in signaling_store.get_series_values("sample_latency_ns") @pytest.mark.asyncio async def test_tracking_window_respected( - self, publisher, aggregator, collecting_emitter + self, publisher, aggregator, signaling_store ): """Samples issued before START_PERFORMANCE_TRACKING are not tracked.""" done = asyncio.Event() # Only s2 should produce metrics (1 metric: sample_latency_ns) - collecting_emitter.set_wait_target(done, 1) + signaling_store.set_wait_target(done, 1) - # Issue s1 before tracking starts — should be ignored + # Issue s1 before tracking starts -- should be ignored _publish_and_sleep( publisher, EventRecord( @@ -265,14 +258,16 @@ async def test_tracking_window_respected( await asyncio.wait_for(done.wait(), timeout=_WAIT_TIMEOUT) - assert collecting_emitter.get_metrics("s1") == {} - assert collecting_emitter.get_metrics("s2")["sample_latency_ns"] == 300 + assert 300 in signaling_store.get_series_values("sample_latency_ns") + # s1 should not have produced any latency values besides s2's + latencies = signaling_store.get_series_values("sample_latency_ns") + assert len(latencies) == 1 @pytest.mark.asyncio async def test_session_ended_triggers_shutdown( - self, publisher, aggregator, collecting_emitter, shutdown_event + self, publisher, aggregator, signaling_store, shutdown_event ): - """ENDED event causes emitter flush, aggregator close, and shutdown signal.""" + """ENDED event causes store close and shutdown signal.""" _publish_and_sleep( publisher, EventRecord( @@ -281,17 +276,16 @@ async def test_session_ended_triggers_shutdown( ), ) await asyncio.wait_for(shutdown_event.wait(), timeout=_WAIT_TIMEOUT) - assert collecting_emitter.flushed - assert collecting_emitter.closed + assert signaling_store.closed @pytest.mark.asyncio async def test_multiple_samples_concurrent( - self, publisher, aggregator, collecting_emitter + self, publisher, aggregator, signaling_store ): """Multiple samples in flight concurrently produce independent metrics.""" done = asyncio.Event() # 2 samples x 2 metrics each (ttft_ns + sample_latency_ns) = 4 - collecting_emitter.set_wait_target(done, 4) + signaling_store.set_wait_target(done, 4) _publish_and_sleep( publisher, @@ -331,86 +325,9 @@ async def test_multiple_samples_concurrent( await asyncio.wait_for(done.wait(), timeout=_WAIT_TIMEOUT) - ma = collecting_emitter.get_metrics("a") - mb = collecting_emitter.get_metrics("b") - assert ma["ttft_ns"] == 100 - assert ma["sample_latency_ns"] == 300 - assert mb["ttft_ns"] == 200 - assert mb["sample_latency_ns"] == 350 - - @pytest.mark.asyncio - async def test_jsonl_emitter_e2e( - self, publisher, aggregator_loop, zmq_context, tmp_path - ): - """Full pipeline with JsonlMetricEmitter writing to disk.""" - emitter = JsonlMetricEmitter(tmp_path / "metrics", flush_interval=1) - agg = MetricsAggregatorService( - publisher.bind_path, - zmq_context, - aggregator_loop, - topics=None, - emitter=emitter, - streaming=True, - ) - aggregator_loop.call_soon_threadsafe(agg.start) - time.sleep(0.5) - - try: - _publish_and_sleep( - publisher, - EventRecord( - event_type=SessionEventType.START_PERFORMANCE_TRACKING, - timestamp_ns=0, - ), - ) - _publish_and_sleep( - publisher, - EventRecord( - event_type=SampleEventType.ISSUED, - timestamp_ns=1000, - sample_uuid="file-test", - ), - ) - _publish_and_sleep( - publisher, - EventRecord( - event_type=SampleEventType.RECV_FIRST, - timestamp_ns=2000, - sample_uuid="file-test", - ), - ) - _publish_and_sleep( - publisher, - EventRecord( - event_type=SampleEventType.COMPLETE, - timestamp_ns=3000, - sample_uuid="file-test", - ), - ) - - # Wait for metrics to be written - for _ in range(30): - try: - content = (tmp_path / "metrics.jsonl").read_text() - lines = [line for line in content.strip().split("\n") if line] - if len(lines) >= 2: - break - except FileNotFoundError: - pass # File not yet created by the emitter; retry. - await asyncio.sleep(0.1) - - content = (tmp_path / "metrics.jsonl").read_text() - lines = [line for line in content.strip().split("\n") if line] - assert len(lines) >= 2 - - records = [json.loads(line) for line in lines] - metric_names = {r["metric_name"] for r in records} - assert "ttft_ns" in metric_names - assert "sample_latency_ns" in metric_names - - ttft = next(r for r in records if r["metric_name"] == "ttft_ns") - assert ttft["value"] == 1000 - assert ttft["sample_uuid"] == "file-test" - finally: - if not agg.is_closed: - agg.close() + ttfts = signaling_store.get_series_values("ttft_ns") + latencies = signaling_store.get_series_values("sample_latency_ns") + assert 100 in ttfts # a: 200 - 100 + assert 300 in latencies # a: 400 - 100 + assert 200 in ttfts # b: 350 - 150 + assert 350 in latencies # b: 500 - 150 diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_emitter.py b/tests/unit/async_utils/services/metrics_aggregator/test_emitter.py deleted file mode 100644 index 99696687..00000000 --- a/tests/unit/async_utils/services/metrics_aggregator/test_emitter.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from inference_endpoint.async_utils.services.metrics_aggregator.emitter import ( - JsonlMetricEmitter, -) - - -@pytest.mark.unit -class TestJsonlMetricEmitter: - def test_emit_writes_jsonl_line(self, tmp_path): - emitter = JsonlMetricEmitter(tmp_path / "metrics", flush_interval=1) - emitter.emit("sample1", "ttft_ns", 1500) - emitter.close() - - lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") - assert len(lines) == 1 - record = json.loads(lines[0]) - assert record["sample_uuid"] == "sample1" - assert record["metric_name"] == "ttft_ns" - assert record["value"] == 1500 - assert "timestamp_ns" in record - - def test_emit_multiple_metrics(self, tmp_path): - emitter = JsonlMetricEmitter(tmp_path / "metrics", flush_interval=10) - emitter.emit("s1", "ttft_ns", 100) - emitter.emit("s1", "sample_latency_ns", 500) - emitter.emit("s2", "ttft_ns", 200) - emitter.close() - - lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") - assert len(lines) == 3 - - def test_flush_interval(self, tmp_path): - emitter = JsonlMetricEmitter(tmp_path / "metrics", flush_interval=2) - emitter.emit("s1", "m1", 1) - # After 1 emit, file may not be flushed yet (OS buffering) - emitter.emit("s1", "m2", 2) - # After 2 emits, flush_interval triggers flush - emitter.flush() # explicit flush to verify no error - emitter.close() - - lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") - assert len(lines) == 2 - - def test_close_is_idempotent(self, tmp_path): - emitter = JsonlMetricEmitter(tmp_path / "metrics") - emitter.close() - emitter.close() # Should not raise - - def test_float_value(self, tmp_path): - emitter = JsonlMetricEmitter(tmp_path / "metrics", flush_interval=1) - emitter.emit("s1", "tpot_ns", 1234.5) - emitter.close() - - lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") - record = json.loads(lines[0]) - assert record["value"] == 1234.5 diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_kv_store.py b/tests/unit/async_utils/services/metrics_aggregator/test_kv_store.py new file mode 100644 index 00000000..f9e23cd7 --- /dev/null +++ b/tests/unit/async_utils/services/metrics_aggregator/test_kv_store.py @@ -0,0 +1,395 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the KVStore (BasicKVStore + BasicKVStoreReader).""" + +import math +import multiprocessing +import struct +from pathlib import Path + +import pytest +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + BasicKVStore, + BasicKVStoreReader, + SeriesStats, +) + +# --------------------------------------------------------------------------- +# SeriesStats +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSeriesStats: + def test_from_values(self): + stats = SeriesStats([10.0, 20.0, 5.0]) + assert stats.count == 3 + assert stats.total == 35.0 + assert stats.min_val == 5.0 + assert stats.max_val == 20.0 + + def test_sum_sq(self): + stats = SeriesStats([3.0, 4.0]) + assert stats.sum_sq == pytest.approx(3.0**2 + 4.0**2) + + def test_empty(self): + stats = SeriesStats() + assert stats.count == 0 + assert stats.total == 0.0 + # Sentinel values for an empty series — compute_summary() is responsible + # for normalizing these to 0 before exposing them to users. + assert stats.min_val == math.inf + assert stats.max_val == -math.inf + + def test_incremental_rollup(self): + stats = SeriesStats([1.0, 2.0]) + assert stats._last_rollup_idx == 2 + stats.values.extend([3.0, 4.0]) + stats._update_rollup() + assert stats.count == 4 + assert stats.total == 10.0 + assert stats._last_rollup_idx == 4 + + +# --------------------------------------------------------------------------- +# BasicKVStore (writer) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBasicKVStore: + def test_counter(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("error_count", "counter") + store.update("error_count", 5) + assert store.get("error_count") == 5 + store.update("error_count", 10) + assert store.get("error_count") == 10 + store.close() + + def test_counter_returns_int(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("c", "counter") + store.update("c", 42) + val = store.get("c") + assert isinstance(val, int) + store.close() + + def test_series_uint64(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("ttft_ns", "series") + store.update("ttft_ns", 100) + store.update("ttft_ns", 200) + result = store.get("ttft_ns") + assert isinstance(result, SeriesStats) + assert result.count == 2 + assert result.values == [100, 200] + store.close() + + def test_series_float64(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("ratio", "series", dtype=float) + store.update("ratio", 1.5) + store.update("ratio", 2.5) + result = store.get("ratio") + assert isinstance(result, SeriesStats) + assert result.count == 2 + assert result.values == [1.5, 2.5] + store.close() + + def test_snapshot(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("n_issued", "counter") + store.create_key("latency", "series") + store.update("n_issued", 42) + store.update("latency", 150) + store.update("latency", 250) + + snap = store.snapshot() + assert snap["n_issued"] == 42 + assert isinstance(snap["latency"], SeriesStats) + assert snap["latency"].count == 2 + store.close() + + def test_snapshot_is_isolated_from_later_writes(self, tmp_path: Path): + """Mutations after snapshot() must not alter the captured snapshot.""" + store = BasicKVStore(tmp_path / "kv") + store.create_key("n_issued", "counter") + store.create_key("latency", "series") + store.update("n_issued", 5) + store.update("latency", 100) + store.update("latency", 200) + + snap = store.snapshot() + + store.update("n_issued", 99) + store.update("latency", 300) + + assert snap["n_issued"] == 5 + latency_snap = snap["latency"] + assert isinstance(latency_snap, SeriesStats) + assert latency_snap.count == 2 + assert latency_snap.values == [100, 200] + assert latency_snap.total == 300 + store.close() + + def test_update_unknown_key_raises(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + with pytest.raises(KeyError, match="Key not created"): + store.update("missing", 1) + store.close() + + def test_create_key_idempotent(self, tmp_path: Path): + store = BasicKVStore(tmp_path / "kv") + store.create_key("x", "counter") + store.update("x", 5) + store.create_key("x", "counter") # should not reset + assert store.get("x") == 5 + store.close() + + def test_unlink(self, tmp_path: Path): + store_dir = tmp_path / "kv" + store = BasicKVStore(store_dir) + store.create_key("a", "counter") + assert store_dir.exists() + store.unlink() + assert not store_dir.exists() + + +# --------------------------------------------------------------------------- +# BasicKVStoreReader +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBasicKVStoreReader: + def test_read_counter(self, tmp_path: Path): + store_dir = tmp_path / "kv" + writer = BasicKVStore(store_dir) + writer.create_key("count", "counter") + writer.update("count", 7) + + reader = BasicKVStoreReader(store_dir) + reader.register_key("count", "counter") + assert reader.get("count") == 7 + + reader.close() + writer.close() + + def test_read_series(self, tmp_path: Path): + store_dir = tmp_path / "kv" + writer = BasicKVStore(store_dir) + writer.create_key("ttft", "series") + writer.update("ttft", 100) + writer.update("ttft", 200) + + reader = BasicKVStoreReader(store_dir) + reader.register_key("ttft", "series") + stats = reader.get("ttft") + assert isinstance(stats, SeriesStats) + assert stats.count == 2 + assert stats.values == [100, 200] + + reader.close() + writer.close() + + def test_incremental_read(self, tmp_path: Path): + store_dir = tmp_path / "kv" + writer = BasicKVStore(store_dir) + writer.create_key("lat", "series") + writer.update("lat", 1000) + + reader = BasicKVStoreReader(store_dir) + reader.register_key("lat", "series") + s1 = reader.get("lat") + assert isinstance(s1, SeriesStats) + assert s1.count == 1 + + writer.update("lat", 2000) + writer.update("lat", 3000) + s2 = reader.get("lat") + assert isinstance(s2, SeriesStats) + assert s2.count == 3 + assert s2.total == 6000 + + reader.close() + writer.close() + + def test_snapshot(self, tmp_path: Path): + store_dir = tmp_path / "kv" + writer = BasicKVStore(store_dir) + writer.create_key("n", "counter") + writer.create_key("s", "series") + writer.update("n", 5) + writer.update("s", 10) + + reader = BasicKVStoreReader(store_dir) + reader.register_key("n", "counter") + reader.register_key("s", "series") + snap = reader.snapshot() + assert snap["n"] == 5 + assert isinstance(snap["s"], SeriesStats) + assert snap["s"].count == 1 + + reader.close() + writer.close() + + def test_reader_lazy_open(self, tmp_path: Path): + """Reader for a key whose file doesn't exist yet opens lazily.""" + store_dir = tmp_path / "kv" + store_dir.mkdir() + reader = BasicKVStoreReader(store_dir) + reader.register_key("lat", "series") + s = reader.get("lat") + assert isinstance(s, SeriesStats) + assert s.count == 0 + + # Now create the writer and write + writer = BasicKVStore(store_dir) + writer.create_key("lat", "series") + writer.update("lat", 42) + + s = reader.get("lat") + assert isinstance(s, SeriesStats) + assert s.count == 1 + assert s.values == [42] + + reader.close() + writer.close() + + +# --------------------------------------------------------------------------- +# Cross-process +# --------------------------------------------------------------------------- + + +def _child_read(store_dir_str: str, queue: multiprocessing.Queue) -> None: + store_dir = Path(store_dir_str) + reader = BasicKVStoreReader(store_dir) + reader.register_key("n", "counter") + reader.register_key("ttft", "series") + snap = reader.snapshot() + ttft = snap["ttft"] + assert isinstance(ttft, SeriesStats) + queue.put((snap["n"], ttft.count, ttft.values)) + reader.close() + + +@pytest.mark.unit +class TestCrossProcess: + def test_cross_process_read(self, tmp_path: Path): + store_dir = tmp_path / "kv" + writer = BasicKVStore(store_dir) + writer.create_key("n", "counter") + writer.create_key("ttft", "series") + writer.update("n", 2) + writer.update("ttft", 42) + writer.update("ttft", 99) + + q: multiprocessing.Queue = multiprocessing.Queue() + proc = multiprocessing.Process(target=_child_read, args=(str(store_dir), q)) + proc.start() + proc.join(timeout=10) + + assert not q.empty() + n, count, values = q.get() + assert n == 2 + assert count == 2 + assert values == [42, 99] + + writer.close() + + +# --------------------------------------------------------------------------- +# Integer precision +# --------------------------------------------------------------------------- + +# First integer not exactly representable in IEEE 754 float64 (53-bit mantissa). +_BEYOND_FLOAT64 = 2**53 + 1 + + +@pytest.mark.unit +class TestIntegerPrecision: + """Verify uint64 storage preserves integers that exceed float64 precision.""" + + def test_float64_struct_loses_precision(self): + """Confirm struct float64 roundtrip is lossy for _BEYOND_FLOAT64. + + If this test fails, the other tests in TestIntegerPrecision lose + validity — they depend on _BEYOND_FLOAT64 being unrepresentable + in float64. + """ + packed_d = struct.pack(" None: + import uvloop + + async def _send(): + with ManagedZMQContext.scoped(socket_dir=socket_dir) as ctx: + await send_ready_signal(ctx, path, identity) + + uvloop.run(_send()) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestReadyCheckCrossProcess: + async def test_cross_process_signal(self): + with tempfile.TemporaryDirectory() as tmpdir: + with ManagedZMQContext.scoped(socket_dir=tmpdir) as ctx: + receiver = ReadyCheckReceiver("ready_xproc", ctx, count=1) + + proc = multiprocessing.Process( + target=_child_send_ready, + args=(tmpdir, "ready_xproc", 99), + ) + proc.start() + + identities = await receiver.wait(timeout=10.0) + assert identities == [99] + + proc.join(timeout=5) + + async def test_multiple_child_processes(self): + with tempfile.TemporaryDirectory() as tmpdir: + n = 3 + with ManagedZMQContext.scoped(socket_dir=tmpdir) as ctx: + receiver = ReadyCheckReceiver("ready_multi_xproc", ctx, count=n) + + procs = [] + for i in range(n): + p = multiprocessing.Process( + target=_child_send_ready, + args=(tmpdir, "ready_multi_xproc", i), + ) + p.start() + procs.append(p) + + identities = await receiver.wait(timeout=10.0) + assert len(identities) == n + assert set(identities) == set(range(n)) + + for p in procs: + p.join(timeout=5) diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index b99db6d2..c664234f 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -45,6 +45,7 @@ from inference_endpoint.config.schema import ( OnlineBenchmarkConfig as OnlineConfig, ) +from inference_endpoint.config.utils import cli_error_formatter as _error_formatter from inference_endpoint.core.types import QueryResult from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.exceptions import InputValidationError @@ -434,10 +435,6 @@ class TestErrorFormatter: @pytest.mark.unit def test_cyclopts_arg_with_children(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - child = SimpleNamespace( name="--endpoints", names=("--endpoints",), required=True, has_tokens=False ) @@ -449,10 +446,6 @@ def test_cyclopts_arg_with_children(self): @pytest.mark.unit def test_cyclopts_leaf_arg(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - arg = SimpleNamespace( name="--model", names=("--model-params.name", "--model"), children=[] ) @@ -463,10 +456,6 @@ def test_cyclopts_leaf_arg(self): @pytest.mark.unit def test_pydantic_validation_error(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - try: BenchmarkConfig( type=TestType.OFFLINE, @@ -481,10 +470,6 @@ def test_pydantic_validation_error(self): @pytest.mark.unit def test_generic_error_fallback(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - class FakeError: argument = None __cause__ = None diff --git a/tests/unit/commands/test_util_commands.py b/tests/unit/commands/test_util_commands.py index 67d0c791..e1e3a3cc 100644 --- a/tests/unit/commands/test_util_commands.py +++ b/tests/unit/commands/test_util_commands.py @@ -15,6 +15,7 @@ """Tests for utility commands (info, validate, init, probe) and main.py dispatch.""" +import asyncio from pathlib import Path from unittest.mock import MagicMock, patch @@ -22,7 +23,7 @@ from inference_endpoint import __version__ from inference_endpoint.commands.info import execute_info from inference_endpoint.commands.init import execute_init -from inference_endpoint.commands.probe import ProbeConfig, execute_probe +from inference_endpoint.commands.probe import ProbeConfig, _probe_async, execute_probe from inference_endpoint.commands.validate import execute_validate from inference_endpoint.config.schema import APIType from inference_endpoint.exceptions import ( @@ -31,6 +32,7 @@ InputValidationError, SetupError, ) +from inference_endpoint.main import run class TestInfoCommand: @@ -160,10 +162,6 @@ def test_execute_probe_calls_async(self, mock_run_async): def test_empty_model_raises(self): config = ProbeConfig(endpoints="http://localhost:8000", model="") with pytest.raises(InputValidationError, match="Model required"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @pytest.mark.unit @@ -173,10 +171,6 @@ def test_setup_failure_raises(self, mock_client_cls): config = ProbeConfig(endpoints="http://localhost:8000", model="test") with pytest.raises(SetupError, match="Probe setup failed"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @pytest.mark.unit @@ -190,10 +184,6 @@ def test_all_issues_fail_raises(self, mock_client_cls): endpoints="http://localhost:8000", model="test", requests=2 ) with pytest.raises(ExecutionError, match="no queries could be issued"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @@ -213,8 +203,6 @@ class TestMainRunExceptionHandling: ], ) def test_exception_exit_codes(self, exc, code): - from inference_endpoint.main import run - with patch("inference_endpoint.main.app") as mock_app: mock_app.meta.side_effect = exc with pytest.raises(SystemExit) as exc_info: diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index b33f7eda..52bdbe77 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -479,7 +479,6 @@ def test_stream_chunk_minimal(self): assert decoded.id == "" assert decoded.response_chunk == "" - assert decoded.is_complete is False assert decoded.metadata == {} def test_stream_chunk_with_basic_content(self): @@ -493,14 +492,12 @@ def test_stream_chunk_with_basic_content(self): assert decoded.id == "query-123" assert decoded.response_chunk == "Hello, this is a chunk of text." - assert decoded.is_complete is False def test_stream_chunk_first_chunk(self): """Test StreamChunk representing first chunk with metadata.""" chunk = StreamChunk( id="query-456", response_chunk="First token", - is_complete=False, metadata={"first_chunk": True, "latency_ns": 1234567}, ) @@ -510,24 +507,11 @@ def test_stream_chunk_first_chunk(self): assert decoded.metadata["first_chunk"] is True assert decoded.metadata["latency_ns"] == 1234567 - def test_stream_chunk_final_chunk(self): - """Test StreamChunk representing final chunk.""" - chunk = StreamChunk( - id="query-789", response_chunk="Final text.", is_complete=True - ) - - encoded = msgspec.msgpack.encode(chunk) - decoded = msgspec.msgpack.decode(encoded, type=StreamChunk) - - assert decoded.is_complete is True - assert decoded.response_chunk == "Final text." - def test_stream_chunk_with_comprehensive_metadata(self): """Test StreamChunk with detailed metadata.""" chunk = StreamChunk( id="query-meta", response_chunk=" next token", - is_complete=False, metadata={ "model": "llama-2-70b", "chunk_index": 5, @@ -569,7 +553,6 @@ def test_stream_chunk_all_fields_populated(self): chunk = StreamChunk( id="query-full-chunk", response_chunk="Complete chunk text", - is_complete=True, metadata={ "model": "gpt-4", "finish_reason": "stop", @@ -582,7 +565,6 @@ def test_stream_chunk_all_fields_populated(self): assert decoded.id == "query-full-chunk" assert decoded.response_chunk == "Complete chunk text" - assert decoded.is_complete is True assert decoded.metadata["finish_reason"] == "stop" def test_stream_chunk_multiple_roundtrips(self): @@ -590,7 +572,6 @@ def test_stream_chunk_multiple_roundtrips(self): original = StreamChunk( id="query-roundtrip", response_chunk="Test chunk", - is_complete=False, metadata={"index": 1}, ) @@ -605,7 +586,6 @@ def test_stream_chunk_multiple_roundtrips(self): # Verify all fields remain consistent assert decoded2.id == original.id assert decoded2.response_chunk == original.response_chunk - assert decoded2.is_complete == original.is_complete assert decoded2.metadata == original.metadata @@ -811,7 +791,7 @@ def test_serialize_list_of_stream_chunks(self): id="q1", response_chunk="First", metadata={"first_chunk": True} ), StreamChunk(id="q1", response_chunk=" second"), - StreamChunk(id="q1", response_chunk=" final", is_complete=True), + StreamChunk(id="q1", response_chunk=" final"), ] encoded = msgspec.msgpack.encode(chunks) @@ -819,7 +799,6 @@ def test_serialize_list_of_stream_chunks(self): assert len(decoded) == 3 assert decoded[0].metadata.get("first_chunk") is True - assert decoded[2].is_complete is True def test_query_result_with_nested_metadata(self): """Test QueryResult with deeply nested metadata and TextModelOutput.""" diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py new file mode 100644 index 00000000..38dd014e --- /dev/null +++ b/tests/unit/load_generator/test_async_session.py @@ -0,0 +1,884 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the async BenchmarkSession.""" + +from __future__ import annotations + +import asyncio +import random + +import inference_endpoint.load_generator.session as _session_mod +import pytest +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.core.record import ( + ErrorEventType, + EventRecord, + SampleEventType, + SessionEventType, +) +from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk +from inference_endpoint.dataset_manager.dataset import Dataset +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseIssuer, + PhaseResult, + PhaseType, + SessionResult, +) +from inference_endpoint.metrics.metric import Throughput + + +@pytest.fixture(autouse=False) +def enable_warmup(monkeypatch): + """Enable warmup phases for tests that use PhaseType.WARMUP.""" + monkeypatch.setattr(_session_mod, "_WARMUP_ENABLED", True) + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class FakeDataset(Dataset): + """In-memory dataset for tests.""" + + def __init__(self, n_samples: int = 10): + self._n = n_samples + + def load_sample(self, index: int) -> dict: + return {"prompt": f"sample_{index}", "model": "test"} + + def num_samples(self) -> int: + return self._n + + +class FakeIssuer: + """Fake SampleIssuer that queues responses for controlled delivery.""" + + def __init__(self, response_delay: float = 0.001): + self._issued: list[Query] = [] + self._response_queue: asyncio.Queue[QueryResult | StreamChunk | None] = ( + asyncio.Queue() + ) + self._response_delay = response_delay + self._auto_respond = True + self._loop: asyncio.AbstractEventLoop | None = None + + def issue(self, query: Query) -> None: + self._issued.append(query) + if self._auto_respond and self._loop: + + def _enqueue_response(q: Query = query) -> None: + self._response_queue.put_nowait( + QueryResult(id=q.id, response_output=None) + ) + + self._loop.call_later(self._response_delay, _enqueue_response) + + async def recv(self) -> QueryResult | StreamChunk | None: + return await self._response_queue.get() + + def shutdown(self) -> None: + self._response_queue.put_nowait(None) + + def inject_response(self, resp: QueryResult | StreamChunk) -> None: + self._response_queue.put_nowait(resp) + + @property + def issued_queries(self) -> list[Query]: + return self._issued + + +class FakePublisher: + """Captures published EventRecords.""" + + def __init__(self): + self.events: list[EventRecord] = [] + + def publish(self, event_record: EventRecord) -> None: + self.events.append(event_record) + + def flush(self) -> None: + pass + + def events_of_type(self, event_type) -> list[EventRecord]: + return [e for e in self.events if e.event_type == event_type] + + +def _make_settings( + load_pattern: LoadPattern | None = None, + n_samples: int = 10, + max_duration_ms: int | None = None, +) -> RuntimeSettings: + return RuntimeSettings( + metric_target=Throughput(100), + reported_metrics=[], + min_duration_ms=0, + max_duration_ms=max_duration_ms, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, + min_sample_count=n_samples, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=load_pattern or LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + +# --------------------------------------------------------------------------- +# PhaseIssuer tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestPhaseIssuer: + def test_issue_builds_query_and_publishes(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + result = phase_issuer.issue(3) + assert result is not None + assert phase_issuer.issued_count == 1 + assert phase_issuer.inflight == 1 + assert len(issuer.issued_queries) == 1 + assert issuer.issued_queries[0].id == result + assert 3 in phase_issuer.uuid_to_index.values() + + # Should have published ISSUED event + issued_events = publisher.events_of_type(SampleEventType.ISSUED) + assert len(issued_events) == 1 + assert issued_events[0].sample_uuid == result + + def test_issue_returns_none_when_stopped(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: True) + + result = phase_issuer.issue(0) + assert result is None + assert phase_issuer.issued_count == 0 + + def test_uuid_is_unique_per_issue(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + ids = [phase_issuer.issue(i % 5) for i in range(10)] + assert len(set(ids)) == 10 + + +# --------------------------------------------------------------------------- +# BenchmarkSession tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBenchmarkSession: + @pytest.mark.asyncio + async def test_single_perf_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + result = await session.run(phases) + + assert len(result.phase_results) == 1 + assert result.perf_results[0].name == "perf" + assert result.perf_results[0].issued_count == 5 + assert len(result.perf_results[0].uuid_to_index) == 5 + + # Check session events + started = publisher.events_of_type(SessionEventType.STARTED) + ended = publisher.events_of_type(SessionEventType.ENDED) + start_track = publisher.events_of_type( + SessionEventType.START_PERFORMANCE_TRACKING + ) + stop_track = publisher.events_of_type( + SessionEventType.STOP_PERFORMANCE_TRACKING + ) + assert len(started) == 1 + assert len(ended) == 1 + assert len(start_track) == 1 + assert len(stop_track) == 1 + + @pytest.mark.asyncio + async def test_accuracy_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "acc", _make_settings(n_samples=3), FakeDataset(3), PhaseType.ACCURACY + ), + ] + result = await session.run(phases) + + assert len(result.accuracy_results) == 1 + assert result.accuracy_results[0].issued_count == 3 + # No tracking events for accuracy + assert ( + len(publisher.events_of_type(SessionEventType.START_PERFORMANCE_TRACKING)) + == 0 + ) + + @pytest.mark.asyncio + async def test_warmup_produces_no_result(self, enable_warmup): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup", + _make_settings(n_samples=3), + FakeDataset(3), + PhaseType.WARMUP, + ), + ] + result = await session.run(phases) + assert len(result.phase_results) == 0 + + @pytest.mark.asyncio + async def test_multi_phase(self, enable_warmup): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup", + _make_settings(n_samples=2), + FakeDataset(2), + PhaseType.WARMUP, + ), + PhaseConfig( + "perf", + _make_settings(n_samples=5), + FakeDataset(5), + PhaseType.PERFORMANCE, + ), + PhaseConfig( + "acc", _make_settings(n_samples=3), FakeDataset(3), PhaseType.ACCURACY + ), + ] + result = await session.run(phases) + + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 5 + assert len(result.accuracy_results) == 1 + assert result.accuracy_results[0].issued_count == 3 + + @pytest.mark.asyncio + async def test_stop_terminates_early(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + + # Stop after a short delay + loop.call_later(0.05, session.stop) + + phases = [ + PhaseConfig( + "perf", + _make_settings(n_samples=100_000, max_duration_ms=10_000), + FakeDataset(100), + ), + ] + result = await session.run(phases) + # Should have stopped early, not issued all 100k + assert result.perf_results[0].issued_count < 100_000 + + @pytest.mark.asyncio + async def test_on_sample_complete_callback(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + completed: list[str] = [] + + def on_complete(result: QueryResult) -> None: + completed.append(result.id) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + await session.run(phases) + assert len(completed) == 5 + + @pytest.mark.asyncio + async def test_stale_completions_ignored_by_strategy(self, enable_warmup): + """Responses from warmup phase should not affect perf phase strategy.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + # Issuer that delays responses significantly so they arrive in next phase + issuer = FakeIssuer(response_delay=0.1) + issuer._loop = loop + + session = BenchmarkSession(issuer, publisher, loop) + + concurrency_settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=2 + ), + n_samples=3, + ) + phases = [ + PhaseConfig( + "sat", _make_settings(n_samples=2), FakeDataset(2), PhaseType.WARMUP + ), + PhaseConfig( + "perf", concurrency_settings, FakeDataset(3), PhaseType.PERFORMANCE + ), + ] + result = await session.run(phases) + + # Perf phase should complete with its own samples, not be confused by stale ones + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 3 + + @pytest.mark.asyncio + async def test_recv_none_triggers_stop(self): + """If issuer.recv() returns None mid-phase, drain should abort quickly.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + + # Schedule transport close after a short delay — recv returns None + loop.call_later(0.05, issuer.shutdown) + + # Session should complete quickly — recv None sets stop_requested, + # which aborts drain. wait_for prevents CI hang if this regresses. + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + assert result is not None + + @pytest.mark.asyncio + async def test_streaming_query_completes_via_queryresult(self): + """Streaming: StreamChunks publish timing events, QueryResult handles completion. + + The worker sends StreamChunk(first) → StreamChunk(delta) → QueryResult. + Only the QueryResult decrements inflight and releases the concurrency + semaphore. StreamChunks only publish timing events. + """ + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=2, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(2))] + + async def inject_streaming_responses(): + """Simulate worker: StreamChunk(first) → StreamChunk(delta) → QueryResult.""" + while not issuer._issued: + await asyncio.sleep(0.005) + q1 = issuer._issued[0] + issuer.inject_response( + StreamChunk(id=q1.id, metadata={"first_chunk": True}) + ) + issuer.inject_response(StreamChunk(id=q1.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q1.id, response_output="out1")) + while len(issuer._issued) < 2: + await asyncio.sleep(0.005) + q2 = issuer._issued[1] + issuer.inject_response( + StreamChunk(id=q2.id, metadata={"first_chunk": True}) + ) + issuer.inject_response(StreamChunk(id=q2.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q2.id, response_output="out2")) + + injector = asyncio.create_task(inject_streaming_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + await injector + + assert result.perf_results[0].issued_count == 2 + recv_first = publisher.events_of_type(SampleEventType.RECV_FIRST) + assert len(recv_first) == 2 + + @pytest.mark.asyncio + async def test_concurrency_strategy_transport_close_no_deadlock(self): + """ConcurrencyStrategy must not deadlock when transport closes mid-phase.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer(response_delay=999) # Responses never arrive in time + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=2 + ), + n_samples=100, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(10))] + + # Close transport after strategy issues initial batch and blocks on semaphore + loop.call_later(0.1, issuer.shutdown) + + # Must complete without deadlock — wait_for prevents CI hang + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + assert result is not None + + @pytest.mark.asyncio + async def test_on_sample_complete_called_for_streaming_query(self): + """on_sample_complete fires exactly once per streaming query (on QueryResult). + + StreamChunks only publish timing events — callback fires only for QueryResult. + """ + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed: list[QueryResult | StreamChunk] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed.append(result) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=1, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(1))] + + async def inject(): + while not issuer._issued: + await asyncio.sleep(0.005) + q = issuer._issued[0] + issuer.inject_response(StreamChunk(id=q.id, metadata={"first_chunk": True})) + issuer.inject_response(StreamChunk(id=q.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q.id, response_output="done")) + + asyncio.create_task(inject()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + assert len(completed) == 1 + assert isinstance(completed[0], QueryResult) + + @pytest.mark.asyncio + async def test_failed_query_published_as_error_event(self): + """Bug #5: QueryResult with error should publish ErrorEventType, not just COMPLETE.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1) + phases = [PhaseConfig("perf", settings, FakeDataset(1))] + + async def inject_error(): + while not issuer._issued: + await asyncio.sleep(0.005) + q = issuer._issued[0] + issuer.inject_response( + QueryResult( + id=q.id, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + ) + + asyncio.create_task(inject_error()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + # Should have published both COMPLETE and an error event + complete_events = publisher.events_of_type(SampleEventType.COMPLETE) + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert len(complete_events) == 1 + # Bug #5: error event should also be published + assert len(error_events) == 1 + + +@pytest.mark.unit +class TestBenchmarkSessionPoissonIntegration: + """Poisson strategy (TimedIssueStrategy) integration with session.""" + + @pytest.mark.asyncio + async def test_poisson_issues_all_samples(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + poisson_settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=5000.0), + n_samples=8, + ) + phases = [ + PhaseConfig("perf", poisson_settings, FakeDataset(8)), + ] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 8 + + @pytest.mark.asyncio + async def test_poisson_respects_stop(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + poisson_settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=100.0), + n_samples=100_000, + max_duration_ms=60_000, + ) + phases = [ + PhaseConfig("perf", poisson_settings, FakeDataset(100)), + ] + loop.call_later(0.05, session.stop) + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + assert result.perf_results[0].issued_count < 100_000 + + +@pytest.mark.unit +class TestBenchmarkSessionMaxDuration: + """max_duration_ms timeout: phase stops after duration even with samples remaining.""" + + @pytest.mark.asyncio + async def test_max_duration_stops_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + # Very short max_duration with many samples to issue + settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=10.0), + n_samples=100_000, + max_duration_ms=50, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(100))] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Should have stopped well before issuing all samples + assert result.perf_results[0].issued_count < 100_000 + + @pytest.mark.asyncio + async def test_max_duration_with_burst(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1_000_000, max_duration_ms=20) + phases = [PhaseConfig("perf", settings, FakeDataset(100))] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Burst fires fast, but stop_check should cut it short + assert result.perf_results[0].issued_count < 1_000_000 + + +@pytest.mark.unit +class TestBenchmarkSessionAccuracyErrorHandling: + """Error handling in accuracy phase: query fails, verify it doesn't corrupt scoring.""" + + @pytest.mark.asyncio + async def test_failed_query_in_accuracy_phase_preserves_uuid_map(self): + loop = asyncio.get_running_loop() + publisher = FakePublisher() + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed_results: list[QueryResult | StreamChunk] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed_results.append(result) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + settings = _make_settings(n_samples=3) + phases = [ + PhaseConfig("acc", settings, FakeDataset(3), PhaseType.ACCURACY), + ] + + async def inject_mixed_responses(): + while len(issuer._issued) < 3: + await asyncio.sleep(0.005) + # First query: success + issuer.inject_response( + QueryResult(id=issuer._issued[0].id, response_output="answer1") + ) + # Second query: error + issuer.inject_response( + QueryResult( + id=issuer._issued[1].id, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + ) + # Third query: success + issuer.inject_response( + QueryResult(id=issuer._issued[2].id, response_output="answer3") + ) + + asyncio.create_task(inject_mixed_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + + assert len(result.accuracy_results) == 1 + acc = result.accuracy_results[0] + # All 3 samples should be in uuid_to_index, including the failed one + assert acc.issued_count == 3 + assert len(acc.uuid_to_index) == 3 + # on_sample_complete should have fired for all 3 + assert len(completed_results) == 3 + + @pytest.mark.asyncio + async def test_error_event_published_in_accuracy_phase(self): + loop = asyncio.get_running_loop() + publisher = FakePublisher() + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1) + phases = [ + PhaseConfig("acc", settings, FakeDataset(1), PhaseType.ACCURACY), + ] + + async def inject_error(): + while not issuer._issued: + await asyncio.sleep(0.005) + issuer.inject_response( + QueryResult( + id=issuer._issued[0].id, + error=ErrorData(error_type="server_error", error_message="500"), + ) + ) + + asyncio.create_task(inject_error()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert len(error_events) == 1 + + +@pytest.mark.unit +class TestBenchmarkSessionMultiPhaseSatPerfSequence: + """Multi-perf + warmup sequence (sat -> perf -> sat -> perf).""" + + @pytest.mark.asyncio + async def test_sat_perf_sat_perf(self, enable_warmup): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup1", + _make_settings(n_samples=2), + FakeDataset(2), + PhaseType.WARMUP, + ), + PhaseConfig( + "perf1", + _make_settings(n_samples=4), + FakeDataset(4), + PhaseType.PERFORMANCE, + ), + PhaseConfig( + "warmup2", + _make_settings(n_samples=3), + FakeDataset(3), + PhaseType.WARMUP, + ), + PhaseConfig( + "perf2", + _make_settings(n_samples=6), + FakeDataset(6), + PhaseType.PERFORMANCE, + ), + ] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Both perf phases should produce results + assert len(result.perf_results) == 2 + assert result.perf_results[0].name == "perf1" + assert result.perf_results[0].issued_count == 4 + assert result.perf_results[1].name == "perf2" + assert result.perf_results[1].issued_count == 6 + + # Saturation phases produce no results + assert len(result.phase_results) == 2 + + # Should have start/stop tracking for each perf phase + start_track = publisher.events_of_type( + SessionEventType.START_PERFORMANCE_TRACKING + ) + stop_track = publisher.events_of_type( + SessionEventType.STOP_PERFORMANCE_TRACKING + ) + assert len(start_track) == 2 + assert len(stop_track) == 2 + + +@pytest.mark.unit +class TestBenchmarkSessionStaleStreamChunk: + """Stale StreamChunk from previous phase is ignored.""" + + @pytest.mark.asyncio + async def test_stale_stream_chunk_ignored(self, enable_warmup): + """StreamChunk from warmup phase should not affect perf phase counts.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed: list[str] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed.append(result.id) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + + # Saturation with slow responses, perf with concurrency + sat_settings = _make_settings(n_samples=2) + perf_settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=2, + ) + + phases = [ + PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.WARMUP), + PhaseConfig("perf", perf_settings, FakeDataset(2), PhaseType.PERFORMANCE), + ] + + async def inject_responses(): + # Wait for warmup queries + while len(issuer._issued) < 2: + await asyncio.sleep(0.005) + sat_ids = [q.id for q in issuer._issued[:2]] + + # Wait for perf phase queries to start + while len(issuer._issued) < 3: + await asyncio.sleep(0.005) + + # Inject stale StreamChunks from warmup phase into perf phase + issuer.inject_response(StreamChunk(id=sat_ids[0], response_chunk="stale")) + issuer.inject_response(StreamChunk(id=sat_ids[1], response_chunk="stale")) + + # Now complete the perf queries + perf_queries = issuer._issued[2:] + for q in perf_queries: + issuer.inject_response(QueryResult(id=q.id, response_output="ok")) + # Wait for second perf query if not yet issued + while len(issuer._issued) < 4: + await asyncio.sleep(0.005) + for q in issuer._issued[2:]: + if q.id not in list(completed): + issuer.inject_response( + QueryResult(id=q.id, response_output="ok") + ) + + asyncio.create_task(inject_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + + # Perf phase should have exactly 2 issued samples + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 2 + # on_sample_complete should only be called for perf-phase queries + # (stale sat queries are not in perf's uuid_to_index) + for cid in completed: + assert cid in result.perf_results[0].uuid_to_index + + +@pytest.mark.unit +class TestSessionResult: + def test_perf_results_filter(self, enable_warmup): + results = [ + PhaseResult("sat", PhaseType.WARMUP, {}, 0, 0, 0), + PhaseResult("perf1", PhaseType.PERFORMANCE, {"a": 1}, 10, 0, 100), + PhaseResult("perf2", PhaseType.PERFORMANCE, {"b": 2}, 20, 100, 200), + PhaseResult("acc", PhaseType.ACCURACY, {"c": 3}, 5, 200, 300), + ] + sr = SessionResult("test", results, 0, 300) + assert len(sr.perf_results) == 2 + assert len(sr.accuracy_results) == 1 + assert sr.perf_results[0].name == "perf1" diff --git a/tests/unit/load_generator/test_load_generator.py b/tests/unit/load_generator/test_load_generator.py deleted file mode 100644 index 87132964..00000000 --- a/tests/unit/load_generator/test_load_generator.py +++ /dev/null @@ -1,293 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import time -from collections import defaultdict -from unittest.mock import patch - -import inference_endpoint.metrics as metrics -from inference_endpoint.config.runtime_settings import RuntimeSettings -from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.core.types import QueryResult, StreamChunk -from inference_endpoint.load_generator.events import SampleEvent -from inference_endpoint.load_generator.load_generator import ( - SampleIssuer, - SchedulerBasedLoadGenerator, -) -from inference_endpoint.load_generator.sample import SampleEventHandler -from inference_endpoint.load_generator.scheduler import ( - MaxThroughputScheduler, - PoissonDistributionScheduler, - SampleOrder, - WithoutReplacementSampleOrder, -) - -from tests.test_helpers import DummyDataLoader, SerialSampleIssuer - - -class FibonacciSampleOrder(SampleOrder): - """Sample order where the corresponding value for a sample index is that number value in - the Fibonacci sequence. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.a = 0 - self.b = 1 - - def next_sample_index(self) -> int: - retval = self.a - c = self.a + self.b - self.a = self.b - self.b = c - return retval - - -@patch("inference_endpoint.load_generator.load_generator.EventRecorder.record_event") -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_load_generator( - load_sample_data_mock, event_recorder_mock, max_throughput_runtime_settings -): - load_sample_data_mock.side_effect = lambda index, _uuid: index**2 - event_recorder_mock.return_value = True - - class ListAppendIssuer(SampleIssuer): - def __init__(self): - self.issued = [] - - def issue(self, sample): - self.issued.append(sample) - - fake_sample_issuer = ListAppendIssuer() - - load_generator = SchedulerBasedLoadGenerator( - fake_sample_issuer, - None, # No Dataloader to set, we're using Mock to prevent accessing the EventRecorder and DataLoader - scheduler=MaxThroughputScheduler( - max_throughput_runtime_settings, - FibonacciSampleOrder, - ), - ) - a = 0 - b = 1 - for i, issued_sample in enumerate(load_generator): - assert issued_sample.sample.data == a**2 - assert issued_sample.sample == fake_sample_issuer.issued[i] - assert len(fake_sample_issuer.issued) == i + 1 - - c = a + b - a = b - b = c - - -@patch("inference_endpoint.metrics.recorder.EventRecorder.record_event") -def test_full_run(record_event_mock): - record_event_mock.return_value = None - - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1000, - max_duration_ms=10_000, - n_samples_from_dataset=100, - n_samples_to_issue=10_000, - min_sample_count=100, - rng_sched=random.Random(1234), - rng_sample_index=random.Random(1234), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - def compute_digits_of_square(n: int): - yield from str(n**2) - - sample_issuer = SerialSampleIssuer(compute_digits_of_square) - load_generator = SchedulerBasedLoadGenerator( - sample_issuer, - DummyDataLoader(100), - scheduler=MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, - ), - ) - - # Hooks for chunk data and query results - received_chunks = defaultdict(list) - - def save_chunk(chunk: StreamChunk): - received_chunks[chunk.id].append(chunk.response_chunk) - - SampleEventHandler.register_hook(SampleEvent.FIRST_CHUNK, save_chunk) - SampleEventHandler.register_hook(SampleEvent.NON_FIRST_CHUNK, save_chunk) - - results = {} - - def save_query_result(result: QueryResult): - results[result.id] = result.get_response_output_string() - - SampleEventHandler.register_hook(SampleEvent.COMPLETE, save_query_result) - - sent_hist = defaultdict(int) - sent_uuids = defaultdict(list) - seen_uuids = set() - for issued_sample in load_generator: - # The test issuer is serial, so we can confirm that a sample is completed before the next - # is issued. - expected = str(issued_sample.index**2) - assert received_chunks[issued_sample.sample.uuid][0] == expected[0] - assert len(received_chunks[issued_sample.sample.uuid]) == len(expected) - assert "".join(received_chunks[issued_sample.sample.uuid]) == expected - assert results[issued_sample.sample.uuid] == expected - - sent_hist[issued_sample.index] += 1 - sent_uuids[issued_sample.index].append(issued_sample.sample.uuid) - seen_uuids.add(issued_sample.sample.uuid) - - # WithoutReplacementSampleOrder should ensure that as long as total # of samples issued is a multiple of dataset size, - # the number of issues per sample is the same - target_issues = rt_settings.n_samples_to_issue // rt_settings.n_samples_from_dataset - for index, n_sent in sent_hist.items(): - assert ( - n_sent == target_issues - ), f"Sample {index} should have been issued {target_issues} times, but was issued {n_sent} times" - - # Check uuid uniqueness - n_distinct_uuids = len(set(sent_uuids[index])) - assert ( - n_distinct_uuids == n_sent - ), f"Sample {index} should have {n_sent} unique uuids, but has {n_distinct_uuids}" - - # Check that ALL uuids are unique - assert ( - len(seen_uuids) == rt_settings.n_samples_to_issue - ), f"Should have seen {rt_settings.n_samples_to_issue} unique uuids, but saw {len(seen_uuids)}" - - -@patch("inference_endpoint.load_generator.load_generator.EventRecorder.record_event") -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_max_duration_ms_stops_issuance(load_sample_data_mock, event_recorder_mock): - """max_duration_ms should stop iteration before n_samples_to_issue is exhausted.""" - load_sample_data_mock.side_effect = lambda index, _uuid: index - event_recorder_mock.return_value = True - - max_duration_ms = 50 - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - reported_metrics=[], - min_duration_ms=0, - max_duration_ms=max_duration_ms, - n_samples_from_dataset=100, - n_samples_to_issue=1_000_000, - min_sample_count=100, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - issued_count = 0 - - class CountingIssuer(SampleIssuer): - def issue(self, sample): - pass - - load_generator = SchedulerBasedLoadGenerator( - CountingIssuer(), - None, - scheduler=MaxThroughputScheduler(rt_settings, WithoutReplacementSampleOrder), - ) - - start = time.monotonic() - for _ in load_generator: - issued_count += 1 - elapsed_s = time.monotonic() - start - - # Should have stopped well before issuing 1,000,000 samples - assert ( - issued_count < 1_000_000 - ), f"Expected timeout to stop issuance, but {issued_count} samples were issued" - # Elapsed wall-clock should be reasonably close to max_duration_ms: - # lower bound ensures the timeout (not an early exit) was responsible for stopping, - # upper bound is generous to accommodate slow CI runners. - max_duration_s = max_duration_ms / 1000 - assert ( - elapsed_s >= max_duration_s * 0.5 - ), f"Elapsed time {elapsed_s:.3f}s is unexpectedly below max_duration_ms={max_duration_ms}ms" - assert ( - elapsed_s < max_duration_s * 2 - ), f"Elapsed time {elapsed_s:.3f}s far exceeds max_duration_ms={max_duration_ms}ms" - - -@patch("inference_endpoint.load_generator.load_generator.EventRecorder.record_event") -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_max_duration_ms_stops_issuance_with_poisson_scheduler( - load_sample_data_mock, event_recorder_mock -): - """max_duration_ms should stop iteration even when the scheduler has inter-sample delays. - - Uses PoissonDistributionScheduler at low QPS so each inter-sample wait is measurable. - No sample should be issued after the wall-clock deadline has elapsed. - """ - load_sample_data_mock.side_effect = lambda index, _uuid: index - event_recorder_mock.return_value = True - - max_duration_ms = 200 - target_qps = 50 # ~20ms average inter-sample delay - rt_settings = RuntimeSettings( - metrics.Throughput(target_qps), - reported_metrics=[], - min_duration_ms=0, - max_duration_ms=max_duration_ms, - n_samples_from_dataset=100, - n_samples_to_issue=1_000_000, - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=target_qps), - ) - - class CountingIssuer(SampleIssuer): - def issue(self, sample): - pass - - load_generator = SchedulerBasedLoadGenerator( - CountingIssuer(), - None, - scheduler=PoissonDistributionScheduler( - rt_settings, WithoutReplacementSampleOrder - ), - ) - - issued_count = 0 - start = time.monotonic() - for _ in load_generator: - issued_count += 1 - elapsed_s = time.monotonic() - start - - assert ( - issued_count < 1_000_000 - ), f"Expected timeout to stop issuance, but {issued_count} samples were issued" - max_duration_s = max_duration_ms / 1000 - assert ( - elapsed_s >= max_duration_s * 0.5 - ), f"Elapsed time {elapsed_s:.3f}s is unexpectedly below max_duration_ms={max_duration_ms}ms" - assert ( - elapsed_s < max_duration_s * 3 - ), f"Elapsed time {elapsed_s:.3f}s far exceeds max_duration_ms={max_duration_ms}ms" diff --git a/tests/unit/load_generator/test_sample.py b/tests/unit/load_generator/test_sample.py deleted file mode 100644 index 7df2fd32..00000000 --- a/tests/unit/load_generator/test_sample.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from unittest.mock import patch - -import pytest -from inference_endpoint.core.types import QueryResult, StreamChunk -from inference_endpoint.load_generator.events import SampleEvent -from inference_endpoint.load_generator.sample import Sample, SampleEventHandler - - -def test_sample_uniqueness(): - sample_uuids = [Sample(None).uuid for _ in range(1000)] - assert len(set(sample_uuids)) == len(sample_uuids), "Sample UUIDs should be unique" - - -def test_sample_lazy_data_loading(): - sample = Sample(None) - sample.data = "test_data" - assert sample.data == "test_data" - - with pytest.raises(AttributeError): - sample.data = "test_data2" - - -def test_sample_eager_data_loading(): - sample = Sample("my data") - - with pytest.raises(AttributeError): - sample.data = "test_data2" - - assert sample.data == "my data" - - -@patch("inference_endpoint.load_generator.sample.EventRecorder.record_event") -def test_sample_callback_times(record_event_mock): - events = [] - - sample = Sample(None) - first_chunk = StreamChunk(id=sample.uuid, metadata={"first_chunk": True}) - non_first_chunk = StreamChunk(id=sample.uuid, metadata={"first_chunk": False}) - complete_result = QueryResult(id=sample.uuid) - - def fake_record_event( - ev_type: SampleEvent, timestamp_ns: int, sample_uuid: str, **kwargs - ): - assert sample_uuid == sample.uuid - events.append((ev_type, timestamp_ns)) - - record_event_mock.side_effect = fake_record_event - - sleep_time_sec = 0.01 - - SampleEventHandler.stream_chunk_complete(first_chunk) - time.sleep(sleep_time_sec) - SampleEventHandler.stream_chunk_complete(non_first_chunk) - time.sleep(sleep_time_sec) - SampleEventHandler.query_result_complete(complete_result) - - assert len(events) == 3 - assert record_event_mock.call_count == 3 - - assert events[0][0] == SampleEvent.FIRST_CHUNK - assert events[1][0] == SampleEvent.NON_FIRST_CHUNK - assert events[2][0] == SampleEvent.COMPLETE - assert events[0][1] < events[1][1] - assert events[1][1] < events[2][1] - - # Times are in nanoseconds - convert to seconds to compare with sleep time - tpot_1_sec = (events[1][1] - events[0][1]) / 1e9 - tpot_2_sec = (events[2][1] - events[1][1]) / 1e9 - - # Resolution of time.sleep is very coarse, so simply check that duration is - # greater than the sleep time - assert tpot_1_sec > sleep_time_sec - assert tpot_2_sec > sleep_time_sec - - -@patch("inference_endpoint.load_generator.sample.EventRecorder.record_event") -def test_sample_invalid_type_errors(record_event_mock): - record_event_mock.return_value = None - - chunk = StreamChunk(id="123", metadata={"first_chunk": True}) - result = QueryResult(id="123") - - with pytest.raises(AssertionError, match="Invalid chunk type"): - SampleEventHandler.stream_chunk_complete(result) - - with pytest.raises(AssertionError, match="Invalid result type"): - SampleEventHandler.query_result_complete(chunk) - - -@patch("inference_endpoint.load_generator.sample.EventRecorder.record_event") -def test_sample_event_handler_register_hook(record_event_mock): - record_event_mock.return_value = None - - progress_counter = [0, 0] - - def progress_hook(_): - progress_counter[1] += 1 - - def non_first_chunk_hook(_): - progress_counter[0] += 1 - - SampleEventHandler.register_hook(SampleEvent.COMPLETE, progress_hook) - SampleEventHandler.register_hook(SampleEvent.NON_FIRST_CHUNK, non_first_chunk_hook) - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": True}) - ) - assert progress_counter == [0, 0] - - SampleEventHandler.query_result_complete(QueryResult(id="123")) - assert progress_counter == [0, 1] - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": True}) - ) - assert progress_counter == [0, 1] - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": False}) - ) - assert progress_counter == [1, 1] diff --git a/tests/unit/load_generator/test_sample_order.py b/tests/unit/load_generator/test_sample_order.py new file mode 100644 index 00000000..4ddeff10 --- /dev/null +++ b/tests/unit/load_generator/test_sample_order.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sample_order.py.""" + +import random + +import pytest +from inference_endpoint.load_generator.sample_order import ( + WithoutReplacementSampleOrder, + WithReplacementSampleOrder, +) + +# Exercise small/medium/large dataset sizes so shuffle-buffer behavior is +# covered for inputs both much smaller and much larger than typical batches. +_DATASET_SIZES = [3, 100, 10_000] + + +@pytest.mark.unit +class TestWithoutReplacementSampleOrder: + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_yields_all_indices(self, n_samples: int): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + indices = [next(order) for _ in range(n_samples)] + assert sorted(indices) == list(range(n_samples)) + + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_reshuffles_after_exhaustion(self, n_samples: int): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + first_pass = [next(order) for _ in range(n_samples)] + second_pass = [next(order) for _ in range(n_samples)] + assert sorted(first_pass) == list(range(n_samples)) + assert sorted(second_pass) == list(range(n_samples)) + + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_never_raises_stop_iteration(self, n_samples: int): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + # Should be able to draw far more than dataset size + draws = max(100, n_samples * 3) + indices = [next(order) for _ in range(draws)] + assert len(indices) == draws + assert all(0 <= i < n_samples for i in indices) + + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_reproducible_with_seed(self, n_samples: int): + order1 = WithoutReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + order2 = WithoutReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + seq1 = [next(order1) for _ in range(n_samples * 2)] + seq2 = [next(order2) for _ in range(n_samples * 2)] + assert seq1 == seq2 + + def test_invalid_size_raises(self): + with pytest.raises(ValueError, match="n_samples_in_dataset must be > 0"): + WithoutReplacementSampleOrder(n_samples_in_dataset=0) + + +@pytest.mark.unit +class TestWithReplacementSampleOrder: + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_yields_valid_indices(self, n_samples: int): + order = WithReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + indices = [next(order) for _ in range(max(100, n_samples))] + assert all(0 <= i < n_samples for i in indices) + + @pytest.mark.parametrize("n_samples", _DATASET_SIZES) + def test_reproducible_with_seed(self, n_samples: int): + order1 = WithReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + order2 = WithReplacementSampleOrder( + n_samples_in_dataset=n_samples, rng=random.Random(42) + ) + seq1 = [next(order1) for _ in range(n_samples * 2)] + seq2 = [next(order2) for _ in range(n_samples * 2)] + assert seq1 == seq2 diff --git a/tests/unit/load_generator/test_scheduler.py b/tests/unit/load_generator/test_scheduler.py deleted file mode 100644 index 05de0600..00000000 --- a/tests/unit/load_generator/test_scheduler.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import random -import threading - -import pytest -from inference_endpoint.load_generator.sample import SampleEventHandler -from inference_endpoint.load_generator.scheduler import ( - ConcurrencyScheduler, - MaxThroughputScheduler, - PoissonDistributionScheduler, - WithoutReplacementSampleOrder, - WithReplacementSampleOrder, -) -from scipy import stats - - -def test_without_replacement_sample_order(): - ordering = WithoutReplacementSampleOrder(12345, 100) - indices = list(iter(ordering)) - for i in range(0, 12345, 100): - assert len(set(indices[i : i + 100])) == min( - 100, 12345 - i - ), "Indices should be unique, and occur at least once" - - # Assert that order is different in each pass of the dataset - assert ( - indices[:100] != indices[100:200] - ), "Order should be different in each pass of the dataset" - - -def test_with_replacement_sample_order(random_seed): - ordering = WithReplacementSampleOrder(12345, 100, rng=random.Random(random_seed)) - indices = list(iter(ordering)) - - # With Python random.Random(42), the order can be deterministic - assert indices[:10] == [ - 81, - 14, - 3, - 94, - 35, - 31, - 28, - 17, - 94, - 13, - ], "Order does not match expected deterministic order" - # Note with this specific seed and order, 94 occurs twice in the first 10 indices - assert indices[:10].count(94) == 2, "94 should occur twice in the first 10 indices" - - -def test_max_throughput_scheduler(max_throughput_runtime_settings): - scheduler = MaxThroughputScheduler( - max_throughput_runtime_settings, WithReplacementSampleOrder - ) - indices = list(iter(scheduler)) - assert len(indices) == 100 - for _, delay in indices: - assert delay == 0 - assert [s_idx for s_idx, _ in indices[:10]] == [ - 81, - 14, - 3, - 94, - 35, - 31, - 28, - 17, - 94, - 13, - ], "Order does not match expected deterministic order" - - -@pytest.mark.parametrize("target_concurrency", [1, 2, 100, 1000], indirect=True) -def test_concurrency_scheduler(concurrency_runtime_settings, target_concurrency): - """Test ConcurrencyScheduler properly gates issuance by completions.""" - total_samples = concurrency_runtime_settings.n_samples_to_issue - - scheduler = ConcurrencyScheduler( - concurrency_runtime_settings, WithReplacementSampleOrder - ) - - # State tracking - state_lock = threading.RLock() - issued_count = 0 - completed_count = 0 - current_inflight = 0 - max_inflight = 0 - - # Synchronization: signal when queries can complete and when they're done - can_complete = [threading.Event() for _ in range(total_samples)] - completed = [threading.Event() for _ in range(total_samples)] - # Signal when each query is issued - issued = [threading.Event() for _ in range(total_samples)] - - def completion_worker(): - """Waits for signals to complete queries.""" - nonlocal completed_count, current_inflight - - for position in range(total_samples): - can_complete[position].wait() - - with state_lock: - completed_count += 1 - current_inflight -= 1 - assert current_inflight >= 0, "Inflight count went negative" - - scheduler._release_slot() - completed[position].set() - - threading.Thread(target=completion_worker, daemon=True).start() - - def issue_worker(): - """Issues queries through scheduler.""" - nonlocal issued_count, current_inflight, max_inflight - - for position, _ in enumerate(scheduler): - with state_lock: - issued_count += 1 - current_inflight += 1 - max_inflight = max(max_inflight, current_inflight) - assert ( - current_inflight <= target_concurrency - ), f"Concurrency {current_inflight} exceeded limit {target_concurrency}" - issued[position].set() - - issue_thread = threading.Thread(target=issue_worker, daemon=True) - issue_thread.start() - - try: - # Phase 1: First target_concurrency queries issue immediately - for position in range(target_concurrency): - issued[position].wait() - - with state_lock: - assert issued_count == target_concurrency - assert completed_count == 0 - assert current_inflight == target_concurrency - - # Phase 2: Verify scheduler blocks when at capacity, unblocks on completion - for position in range(target_concurrency, total_samples): - position_to_complete = position - target_concurrency - - # Verify next query hasn't issued yet (scheduler is blocking) - assert not issued[ - position - ].is_set(), f"Query {position} issued before slot was freed" - - # Free a slot - can_complete[position_to_complete].set() - completed[position_to_complete].wait() - - # Verify next query now issues - issued[position].wait() - - with state_lock: - assert current_inflight == target_concurrency - - # Phase 3: Complete remaining queries and cleanup - for position in range(target_concurrency, total_samples): - can_complete[position].set() - completed[position].wait() - - issue_thread.join() - - # Final validation - with state_lock: - assert issued_count == total_samples - assert completed_count == total_samples - assert current_inflight == 0 - assert max_inflight == target_concurrency - - finally: - SampleEventHandler.clear_hooks() - - -@pytest.mark.parametrize("target_qps", [50.0, 100.0, 500.0, 1000.0], indirect=True) -def test_poisson_scheduler_distribution(poisson_runtime_settings, target_qps): - """Test PoissonDistributionScheduler produces exponentially distributed inter-arrival times. - - For a Poisson process with rate λ (target QPS), inter-arrival times must follow - exponential distribution with mean = 1/λ. - - Three-tier validation: - 1. Mean with 99.9% confidence interval - 2. Coefficient of Variation (CV) ≈ 1.0 (exponential signature) - 3. Kolmogorov-Smirnov test for distribution shape - """ - scheduler = PoissonDistributionScheduler( - poisson_runtime_settings, WithReplacementSampleOrder - ) - - # Test configuration - TARGET_QPS = target_qps - expected_mean_s = 1.0 / TARGET_QPS - - # Collect delays from scheduler (in seconds) for statistical analysis - delays_s = [] - for _, delay_ns in scheduler: - delays_s.append(delay_ns / 1e9) # Convert ns to seconds - - # Validate sufficient sample size - n = len(delays_s) - - # Calculate sample statistics using Bessel's correction for unbiased variance (whitened) - sample_mean = sum(delays_s) / n - sample_variance = sum((x - sample_mean) ** 2 for x in delays_s) / (n - 1) - sample_std = math.sqrt(sample_variance) - cv = sample_std / sample_mean - - # Test 1: Mean with statistical confidence interval (99.9% CI) - # For exponential: std(X̄) = sigma/√n = mu/√n - z_critical = 3.29 # 99.9% two-tailed - margin_of_error = z_critical * (sample_std / math.sqrt(n)) - assert abs(sample_mean - expected_mean_s) < margin_of_error, ( - f"Mean {sample_mean*1000:.3f}ms outside 99.9% CI: " - f"[{(expected_mean_s - margin_of_error)*1000:.3f}, " - f"{(expected_mean_s + margin_of_error)*1000:.3f}] ms" - ) - - # Test 2: CV should be close to 1.0 (exponential property: std = mean) - # Use adaptive tolerance based on sample size, max(10%, 1 std. error) - cv_tolerance = max(0.10, 1.0 / math.sqrt(n)) - assert ( - abs(cv - 1.0) < cv_tolerance - ), f"CV {cv:.3f} deviates from 1.0 by more than {cv_tolerance:.3f}" - - # Test 3: Kolmogorov-Smirnov test for exponential distribution - # kstest compares data against exponential CDF with scale parameter = mean - ks_statistic, p_value = stats.kstest( - delays_s, - "expon", - args=(0, sample_mean), # loc=0 (no shift), scale=mean - alternative="two-sided", - ) - - # Reject if p-value < 0.0001 (99.99% confidence that distribution is NOT exponential) - ALPHA = 0.0001 - assert p_value > ALPHA, ( - f"KS test rejected exponential distribution: " - f"p-value={p_value:.4f} < alpha={ALPHA} (D={ks_statistic:.4f})" - ) diff --git a/tests/unit/load_generator/test_session.py b/tests/unit/load_generator/test_session.py deleted file mode 100644 index afcd9a4b..00000000 --- a/tests/unit/load_generator/test_session.py +++ /dev/null @@ -1,159 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from pathlib import Path -from unittest.mock import patch - -import inference_endpoint.metrics as metrics -import pytest -from inference_endpoint.config.runtime_settings import RuntimeSettings -from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.load_generator.events import SampleEvent -from inference_endpoint.load_generator.sample import Sample, SampleEventHandler -from inference_endpoint.load_generator.scheduler import ( - MaxThroughputScheduler, - WithoutReplacementSampleOrder, -) -from inference_endpoint.load_generator.session import BenchmarkSession -from inference_endpoint.metrics.reporter import MetricsReporter -from tqdm import tqdm - -from tests.test_helpers import ( - DummyDataLoader, - PooledSampleIssuer, -) - -# The following are tests for PooledSampleIssuer in test_helpers.py. If these tests pass -# but session.py tests fail, it's probably not the PooledSampleIssuer's fault. - - -@patch("inference_endpoint.load_generator.sample.EventRecorder.record_event") -def test_pooled_issuer_exception_propagation(record_event_mock): - record_event_mock.return_value = None - - """Test that exceptions in worker threads are properly propagated to the main thread.""" - - def failing_compute(sample): - raise ValueError("Worker thread error!") - - issuer = PooledSampleIssuer(compute_func=failing_compute, n_workers=2) - - sample1 = Sample(b"sample1") - sample2 = Sample(b"sample2") - - # Submit some work that will fail - issuer.issue(sample1) - issuer.issue(sample2) - - # Shutdown should raise the exception from the worker thread - with pytest.raises(ValueError, match="Worker thread error!"): - issuer.shutdown() - - -@patch("inference_endpoint.load_generator.sample.EventRecorder.record_event") -def test_pooled_issuer_futures_cleanup(record_event_mock): - record_event_mock.return_value = None - - """Test that completed futures are cleaned up to prevent memory leaks.""" - import time - - def slow_compute(sample): - time.sleep(0.01) # Small delay - return [sample.decode("utf-8")] - - issuer = PooledSampleIssuer(compute_func=slow_compute, n_workers=4) - - # Submit 250 samples (should trigger cleanup at 100 and 200) - for _ in range(250): - issuer.issue(Sample(b"sample")) - - # Let some time pass first - time.sleep(0.2) - - # Manually check errors to trigger cleanup - issuer.check_errors() - - for _ in range(250): - issuer.issue(Sample(b"sample")) - - issuer.shutdown() - - # After shutdown, all futures should be cleared - assert len(issuer.futures) == 0, "Futures not cleared after shutdown" - - -# session.py tests - - -def test_session_start(clean_sample_event_hooks): - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1000, - max_duration_ms=None, - n_samples_from_dataset=100, - n_samples_to_issue=10_000, - min_sample_count=100, - rng_sched=random.Random(1234), - rng_sample_index=random.Random(1234), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - def compute_digits_of_square(n: int): - yield from str(n**2) - - dl = DummyDataLoader(n_samples=100) - sample_issuer = PooledSampleIssuer(compute_digits_of_square) - sched = MaxThroughputScheduler(rt_settings, WithoutReplacementSampleOrder) - - class ProgressBarHook: - def __init__(self, pbar: tqdm | None = None): - self.pbar = pbar - - def __call__(self, _): - if isinstance(self.pbar, tqdm): - self.pbar.update(1) - - def set_pbar(self, pbar: tqdm): - self.pbar = pbar - - pbar_hook = ProgressBarHook() - SampleEventHandler.register_hook(SampleEvent.COMPLETE, pbar_hook) - - with tqdm(desc="pytest_test_session_start", total=10_000, unit="samples") as pbar: - pbar_hook.set_pbar(pbar) - sess = BenchmarkSession.start( - rt_settings, - dl, - sample_issuer, - sched, - name="pytest_test_session_start", - max_shutdown_timeout_s=300, - ) - events_db_path = sess.event_recorder.connection_name - assert sess.wait_for_test_end( - timeout=120.0 - ), "Session did not complete within timeout" - - # Shutdown the sample issuer to ensure proper cleanup and error propagation - sample_issuer.shutdown() - - assert Path(events_db_path).exists() - with MetricsReporter(events_db_path) as reporter: - stats = reporter.get_sample_statuses() - assert stats["total_sent"] == 10_000 - assert stats["completed"] == 10_000 - assert stats["in_flight"] == 0 diff --git a/tests/unit/load_generator/test_strategy.py b/tests/unit/load_generator/test_strategy.py new file mode 100644 index 00000000..650a92c1 --- /dev/null +++ b/tests/unit/load_generator/test_strategy.py @@ -0,0 +1,701 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for load strategies.""" + +from __future__ import annotations + +import asyncio +import random +from collections.abc import Callable +from time import monotonic_ns + +import pytest +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.load_generator.delay import make_delay_fn, poisson_delay_fn +from inference_endpoint.load_generator.sample_order import WithoutReplacementSampleOrder +from inference_endpoint.load_generator.strategy import ( + BurstStrategy, + ConcurrencyStrategy, + TimedIssueStrategy, + create_load_strategy, +) +from inference_endpoint.metrics.metric import Throughput + + +def _constant_delay(ns: int = 1_000) -> Callable[[], int]: + return lambda: ns + + +# --------------------------------------------------------------------------- +# Mock PhaseIssuer +# --------------------------------------------------------------------------- + + +class MockPhaseIssuer: + """Minimal PhaseIssuer for strategy tests.""" + + def __init__(self, max_issues: int = 100): + self.issued_indices: list[int] = [] + self.issued_count: int = 0 + self._max = max_issues + + def issue(self, sample_index: int) -> str | None: + if self.issued_count >= self._max: + return None + self.issued_indices.append(sample_index) + self.issued_count += 1 + return f"q{self.issued_count}" + + +# --------------------------------------------------------------------------- +# TimedIssueStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyCallAt: + @pytest.mark.asyncio + async def test_issues_correct_count(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + + issuer = MockPhaseIssuer(max_issues=20) + count = await strategy.execute(issuer) + assert count == 20 + assert issuer.issued_count == 20 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + + issuer = MockPhaseIssuer(max_issues=3) + count = await strategy.execute(issuer) + assert count == 3 + + @pytest.mark.asyncio + async def test_timing_precision(self): + """call_at should achieve sub-ms precision for moderate delays.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000_000) + + timestamps: list[int] = [] + + class TimingIssuer: + issued_count = 0 + + def issue(self, idx): + timestamps.append(monotonic_ns()) + self.issued_count += 1 + if self.issued_count >= 10: + return None + return f"q{self.issued_count}" + + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + await strategy.execute(TimingIssuer()) + + # Check inter-arrival times are positive (callbacks fire in order) + for i in range(1, len(timestamps)): + delta_ns = timestamps[i] - timestamps[i - 1] + assert delta_ns > 0, f"Issue {i}: non-monotonic timestamps" + # Total elapsed should be roughly 9ms (9 delays of 1ms) + total_ns = timestamps[-1] - timestamps[0] + assert ( + total_ns > 5_000_000 + ), f"Total elapsed {total_ns}ns too small for 9x1ms delays" + + +@pytest.mark.unit +class TestTimedIssueStrategyExecutor: + @pytest.mark.asyncio + async def test_issues_correct_count(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=True) + + issuer = MockPhaseIssuer(max_issues=20) + count = await strategy.execute(issuer) + assert count == 20 + + +# --------------------------------------------------------------------------- +# BurstStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBurstStrategy: + @pytest.mark.asyncio + async def test_issues_all(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + issuer = MockPhaseIssuer(max_issues=50) + count = await strategy.execute(issuer) + assert count == 50 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + issuer = MockPhaseIssuer(max_issues=7) + count = await strategy.execute(issuer) + assert count == 7 + + @pytest.mark.asyncio + async def test_does_not_starve_event_loop(self): + """Verify other coroutines get to run during burst issuance. + + We schedule a coroutine that increments a counter each time it wakes. + If burst issuance yields properly, the counter should be > 0 before + issuance completes. + """ + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=200, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + wakeup_count = 0 + stop = asyncio.Event() + + async def competing_task(): + nonlocal wakeup_count + while not stop.is_set(): + await asyncio.sleep(0) + wakeup_count += 1 + + task = asyncio.create_task(competing_task()) + issuer = MockPhaseIssuer(max_issues=200) + await strategy.execute(issuer) + stop.set() + await task + # The competing task should have woken up multiple times during issuance + assert wakeup_count > 1, f"Competing task only woke {wakeup_count} times" + + +# --------------------------------------------------------------------------- +# ConcurrencyStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestConcurrencyStrategy: + @pytest.mark.asyncio + async def test_issues_up_to_concurrency_then_waits(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=3, sample_order=order) + issuer = MockPhaseIssuer(max_issues=10) + + # Start strategy but don't await — it should block after 3 issues + task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) # let it run + assert issuer.issued_count == 3 + + # Simulate completions + for i in range(1, 4): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.01) + assert issuer.issued_count == 6 + + # Complete remaining + for i in range(4, 11): + strategy.on_query_complete(f"q{i}") + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 10 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=5, sample_order=order) + issuer = MockPhaseIssuer(max_issues=3) + + # Complete queries as they arrive so strategy doesn't block + async def completer(): + while True: + await asyncio.sleep(0.005) + for i in range(1, issuer.issued_count + 1): + strategy.on_query_complete(f"q{i}") + + completer_task = asyncio.create_task(completer()) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=2.0) + completer_task.cancel() + try: + await completer_task + except asyncio.CancelledError: + pass + assert count == 3 + + def test_invalid_concurrency_raises(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + with pytest.raises(ValueError, match="target_concurrency must be > 0"): + ConcurrencyStrategy(target_concurrency=0, sample_order=order) + + +# --------------------------------------------------------------------------- +# Delay functions +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDelayFunctions: + def test_poisson_delay_positive(self): + fn = poisson_delay_fn(1000.0, random.Random(42)) + delays = [fn() for _ in range(100)] + assert all(d >= 1 for d in delays) + + def test_poisson_delay_mean(self): + """Mean delay should be close to 1/target_qps in ns.""" + target_qps = 10_000.0 + fn = poisson_delay_fn(target_qps, random.Random(42)) + delays = [fn() for _ in range(10_000)] + mean_ns = sum(delays) / len(delays) + expected_ns = 1e9 / target_qps # 100_000 ns + assert abs(mean_ns - expected_ns) / expected_ns < 0.1 # within 10% + + def test_poisson_delay_invalid_qps(self): + with pytest.raises(ValueError, match="target_qps must be > 0"): + poisson_delay_fn(0, random.Random(42)) + + def test_make_delay_fn_unsupported_pattern(self): + lp = LoadPattern(type=LoadPatternType.MAX_THROUGHPUT) + with pytest.raises(ValueError, match="No delay function"): + make_delay_fn(lp, random.Random(42)) + + +# --------------------------------------------------------------------------- +# create_load_strategy factory +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCreateLoadStrategy: + def test_max_throughput(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings(LoadPattern(type=LoadPatternType.MAX_THROUGHPUT)) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, BurstStrategy) + finally: + loop.close() + + def test_poisson_default(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.POISSON, target_qps=1000.0) + ) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, TimedIssueStrategy) + assert not strategy._use_executor + finally: + loop.close() + + def test_poisson_executor(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.POISSON, target_qps=1000.0) + ) + strategy = create_load_strategy(settings, loop, use_executor=True) + assert isinstance(strategy, TimedIssueStrategy) + assert strategy._use_executor + finally: + loop.close() + + def test_concurrency(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.CONCURRENCY, target_concurrency=32) + ) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, ConcurrencyStrategy) + assert strategy._target == 32 + finally: + loop.close() + + def test_no_load_pattern_raises(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings(None) + with pytest.raises(ValueError, match="load_pattern must not be None"): + create_load_strategy(settings, loop) + finally: + loop.close() + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestEdgeCases: + @pytest.mark.asyncio + async def test_burst_single_sample(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=1, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=1) + count = await strategy.execute(issuer) + assert count == 1 + + @pytest.mark.asyncio + async def test_burst_stop_immediately(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=0) + count = await strategy.execute(issuer) + assert count == 0 + + @pytest.mark.asyncio + async def test_burst_exception_in_issue_does_not_hang(self): + """If issue() raises, strategy should not hang forever.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + self.issued_count += 1 + if self.issued_count == 3: + raise RuntimeError("load_sample failed") + return f"q{self.issued_count}" + + issuer = FailingIssuer() + # Must not hang — should complete (with error) within timeout + with pytest.raises(RuntimeError, match="load_sample failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + @pytest.mark.asyncio + async def test_timed_call_at_exception_in_issue_does_not_hang(self): + """If issue() raises in call_at callback, strategy should not hang.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1_000), order, loop, use_executor=False + ) + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + self.issued_count += 1 + if self.issued_count == 3: + raise RuntimeError("load_sample failed") + return f"q{self.issued_count}" + + issuer = FailingIssuer() + with pytest.raises(RuntimeError, match="load_sample failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + def test_sample_order_single_element(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=1, rng=random.Random(42) + ) + indices = [next(order) for _ in range(10)] + assert all(i == 0 for i in indices) + + +# --------------------------------------------------------------------------- +# Executor mode exceptions +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyExecutorExceptions: + @pytest.mark.asyncio + async def test_executor_issue_raises(self): + """If issue() raises inside run_in_executor path, exception propagates.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1_000), order, loop, use_executor=True + ) + + call_count = 0 + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + nonlocal call_count + call_count += 1 + self.issued_count += 1 + if call_count == 3: + raise ValueError("executor callback failed") + return f"q{call_count}" + + with pytest.raises(ValueError, match="executor callback failed"): + await asyncio.wait_for(strategy.execute(FailingIssuer()), timeout=5.0) + + @pytest.mark.asyncio + async def test_executor_delay_fn_raises(self): + """If delay_fn raises inside executor path, exception propagates.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + call_count = 0 + + def bad_delay(): + nonlocal call_count + call_count += 1 + if call_count == 3: + raise RuntimeError("delay computation failed") + return 1_000 + + strategy = TimedIssueStrategy(bad_delay, order, loop, use_executor=True) + issuer = MockPhaseIssuer(max_issues=100) + + with pytest.raises(RuntimeError, match="delay computation failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + +# --------------------------------------------------------------------------- +# Concurrent on_query_complete calls +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestConcurrencyStrategyConcurrentCompletions: + @pytest.mark.asyncio + async def test_multiple_completions_simultaneously(self): + """Multiple on_query_complete calls arriving at the same time.""" + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=20, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=5, sample_order=order) + issuer = MockPhaseIssuer(max_issues=20) + + task = asyncio.create_task(strategy.execute(issuer)) + + # Let strategy issue initial batch of 5 + await asyncio.sleep(0.02) + assert issuer.issued_count == 5 + + # Release all 5 at once + for i in range(1, 6): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.02) + assert issuer.issued_count == 10 + + # Release next batch all at once + for i in range(6, 11): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.02) + assert issuer.issued_count == 15 + + # Release rest + for i in range(11, 21): + strategy.on_query_complete(f"q{i}") + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 20 + + @pytest.mark.asyncio + async def test_completions_interleaved_with_issues(self): + """Completions arriving while new issues are being scheduled.""" + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=2, sample_order=order) + issuer = MockPhaseIssuer(max_issues=10) + + task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) + assert issuer.issued_count == 2 + + # Alternate: complete one, let it issue one more + for i in range(1, 11): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.005) + + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 10 + + +# --------------------------------------------------------------------------- +# Near-zero delay (high QPS poisson) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyNearZeroDelay: + @pytest.mark.asyncio + async def test_very_high_qps(self): + """Poisson with extremely high QPS should still issue all samples.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + # 1ns delay -- essentially zero + strategy = TimedIssueStrategy( + _constant_delay(1), order, loop, use_executor=False + ) + issuer = MockPhaseIssuer(max_issues=50) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 50 + + @pytest.mark.asyncio + async def test_very_high_qps_executor(self): + """Near-zero delay in executor mode.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1), order, loop, use_executor=True + ) + issuer = MockPhaseIssuer(max_issues=50) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 50 + + @pytest.mark.asyncio + async def test_poisson_high_qps_statistical(self): + """Real poisson distribution at 1M QPS should complete quickly.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + delay_fn = poisson_delay_fn(1_000_000.0, random.Random(42)) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + issuer = MockPhaseIssuer(max_issues=100) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 100 + + +# --------------------------------------------------------------------------- +# Large-scale burst +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBurstStrategyLargeScale: + @pytest.mark.asyncio + async def test_burst_1000_samples(self): + """BurstStrategy should handle 1000+ samples without issues.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=200, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=1000) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=10.0) + assert count == 1000 + + @pytest.mark.asyncio + async def test_burst_5000_samples(self): + """BurstStrategy at 5000 samples -- verify count and no event loop starvation.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=500, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + wakeups = 0 + stop = asyncio.Event() + + async def observer(): + nonlocal wakeups + while not stop.is_set(): + await asyncio.sleep(0) + wakeups += 1 + + obs_task = asyncio.create_task(observer()) + issuer = MockPhaseIssuer(max_issues=5000) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=10.0) + stop.set() + await obs_task + + assert count == 5000 + assert wakeups > 10, f"Event loop starved: observer only ran {wakeups} times" + + @pytest.mark.asyncio + async def test_burst_indices_wrap_around(self): + """With dataset_size < issue_count, indices should wrap around.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=3, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=10) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 10 + # All indices should be 0, 1, or 2 + assert all(0 <= idx <= 2 for idx in issuer.issued_indices) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_settings(load_pattern): + """Create minimal RuntimeSettings for factory tests.""" + return RuntimeSettings( + metric_target=Throughput(100), + reported_metrics=[], + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=10, + n_samples_to_issue=10, + min_sample_count=10, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=load_pattern, + ) diff --git a/tests/unit/metrics/test_recorder.py b/tests/unit/metrics/test_recorder.py deleted file mode 100644 index 058ee50a..00000000 --- a/tests/unit/metrics/test_recorder.py +++ /dev/null @@ -1,381 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import sqlite3 -import uuid -from collections import namedtuple -from unittest.mock import patch - -import msgspec.json -import pytest -from inference_endpoint.load_generator.events import SampleEvent, SessionEvent -from inference_endpoint.metrics.recorder import ( - EventRecorder, - EventRecorderSingletonViolation, - EventRow, - sqlite3_cursor, -) - - -def test_event_row_to_table_query(): - """Test that EventRow.to_table_query() generates correct SQL CREATE TABLE statement.""" - query = EventRow.to_table_query() - - # Verify it's a CREATE TABLE statement - assert query.startswith("CREATE TABLE IF NOT EXISTS events") - - # Verify all expected fields are present with correct types - assert "sample_uuid TEXT" in query - assert "event_type TEXT" in query - assert "timestamp_ns INTEGER" in query - assert "data BLOB" in query - - # Verify the query is valid SQL by executing it - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - cursor.execute(query) - conn.commit() - - # Verify the table was created with correct schema - cursor.execute("PRAGMA table_info(events)") - columns = cursor.fetchall() - - # Extract column names and types - column_info = {col[1]: col[2] for col in columns} # col[1] is name, col[2] is type - - assert "sample_uuid" in column_info - assert "event_type" in column_info - assert "timestamp_ns" in column_info - assert "data" in column_info - assert column_info["timestamp_ns"] == "INTEGER" - assert column_info["data"] == "BLOB" - - cursor.close() - conn.close() - - -def test_event_row_insert_query(): - """Test that EventRow.insert_query() generates correct SQL INSERT statement.""" - query = EventRow.insert_query() - - # Verify it's an INSERT statement with correct structure - assert query.startswith("INSERT INTO events") - assert "VALUES" in query - - # Verify all fields are included - assert "sample_uuid" in query - assert "event_type" in query - assert "timestamp_ns" in query - assert "data" in query - - # Count placeholders - should be 4 (one for each field) - placeholder_count = query.count("?") - assert placeholder_count == 4 - - # Verify the query is valid SQL by creating a table and inserting - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - cursor.execute(EventRow.to_table_query()) - - # Try inserting a row using the generated query - test_data = ("test_uuid", "TEST_EVENT", 12345, b"test_data") - cursor.execute(query, test_data) - conn.commit() - - # Verify the data was inserted - cursor.execute("SELECT * FROM events") - rows = cursor.fetchall() - assert len(rows) == 1 - assert rows[0] == test_data - - cursor.close() - conn.close() - - -def test_event_row_to_insert_params(sample_uuids): - """Test that EventRow.to_insert_params() returns correct tuple for SQL insertion.""" - uuid1 = sample_uuids(1) - test_data = {"key": "value", "number": 42} - - event_row = EventRow( - sample_uuid=uuid1, - event_type=SampleEvent.FIRST_CHUNK, - timestamp_ns=10000, - data=msgspec.json.encode(test_data), - ) - - params = event_row.to_insert_params() - - # Verify the tuple has correct structure - assert isinstance(params, tuple) - assert len(params) == 4 - - # Verify each field - assert params[0] == uuid1 - assert params[1] == SampleEvent.FIRST_CHUNK.value - assert params[2] == 10000 - assert params[3] == msgspec.json.encode(test_data) - - -def test_event_row_to_insert_params_empty_data(sample_uuids): - """Test EventRow.to_insert_params() with empty data field.""" - uuid1 = sample_uuids(1) - - event_row = EventRow( - sample_uuid=uuid1, - event_type=SessionEvent.LOADGEN_ISSUE_CALLED, - timestamp_ns=5000, - data=b"", - ) - - params = event_row.to_insert_params() - - assert params[0] == uuid1 - assert params[1] == SessionEvent.LOADGEN_ISSUE_CALLED.value - assert params[2] == 5000 - assert params[3] == b"" - - -def test_event_row_integration_with_sqlite(sample_uuids): - """Integration test: Create table, insert EventRow, and verify data roundtrip.""" - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - - # Create table using EventRow.to_table_query() - cursor.execute(EventRow.to_table_query()) - conn.commit() - - # Create test events with various data types - events = [ - EventRow( - sample_uuid=uuid1, - event_type=SessionEvent.LOADGEN_ISSUE_CALLED, - timestamp_ns=10000, - data=b"", - ), - EventRow( - sample_uuid=uuid2, - event_type=SampleEvent.FIRST_CHUNK, - timestamp_ns=10100, - data=msgspec.json.encode({"chunk": "Hello"}), - ), - EventRow( - sample_uuid=uuid2, - event_type=SampleEvent.COMPLETE, - timestamp_ns=10200, - data=msgspec.json.encode({"output": ["Hello", " World"]}), - ), - ] - - # Insert using EventRow.insert_query() - insert_query = EventRow.insert_query() - for event in events: - cursor.execute(insert_query, event.to_insert_params()) - conn.commit() - - # Verify data was inserted correctly - cursor.execute("SELECT * FROM events") - rows = cursor.fetchall() - - assert len(rows) == 3 - - # Verify first row (empty data) - assert rows[0][0] == uuid1 - assert rows[0][1] == SessionEvent.LOADGEN_ISSUE_CALLED.value - assert rows[0][2] == 10000 - assert rows[0][3] == b"" - - # Verify second row (with JSON data) - assert rows[1][0] == uuid2 - assert rows[1][1] == SampleEvent.FIRST_CHUNK.value - assert rows[1][2] == 10100 - assert msgspec.json.decode(rows[1][3]) == {"chunk": "Hello"} - - # Verify third row (with complex JSON data) - assert rows[2][0] == uuid2 - assert rows[2][1] == SampleEvent.COMPLETE.value - assert rows[2][2] == 10200 - assert msgspec.json.decode(rows[2][3]) == {"output": ["Hello", " World"]} - - cursor.close() - conn.close() - - -def get_EventRecorder(*args, **kwargs): - # Set requirement to 128MB for testing - return EventRecorder(*args, min_memory_req_bytes=128 * 1024 * 1024, **kwargs) - - -def test_event_recorder_singleton_violation_create_multiple(): - with get_EventRecorder(): - with pytest.raises(EventRecorderSingletonViolation): - with get_EventRecorder(): - pass - - -def test_event_recorder_singleton_violation_close_non_active(): - with get_EventRecorder(): - other_rec = get_EventRecorder() - with pytest.raises(EventRecorderSingletonViolation): - other_rec.close() - - -def test_event_recorder_singleton_violation_record_event_non_active(sample_uuids): - assert ( - EventRecorder.LIVE is None - ), "Cannot run test - EventRecorder is active from previous test" - with pytest.raises(EventRecorderSingletonViolation): - EventRecorder.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, 10000, sample_uuid=sample_uuids(1) - ) - assert ( - EventRecorder.record_event( - SessionEvent.LOADGEN_ISSUE_CALLED, - 10000, - sample_uuid=sample_uuids(1), - assert_active=False, - ) - is False - ) - - -def test_record_event(sample_uuids): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with get_EventRecorder() as rec: - rec.record_event(SessionEvent.LOADGEN_ISSUE_CALLED, 10000, sample_uuid=uuid1) - rec.record_event(SessionEvent.LOADGEN_ISSUE_CALLED, 10003, sample_uuid=uuid2) - rec.record_event(SampleEvent.FIRST_CHUNK, 10010, sample_uuid=uuid1) - rec.record_event(SampleEvent.FIRST_CHUNK, 10190, sample_uuid=uuid2) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10201, sample_uuid=uuid1) - rec.record_event(SessionEvent.LOADGEN_ISSUE_CALLED, 10202, sample_uuid=uuid3) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10203, sample_uuid=uuid1) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10210, sample_uuid=uuid2) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10211, sample_uuid=uuid1) - rec.record_event(SampleEvent.COMPLETE, 10211, sample_uuid=uuid1) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10214, sample_uuid=uuid2) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10217, sample_uuid=uuid2) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10219, sample_uuid=uuid2) - rec.record_event(SampleEvent.COMPLETE, 10219, sample_uuid=uuid2) - - # Wait for writer thread to process all events - rec.wait_for_writes() - - # Read from the database directly - with sqlite3_cursor(rec.connection_name) as (cursor, _): - actual_rows = cursor.execute("SELECT * FROM events").fetchall() - - expected_rows = [ - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - (uuid2, SampleEvent.FIRST_CHUNK.value, 10190, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10201, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10202, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10203, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10210, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10211, b""), - (uuid1, SampleEvent.COMPLETE.value, 10211, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10214, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10217, b""), - (uuid2, SampleEvent.NON_FIRST_CHUNK.value, 10219, b""), - (uuid2, SampleEvent.COMPLETE.value, 10219, b""), - ] - - assert expected_rows == actual_rows - assert len(actual_rows) == 14 - - -def worker_proc_read_entries(sess_id, events_created_ev, uuid1, uuid2): - events_created_ev.wait() - expected_rows = [ - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - (uuid2, SampleEvent.FIRST_CHUNK.value, 10190, b""), - (uuid1, SampleEvent.NON_FIRST_CHUNK.value, 10201, b""), - ] - with sqlite3_cursor(EventRecorder.db_path(sess_id)) as (cursor, _): - actual_rows = cursor.execute("SELECT * FROM events").fetchall() - assert expected_rows == actual_rows - - -def test_shm_usage(sample_uuids): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - # Set mp start method - ctx = multiprocessing.get_context("spawn") - events_created_ev = ctx.Event() - sess_id = uuid.uuid4().hex - - worker_proc = ctx.Process( - target=worker_proc_read_entries, args=(sess_id, events_created_ev, uuid1, uuid2) - ) - worker_proc.start() - - with get_EventRecorder(session_id=sess_id) as rec: - rec.record_event(SessionEvent.LOADGEN_ISSUE_CALLED, 10000, sample_uuid=uuid1) - rec.record_event(SessionEvent.LOADGEN_ISSUE_CALLED, 10003, sample_uuid=uuid2) - rec.record_event(SampleEvent.FIRST_CHUNK, 10010, sample_uuid=uuid1) - rec.record_event(SampleEvent.FIRST_CHUNK, 10190, sample_uuid=uuid2) - rec.record_event(SampleEvent.NON_FIRST_CHUNK, 10201, sample_uuid=uuid1) - # Wait for writer thread to process all events - rec.wait_for_writes() - events_created_ev.set() - - worker_proc.join(timeout=10) - if worker_proc.is_alive(): - worker_proc.terminate() - worker_proc.join(timeout=1) - assert ( - not worker_proc.is_alive() - ), "Worker process could not be terminated after cleanup" - raise AssertionError("Worker process failed to complete in a reasonable time") - assert worker_proc.exitcode == 0 - - -MemStat = namedtuple("MemStat", ["total", "used", "free"]) - - -@patch("inference_endpoint.metrics.recorder.shutil.disk_usage") -def test_shm_too_small(mock_run): - mock_run.return_value = MemStat( - total=64 * 1024 * 1024, used=0, free=64 * 1024 * 1024 - ) - with pytest.raises(MemoryError) as err: - EventRecorder( - min_memory_req_bytes=512 * 1024 * 1024 - ) # Instantiate will not init the connection - assert "total space" in err.value.args[0] - - -@patch("inference_endpoint.metrics.recorder.shutil.disk_usage") -def test_shm_not_enough_space(mock_run): - mock_run.return_value = MemStat( - total=1024 * 1024 * 1024, used=64 * 1024 * 1024, free=960 * 1024 * 1024 - ) - with pytest.raises(MemoryError) as err: - EventRecorder( - min_memory_req_bytes=1024 * 1024 * 1024 - ) # Instantiate will not init the connection - assert "free space" in err.value.args[0] - assert "960MB" in err.value.args[0] diff --git a/tests/unit/metrics/test_report_builder.py b/tests/unit/metrics/test_report_builder.py new file mode 100644 index 00000000..d5e66526 --- /dev/null +++ b/tests/unit/metrics/test_report_builder.py @@ -0,0 +1,264 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for report.py and report_builder.py.""" + +import json +from pathlib import Path + +import pytest +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + BasicKVStore, + BasicKVStoreReader, + SeriesStats, +) +from inference_endpoint.metrics.report import Report, compute_summary + +# --------------------------------------------------------------------------- +# compute_summary +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestComputeSummary: + def test_empty(self): + s = compute_summary(SeriesStats()) + assert s["total"] == 0 + assert s["min"] == 0 + assert s["max"] == 0 + assert s["std_dev"] == 0 + assert s["histogram"]["buckets"] == [] + + def test_single_value(self): + s = compute_summary(SeriesStats([42.0], dtype=float)) + assert s["min"] == 42.0 + assert s["max"] == 42.0 + assert s["avg"] == 42.0 + assert s["std_dev"] == 0.0 + + def test_multiple_values(self): + s = compute_summary(SeriesStats([1.0, 2.0, 3.0, 4.0, 5.0], dtype=float)) + assert s["min"] == 1.0 + assert s["max"] == 5.0 + assert s["total"] == 15.0 + assert s["avg"] == 3.0 + assert s["median"] == 3.0 + assert len(s["histogram"]["buckets"]) > 0 + assert len(s["percentiles"]) > 0 + + def test_percentiles(self): + values = list(range(1, 101)) # 1..100 + s = compute_summary( + SeriesStats([float(v) for v in values], dtype=float), + percentiles=(50, 90, 99), + ) + assert s["percentiles"]["50"] == pytest.approx(50.5, abs=1) + assert s["percentiles"]["90"] == pytest.approx(90.1, abs=1) + assert s["percentiles"]["99"] == pytest.approx(99.01, abs=1) + + +# --------------------------------------------------------------------------- +# Helper: create a populated KVStore writer + reader +# --------------------------------------------------------------------------- + + +def _make_store(tmp_path: Path, n_samples: int = 50): + """Create a writer with typical benchmark data and return (writer, reader).""" + store_dir = tmp_path / "kv" + w = BasicKVStore(store_dir) + + # Counter keys matching MetricCounterKey enum + for key in [ + "total_samples_issued", + "total_samples_completed", + "total_samples_failed", + "tracked_samples_issued", + "tracked_samples_completed", + "tracked_duration_ns", + "total_duration_ns", + ]: + w.create_key(key, "counter") + for key in ["ttft_ns", "sample_latency_ns", "osl", "isl", "chunk_delta_ns"]: + w.create_key(key, "series") + w.create_key("tpot_ns", "series", dtype=float) + + w.update("tracked_samples_issued", n_samples) + w.update("tracked_samples_completed", n_samples) + w.update("total_samples_failed", 0) + if n_samples > 0: + w.update("tracked_duration_ns", 10_000_000_000) + + for i in range(n_samples): + w.update("ttft_ns", 1_000_000 + i * 10_000) + w.update("sample_latency_ns", 5_000_000 + i * 50_000) + w.update("osl", 100 + i) + + r = BasicKVStoreReader(store_dir) + for key in [ + "total_samples_issued", + "total_samples_completed", + "total_samples_failed", + "tracked_samples_issued", + "tracked_samples_completed", + "tracked_duration_ns", + "total_duration_ns", + ]: + r.register_key(key, "counter") + for key in ["ttft_ns", "sample_latency_ns", "osl", "isl", "chunk_delta_ns"]: + r.register_key(key, "series") + r.register_key("tpot_ns", "series", dtype=float) + + return w, r + + +# --------------------------------------------------------------------------- +# build_report +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBuildReport: + def test_empty_store(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=0) + report = Report.from_kv_reader(r) + + assert report.n_samples_issued == 0 + assert report.duration_ns is None + assert report.qps() is None + assert report.ttft == {} + assert report.latency == {} + + r.close() + w.close() + + def test_with_metrics(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=50) + report = Report.from_kv_reader(r) + + assert report.n_samples_issued == 50 + assert report.n_samples_completed == 50 + assert report.duration_ns == 10_000_000_000 + assert report.qps() == pytest.approx(5.0) + + assert "min" in report.ttft + assert "percentiles" in report.ttft + assert "histogram" in report.ttft + assert report.ttft["min"] > 0 + assert report.latency["min"] > 0 + assert report.tpot == {} # No TPOT values written + assert report.tps() is not None # OSL data present + + r.close() + w.close() + + +# --------------------------------------------------------------------------- +# Report display and serialization +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestReport: + def test_display_summary(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=10) + report = Report.from_kv_reader(r) + + lines: list[str] = [] + report.display(fn=lines.append, summary_only=True) + output = "\n".join(lines) + + assert "Summary" in output + assert "QPS:" in output + assert "End of Summary" in output + + r.close() + w.close() + + def test_display_full(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=10) + report = Report.from_kv_reader(r) + + lines: list[str] = [] + report.display(fn=lines.append, summary_only=False) + output = "\n".join(lines) + + assert "Latency Breakdowns" in output + assert "TTFT" in output + assert "Histogram" in output + assert "Percentiles" in output + + r.close() + w.close() + + def test_to_json(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=5) + report = Report.from_kv_reader(r) + + data = json.loads(report.to_json()) + assert data["n_samples_completed"] == 5 + assert "ttft" in data + + r.close() + w.close() + + def test_to_json_save(self, tmp_path: Path): + w, r = _make_store(tmp_path, n_samples=5) + report = Report.from_kv_reader(r) + + out_path = tmp_path / "report.json" + report.to_json(save_to=out_path) + assert out_path.exists() + data = json.loads(out_path.read_bytes()) + assert data["n_samples_completed"] == 5 + + r.close() + w.close() + + def test_qps_none_without_duration(self): + report = Report( + version="test", + git_sha=None, + test_started_at=0, + n_samples_issued=100, + n_samples_completed=100, + n_samples_failed=0, + duration_ns=None, + ttft={}, + tpot={}, + latency={}, + output_sequence_lengths={}, + ) + assert report.qps() is None + assert report.tps() is None + + def test_display_no_started_at(self): + """test_started_at=0 should not display a timestamp.""" + report = Report( + version="test", + git_sha=None, + test_started_at=0, + n_samples_issued=0, + n_samples_completed=0, + n_samples_failed=0, + duration_ns=None, + ttft={}, + tpot={}, + latency={}, + output_sequence_lengths={}, + ) + lines: list[str] = [] + report.display(fn=lines.append, summary_only=True) + output = "\n".join(lines) + assert "Test started at" not in output diff --git a/tests/unit/metrics/test_reporter.py b/tests/unit/metrics/test_reporter.py deleted file mode 100644 index 92416de8..00000000 --- a/tests/unit/metrics/test_reporter.py +++ /dev/null @@ -1,1180 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import math - -import msgspec.json -import pytest -from inference_endpoint.core.types import TextModelOutput -from inference_endpoint.load_generator.events import SampleEvent, SessionEvent -from inference_endpoint.metrics.recorder import sqlite3_cursor -from inference_endpoint.metrics.reporter import ( - MetricsReporter, - RollupQueryTable, - TPOTReportingMode, - output_sequence_from_data, -) - - -def test_sample_counting(events_db): - with MetricsReporter(events_db) as reporter: - stats = reporter.get_sample_statuses() - assert stats["completed"] == 2 - assert stats["in_flight"] == 1 - - -def test_error_counting(events_db): - """get_error_count returns distinct failed samples, not raw ERROR event count. - - The fixture has 3 ERROR events all belonging to uuid3, so the count should be 1. - """ - with MetricsReporter(events_db) as reporter: - assert reporter.get_error_count() == 1 - - -def test_derive_ttft(events_db, sample_uuids): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - with MetricsReporter(events_db) as reporter: - ttft_rows = reporter.derive_TTFT() - assert len(ttft_rows) == 2 - assert ttft_rows[0].metric_type == "ttft" - assert ttft_rows[1].metric_type == "ttft" - assert ttft_rows.filter_uuid(uuid1, only_first=True) == 10 - assert ttft_rows.filter_uuid(uuid2, only_first=True) == 187 - assert ttft_rows.filter_uuid("asdf", only_first=True) is None - assert ttft_rows.filter_uuid("asdf", only_first=False) == () - - -def test_derive_tpot(events_db, sample_uuids, fake_outputs, tokenizer): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - with MetricsReporter(events_db) as reporter: - tpot_rows = reporter.derive_TPOT( - tokenizer, reporting_mode=TPOTReportingMode.TOKEN_WEIGHTED - ) - - # From test_derive_sample_latency and ttft: - expected_tpot1 = (10211 - 10000 - 10) / len(fake_outputs[uuid1][1]) - expected_tpot2 = (10219 - 10003 - 187) / len(fake_outputs[uuid2][1]) - - tpot1 = tpot_rows.filter_uuid(uuid1, only_first=False) - tpot2 = tpot_rows.filter_uuid(uuid2, only_first=False) - assert len(tpot1) == len(fake_outputs[uuid1][1]) - assert len(tpot2) == len(fake_outputs[uuid2][1]) - assert all(tpot == expected_tpot1 for tpot in tpot1) - assert all(tpot == expected_tpot2 for tpot in tpot2) - - -def test_derive_tpot_with_string_output(tmp_path, sample_uuids, tokenizer): - """Test that derive_TPOT handles a plain string output gracefully. - - A single-string output has only one chunk, so TPOT cannot be computed. - The reporter should not raise an exception and should return None. - """ - test_db = str(tmp_path / "test_string_output.db") - uuid1 = sample_uuids(1) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - msgspec.json.encode(TextModelOutput(output="the final answer")), - ), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - tpot_rows = reporter.derive_TPOT(tokenizer) - - # A single-string output produces only 1 chunk — TPOT requires at least 2 - assert tpot_rows is None - - -def test_derive_tpot_string_output_with_list_reasoning( - tmp_path, sample_uuids, tokenizer -): - """Test that derive_TPOT computes TPOT when string output is paired with a list reasoning sequence. - - The fix wraps string outputs into a single-element list so they can be combined with - reasoning chunks. Without the fix, the string output causes the sample to be silently - skipped before reasoning is considered, so TPOT returns None even though there are - enough chunks (output + reasoning) to compute it. - """ - test_db = str(tmp_path / "test_string_output_with_reasoning.db") - uuid1 = sample_uuids(1) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - msgspec.json.encode( - TextModelOutput( - output="the answer", reasoning=("thought step",) - ) - ), - ), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - tpot_rows = reporter.derive_TPOT(tokenizer) - - # String output ("the answer") + list reasoning (["thought step"]) = 2 chunks total, - # which is enough for TPOT computation. - assert tpot_rows is not None - assert len(tpot_rows) == 1 - - -def test_derive_sample_latency(events_db, sample_uuids): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - with MetricsReporter(events_db) as reporter: - sample_latency_rows = reporter.derive_sample_latency() - - assert len(sample_latency_rows) == 2 - latency1, latency2 = tuple(sorted(sample_latency_rows, key=lambda x: x.sample_uuid)) - assert latency1.metric_type == "sample_latency" - assert latency1.sample_uuid == uuid1 - assert latency1.metric_value == 10211 - 10000 - - assert latency2.metric_type == "sample_latency" - assert latency2.sample_uuid == uuid2 - assert latency2.metric_value == 10219 - 10003 - - -def test_derive_duration(events_db): - with MetricsReporter(events_db) as reporter: - duration = reporter.derive_duration() - assert duration == (10300 - 5000) - - -def test_derive_duration_malformed(tmp_path): - test_db_path = str(tmp_path / "bad_events.db") - with sqlite3_cursor(test_db_path) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ("", SessionEvent.TEST_STARTED.value, 11000, b""), - ("", SessionEvent.TEST_ENDED.value, 12000, b""), - ], - ) - conn.commit() - - with pytest.raises( - RuntimeError, match=r"Multiple .*TEST_.* events found - 2 events" - ): - with MetricsReporter(test_db_path) as reporter: - reporter.derive_duration() - - -def test_derive_duration_multiple_starts_check_malformed_false(tmp_path): - """Test that derive_duration doesn't raise error for multiple TEST_STARTED when check_malformed=False.""" - test_db_path = str(tmp_path / "multiple_starts.db") - with sqlite3_cursor(test_db_path) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_STARTED.value, 6000, b""), # Duplicate start - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - # Should not raise when check_malformed=False - with MetricsReporter(test_db_path) as reporter: - duration = reporter.derive_duration(check_malformed=False) - - # Should use max(TEST_STARTED) which is 6000 - assert duration == 10300 - 6000 - - -def test_derive_duration_multiple_ends_check_malformed_false(tmp_path): - """Test that derive_duration doesn't raise error for multiple TEST_ENDED when check_malformed=False.""" - test_db_path = str(tmp_path / "multiple_ends.db") - with sqlite3_cursor(test_db_path) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ("", SessionEvent.TEST_ENDED.value, 12000, b""), # Duplicate end - ], - ) - conn.commit() - - # Should not raise when check_malformed=False - with MetricsReporter(test_db_path) as reporter: - duration = reporter.derive_duration(check_malformed=False) - - # Should use max(timestamp_ns) which is 12000 - assert duration == 12000 - 5000 - - -def test_derive_duration_test_ended_not_last_check_malformed_false(tmp_path): - """Test that derive_duration doesn't raise error when TEST_ENDED is not max timestamp and check_malformed=False.""" - test_db_path = str(tmp_path / "test_ended_not_last.db") - with sqlite3_cursor(test_db_path) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ( - "some_uuid", - SampleEvent.COMPLETE.value, - 15000, - b"", - ), # Event after TEST_ENDED - ], - ) - conn.commit() - - # Should raise when check_malformed=True (default) - with pytest.raises( - RuntimeError, - match=r"TEST_ENDED exists .* but is not the maximum timestamp in database", - ): - with MetricsReporter(test_db_path) as reporter: - reporter.derive_duration(check_malformed=True) - - # Should not raise when check_malformed=False - with MetricsReporter(test_db_path) as reporter: - duration = reporter.derive_duration(check_malformed=False) - - # Should use max(timestamp_ns) which is 15000 - assert duration == 15000 - 5000 - - -def test_tpot_to_histogram(events_db, fake_outputs, tokenizer, sample_uuids): - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - - expected = [ - { - "tpot": (10211 - 10000 - 10) / len(fake_outputs[uuid1][1]), - "count": len(fake_outputs[uuid1][1]), - }, - { - "tpot": (10219 - 10003 - 187) / len(fake_outputs[uuid2][1]), - "count": len(fake_outputs[uuid2][1]), - }, - ] - expected.sort(key=lambda x: x["tpot"]) - - bucket_boundaries = [ - expected[0]["tpot"] - 1, - (expected[0]["tpot"] + expected[1]["tpot"]) / 2, - expected[1]["tpot"] + 1, - ] - - with MetricsReporter(events_db) as reporter: - tpot_rows = reporter.derive_TPOT( - tokenizer, reporting_mode=TPOTReportingMode.TOKEN_WEIGHTED - ) - - # This isn't documented since it's an internal detail and should not be relied on, but `n_buckets` - # is passed directly to np.histogram, so we can specify exact buckets to use - buckets, counts = tpot_rows.to_histogram(n_buckets=bucket_boundaries) - assert len(buckets) == 2 - assert len(counts) == 2 - - assert buckets[0] == (bucket_boundaries[0], bucket_boundaries[1]) - assert buckets[1] == (bucket_boundaries[1], bucket_boundaries[2]) - assert counts[0] == expected[0]["count"] - assert counts[1] == expected[1]["count"] - - -def test_percentile(): - values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - table = RollupQueryTable( - metric_type="test", from_query="", rows=[(0, v) for v in values] - ) - assert table.percentile(50) == 5 - assert table.percentile([50, 75]) == {50: 5, 75: 7.5} - assert table.percentile(90) == 9 - with pytest.raises(TypeError): - table.percentile("10") - with pytest.raises(ValueError): - table.percentile(101) - with pytest.raises(ValueError): - table.percentile(-1) - - -def test_rollup_summarize(events_db): - with MetricsReporter(events_db) as reporter: - latencies = reporter.derive_sample_latency() - summary = latencies.summarize() - values = [10211 - 10000, 10219 - 10003] - assert summary["total"] == sum(values) - assert summary["min"] == min(values) - assert summary["max"] == max(values) - assert summary["median"] == (values[0] + values[1]) / 2 - assert summary["avg"] == (values[0] + values[1]) / 2 - - deviations_squared = [(value - summary["avg"]) ** 2 for value in values] - - assert math.isclose( - summary["std_dev"], - math.sqrt(sum(deviations_squared) / len(values)), - rel_tol=1e-3, - ) - - for percentile in [99.9, 99, 95, 90, 80, 75, 50, 25, 10, 5, 1]: - s = str(percentile) - assert s in summary["percentiles"] - assert summary["percentiles"][s] == latencies.percentile(percentile) - - -def test_reporter_create_report(events_db, fake_outputs, tokenizer): - with MetricsReporter(events_db) as reporter: - report = reporter.create_report(tokenizer) - - # Expected - ttft_rollup = reporter.derive_TTFT() - sample_latency_rollup = reporter.derive_sample_latency() - tpot_rollup = reporter.derive_TPOT( - tokenizer, - ttft_rollup=ttft_rollup, - sample_latency_rollup=sample_latency_rollup, - ) - - assert report.n_samples_issued == 3 - assert report.n_samples_completed == 2 - assert ( - report.n_samples_failed == 1 - ) # 1 distinct failed sample (uuid3), with 3 ERROR events in fixture - assert report.duration_ns == (10300 - 5000) - - for k, expected in ttft_rollup.summarize().items(): - assert k in report.ttft - assert report.ttft[k] == expected - for k, expected in tpot_rollup.summarize().items(): - assert k in report.tpot - assert report.tpot[k] == expected - for k, expected in sample_latency_rollup.summarize().items(): - assert k in report.latency - assert report.latency[k] == expected - for k, expected in tpot_rollup.summarize().items(): - assert k in report.tpot - assert report.tpot[k] == expected - - # QPS should be: completed_samples / (duration_ns / 1e9) - expected_qps = report.n_samples_completed / (report.duration_ns / 1e9) - assert report.qps == expected_qps - - expected_total_tokens = 0 - for output in fake_outputs.values(): - for chunk in output: - expected_total_tokens += len(tokenizer.tokenize(chunk)) - expected_tps = expected_total_tokens / ((10300 - 5000) / 1e9) - assert report.tps == expected_tps - - -def test_reporter_json(events_db): - with MetricsReporter(events_db) as reporter: - report = reporter.create_report() - - json_str = report.to_json() - - json_dict = json.loads(json_str) - - expected_keys = [ - "version", - "git_sha", - "n_samples_issued", - "n_samples_completed", - "n_samples_failed", - "duration_ns", - "ttft", - "tpot", - "latency", - "output_sequence_lengths", - "tpot_reporting_mode", - "qps", - "tps", - "test_started_at", - ] - assert set(json_dict.keys()) == set(expected_keys) - assert json_dict["n_samples_issued"] == report.n_samples_issued - assert json_dict["n_samples_completed"] == report.n_samples_completed - assert json_dict["n_samples_failed"] == report.n_samples_failed - assert json_dict["duration_ns"] == report.duration_ns - assert json_dict["qps"] == report.qps - assert json_dict["tps"] == report.tps - - # For ttft, tpot, and latency, JSON decode will only decode as lists, not tuples - # This only matters in the histogram - def _assert_rollup_summary_equal(json_dict, summary_dict): - if summary_dict is None: - assert json_dict is None - return - - for k in summary_dict.keys(): - if k == "histogram": - continue - assert json_dict[k] == summary_dict[k] - - assert json_dict["histogram"]["buckets"] == [ - list(bucket) for bucket in summary_dict["histogram"]["buckets"] - ] - assert json_dict["histogram"]["counts"] == summary_dict["histogram"]["counts"] - - _assert_rollup_summary_equal(json_dict["ttft"], report.ttft) - _assert_rollup_summary_equal(json_dict["tpot"], report.tpot) - _assert_rollup_summary_equal(json_dict["latency"], report.latency) - _assert_rollup_summary_equal( - json_dict["output_sequence_lengths"], report.output_sequence_lengths - ) - - -def test_display_report(events_db): - with MetricsReporter(events_db) as reporter: - report = reporter.create_report() - - import io - - buf = io.StringIO() - - def _write_with_newline(s): - buf.write(s + "\n") - - report.display(fn=_write_with_newline) - s = buf.getvalue() - lines = s.splitlines() - - assert "- Summary -" in lines[0] - assert lines[1].startswith("Version:") - # Git SHA may or may not be present, so Total samples issued can be on line 2 or 3 - assert any(line.startswith("Total samples issued:") for line in lines[2:4]) - - -def test_stop_performance_tracking_timestamp_property(tmp_path, sample_uuids): - """Test that stop_performance_tracking_timestamp_ns returns correct value when event exists.""" - test_db = str(tmp_path / "test_stop_perf_tracking.db") - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10100, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - assert reporter.stop_performance_tracking_timestamp_ns == 10100 - - -def test_stop_performance_tracking_timestamp_missing(tmp_path): - """Test that stop_performance_tracking_timestamp_ns returns infinity when event is missing.""" - test_db = str(tmp_path / "test_no_stop_perf_tracking.db") - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - assert reporter.stop_performance_tracking_timestamp_ns == float("inf") - assert reporter.derive_duration() == 10300 - 5000 - - -def test_derive_ttft_with_stop_performance_tracking(tmp_path, sample_uuids): - """Test that derive_TTFT excludes samples issued after STOP_PERFORMANCE_TRACKING.""" - test_db = str(tmp_path / "test_ttft_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - (uuid2, SampleEvent.FIRST_CHUNK.value, 10190, b""), - ( - "", - SessionEvent.STOP_PERFORMANCE_TRACKING.value, - 10150, - b"", - ), # Before uuid3 issued - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10200, b""), - (uuid3, SampleEvent.FIRST_CHUNK.value, 10220, b""), - (uuid1, SampleEvent.COMPLETE.value, 10211, b""), - (uuid2, SampleEvent.COMPLETE.value, 10219, b""), - (uuid3, SampleEvent.COMPLETE.value, 10250, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - ttft_rows = reporter.derive_TTFT() - - # Should only include uuid1 and uuid2, not uuid3 (issued after STOP_PERFORMANCE_TRACKING) - assert len(ttft_rows) == 2 - assert ttft_rows.filter_uuid(uuid1, only_first=True) == 10 - assert ttft_rows.filter_uuid(uuid2, only_first=True) == 187 - assert ttft_rows.filter_uuid(uuid3, only_first=True) is None - - -def test_derive_sample_latency_with_stop_performance_tracking(tmp_path, sample_uuids): - """Test that derive_sample_latency excludes samples issued after STOP_PERFORMANCE_TRACKING.""" - test_db = str(tmp_path / "test_latency_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.COMPLETE.value, 10211, b""), - (uuid2, SampleEvent.COMPLETE.value, 10219, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10150, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10200, b""), - (uuid3, SampleEvent.COMPLETE.value, 10250, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - latency_rows = reporter.derive_sample_latency() - - # Should only include uuid1 and uuid2 - assert len(latency_rows) == 2 - assert latency_rows.filter_uuid(uuid1, only_first=True) == 10211 - 10000 - assert latency_rows.filter_uuid(uuid2, only_first=True) == 10219 - 10003 - assert latency_rows.filter_uuid(uuid3, only_first=True) is None - - -def test_derive_duration_with_stop_performance_tracking_no_samples(tmp_path): - """Test that derive_duration uses STOP_PERFORMANCE_TRACKING timestamp when present.""" - test_db = str(tmp_path / "test_duration_stop_perf.db") - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10100, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - duration = reporter.derive_duration() - - # Should use STOP_PERFORMANCE_TRACKING - TEST_STARTED (not TEST_ENDED - TEST_STARTED) - assert duration is None # Default behavior - No perf test run. - - -def test_derive_duration_with_stop_performance_tracking(tmp_path, sample_uuids): - """Test that derive_duration uses STOP_PERFORMANCE_TRACKING timestamp when present.""" - test_db = str(tmp_path / "test_duration_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.COMPLETE.value, 10211, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10213, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10216, b""), - ( - uuid2, - SampleEvent.COMPLETE.value, - 10250, - b"", - ), # Intentionally out of order for test. - (uuid3, SampleEvent.COMPLETE.value, 10219, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - duration = reporter.derive_duration() - - # Should use timestamp of uuid2's COMPLETE event - timestamp of TEST_STARTED event - assert duration == 10250 - 5000 - - -def test_derive_duration_all_samples_complete_after_stop_performance_tracking( - tmp_path, sample_uuids -): - """Test derive_duration when samples are issued before but all complete after STOP_PERFORMANCE_TRACKING.""" - test_db = str(tmp_path / "test_duration_all_complete_after_stop.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ( - uuid1, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10000, - b"", - ), # Issued before stop - ( - uuid2, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10003, - b"", - ), # Issued before stop - ( - uuid3, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10010, - b"", - ), # Issued before stop - ( - "", - SessionEvent.STOP_PERFORMANCE_TRACKING.value, - 10100, - b"", - ), # STOP marker - # All completions happen AFTER stop_ts - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - b"", - ), # Complete after stop - ( - uuid2, - SampleEvent.COMPLETE.value, - 10252, - b"", - ), # Complete after stop - ( - uuid3, - SampleEvent.COMPLETE.value, - 10219, - b"", - ), # Complete after stop - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - duration = reporter.derive_duration() - - # Should use timestamp of the last COMPLETE event (uuid2 at 10250) - TEST_STARTED - # Since all samples were issued before STOP_PERFORMANCE_TRACKING, they all count, - # and duration is measured until the last one completes - assert duration == 10252 - 5000 - - -def test_derive_duration_without_stop_performance_tracking(tmp_path): - """Test that derive_duration uses TEST_ENDED when STOP_PERFORMANCE_TRACKING is absent.""" - test_db = str(tmp_path / "test_duration_no_stop_perf.db") - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - duration = reporter.derive_duration() - - # Should use TEST_ENDED - TEST_STARTED - assert duration == 10300 - 5000 - - -def test_get_sample_statuses_with_stop_performance_tracking(tmp_path, sample_uuids): - """Test that get_sample_statuses excludes samples issued after STOP_PERFORMANCE_TRACKING. - - This test verifies: - 1. Samples issued before stop_ts are counted in total_sent - 2. Samples issued after stop_ts are NOT counted in total_sent - 3. Completed samples are only counted if they were issued before stop_ts - 4. Samples issued before stop_ts but completing after stop_ts ARE still counted as completed - """ - test_db = str(tmp_path / "test_statuses_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ( - uuid1, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10000, - b"", - ), # Issued before stop_ts - ( - uuid2, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10003, - b"", - ), # Issued before stop_ts - ( - uuid1, - SampleEvent.COMPLETE.value, - 10100, - b"", - ), # Completed before stop_ts - ( - "", - SessionEvent.STOP_PERFORMANCE_TRACKING.value, - 10150, - b"", - ), # STOP marker - ( - uuid3, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10200, - b"", - ), # Issued AFTER stop_ts - ( - uuid2, - SampleEvent.COMPLETE.value, - 10219, - b"", - ), # Issued before but completed AFTER stop_ts - (uuid3, SampleEvent.COMPLETE.value, 10250, b""), # Issued after stop_ts - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - stats = reporter.get_sample_statuses() - assert reporter.stop_performance_tracking_timestamp_ns == 10150 - - # Should only count uuid1 and uuid2 as issued (uuid3 issued after cutoff) - # Both uuid1 and uuid2 should be counted as completed even though uuid2 completed after stop_ts - assert stats["total_sent"] == 2 - assert ( - stats["completed"] == 2 - ) # uuid1 and uuid2 (uuid3 not counted because it was issued after stop_ts) - assert stats["in_flight"] == 0 - - -def test_get_sample_statuses_excludes_late_issued_completions(tmp_path, sample_uuids): - """Test that completed samples issued after STOP_PERFORMANCE_TRACKING are not counted. - - This specifically tests the edge case where a sample is issued after stop_ts and completes, - ensuring it's not included in the completed count. - """ - test_db = str(tmp_path / "test_late_completion.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - uuid4 = sample_uuids(4) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.COMPLETE.value, 10100, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10150, b""), - # uuid2 still in flight when stop_ts happens - ( - uuid3, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10200, - b"", - ), # Issued after stop_ts - ( - uuid4, - SessionEvent.LOADGEN_ISSUE_CALLED.value, - 10205, - b"", - ), # Issued after stop_ts - ( - uuid3, - SampleEvent.COMPLETE.value, - 10250, - b"", - ), # Completes but was issued after stop_ts - ( - uuid4, - SampleEvent.COMPLETE.value, - 10260, - b"", - ), # Completes but was issued after stop_ts - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - stats = reporter.get_sample_statuses() - - # Only uuid1 and uuid2 should be counted (issued before stop_ts) - assert stats["total_sent"] == 2 - # Only uuid1 completed (uuid2 never completed, uuid3 and uuid4 don't count) - assert stats["completed"] == 1 - assert stats["in_flight"] == 1 # uuid2 is still in flight - - -def test_get_output_sequence_lengths_with_stop_performance_tracking( - tmp_path, sample_uuids, tokenizer -): - """Test that get_output_sequence_lengths excludes samples issued after STOP_PERFORMANCE_TRACKING.""" - test_db = str(tmp_path / "test_osl_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - msgspec.json.encode(TextModelOutput(output=("Hello, ", "world"))), - ), - ( - uuid2, - SampleEvent.COMPLETE.value, - 10219, - msgspec.json.encode(TextModelOutput(output=("And ", "goodbye."))), - ), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10150, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10200, b""), - ( - uuid3, - SampleEvent.COMPLETE.value, - 10250, - msgspec.json.encode(TextModelOutput(output=("Extra ", "sample"))), - ), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - osl_rollup = reporter.get_output_sequence_lengths(tokenizer) - - # Should only include uuid1 and uuid2 (uuid3 issued after STOP_PERFORMANCE_TRACKING) - assert len(osl_rollup) == 2 - assert uuid1 in osl_rollup - assert uuid2 in osl_rollup - assert uuid3 not in osl_rollup - - -def test_create_report_with_stop_performance_tracking( - tmp_path, sample_uuids, tokenizer -): - """Test that create_report respects STOP_PERFORMANCE_TRACKING for all metrics.""" - test_db = str(tmp_path / "test_report_stop_perf.db") - uuid1 = sample_uuids(1) - uuid2 = sample_uuids(2) - uuid3 = sample_uuids(3) - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - (uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""), - (uuid2, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10003, b""), - (uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""), - (uuid2, SampleEvent.FIRST_CHUNK.value, 10190, b""), - ( - uuid1, - SampleEvent.COMPLETE.value, - 10211, - msgspec.json.encode(TextModelOutput(output=("Hello, ", "world"))), - ), - ( - uuid2, - SampleEvent.COMPLETE.value, - 10219, - msgspec.json.encode(TextModelOutput(output=("And ", "goodbye."))), - ), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10150, b""), - (uuid3, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10200, b""), - (uuid3, SampleEvent.FIRST_CHUNK.value, 10220, b""), - ( - uuid3, - SampleEvent.COMPLETE.value, - 10250, - msgspec.json.encode(TextModelOutput(output=("Extra ", "sample"))), - ), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - report = reporter.create_report(tokenizer) - - # Verify that only uuid1 and uuid2 are counted - assert report.n_samples_issued == 2 # Only uuid1 and uuid2 - assert report.n_samples_completed == 2 - assert ( - report.duration_ns == 10219 - 5000 - ) # timestamp of uuid2's COMPLETE event - timestamp of TEST_STARTED event - - # Verify QPS is based on the truncated duration - expected_qps = 2 / ((10219 - 5000) / 1e9) - assert report.qps == expected_qps - - # Verify latency includes only uuid1 and uuid2 - assert report.latency["total"] == (10211 - 10000) + (10219 - 10003) - - -def test_create_report_with_zero_samples_before_stop_performance_tracking(tmp_path): - """Test that create_report shows 'Duration: N/A' when 0 samples issued before STOP_PERFORMANCE_TRACKING.""" - test_db = str(tmp_path / "test_zero_samples_stop_perf.db") - - with sqlite3_cursor(test_db) as (cursor, conn): - cursor.execute( - "CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)" - ) - cursor.executemany( - "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)", - [ - ("", SessionEvent.TEST_STARTED.value, 5000, b""), - ("", SessionEvent.STOP_PERFORMANCE_TRACKING.value, 10100, b""), - ("", SessionEvent.TEST_ENDED.value, 10300, b""), - ], - ) - conn.commit() - - with MetricsReporter(test_db) as reporter: - report = reporter.create_report() - - # Verify report values - assert report.n_samples_issued == 0 - assert report.n_samples_completed == 0 - assert report.duration_ns is None - - # Verify display shows 'Duration: N/A' - import io - - buf = io.StringIO() - - def _write_with_newline(s): - buf.write(s + "\n") - - report.display(fn=_write_with_newline) - display_output = buf.getvalue() - - assert "Duration: N/A" in display_output - assert "(no performance samples were issued)" in display_output - - -@pytest.mark.unit -class TestOutputSequenceFromData: - """Tests for output_sequence_from_data covering all supported data formats.""" - - def test_text_model_output_string(self): - """TextModelOutput with string output (array_like encoding).""" - data = msgspec.json.encode(TextModelOutput(output="hello world")) - output, reasoning = output_sequence_from_data(data) - assert output == "hello world" - assert reasoning is None - - def test_text_model_output_chunks(self): - """TextModelOutput with tuple output (streaming chunks).""" - data = msgspec.json.encode(TextModelOutput(output=("chunk1", "chunk2"))) - output, reasoning = output_sequence_from_data(data) - assert output == "chunk1chunk2" - assert reasoning is None - - def test_text_model_output_chunks_no_join(self): - """TextModelOutput with join_chunks=False returns raw list.""" - data = msgspec.json.encode(TextModelOutput(output=("chunk1", "chunk2"))) - output, reasoning = output_sequence_from_data(data, join_chunks=False) - assert output == ["chunk1", "chunk2"] - assert reasoning is None - - def test_text_model_output_with_reasoning(self): - """TextModelOutput with both output and reasoning.""" - data = msgspec.json.encode( - TextModelOutput(output=("out1", "out2"), reasoning=("r1", "r2")) - ) - output, reasoning = output_sequence_from_data(data) - assert output == "out1out2" - assert reasoning == "r1r2" - - def test_text_model_output_with_reasoning_no_join(self): - """TextModelOutput with reasoning and join_chunks=False.""" - data = msgspec.json.encode( - TextModelOutput(output=("out1", "out2"), reasoning=("r1", "r2")) - ) - output, reasoning = output_sequence_from_data(data, join_chunks=False) - assert output == ["out1", "out2"] - assert reasoning == ["r1", "r2"] - - def test_legacy_string_format(self): - """Legacy plain string format (backward compat).""" - data = msgspec.json.encode("just a string") - output, reasoning = output_sequence_from_data(data) - assert output == "just a string" - assert reasoning is None - - def test_legacy_dict_format(self): - """Legacy dict format with output key (backward compat).""" - data = msgspec.json.encode({"output": "from dict", "reasoning": "think"}) - output, reasoning = output_sequence_from_data(data) - assert output == "from dict" - assert reasoning == "think" - - def test_legacy_dict_chunked(self): - """Legacy dict format with list output (backward compat).""" - data = msgspec.json.encode({"output": ["c1", "c2"]}) - output, reasoning = output_sequence_from_data(data) - assert output == "c1c2" - assert reasoning is None - - def test_list_without_tag_returns_none(self): - """A list without 'TextModelOutput' tag returns (None, None).""" - data = msgspec.json.encode(["not-a-tag", "some", "data"]) - output, reasoning = output_sequence_from_data(data) - assert output is None - assert reasoning is None - - def test_short_list_returns_none(self): - """A single-element list returns (None, None).""" - data = msgspec.json.encode(["only-one"]) - output, reasoning = output_sequence_from_data(data) - assert output is None - assert reasoning is None - - def test_none_data(self): - """None data returns (None, None).""" - output, reasoning = output_sequence_from_data(None) - assert output is None - assert reasoning is None - - def test_empty_data(self): - """Empty bytes returns (None, None).""" - output, reasoning = output_sequence_from_data(b"") - assert output is None - assert reasoning is None diff --git a/tests/unit/test_core_types.py b/tests/unit/test_core_types.py index f1a574bf..99f1c169 100644 --- a/tests/unit/test_core_types.py +++ b/tests/unit/test_core_types.py @@ -97,11 +97,10 @@ class TestStreamChunk: def test_stream_chunk_creation(self) -> None: """Test creating a stream chunk.""" - chunk = StreamChunk(id="test-123", response_chunk="partial", is_complete=False) + chunk = StreamChunk(id="test-123", response_chunk="partial") assert chunk.id == "test-123" assert chunk.response_chunk == "partial" - assert chunk.is_complete is False assert chunk.metadata == {} diff --git a/tests/unit/transport/test_zmq_pool_transport.py b/tests/unit/transport/test_zmq_pool_transport.py new file mode 100644 index 00000000..69c90f10 --- /dev/null +++ b/tests/unit/transport/test_zmq_pool_transport.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ZmqWorkerPoolTransport and ReadyCheckReceiver. + +Includes regression test for the 'Socket operation on non-socket' bug where +ReadyCheckReceiver.wait() closed its socket on TimeoutError, breaking the +retry loop in WorkerManager._wait_for_workers_with_liveness_check(). +""" + +import asyncio +import uuid + +import pytest +import zmq +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordPublisher, +) +from inference_endpoint.async_utils.transport.zmq.ready_check import ( + ReadyCheckReceiver, +) +from inference_endpoint.async_utils.transport.zmq.transport import ( + ZMQTransportConfig, + ZmqWorkerPoolTransport, +) + + +@pytest.fixture(autouse=True) +def reset_zmq_singleton(): + """Ensure each test gets a fresh ManagedZMQContext singleton.""" + yield + instance = ManagedZMQContext._instance + if instance is not None and getattr(instance, "_initialized", False): + instance.cleanup() + ManagedZMQContext._instance = None + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestReadyCheckReceiverTimeout: + """Regression: ReadyCheckReceiver must survive timeout for retry.""" + + async def test_socket_survives_timeout(self): + """After wait() times out, the socket must still be usable for retry. + + This is the core regression test for the ENOTSOCK bug. The old code + had `except BaseException: self.close()` which closed the socket on + TimeoutError. The caller (_wait_for_workers_with_liveness_check) + catches TimeoutError and retries, hitting a dead socket. + """ + zmq_ctx = ManagedZMQContext(io_threads=1) + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy_pub") + + receiver = ReadyCheckReceiver("ready_test", zmq_ctx, count=1) + + # First wait should timeout (no signals sent) + with pytest.raises(TimeoutError): + await receiver.wait(timeout=0.05) + + # Socket must still be usable after timeout + assert not receiver._sock.closed, ( + "ReadyCheckReceiver closed its socket on TimeoutError — " + "this breaks the retry loop in _wait_for_workers_with_liveness_check" + ) + _ = receiver._sock.rcvtimeo # Would raise ENOTSOCK if socket is dead + + # Second wait should also timeout cleanly (not ENOTSOCK) + with pytest.raises(TimeoutError): + await receiver.wait(timeout=0.05) + + receiver.close() + dummy.close() + zmq_ctx.cleanup() + + async def test_socket_closed_on_cancellation(self): + """Socket SHOULD be closed on non-timeout exceptions (e.g. cancel).""" + zmq_ctx = ManagedZMQContext(io_threads=1) + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy_pub") + + receiver = ReadyCheckReceiver("ready_test", zmq_ctx, count=1) + + task = asyncio.create_task(receiver.wait(timeout=10.0)) + await asyncio.sleep(0.05) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert receiver._sock.closed + + dummy.close() + zmq_ctx.cleanup() + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestZmqPoolTransport: + """Pool transport creation with and without a publisher on the same context.""" + + @pytest.mark.parametrize("num_workers", [2, 3, 4, 8]) + @pytest.mark.parametrize("create_publisher", [True, False]) + async def test_pool(self, num_workers: int, create_publisher: bool): + loop = asyncio.get_running_loop() + zmq_ctx = ManagedZMQContext(io_threads=2) + + publisher = None + dummy = None + if create_publisher: + sid = uuid.uuid4().hex[:8] + publisher = ZmqEventRecordPublisher(f"ev_pub_{sid}", zmq_ctx, loop=loop) + else: + # Baseline: bind an unrelated PUB socket so the context is non-empty. + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy") + + pool = ZmqWorkerPoolTransport.create( + loop, num_workers, config=ZMQTransportConfig() + ) + + rc = pool._ready_check + assert not rc._sock.closed + _ = rc._sock.rcvtimeo + + with pytest.raises(TimeoutError): + await pool.wait_for_workers_ready(timeout=0.1) + + pool.cleanup() + if publisher is not None: + publisher.close() + if dummy is not None: + dummy.close() + zmq_ctx.cleanup()