Skip to content

Commit ad19ae1

Browse files
committed
Small refactors and comments to explain return values and reduce code duplication
1 parent 2146b0e commit ad19ae1

3 files changed

Lines changed: 40 additions & 32 deletions

File tree

src/inference_endpoint/load_generator/load_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self,
125125
sample_issuer: SampleIssuer,
126126
dataloader: DataLoader,
127+
name: str | None = None,
127128
):
128129
"""Initialize load generator with required dependencies.
129130
@@ -133,7 +134,7 @@ def __init__(
133134
"""
134135
self.sample_issuer = sample_issuer
135136
self.dataloader = dataloader
136-
137+
self.name = name
137138
self.uuid_to_index_map = {}
138139

139140
@abstractmethod

src/inference_endpoint/load_generator/session.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _run_test(
7575
)
7676

7777
for _ in perf_test_generator:
78+
# Actual issue is done during next(generator). Nothing else to do here, just pass.
7879
pass
7980

8081
EventRecorder.record_event(
@@ -84,6 +85,7 @@ def _run_test(
8485
if accuracy_test_generators:
8586
for _, generator in accuracy_test_generators.items():
8687
for _ in generator:
88+
# Actual issue is done during next(generator). Nothing else to do here, just pass.
8789
pass
8890

8991
self.event_recorder.should_check_idle = True
@@ -134,11 +136,17 @@ def _run_test(
134136
report = reporter.create_report(tokenizer)
135137

136138
# Consolidate UUID->index mappings
139+
perf_name = (
140+
perf_test_generator.name
141+
if perf_test_generator.name
142+
else "performance"
143+
)
137144
sample_idx_map = {
138-
"performance": perf_test_generator.uuid_to_index_map,
145+
perf_name: perf_test_generator.uuid_to_index_map,
139146
}
140147
if accuracy_test_generators:
141-
for name, generator in accuracy_test_generators.items():
148+
for default_name, generator in accuracy_test_generators.items():
149+
name = generator.name if generator.name else default_name
142150
sample_idx_map[name] = generator.uuid_to_index_map
143151
self.sample_uuid_map = sample_idx_map
144152

src/inference_endpoint/metrics/reporter.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,15 @@ def display(
523523
fn("\n")
524524

525525

526+
def _output_sequence_to_str(output_sequence: str | list[str]) -> str | None:
527+
if isinstance(output_sequence, list):
528+
return "".join(output_sequence)
529+
elif isinstance(output_sequence, str):
530+
return output_sequence
531+
else:
532+
return None
533+
534+
526535
def output_sequence_from_data(
527536
data_bytes: bytes,
528537
join_chunks: bool = True,
@@ -553,51 +562,37 @@ def output_sequence_from_data(
553562
logging.warning("Failed to decode data bytes")
554563
return None, None
555564

556-
output_sequence = None
557-
reasoning_sequence = None
558-
565+
output, reasoning = None, None
559566
if isinstance(decoded_data, str):
560567
# If decoded value is a string, it's the output sequence
561-
output_sequence = decoded_data
568+
output = decoded_data
562569
elif isinstance(decoded_data, dict):
563570
# If decoded value is a dict, extract 'output' and optionally 'reasoning'
564571
if "output" not in decoded_data:
565572
logging.warning("Dictionary data missing required 'output' key")
566573
return None, None
567574

568575
# Extract output - can be string or list of strings
569-
output = decoded_data["output"]
570-
if isinstance(output, list):
571-
if join_chunks:
572-
output_sequence = "".join(output)
573-
else:
574-
output_sequence = output
575-
elif isinstance(output, str):
576-
output_sequence = output
577-
else:
576+
output = (
577+
_output_sequence_to_str(decoded_data["output"])
578+
if join_chunks
579+
else decoded_data["output"]
580+
)
581+
if output is None:
578582
logging.warning(f"Output field has unexpected type: {type(output)}")
579583
return None, None
580584

581585
# Extract reasoning if present - can be string or list of strings
582586
if "reasoning" in decoded_data:
583-
reasoning = decoded_data["reasoning"]
584-
if isinstance(reasoning, list):
585-
if join_chunks:
586-
reasoning_sequence = "".join(reasoning)
587-
else:
588-
reasoning_sequence = reasoning
589-
elif isinstance(reasoning, str):
590-
reasoning_sequence = reasoning
591-
else:
592-
logging.warning(
593-
f"Reasoning field has unexpected type: {type(reasoning)}"
594-
)
595-
# Continue with output_sequence, reasoning is optional
587+
reasoning = (
588+
_output_sequence_to_str(decoded_data["reasoning"])
589+
if join_chunks
590+
else decoded_data["reasoning"]
591+
)
596592
else:
597593
logging.warning(f"Decoded data has unexpected type: {type(decoded_data)}")
598594
return None, None
599-
600-
return output_sequence, reasoning_sequence
595+
return output, reasoning
601596

602597

603598
class MetricsReporter:
@@ -664,7 +659,8 @@ def stop_performance_tracking_timestamp_ns(self) -> float:
664659
"""Returns the timestamp_ns of the STOP_PERFORMANCE_TRACKING event.
665660
666661
This method is cached to prevent re-derivation. If the event is not found,
667-
returns positive infinity.
662+
returns positive infinity, since this indicates that the performance run is probably still
663+
running, or the test was killed before it could complete.
668664
669665
Returns:
670666
float: The timestamp_ns of STOP_PERFORMANCE_TRACKING event, or float('inf') if not found.
@@ -677,6 +673,9 @@ def stop_performance_tracking_timestamp_ns(self) -> float:
677673
""").fetchone()
678674

679675
if result is None:
676+
logging.warning(
677+
"No STOP_PERFORMANCE_TRACKING event found, performance run not yet complete"
678+
)
680679
return float("inf")
681680
return float(result[0])
682681

0 commit comments

Comments
 (0)