Skip to content

Commit c810acd

Browse files
nv-alichengclaude
andcommitted
refactor: generalize ZMQ pub/sub over message type via MessageCodec
Replace EventRecord-specific publisher/subscriber classes with generic ZmqMessagePublisher[T] / ZmqMessageSubscriber[T] parameterized by a MessageCodec[T] Protocol. EventRecordCodec preserves existing wire format and decode-error wrapping behavior. Sets up the generic transport that the upcoming MetricsSnapshot publisher will reuse. - protocol.py: drop EventRecordPublisher/Subscriber ABCs; add MessageCodec, MessagePublisher[T], MessageSubscriber[T]. - pubsub.py: rewrite as ZmqMessagePublisher[T]/ZmqMessageSubscriber[T]; expose sndhwm/linger/conflate so future callers (e.g. live snapshots) can choose drop-old vs. delivery-guarantee semantics. - record.py: add EventRecordCodec next to encode/decode helpers. - Update EventPublisherService, EventLoggerService, MetricsAggregatorService and tests to use the generic classes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b59c56a commit c810acd

9 files changed

Lines changed: 231 additions & 138 deletions

File tree

src/inference_endpoint/async_utils/event_publisher.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
from inference_endpoint.async_utils.loop_manager import LoopManager
2020
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
21-
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher
21+
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessagePublisher
22+
from inference_endpoint.core.record import EventRecord, EventRecordCodec
2223

2324

24-
class EventPublisherService(ZmqEventRecordPublisher):
25+
class EventPublisherService(ZmqMessagePublisher[EventRecord]):
2526
"""Publisher for publishing event records over ZMQ PUB socket.
2627
27-
Wraps ZmqEventRecordPublisher with LoopManager integration and
28+
Wraps ZmqMessagePublisher[EventRecord] with LoopManager integration and
2829
auto-generated socket names.
2930
"""
3031

@@ -44,7 +45,7 @@ def __init__(
4445
synchronization mechanism (e.g., ENDED as a stop signal).
4546
isolated_event_loop: If True, runs on a separate event loop thread.
4647
send_threshold: Minimum number of buffered records before an
47-
automatic flush is triggered. See ZmqEventRecordPublisher.
48+
automatic flush is triggered. See ZmqMessagePublisher.
4849
"""
4950
if extra_eager:
5051
loop = None
@@ -54,6 +55,7 @@ def __init__(
5455
loop = LoopManager().default_loop
5556
self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}"
5657
super().__init__(
58+
EventRecordCodec(),
5759
self.socket_name,
5860
managed_zmq_context,
5961
loop=loop,

src/inference_endpoint/async_utils/services/event_logger/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929

3030
from inference_endpoint.async_utils.loop_manager import LoopManager
3131
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
32-
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber
32+
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber
3333
from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal
3434
from inference_endpoint.core.record import (
3535
EventRecord,
36+
EventRecordCodec,
3637
SessionEventType,
3738
)
3839
from inference_endpoint.utils.logging import setup_logging
@@ -52,7 +53,7 @@
5253
_WRITER_REGISTRY["sql"] = SQLWriter
5354

5455

55-
class EventLoggerService(ZmqEventRecordSubscriber):
56+
class EventLoggerService(ZmqMessageSubscriber[EventRecord]):
5657
"""Event logger service for logging event records.
5758
5859
When SessionEventType.ENDED is received (topic 'session.ended'), the service writes
@@ -69,7 +70,7 @@ def __init__(
6970
shutdown_event: asyncio.Event | None = None,
7071
**kwargs,
7172
):
72-
super().__init__(*args, **kwargs)
73+
super().__init__(EventRecordCodec(), *args, **kwargs)
7374
self._shutdown_received = False
7475
self._shutdown_event = shutdown_event
7576

src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from enum import Enum
2323

2424
from inference_endpoint.async_utils.transport.zmq.pubsub import (
25-
ZmqEventRecordSubscriber,
25+
ZmqMessageSubscriber,
2626
)
2727
from inference_endpoint.core.record import (
2828
ErrorEventType,
2929
EventRecord,
30+
EventRecordCodec,
3031
SampleEventType,
3132
SessionEventType,
3233
)
@@ -81,7 +82,7 @@ class MetricCounterKey(str, Enum):
8182
)
8283

