Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/inference_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def create_parser() -> argparse.ArgumentParser:
required=True,
help="Template type",
)
Comment thread
anandhu-eng marked this conversation as resolved.
init_parser.add_argument("--output", "-o", type=str, help="Output filename")

return parser

Expand Down
84 changes: 45 additions & 39 deletions src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import shutil
import signal
import tempfile
Comment thread
anandhu-eng marked this conversation as resolved.
import time
import uuid
Comment thread
arekay-nv marked this conversation as resolved.
from pathlib import Path
from urllib.parse import urljoin
Expand Down Expand Up @@ -275,7 +274,6 @@ def _build_config_from_cli(
)
timeout = getattr(args, "timeout", None)
verbose = getattr(args, "verbose", False)
output = getattr(args, "output", None)
# Build BenchmarkConfig from CLI params
return BenchmarkConfig(
name=f"cli_{benchmark_mode}",
Expand Down Expand Up @@ -327,7 +325,6 @@ def _build_config_from_cli(
metrics=Metrics(),
baseline=None, # CLI mode doesn't use baseline
report_dir=report_dir,
output=output,
timeout=timeout,
verbose=verbose,
)
Expand Down Expand Up @@ -614,7 +611,6 @@ def _run_benchmark(

# Run benchmark
logger.info("Running...")
start_time = time.time()

sess = None
try:
Expand Down Expand Up @@ -644,15 +640,26 @@ def signal_handler(signum, frame):
# Always restore original handler
signal.signal(signal.SIGINT, old_handler)

elapsed_time = time.time() - start_time
success_count = response_collector.count - len(response_collector.errors)
estimated_qps = success_count / elapsed_time if elapsed_time > 0 else 0
# Prefer authoritative metrics from the session report
report = getattr(sess, "report", None)
if report is None:
logger.error(
"Session report missing — benchmark reporter failed to produce results"
)
raise ExecutionError(
"Session report missing — cannot produce benchmark results"
)
Comment thread
arekay-nv marked this conversation as resolved.

elapsed_time = report.duration_ns / 1e9
total = report.n_samples_issued
success_count = report.n_samples_completed

# qps will be None if duration was 0, so fall back to 0.0
estimated_qps = report.qps or 0.0

# Report results
logger.info(f"Completed in {elapsed_time:.1f}s")
logger.info(
f"Results: {success_count}/{scheduler.total_samples_to_issue} successful"
)
logger.info(f"Results: {success_count}/{total} successful")
logger.info(f"Estimated QPS: {estimated_qps:.1f}")

if response_collector.errors:
Expand All @@ -663,36 +670,35 @@ def signal_handler(signum, frame):
if len(response_collector.errors) > 3:
logger.warning(f" ... +{len(response_collector.errors) - 3} more")

# Save results if requested
if config.output:
try:
results = {
"config": {
"endpoint": endpoint,
"mode": test_mode,
"target_qps": target_qps,
},
"results": {
"total": scheduler.total_samples_to_issue,
"successful": success_count,
"failed": len(response_collector.errors),
"elapsed_time": elapsed_time,
"qps": estimated_qps,
},
}

if collect_responses:
results["responses"] = response_collector.responses

# Always save all errors (useful for debugging)
if response_collector.errors:
results["errors"] = response_collector.errors

with open(config.output, "w") as f:
try:
results = {
"config": {
"endpoint": endpoint,
"mode": test_mode,
"target_qps": target_qps,
},
"results": {
"total": total,
"successful": success_count,
"failed": len(response_collector.errors),
Comment thread
anandhu-eng marked this conversation as resolved.
Outdated
"elapsed_time": elapsed_time,
"qps": estimated_qps,
},
}
if collect_responses:
results["responses"] = response_collector.responses
# Always save all errors (useful for debugging)
if response_collector.errors:
results["errors"] = response_collector.errors
if config.report_dir is not None:
results_path = report_dir / "results.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Saved: {config.output}")
except Exception as e:
logger.error(f"Save failed: {e}")
logger.info(f"Saved: {results_path}")
else:
logger.error("No report-dir specified; results.json not saved, but summary may be available in the default report directory.")
except Exception as e:
logger.error(f"Save failed: {e}")

except KeyboardInterrupt:
logger.warning("Benchmark interrupted by user")
Expand Down
1 change: 0 additions & 1 deletion src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ class BenchmarkConfig(BaseModel):
settings: Settings = Field(default_factory=Settings)
metrics: Metrics = Field(default_factory=Metrics)
endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig)
output: Path | None = None
report_dir: Path | None = None
timeout: int | None = None
verbose: bool = False
Expand Down
5 changes: 4 additions & 1 deletion src/inference_endpoint/load_generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(
self.event_recorder = EventRecorder(
session_id=self.session_id, notify_idle=self.end_event
)
# Will be populated after the test finishes by _run_test
self.report = None

@property
def is_running(self):
Expand Down Expand Up @@ -123,7 +125,8 @@ def _run_test(
)
tokenizer = None
report = reporter.create_report(tokenizer)

# Store report on session so external callers can use it
self.report = report
# Save to report directory if provided
if report_dir:
Path(report_dir).mkdir(parents=True, exist_ok=True)
Expand Down
16 changes: 11 additions & 5 deletions tests/integration/commands/test_benchmark_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ async def test_benchmark_with_output_file(
self, mock_http_echo_server, ds_pickle_dataset_path, tmp_path
):
"""Test benchmark saves results to JSON file."""
output_file = tmp_path / "benchmark_results.json"
# The benchmark command writes results to `results.json` inside the
# configured `report_dir`. Pass `report_dir=tmp_path` so the command
# will write output into this temporary directory and we can assert on
# the produced file.
report_dir = tmp_path

args = argparse.Namespace(
benchmark_mode="offline",
Expand All @@ -133,17 +137,19 @@ async def test_benchmark_with_output_file(
min_output_tokens=None,
max_output_tokens=None,
mode=None,
output=output_file,
report_dir=report_dir,
verbose=0,
model="echo-server",
timeout=None,
)

await run_benchmark_command(args)

# Verify file was created
assert output_file.exists()
# Verify file was created at <report_dir>/results.json
results_path = report_dir / "results.json"
assert results_path.exists()

with open(output_file) as f:
with open(results_path) as f:
results = json.load(f)

assert "config" in results
Expand Down
Loading