Skip to content

Commit 2146b0e

Browse files
committed
Fix type annotation for msgspec serialization error, add tests to catch similar errors in the future
1 parent e12f20b commit 2146b0e

3 files changed

Lines changed: 820 additions & 2 deletions

File tree

src/inference_endpoint/core/types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class QueryStatus(Enum):
4848
CANCELLED = "cancelled"
4949

5050

51+
_OUTPUT_DICT_TYPE = dict[str, str | list[str]]
52+
_OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None
53+
54+
5155
class Query(msgspec.Struct, kw_only=True):
5256
"""Represents a single inference query to be sent to an endpoint.
5357
@@ -105,10 +109,10 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True)
105109
"""
106110

107111
id: str = ""
108-
response_output: str | tuple[str, ...] | None = None
112+
response_output: _OUTPUT_RESULT_TYPE = None
109113
metadata: dict[str, Any] = msgspec.field(default_factory=dict)
110114
error: str | None = None
111-
completed_at: float = msgspec.UNSET
115+
completed_at: int = msgspec.UNSET
112116

113117
def __post_init__(self):
114118
"""Set completion timestamp automatically.
@@ -122,6 +126,9 @@ def __post_init__(self):
122126
"""
123127
# Disallow user setting completed_at time to prevent cheating.
124128
# Timestamp must be generated internally
129+
# Note that this will also be regenerated during encode+decode. This is
130+
# intentional, since timestamps in child and parent processes may be different
131+
# due to how monotonic_ns works.
125132
msgspec.structs.force_setattr(self, "completed_at", time.monotonic_ns())
126133

127134
# A list can be passed on, but we need to convert it to a tuple to maintain immutability,

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ async def _handle_streaming_request(self, query: Query) -> None:
462462

463463
# Send final complete response
464464
if reasoning_chunks:
465+
# If there are reasoning chunks, then the first chunk received
466+
# is the first reasoning chunk. The rest of the reasoning chunks,
467+
# as well as the output chunks can be joined together.
465468
resp_reasoning = [reasoning_chunks[0]]
466469
if len(reasoning_chunks) > 1:
467470
resp_reasoning.append("".join(reasoning_chunks[1:]))
@@ -470,6 +473,8 @@ async def _handle_streaming_request(self, query: Query) -> None:
470473
"reasoning": resp_reasoning,
471474
}
472475
elif output_chunks:
476+
# If there are only output chunks, the first chunk is the used for
477+
# TTFT calculations. The rest are joined together.
473478
resp_output = [output_chunks[0]]
474479
if len(output_chunks) > 1:
475480
resp_output.append("".join(output_chunks[1:]))

0 commit comments

Comments
 (0)