Skip to content

Commit 4b7e8f6

Browse files
authored
[fix] Pre-parse data + decouple args (#31)
* Pre-parse data + decouple args + remove concurrency flag from worker --------- Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
1 parent b9fb4de commit 4b7e8f6

17 files changed

Lines changed: 88 additions & 87 deletions

File tree

examples/02_ServerBenchmarking/offline_llama3_8b_cnn.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ settings:
2929

3030
client:
3131
workers: 4
32-
max_concurrency: -1 # -1 = unlimited
3332

3433
metrics:
3534
collect:

examples/02_ServerBenchmarking/online_llama2_70b_cnn.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ settings:
3030

3131
client:
3232
workers: 4
33-
max_concurrency: -1 # -1 = unlimited
3433

3534
metrics:
3635
collect:

src/inference_endpoint/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _add_shared_benchmark_args(parser):
213213
parser.add_argument("--min-output-tokens", type=int, help="Min output tokens")
214214
parser.add_argument("--max-output-tokens", type=int, help="Max output tokens")
215215
parser.add_argument(
216-
"--report-path", type=Path, help="Path to save detailed benchmark report"
216+
"--report-dir", type=Path, help="Path to save detailed benchmark report"
217217
)
218218

219219

src/inference_endpoint/commands/benchmark.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ async def run_benchmark_command(args: argparse.Namespace) -> None:
235235
collect_responses = test_mode in [TestMode.ACC, TestMode.BOTH]
236236

237237
# Run benchmark
238-
_run_benchmark(args, effective_config, collect_responses, test_mode, benchmark_mode)
238+
_run_benchmark(effective_config, collect_responses, test_mode, benchmark_mode)
239239

240240

241241
def _build_config_from_cli(
@@ -264,7 +264,10 @@ def _build_config_from_cli(
264264
load_pattern_type = LoadPatternType.CONCURRENCY
265265
case "online":
266266
load_pattern_type = LoadPatternType.POISSON
267-
267+
report_dir = getattr(args, "report_dir", None)
268+
timeout = getattr(args, "timeout", None)
269+
verbose = getattr(args, "verbose", False)
270+
output = getattr(args, "output", None)
268271
# Build BenchmarkConfig from CLI params
269272
return BenchmarkConfig(
270273
name=f"cli_{benchmark_mode}",
@@ -315,6 +318,10 @@ def _build_config_from_cli(
315318
endpoint_config=EndpointConfig(endpoint=args.endpoint, api_key=args.api_key),
316319
metrics=Metrics(),
317320
baseline=None, # CLI mode doesn't use baseline
321+
report_dir=report_dir,
322+
output=output,
323+
timeout=timeout,
324+
verbose=verbose,
318325
)
319326

320327

@@ -391,7 +398,6 @@ def _get_dataset_format(config: BenchmarkConfig, dataset_path: Path) -> str:
391398

392399

393400
def _run_benchmark(
394-
args: argparse.Namespace,
395401
config: BenchmarkConfig,
396402
collect_responses: bool,
397403
test_mode: TestMode,
@@ -440,12 +446,19 @@ def _run_benchmark(
440446
# Load tokenizer if model name is provided
441447
# Priority: CLI args (offline/online modes) > config submission_ref (from-config mode)
442448
tokenizer = None
443-
model_name = getattr(args, "model", None)
449+
model_name = config.model_params.name
444450
if not model_name and config.submission_ref:
445451
model_name = config.submission_ref.model
446452
if not model_name and config.model_params.name:
447453
model_name = config.model_params.name
448454

455+
if config.report_dir:
456+
report_dir = Path(config.report_dir)
457+
report_dir.mkdir(parents=True, exist_ok=True)
458+
config.to_yaml_file(report_dir / "config.yaml")
459+
460+
max_tokens = config.model_params.max_new_tokens
461+
449462
if model_name:
450463
try:
451464
logger.info(f"Loading tokenizer for model: {model_name}")
@@ -460,18 +473,14 @@ def _run_benchmark(
460473
# Throw exception if no model name is provided
461474
raise InputValidationError("No model name provided")
462475

463-
# Get report path if specified
464-
report_path = getattr(args, "report_path", None)
465-
if report_path:
466-
logger.info(f"Report will be saved to: {report_path}")
467-
468476
# Get dataset - from CLI or from config
469477
# TODO: Dataset Logic is not yet fully implemented
470-
dataset_path = _get_dataset_path(args, config)
478+
# dataset_path = _get_dataset_path(args, config)
479+
dataset_path = config.datasets[0].path
471480

472481
# Load dataset using factory
473482
dataset_format = _get_dataset_format(config, dataset_path)
474-
logger.info(f"Loading: {dataset_path.name} (format: {dataset_format})")
483+
logger.info(f"Loading: {dataset_path} (format: {dataset_format})")
475484

476485
# Determine if streaming should be enabled based on config
477486
streaming_mode = config.model_params.streaming
@@ -500,10 +509,17 @@ def _run_benchmark(
500509
dataset_path,
501510
format=dataset_format,
502511
key_maps=key_maps,
503-
metadata={"model": model_name, "stream": enable_streaming},
512+
metadata={
513+
"model": model_name,
514+
"stream": enable_streaming,
515+
"max_completion_tokens": max_tokens,
516+
},
504517
)
505518
dataloader.load()
506519
logger.info(f"Loaded {dataloader.num_samples()} samples")
520+
except FileNotFoundError as e:
521+
logger.error(f"Dataset file not found: {dataset_path}")
522+
raise InputValidationError(f"Dataset file not found: {dataset_path}") from e
507523
except NotImplementedError as e:
508524
logger.error(f"Dataset format not supported: {dataset_format}")
509525
raise SetupError(str(e)) from e
@@ -550,20 +566,17 @@ def _run_benchmark(
550566
# Create endpoint client
551567
endpoint = config.endpoint_config.endpoint
552568
num_workers = config.settings.client.workers
553-
max_concurrency = config.settings.client.max_concurrency
554569

555570
logger.info(f"Connecting: {endpoint}")
556-
logger.info(
557-
f"Client config: workers={num_workers}, max_concurrency={max_concurrency if max_concurrency > 0 else 'unlimited'}"
558-
)
571+
logger.info(f"Client config: workers={num_workers}")
559572

560573
tmp_dir = tempfile.mkdtemp(prefix="inference_endpoint_")
561574

562575
try:
563576
http_config = HTTPClientConfig(
564577
endpoint_url=urljoin(endpoint, "/v1/chat/completions"),
565578
num_workers=num_workers,
566-
max_concurrency=max_concurrency,
579+
max_concurrency=-1, # unlimited
567580
)
568581
aiohttp_config = AioHttpConfig()
569582
zmq_config = ZMQConfig(
@@ -595,9 +608,9 @@ def _run_benchmark(
595608
scheduler,
596609
name="cli_benchmark",
597610
stop_sample_issuer_on_test_end=False,
598-
report_path=report_path,
611+
report_dir=config.report_dir,
599612
tokenizer_override=tokenizer,
600-
max_shutdown_timeout_s=args.timeout if args.timeout else None,
613+
max_shutdown_timeout_s=config.timeout if config.timeout else None,
601614
)
602615

603616
# Wait for test end with ability to interrupt
@@ -629,14 +642,14 @@ def signal_handler(signum, frame):
629642

630643
if response_collector.errors:
631644
logger.warning(f"Errors: {len(response_collector.errors)}")
632-
if args.verbose:
645+
if config.verbose:
633646
for error in response_collector.errors[:3]:
634647
logger.warning(f" {error}")
635648
if len(response_collector.errors) > 3:
636649
logger.warning(f" ... +{len(response_collector.errors) - 3} more")
637650

638651
# Save results if requested
639-
if hasattr(args, "output") and args.output:
652+
if config.output:
640653
try:
641654
results = {
642655
"config": {
@@ -660,9 +673,9 @@ def signal_handler(signum, frame):
660673
if response_collector.errors:
661674
results["errors"] = response_collector.errors
662675

663-
with open(args.output, "w") as f:
676+
with open(config.output, "w") as f:
664677
json.dump(results, f, indent=2)
665-
logger.info(f"Saved: {args.output}")
678+
logger.info(f"Saved: {config.output}")
666679
except Exception as e:
667680
logger.error(f"Save failed: {e}")
668681

@@ -685,5 +698,5 @@ def signal_handler(signum, frame):
685698
http_client.shutdown()
686699
shutil.rmtree(tmp_dir, ignore_errors=True)
687700
except Exception as e:
688-
if args.verbose:
701+
if config.verbose:
689702
logger.warning(f"Cleanup error: {e}")

src/inference_endpoint/config/schema.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,12 @@ class LoadPattern(BaseModel):
230230
class ClientSettings(BaseModel):
231231
"""HTTP client configuration.
232232
233-
Only workers and max_concurrency are required to configure the client.
233+
Only workers are required to configure the client.
234234
Timeout is handled by the HTTP client internally.
235235
236-
Note: max_concurrency = -1 means unlimited (no semaphore limit).
237236
"""
238237

239238
workers: int = 4
240-
max_concurrency: int = -1 # -1 = unlimited (default for CLI and YAML)
241239

242240

243241
class Settings(BaseModel):
@@ -321,6 +319,10 @@ class BenchmarkConfig(BaseModel):
321319
settings: Settings = Field(default_factory=Settings)
322320
metrics: Metrics = Field(default_factory=Metrics)
323321
endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig)
322+
output: Path | None = None
323+
report_dir: Path | None = None
324+
timeout: int | None = None
325+
verbose: bool = False
324326

325327
@classmethod
326328
def from_yaml_file(cls, path: Path) -> BenchmarkConfig:
@@ -470,27 +472,6 @@ def validate_client_settings(self) -> None:
470472
f"workers must be >= 1, got {self.settings.client.workers}"
471473
)
472474

473-
# max_concurrency: -1 means unlimited, otherwise must be >= 1
474-
if (
475-
self.settings.client.max_concurrency < -1
476-
or self.settings.client.max_concurrency == 0
477-
):
478-
raise ValueError(
479-
f"max_concurrency must be -1 (unlimited) or >= 1, got {self.settings.client.max_concurrency}"
480-
)
481-
482-
# Ensure max_concurrency can handle target_concurrency if set
483-
target_concurrency = self.settings.load_pattern.target_concurrency
484-
max_concurrency = self.settings.client.max_concurrency
485-
486-
if (
487-
target_concurrency is not None and max_concurrency > 0
488-
): # Skip check if unlimited (-1)
489-
if max_concurrency < target_concurrency:
490-
raise ValueError(
491-
f"max_concurrency ({max_concurrency}) must be >= target_concurrency ({target_concurrency})"
492-
)
493-
494475
def validate_runtime_settings(self) -> None:
495476
"""Validate runtime settings are reasonable.
496477
@@ -579,7 +560,7 @@ def create_default_config(cls, test_type: TestType) -> BenchmarkConfig:
579560
scheduler_random_seed=42,
580561
dataloader_random_seed=42,
581562
),
582-
client=ClientSettings(workers=4, max_concurrency=-1),
563+
client=ClientSettings(workers=4),
583564
),
584565
model_params=ModelParams(temperature=0.7, max_new_tokens=1024),
585566
metrics=Metrics(),
@@ -601,7 +582,7 @@ def create_default_config(cls, test_type: TestType) -> BenchmarkConfig:
601582
scheduler_random_seed=42,
602583
dataloader_random_seed=42,
603584
),
604-
client=ClientSettings(workers=4, max_concurrency=-1),
585+
client=ClientSettings(workers=4),
605586
),
606587
model_params=ModelParams(temperature=0.7, max_new_tokens=1024),
607588
metrics=Metrics(),

src/inference_endpoint/config/templates/concurrency_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ settings:
3131

3232
client:
3333
workers: 4
34-
max_concurrency: -1 # -1 = unlimited # Should exceed/match target_concurrency
3534

3635
metrics:
3736
collect:

src/inference_endpoint/config/templates/eval_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ settings:
2525

2626
client:
2727
workers: 4
28-
max_concurrency: -1 # -1 = unlimited
2928

3029
metrics:
3130
collect:

src/inference_endpoint/config/templates/offline_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ settings:
3030

3131
client:
3232
workers: 4
33-
max_concurrency: -1 # -1 = unlimited
3433

3534
metrics:
3635
collect:

src/inference_endpoint/config/templates/online_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ settings:
3030

3131
client:
3232
workers: 4
33-
max_concurrency: -1 # -1 = unlimited
3433

3534
metrics:
3635
collect:

src/inference_endpoint/config/templates/submission_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ settings:
5656

5757
client:
5858
workers: 4
59-
max_concurrency: -1 # -1 = unlimited
6059

6160
metrics:
6261
collect:

0 commit comments

Comments
 (0)