8384

84-
class MetricsAggregatorService(ZmqEventRecordSubscriber):
85+
class MetricsAggregatorService(ZmqMessageSubscriber[EventRecord]):
8586
"""Subscribes to EventRecords and computes per-sample metrics in real time.
8687
8788
The aggregator is a thin event router. All state management, trigger
@@ -99,7 +100,7 @@ def __init__(
99100
shutdown_event: asyncio.Event | None = None,
100101
**kwargs,
101102
):
102-
super().__init__(*args, **kwargs)
103+
super().__init__(EventRecordCodec(), *args, **kwargs)
103104
self._kv_store = kv_store
104105
self._tokenize_pool = tokenize_pool
105106
self._shutdown_event = shutdown_event

src/inference_endpoint/async_utils/transport/protocol.py

Lines changed: 92 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,14 @@
2626
from abc import ABC, abstractmethod
2727
from collections.abc import AsyncIterator
2828
from contextlib import asynccontextmanager
29-
from typing import Any, Protocol, runtime_checkable
29+
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable
3030

3131
import msgspec
3232
from pydantic import BaseModel, ConfigDict, Field
3333

34-
from inference_endpoint.core.record import (
35-
ErrorEventType,
36-
EventRecord,
37-
decode_event_record,
38-
encode_event_record,
39-
)
40-
from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk
34+
from inference_endpoint.core.types import Query, QueryResult, StreamChunk
35+
36+
T = TypeVar("T")
4137

4238

4339
class TransportConfig(BaseModel, ABC):
@@ -235,54 +231,76 @@ def cleanup(self) -> None:
235231
pass
236232

237233

238-
class EventRecordPublisher(ABC):
239-
"""Abstract base class for publishing event records over a transport."""
234+
class MessageCodec(Protocol[T]):
235+
"""Encode/decode policy for a single message type on the pub/sub layer.
236+
237+
The codec is the only type-specific surface in the pub/sub stack. All
238+
transport machinery (ZmqMessagePublisher / ZmqMessageSubscriber) operates
239+
on (topic_bytes, payload_bytes); the codec is what binds those bytes to
240+
a concrete Python type T.
241+
"""
242+
243+
def encode(self, item: T) -> tuple[bytes, bytes]:
244+
"""Return (topic, payload). topic must be exactly TOPIC_FRAME_SIZE bytes."""
245+
...
246+
247+
def decode(self, payload: bytes) -> T:
248+
"""Decode payload back to T. May raise; the caller routes failures
249+
through on_decode_error."""
250+
...
251+
252+
def on_decode_error(self, payload: bytes, exc: Exception) -> T | None:
253+
"""Fallback for malformed payloads. Return a sentinel item or None
254+
to drop the message."""
255+
...
256+
257+
258+
class MessagePublisher(ABC, Generic[T]):
259+
"""Abstract base for publishing typed messages over a transport.
260+
261+
Subclasses implement send(topic, payload) and close(). publish() is
262+
generic over T via the codec.
263+
"""
240264

241265
def __init__(
242266
self,
267+
codec: MessageCodec[T],
243268
bind_address: str,
244269
loop: asyncio.AbstractEventLoop | None = None,
245270
):
246-
"""Creates a new EventRecordPublisher.
271+
"""Creates a new MessagePublisher.
247272
248273
Args:
249-
bind_address: The address to bind the publisher to. This can be an IPC or TCP socket address.
250-
loop: The event loop to use for the publisher. If not provided, it is assumed that the publisher
251-
should always execute eagerly and will be blocking. This means that the call to `.publish()`
252-
will always be called immediately and the current loop and thread will block until the message
253-
is sent.
274+
codec: Encode policy. Required because turning T into wire bytes
275+
is the only type-specific operation; injecting it is the
276+
whole point of generalization.
277+
bind_address: IPC or TCP socket address to bind to.
278+
loop: Event loop to register async writes on. If None, send is
279+
eager/blocking — used by callers that publish before a loop
280+
is running (e.g. service startup).
254281
"""
282+
self._codec = codec
255283
self.bind_address = bind_address
256284
self.loop = loop
257285
self.is_closed: bool = False
258286

