Skip to content
Merged
6 changes: 6 additions & 0 deletions examples/01_LocalBenchmark/run_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def issue(self, sample):
action="store_true",
help="Dump the events to a CSV file",
)
parser.add_argument(
"--total-sample-count",
type=int,
help="Total number of samples to issue (Debug only)",
)
args = parser.parse_args()

# Set up progress bar hook to monitor sample completion
Expand All @@ -205,6 +210,7 @@ def issue(self, sample):
min_sample_count=100, # Minimum samples to issue
min_duration_ms=10 * 1000, # 10 seconds minimum
max_duration_ms=5 * 60 * 1000, # 5 minutes maximum
total_sample_count=args.total_sample_count if args.total_sample_count else None,
ds_subset_size=dataloader.num_samples(), # Use all available samples
)

Expand Down
15 changes: 13 additions & 2 deletions src/inference_endpoint/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class QueryStatus(Enum):
CANCELLED = "cancelled"


_OUTPUT_DICT_TYPE = dict[str, str | list[str]]
_OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None


class Query(msgspec.Struct, kw_only=True):
"""Represents a single inference query to be sent to an endpoint.

Expand Down Expand Up @@ -105,10 +109,10 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True)
"""

id: str = ""
response_output: str | tuple[str, ...] | None = None
response_output: _OUTPUT_RESULT_TYPE = None
metadata: dict[str, Any] = msgspec.field(default_factory=dict)
error: str | None = None
completed_at: float = msgspec.UNSET
completed_at: int = msgspec.UNSET

def __post_init__(self):
"""Set completion timestamp automatically.
Expand All @@ -122,6 +126,9 @@ def __post_init__(self):
"""
# Disallow user setting completed_at time to prevent cheating.
# Timestamp must be generated internally
# Note that this will also be regenerated during encode+decode. This is
# intentional, since timestamps in child and parent processes may be different
# due to how monotonic_ns works.
msgspec.structs.force_setattr(self, "completed_at", time.monotonic_ns())

# A list can be passed on, but we need to convert it to a tuple to maintain immutability,
Expand All @@ -130,6 +137,10 @@ def __post_init__(self):
msgspec.structs.force_setattr(
self, "response_output", tuple(self.response_output)
)
elif isinstance(self.response_output, dict):
for k, v in self.response_output.items():
if isinstance(v, list):
self.response_output[k] = tuple(v)


