Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ docs/superpowers/

# User-specific local dev configs; do not commit
CLAUDE.local.md

# Generated dataset cache (created by Dataset.get_dataloader())
dataset_cache/
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ async def main() -> None:

# Using ternary operator causes errors in MyPy object type coalescing
# (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]')
pool_cm: AbstractContextManager[TokenizePool | None]
if args.tokenizer:
pool_cm: AbstractContextManager[TokenizePool | None] = TokenizePool(
args.tokenizer, n_workers=args.tokenizer_workers
)
pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers)
else:
pool_cm = nullcontext()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def __init__(self, tokenizer_name: str, n_workers: int) -> None:
futures = [
self._executor.submit(self._get_thread_tokenizer) for _ in range(n_workers)
]
for f in futures:
f.result()
try:
for f in futures:
f.result()
except Exception:
self._executor.shutdown(wait=False)
self._executor = None
raise

def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase:
"""Return the tokenizer for the current thread, loading it if needed."""
Expand Down
60 changes: 59 additions & 1 deletion src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import json
import logging
import platform
import random
import shutil
import signal
import tempfile
import uuid
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from datetime import datetime
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -347,6 +349,35 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]:
"""Build the phase list from BenchmarkContext."""
phases: list[PhaseConfig] = []

# Warmup phase (optional, before performance)
warmup_cfg = ctx.config.settings.warmup
if warmup_cfg.enabled:
warmup_dataset: Dataset = (
ctx.dataloader.with_salt(random.Random(warmup_cfg.warmup_random_seed + 2))
if warmup_cfg.salt
else ctx.dataloader
)
warmup_rt = dataclass_replace(
ctx.rt_settings,
min_duration_ms=0,
Comment thread
arekay-nv marked this conversation as resolved.
max_duration_ms=None,
n_samples_from_dataset=ctx.dataloader.num_samples(),
n_samples_to_issue=warmup_cfg.n_requests,
min_sample_count=1,
rng_sched=random.Random(warmup_cfg.warmup_random_seed),
rng_sample_index=random.Random(warmup_cfg.warmup_random_seed + 1),
load_pattern=ctx.rt_settings.load_pattern,
)
phases.append(
PhaseConfig(
"warmup",
warmup_rt,
warmup_dataset,
PhaseType.WARMUP,
drain_after=warmup_cfg.drain,
)
)

# Performance phase
phases.append(
PhaseConfig(
Expand Down Expand Up @@ -525,12 +556,39 @@ async def _run_benchmark_async(
phases = _build_phases(ctx)
report: Report | None = None

# Timer starts when the performance phase begins (after warmup drains),
# so max_duration_ms applies only to the perf phase, not warmup.
global_timeout_handle = None
_timeout_done = False
max_duration_ms = ctx.rt_settings.max_duration_ms

def _on_global_timeout() -> None:
if not _timeout_done:
logger.warning(
"Global experiment timeout reached (%d ms); stopping session.",
max_duration_ms,
)
session.stop()
Comment thread
arekay-nv marked this conversation as resolved.

def _on_phase_start(phase: PhaseConfig) -> None:
nonlocal global_timeout_handle
if (
phase.phase_type == PhaseType.PERFORMANCE
and max_duration_ms is not None
):
global_timeout_handle = loop.call_later(
max_duration_ms / 1000.0, _on_global_timeout
)

loop.add_signal_handler(signal.SIGINT, session.stop)
try:
result = await session.run(phases)
result = await session.run(phases, on_phase_start=_on_phase_start)
except Exception as e:
raise ExecutionError(f"Benchmark execution failed: {e}") from e
finally:
_timeout_done = True
if global_timeout_handle is not None:
global_timeout_handle.cancel()
loop.remove_signal_handler(signal.SIGINT)
logger.info("Cleaning up...")
try:
Expand Down
25 changes: 25 additions & 0 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,30 @@ def _validate_completeness(self) -> Self:
return self


class WarmupConfig(BaseModel):
"""Warmup phase configuration. Runs before the performance phase; results are not recorded."""

model_config = ConfigDict(extra="forbid", frozen=True)

enabled: bool = Field(
False, description="Enable warmup phase before performance run"
)
n_requests: int | None = Field(
None, gt=0, description="Warmup request count (None = full dataset once)"
)
salt: bool = Field(
False, description="Prepend a unique random hex salt to each warmup prompt"
)
drain: bool = Field(
False,
description="Drain in-flight warmup requests before starting the performance phase",
)
warmup_random_seed: int = Field(
42,
description="RNG seed for warmup scheduling and sample ordering",
)


@cyclopts.Parameter(name="*")
class Settings(BaseModel):
"""Test settings."""
Expand All @@ -401,6 +425,7 @@ class Settings(BaseModel):
runtime: RuntimeConfig = Field(default_factory=RuntimeConfig)
load_pattern: LoadPattern = Field(default_factory=LoadPattern)
client: HTTPClientConfig = Field(default_factory=HTTPClientConfig)
warmup: WarmupConfig = Field(default_factory=WarmupConfig)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Review Council — Claude] low · api-contract

Settings.warmup: WarmupConfig has no cyclopts.Parameter annotation, so cyclopts auto-generates only dotted CLI flags like --settings.warmup.enabled=true, --settings.warmup.n-requests=.... Compare with runtime and load_pattern which expose ergonomic shorthands via field-level cyclopts.Parameter(alias=...). As-is, enabling warmup from the CLI is unwieldy. Add cyclopts.Parameter(alias="--warmup-...") aliases on the WarmupConfig fields, or document that warmup is YAML-only.



class OfflineSettings(Settings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ settings:
max_idle_time: 4.0 # Discard connections idle longer than this (seconds)
min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled)
worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system
warmup:
enabled: false # Enable warmup phase before performance run
n_requests: null # Warmup request count (None = full dataset once)
salt: false # Prepend a unique random hex salt to each warmup prompt
drain: false # Drain in-flight warmup requests before starting the performance phase
warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering
endpoint_config:
endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'.
- http://localhost:8000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ settings:
max_idle_time: 4.0 # Discard connections idle longer than this (seconds)
min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled)
worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system
warmup:
enabled: false # Enable warmup phase before performance run
n_requests: null # Warmup request count (None = full dataset once)
salt: false # Prepend a unique random hex salt to each warmup prompt
drain: false # Drain in-flight warmup requests before starting the performance phase
warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering
endpoint_config:
endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'.
- http://localhost:8000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ settings:
max_idle_time: 4.0 # Discard connections idle longer than this (seconds)
min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled)
worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system
warmup:
enabled: false # Enable warmup phase before performance run
n_requests: null # Warmup request count (None = full dataset once)
salt: false # Prepend a unique random hex salt to each warmup prompt
drain: false # Drain in-flight warmup requests before starting the performance phase
warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering
endpoint_config:
endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'.
- http://localhost:8000
Expand Down
63 changes: 60 additions & 3 deletions src/inference_endpoint/dataset_manager/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import os
import random
from abc import ABC
from enum import Enum
from logging import getLogger
Expand Down Expand Up @@ -276,11 +278,12 @@ class Dataset:
def __init_subclass__(
cls,
dataset_id: str | None = None,
register: bool = True,
**kwargs,
):
super().__init_subclass__(**kwargs)