259-
def publish(self, event_record: EventRecord) -> None:
260-
"""Publish the event record on the bound address.
261-
262-
Args:
263-
event_record: The event record to publish.
264-
"""
287+
def publish(self, item: T) -> None:
288+
"""Encode item via the codec and send."""
265289
if self.is_closed:
266290
return
267-
268-
topic, payload = encode_event_record(event_record)
291+
topic, payload = self._codec.encode(item)
269292
self.send(topic, payload)
270293

271294
@abstractmethod
272295
def send(self, topic: bytes, payload: bytes) -> None:
273-
"""Send the message via the implemented transport layer.
274-
275-
Args:
276-
topic: The topic of the message.
277-
payload: The payload of the message.
278-
"""
279-
raise NotImplementedError("Subclasses must implement this method.")
296+
"""Send raw frame via the implemented transport layer."""
297+
raise NotImplementedError
280298

281299
def flush(self) -> None: # noqa: B027 — intentionally non-abstract
282300
"""Force-send any buffered records.
283301
284302
Unbuffered implementations need no override. Buffered subclasses
285-
(e.g., ZmqEventRecordPublisher) override this to drain their buffer.
303+
(e.g. ZmqMessagePublisher) override this to drain their buffer.
286304
"""
287305

288306
@abstractmethod
@@ -291,34 +309,39 @@ def close(self) -> None:
291309
292310
Implementations must flush any buffered records before closing.
293311
"""
294-
raise NotImplementedError("Subclasses must implement this method.")
312+
raise NotImplementedError
295313

296314

297-
class EventRecordSubscriber(ABC):
298-
"""Abstract base class for subscribing to event records over a transport."""
315+
class MessageSubscriber(ABC, Generic[T]):
316+
"""Abstract base for subscribing to typed messages over a transport.
317+
318+
Subclasses implement receive() (raw bytes from socket) and process()
319+
(handle decoded items). _on_readable wires them together using the
320+
codec.
321+
"""
299322

300323
def __init__(
301324
self,
325+
codec: MessageCodec[T],
302326
connect_address: str,
303327
loop: asyncio.AbstractEventLoop,
304328
topics: list[str] | None = None,
305329
):
306-
"""Creates a new EventRecordSubscriber.
330+
"""Creates a new MessageSubscriber.
307331
308-
Initializing the subscriber does NOT start processing. The subscriber connects
309-
to the address and subscribes to topics, but the socket reader is only added
310-
when .start() is called. This allows bookkeeping or other setup before
311-
listening. Each subscriber should use its own event loop (e.g. from LoopManager),
312-
not shared with the publisher.
313-
314-
It is mandatory for subscriber implementations to set the `_fd` attribute to the file
315-
descriptor of the socket to add an asyncio reader to the event loop.
332+
Initializing does NOT start processing — call .start() to add the
333+
socket reader to the loop. Subclasses must set ``self._fd`` to the
334+
socket file descriptor before .start() is called.
316335
317336
Args:
318-
connect_address: The address to connect the subscriber to. This can be an IPC or TCP socket address.
319-
loop: The event loop to use for the subscriber (typically a dedicated loop per subscriber).
320-
topics: The topics to subscribe to. If not provided, it is assumed that the subscriber should subscribe to all topics.
337+
codec: Decode policy. Required for the same reason as in
338+
MessagePublisher.
339+
connect_address: IPC or TCP socket address to connect to.
340+
loop: Dedicated loop for this subscriber (typically from
341+
LoopManager — not shared with the publisher).
342+
topics: Topics to subscribe to. None means subscribe to all.
321343
"""
344+
self._codec = codec
322345
self.connect_address = connect_address
323346
self.topics = topics
324347
self.loop = loop
@@ -328,31 +351,22 @@ def __init__(
328351

329352
@abstractmethod
330353
def receive(self) -> bytes | None:
331-
"""Receive data from the transport.
332-
333-
Should receive data from the socket and return a bytes object that should be able
334-
to be decoded into an EventRecord.
354+
"""Receive a single payload (no topic prefix) from the transport.
335355
336-
If the received data is malformed, this method should return None.
337-
338-
For the specific case that the transport is not readable or the underlying socket is busy
339-
(such as when an EAGAIN error is raised), this method should raise a StopIteration exception.
356+
Returns None for malformed-but-recognized frames. Raises
357+
StopIteration when the transport has nothing more to deliver right
358+
now (EAGAIN).
340359
"""
341-
raise NotImplementedError("Subclasses must implement this method.")
360+
raise NotImplementedError
342361

343362
@abstractmethod
344-
async def process(self, records: list[EventRecord]) -> None:
345-
"""Process a list of EventRecords.
346-
347-
Called asynchronously (scheduled via create_task) so that heavy work does not
348-
block the socket read path. Implementations should be async.
349-
"""
350-
raise NotImplementedError("Subclasses must implement this method.")
363+
async def process(self, items: list[T]) -> None:
364+
"""Handle a batch of decoded items. Called as an asyncio task so
365+
heavy work does not block the socket read path."""
366+
raise NotImplementedError
351367

352368
def close(self) -> None:
353-
"""Close the subscriber and release resources (e.g. remove reader, close socket).
354-
Should be idempotent; safe to call multiple times. Call when the session has ended.
355-
"""
369+
"""Close the subscriber. Idempotent."""
356370
if self.loop is not None and self._fd is not None:
357371
try:
358372
self.loop.remove_reader(self._fd)
@@ -361,46 +375,32 @@ def close(self) -> None:
361375
pass
362376

363377
def _on_readable(self) -> None:
364-
"""Drain socket, decode records, and schedule process() as an async task."""
378+
"""Drain socket, decode via codec, and schedule process()."""
365379
if self.is_closed:
366380
return
367381

368-
records: list[EventRecord] = []
382+
items: list[T] = []
369383
try:
370384
while True:
371385
payload = self.receive()
372386
if payload is None:
373387
continue
374-
375-
# Attempt decode
376388
try:
377-
event_record = decode_event_record(payload)
389+
items.append(self._codec.decode(payload))
378390
except msgspec.DecodeError as e:
379-
event_record = EventRecord(
380-
event_type=ErrorEventType.GENERIC,
381-
data=ErrorData(
382-
error_type="msgspec.DecodeError",
383-
error_message=str(e),
384-
),
385-
)
386-
records.append(event_record)
391+
fallback = self._codec.on_decode_error(payload, e)
392+
if fallback is not None:
393+
items.append(fallback)
387394
except StopIteration:
388-
# No more messages to receive right now
389395
pass
390396
finally:
391-
if records:
392-
# Schedule process() so it does not block the socket read path
393-
self.loop.create_task(self.process(records))
397+
if items:
398+
self.loop.create_task(self.process(items))
394399

395400
def start(self) -> None:
396-
"""Start the subscriber: add the socket reader to the loop and begin processing.
397-
398-
Call this after any setup (e.g. when the session is about to start). Before
399-
start() is called, no messages are received.
400-
"""
401+
"""Add the socket reader to the loop and begin processing."""
401402
if self._fd is None:
402403
raise ValueError("Subscriber not initialized with a file descriptor")
403-
404404
self.loop.add_reader(self._fd, self._on_readable)
405405

406406

@@ -410,6 +410,7 @@ def start(self) -> None:
410410
"SenderTransport",
411411
"WorkerConnector",
412412
"WorkerPoolTransport",
413-
"EventRecordPublisher",
414-
"EventRecordSubscriber",
413+
"MessageCodec",
414+
"MessagePublisher",
415+
"MessageSubscriber",
415416
]

0 commit comments

Comments
 (0)