class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True):
Expand Down
107 changes: 71 additions & 36 deletions src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,50 +404,85 @@ async def _iter_sse_lines(
async def _handle_streaming_request(self, query: Query) -> None:
"""Handle streaming response."""
async for response in self._make_http_request(query):
accumulated_content = []
output_chunks = []
reasoning_chunks = []
first_chunk_sent = False

# Process SSE stream - yields batches of chunks
async for chunk_batch in self._iter_sse_lines(response):
accumulated_content.extend(chunk_batch)

# Determine which chunks to send: all or just first
chunks_to_send = (
chunk_batch
if self.http_config.stream_all_chunks
else chunk_batch[:1]
if not first_chunk_sent
else []
)

# Send chunks
for content in chunks_to_send:
await self._response_socket.send(
StreamChunk(
id=query.id,
response_chunk=content,
is_complete=False,
metadata={
"first_chunk": not first_chunk_sent,
"final_chunk": False,
},
)
output_delta = []
reasoning_delta = []
for delta in chunk_batch:
if delta.content:
output_delta.append(delta.content)
elif delta.reasoning:
reasoning_delta.append(delta.reasoning)
else:
logger.debug("empty SSE delta")
continue

for delta_batch, accumulator in (
(reasoning_delta, reasoning_chunks),
(output_delta, output_chunks),
):
if not delta_batch:
continue
accumulator.extend(delta_batch)

# Determine which chunks to send: all or just first
chunks_to_send = (
delta_batch
if self.http_config.stream_all_chunks
else delta_batch[:1]
if not first_chunk_sent
else []
)
first_chunk_sent = True
if self.http_config.record_worker_events:
EventRecorder.record_event(
SampleEvent.ZMQ_RESPONSE_SENT,
time.monotonic_ns(),
sample_uuid=query.id,
assert_active=True,

# Send chunks
for content in chunks_to_send:
await self._response_socket.send(
StreamChunk(
id=query.id,
response_chunk=content,
is_complete=False,
metadata={
"first_chunk": not first_chunk_sent,
"final_chunk": False,
},
)
)
first_chunk_sent = True
if self.http_config.record_worker_events:
EventRecorder.record_event(
SampleEvent.ZMQ_RESPONSE_SENT,
time.monotonic_ns(),
sample_uuid=query.id,
assert_active=True,
)

# Send final complete response
response_output = []
if accumulated_content:
response_output.append(accumulated_content[0])
if len(accumulated_content) > 1:
response_output.append("".join(accumulated_content[1:]))
if reasoning_chunks:
# If there are reasoning chunks, then the first chunk received
# is the first reasoning chunk. The rest of the reasoning chunks,
# as well as the output chunks can be joined together.
resp_reasoning = [reasoning_chunks[0]]
if len(reasoning_chunks) > 1:
resp_reasoning.append("".join(reasoning_chunks[1:]))
response_output = {
"output": "".join(output_chunks),
"reasoning": resp_reasoning,
}
elif output_chunks:
# If there are only output chunks, the first chunk is the used for
# TTFT calculations. The rest are joined together.
resp_output = [output_chunks[0]]
if len(output_chunks) > 1:
resp_output.append("".join(output_chunks[1:]))
response_output = {
"output": resp_output,
}
else:
response_output = {"output": []}

await self._response_socket.send(
QueryResult(
Expand Down
1 change: 1 addition & 0 deletions src/inference_endpoint/load_generator/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SessionEvent(Event):
LOADGEN_ISSUE_CALLED = "loadgen_issue_called"
LOADGEN_STOP = "loadgen_stop"
LOADGEN_DATA_LOAD = "loadgen_data_load"
STOP_PERFORMANCE_TRACKING = "stop_performance_tracking"
ERROR = "error"


Expand Down
8 changes: 7 additions & 1 deletion src/inference_endpoint/load_generator/load_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
self,
sample_issuer: SampleIssuer,
dataloader: DataLoader,
name: str | None = None,
):
"""Initialize load generator with required dependencies.

Expand All @@ -133,6 +134,8 @@ def __init__(
"""
self.sample_issuer = sample_issuer
self.dataloader = dataloader
self.name = name
self.uuid_to_index_map = {}

@abstractmethod
def __next__(self) -> tuple[Sample, int]:
Expand Down Expand Up @@ -160,6 +163,7 @@ def __next__(self) -> tuple[Sample, int]:

def __iter__(self):
"""Return self as an iterator."""
self.uuid_to_index_map = {}
Comment thread
nv-alicheng marked this conversation as resolved.
return self

def load_sample_data(self, sample_index: int, sample_uuid: str) -> Any:
Expand Down Expand Up @@ -291,6 +295,8 @@ def __next__(self) -> IssuedSample:
sample = Sample(None) # Create sample object first to generate uuid
sample.data = self.load_sample_data(s_idx, sample.uuid)

self.uuid_to_index_map[sample.uuid] = s_idx

scheduled_issue_timestamp_ns = self.last_issue_timestamp_ns + delay_ns
while (now := time.monotonic_ns()) < scheduled_issue_timestamp_ns:
sleep_ns(scheduled_issue_timestamp_ns - now)
Expand All @@ -303,4 +309,4 @@ def __iter__(self):
"SchedulerBasedLoadGenerator can only be iterated over once"
)
self._iterator = iter(self.scheduler)
return self
return super().__iter__()
4 changes: 2 additions & 2 deletions src/inference_endpoint/load_generator/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def stream_chunk_complete(self, chunk: StreamChunk) -> None:
SampleEvent.FIRST_CHUNK,
timestamp_ns,
sample_uuid=chunk.id,
output=chunk.response_chunk,
data=chunk.response_chunk,
)
hooks = self.first_chunk_hooks
else:
Expand Down Expand Up @@ -186,7 +186,7 @@ def query_result_complete(self, result: QueryResult) -> None:
SampleEvent.COMPLETE,
timestamp_ns,
sample_uuid=result.id,
output=result.response_output,
data=result.response_output,
)

for hook in self.complete_hooks:
Expand Down
68 changes: 46 additions & 22 deletions src/inference_endpoint/load_generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import logging
import os
import shutil
import threading
import time
import uuid
Expand Down Expand Up @@ -49,18 +48,20 @@ def __init__(
self.end_event = threading.Event()
self.thread = None

self.sample_uuid_map = {}
self.event_recorder = EventRecorder(
session_id=self.session_id, notify_idle=self.end_event
)

self.sample_uuid_map = None

@property
def is_running(self):
return self.thread is not None and self.thread.is_alive()

def _run_test(
self,
load_generator: LoadGenerator,
perf_test_generator: LoadGenerator,
accuracy_test_generators: dict[str, LoadGenerator] | None = None,
stop_sample_issuer_on_test_end: bool = True,
max_shutdown_timeout_s: float = 300.0,
report_dir: os.PathLike | None = None,
Expand All @@ -72,10 +73,20 @@ def _run_test(
EventRecorder.record_event(
SessionEvent.TEST_STARTED, time.monotonic_ns()
)
for issued_sample in load_generator:
# In the future, we'll want to push this to some thread or process that
# performs output verification / accuracy checks.
self.sample_uuid_map[issued_sample.sample.uuid] = issued_sample

for _ in perf_test_generator:
# Actual issue is done during next(generator). Nothing else to do here, just pass.
pass

EventRecorder.record_event(
SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns()
)

if accuracy_test_generators:
for _, generator in accuracy_test_generators.items():
for _ in generator:
# Actual issue is done during next(generator). Nothing else to do here, just pass.
pass
Comment thread
nv-alicheng marked this conversation as resolved.

self.event_recorder.should_check_idle = True
EventRecorder.record_event(
Expand All @@ -99,7 +110,7 @@ def _run_test(
raise e
finally:
if stop_sample_issuer_on_test_end:
load_generator.sample_issuer.shutdown()
perf_test_generator.sample_issuer.shutdown()
EventRecorder.record_event(SessionEvent.TEST_ENDED, time.monotonic_ns())

self.event_recorder.wait_for_writes()
Expand All @@ -124,17 +135,26 @@ def _run_test(
tokenizer = None
report = reporter.create_report(tokenizer)

# Consolidate UUID->index mappings
perf_name = (
perf_test_generator.name
if perf_test_generator.name
else "performance"
)
sample_idx_map = {
perf_name: perf_test_generator.uuid_to_index_map,
}
if accuracy_test_generators:
for default_name, generator in accuracy_test_generators.items():
name = generator.name if generator.name else default_name
sample_idx_map[name] = generator.uuid_to_index_map
self.sample_uuid_map = sample_idx_map

# Save to report directory if provided
if report_dir:
Path(report_dir).mkdir(parents=True, exist_ok=True)
report.to_json(save_to=Path(report_dir) / "result_summary.json")

# Copy over outputs for validation
shutil.copy(
self.event_recorder.outputs_path,
Path(report_dir) / "outputs.jsonl",
)

# Dump runtime settings to report directory
rt_settings_data = {
"min_duration_ms": self.runtime_settings.min_duration_ms,
Expand Down Expand Up @@ -163,6 +183,10 @@ def _run_test(
).decode("utf-8")
)

# Save the UUID mapping for output verification
with (Path(report_dir) / "sample_idx_map.json").open("w") as f:
f.write(orjson.dumps(self.sample_uuid_map).decode("utf-8"))

if dump_events_csv:
reporter.dump_to_csv(Path(report_dir) / "events.csv")

Expand Down Expand Up @@ -221,14 +245,14 @@ def start(
load_generator = load_generator_cls(sample_issuer, dataloader, *args)
session.thread = threading.Thread(
target=session._run_test,
args=(
load_generator,
stop_sample_issuer_on_test_end,
max_shutdown_timeout_s,
report_dir,
tokenizer_override,
dump_events_csv,
),
args=(load_generator,),
kwargs={
"stop_sample_issuer_on_test_end": stop_sample_issuer_on_test_end,
"max_shutdown_timeout_s": max_shutdown_timeout_s,
"report_dir": report_dir,
"tokenizer_override": tokenizer_override,
"dump_events_csv": dump_events_csv,
},
)
session.thread.start()
return session
Loading