if not inspect.isabstract(cls):
if register and not inspect.isabstract(cls):
if dataset_id is None:
dataset_id = cls.__name__
cls.DATASET_ID = dataset_id
Expand Down Expand Up @@ -309,6 +312,7 @@ def __init__(
self.transforms = transforms
self.repeats = repeats
self.data: list[dict[str, Any]] | None = None
self._salt_rng: random.Random | None = None

@classmethod
def load_from_file(
Expand Down Expand Up @@ -402,7 +406,60 @@ def load_sample(self, index: int) -> Any:
IOError: If data cannot be loaded from disk.
"""
assert self.data is not None, "Dataset not loaded. Call load() first."
return self.data[index]
data = self.data[index]
if self._salt_rng is not None:
data = self._apply_salt(data)
return data

def with_salt(self, rng: random.Random) -> "Dataset":
"""Return a shallow copy of this dataset that salts each load_sample() call.

The returned dataset shares the same loaded data — no re-loading needed.
Each load_sample() call on the returned dataset prepends a unique hex salt
derived from rng to the prompt field, preventing KV-cache reuse.
"""
clone = copy.copy(self)
clone._salt_rng = rng
return clone

def _apply_salt(self, data: Any) -> Any:
"""Prepend a unique salt to the prompt field of a sample dict."""
assert self._salt_rng is not None
if not isinstance(data, dict):
return data
if "input_tokens" in data and "prompt" not in data:
self.logger.warning(
"salt=True: sample has 'input_tokens' but no 'prompt' — "
"salt cannot be applied to pre-tokenized input; KV-cache reuse may not be prevented"
)
return data
if "input_tokens" in data and "prompt" in data:
self.logger.warning(
"salt=True: sample has both 'input_tokens' and 'prompt' — "
"salt applied to 'prompt' only; adapters that use 'input_tokens' "
"directly will still reuse the KV cache"
)
if "prompt" not in data:
return data
prompt = data["prompt"]
salt = self._salt_rng.randbytes(8).hex()
if isinstance(prompt, str):
return {**data, "prompt": f"[{salt}] {prompt}"}
if isinstance(prompt, list) and prompt:
# Find the first text part at any index (image-first prompts place text at index 1+)
for i, part in enumerate(prompt):
if isinstance(part, dict) and part.get("type") == "text":
salted_parts = [
*prompt[:i],
{**part, "text": f"[{salt}] {part['text']}"},
*prompt[i + 1 :],
]
return {**data, "prompt": salted_parts}
self.logger.warning(
"salt=True: multimodal prompt has no text part — "
"salt cannot be applied; KV-cache reuse may not be prevented"
)
return data # unsupported prompt type — skip salting

def num_samples(self) -> int:
assert self.data is not None, "Dataset not loaded. Call load() first."
Expand All @@ -411,7 +468,7 @@ def num_samples(self) -> int:
@classmethod
def get_dataloader(
cls,
datasets_dir: Path = Path("datasets"),
datasets_dir: Path = Path("dataset_cache"),
num_repeats: int = 1,
transforms: list[Transform] | None = None,
force_regenerate: bool = False,
Expand Down
Loading
Loading