Skip to content

Commit ed3c742

Browse files
nv-alichengclaude
andauthored
refactor: generalize ZMQ pub/sub over message type via MessageCodec (#300)
* refactor: generalize ZMQ pub/sub over message type via MessageCodec Replace EventRecord-specific publisher/subscriber classes with generic ZmqMessagePublisher[T] / ZmqMessageSubscriber[T] parameterized by a MessageCodec[T] Protocol. EventRecordCodec preserves existing wire format and decode-error wrapping behavior. Sets up the generic transport that the upcoming MetricsSnapshot publisher will reuse. - protocol.py: drop EventRecordPublisher/Subscriber ABCs; add MessageCodec, MessagePublisher[T], MessageSubscriber[T]. - pubsub.py: rewrite as ZmqMessagePublisher[T]/ZmqMessageSubscriber[T]; expose sndhwm/linger/conflate so future callers (e.g. live snapshots) can choose drop-old vs. delivery-guarantee semantics. - record.py: add EventRecordCodec next to encode/decode helpers. - Update EventPublisherService, EventLoggerService, MetricsAggregatorService and tests to use the generic classes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: keep MessageSubscriber catch generic; narrow per-codec Per Gemini review on PR #300: catching only msgspec.DecodeError in MessageSubscriber._on_readable bakes the codec implementation into the supposedly-generic base class. A future codec backed by json, pickle, etc. raises different exception types and would bypass on_decode_error, crashing the reader. - protocol.py: widen the catch back to Exception so the base class makes no assumption about which decoder library a codec uses; drop the now- unused msgspec import. - record.py: tighten EventRecordCodec.on_decode_error to wrap only msgspec.DecodeError and re-raise other exceptions. Preserves the previous behavior parity (only malformed-payload errors become ErrorEventType.GENERIC records; programmer bugs in the decode path still surface). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: cover EventRecordCodec.on_decode_error branches Address PR #300 review feedback: on_decode_error has two distinct branches and neither was exercised. The re-raise branch in particular is the behavior MessageSubscriber._on_readable relies on to surface decode-path bugs — a non-DecodeError must propagate, otherwise it escapes the asyncio reader callback and silently de-registers the subscriber. - test_wraps_msgspec_decode_error_into_generic_error_record: forces a real msgspec.DecodeError through the codec's own decoder (not a hand-constructed exception), then asserts on_decode_error returns a wrapped EventRecord(ErrorEventType.GENERIC, ErrorData(...)). - test_reraises_non_decode_error: passes a ValueError and asserts it propagates unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9ee015e commit ed3c742

10 files changed

Lines changed: 309 additions & 181 deletions

File tree

src/inference_endpoint/async_utils/event_publisher.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
from inference_endpoint.async_utils.loop_manager import LoopManager
2020
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
21-
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher
21+
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessagePublisher
22+
from inference_endpoint.core.record import EventRecord, EventRecordCodec
2223

2324

24-
class EventPublisherService(ZmqEventRecordPublisher):
25+
class EventPublisherService(ZmqMessagePublisher[EventRecord]):
2526
"""Publisher for publishing event records over ZMQ PUB socket.
2627
27-
Wraps ZmqEventRecordPublisher with LoopManager integration and
28+
Wraps ZmqMessagePublisher[EventRecord] with LoopManager integration and
2829
auto-generated socket names.
2930
"""
3031

@@ -44,7 +45,7 @@ def __init__(
4445
synchronization mechanism (e.g., ENDED as a stop signal).
4546
isolated_event_loop: If True, runs on a separate event loop thread.
4647
send_threshold: Minimum number of buffered records before an
47-
automatic flush is triggered. See ZmqEventRecordPublisher.
48+
automatic flush is triggered. See ZmqMessagePublisher.
4849
"""
4950
if extra_eager:
5051
loop = None
@@ -54,6 +55,7 @@ def __init__(
5455
loop = LoopManager().default_loop
5556
self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}"
5657
super().__init__(
58+
EventRecordCodec(),
5759
self.socket_name,
5860
managed_zmq_context,
5961
loop=loop,

src/inference_endpoint/async_utils/services/event_logger/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929

3030
from inference_endpoint.async_utils.loop_manager import LoopManager
3131
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
32-
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber
32+
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber
3333
from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal
3434
from inference_endpoint.core.record import (
3535
EventRecord,
36+
EventRecordCodec,
3637
SessionEventType,
3738
)
3839
from inference_endpoint.utils.logging import setup_logging
@@ -52,7 +53,7 @@
5253
_WRITER_REGISTRY["sql"] = SQLWriter
5354

5455

55-
class EventLoggerService(ZmqEventRecordSubscriber):
56+
class EventLoggerService(ZmqMessageSubscriber[EventRecord]):
5657
"""Event logger service for logging event records.
5758
5859
When SessionEventType.ENDED is received (topic 'session.ended'), the service writes
@@ -69,7 +70,7 @@ def __init__(
6970
shutdown_event: asyncio.Event | None = None,
7071
**kwargs,
7172
):
72-
super().__init__(*args, **kwargs)
73+
super().__init__(EventRecordCodec(), *args, **kwargs)
7374
self._shutdown_received = False
7475
self._shutdown_event = shutdown_event
7576

src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from enum import Enum
2323

2424
from inference_endpoint.async_utils.transport.zmq.pubsub import (
25-
ZmqEventRecordSubscriber,
25+
ZmqMessageSubscriber,
2626
)
2727
from inference_endpoint.core.record import (
2828
ErrorEventType,
2929
EventRecord,
30+
EventRecordCodec,
3031
SampleEventType,
3132
SessionEventType,
3233
)
@@ -81,7 +82,7 @@ class MetricCounterKey(str, Enum):
8182
)
8283

8384

84-
class MetricsAggregatorService(ZmqEventRecordSubscriber):
85+
class MetricsAggregatorService(ZmqMessageSubscriber[EventRecord]):
8586
"""Subscribes to EventRecords and computes per-sample metrics in real time.
8687
8788
The aggregator is a thin event router. All state management, trigger
@@ -99,7 +100,7 @@ def __init__(
99100
shutdown_event: asyncio.Event | None = None,
100101
**kwargs,
101102
):
102-
super().__init__(*args, **kwargs)
103+
super().__init__(EventRecordCodec(), *args, **kwargs)
103104
self._kv_store = kv_store
104105
self._tokenize_pool = tokenize_pool
105106
self._shutdown_event = shutdown_event

src/inference_endpoint/async_utils/transport/protocol.py

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,13 @@
2626
from abc import ABC, abstractmethod
2727
from collections.abc import AsyncIterator
2828
from contextlib import asynccontextmanager
29-
from typing import Any, Protocol, runtime_checkable
29+
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable
3030

31-
import msgspec
3231
from pydantic import BaseModel, ConfigDict, Field
3332

34-
from inference_endpoint.core.record import (
35-
ErrorEventType,
36-
EventRecord,
37-
decode_event_record,
38-
encode_event_record,
39-
)
40-
from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk
33+
from inference_endpoint.core.types import Query, QueryResult, StreamChunk
34+
35+
T = TypeVar("T")
4136

4237

4338
class TransportConfig(BaseModel, ABC):
@@ -235,54 +230,76 @@ def cleanup(self) -> None:
235230
pass
236231

237232

238-
class EventRecordPublisher(ABC):
239-
"""Abstract base class for publishing event records over a transport."""
233+
class MessageCodec(Protocol[T]):
234+
"""Encode/decode policy for a single message type on the pub/sub layer.
235+
236+
The codec is the only type-specific surface in the pub/sub stack. All
237+
transport machinery (ZmqMessagePublisher / ZmqMessageSubscriber) operates
238+
on (topic_bytes, payload_bytes); the codec is what binds those bytes to
239+
a concrete Python type T.
240+
"""
241+
242+
def encode(self, item: T) -> tuple[bytes, bytes]:
243+
"""Return (topic, payload). topic must be exactly TOPIC_FRAME_SIZE bytes."""
244+
...
245+
246+
def decode(self, payload: bytes) -> T:
247+
"""Decode payload back to T. May raise; the caller routes failures
248+
through on_decode_error."""
249+
...
250+
251+
def on_decode_error(self, payload: bytes, exc: Exception) -> T | None:
252+
"""Fallback for malformed payloads. Return a sentinel item or None
253+
to drop the message."""
254+
...
255+
256+
257+
class MessagePublisher(ABC, Generic[T]):
258+
"""Abstract base for publishing typed messages over a transport.
259+
260+
Subclasses implement send(topic, payload) and close(). publish() is
261+
generic over T via the codec.
262+
"""
240263

241264
def __init__(
242265
self,
266+
codec: MessageCodec[T],
243267
bind_address: str,
244268
loop: asyncio.AbstractEventLoop | None = None,
245269
):
246-
"""Creates a new EventRecordPublisher.
270+
"""Creates a new MessagePublisher.
247271
248272
Args:
249-
bind_address: The address to bind the publisher to. This can be an IPC or TCP socket address.
250-
loop: The event loop to use for the publisher. If not provided, it is assumed that the publisher
251-
should always execute eagerly and will be blocking. This means that the call to `.publish()`
252-
will always be called immediately and the current loop and thread will block until the message
253-
is sent.
273+
codec: Encode policy. Required because turning T into wire bytes
274+
is the only type-specific operation; injecting it is the
275+
whole point of generalization.
276+
bind_address: IPC or TCP socket address to bind to.
277+
loop: Event loop to register async writes on. If None, send is
278+
eager/blocking — used by callers that publish before a loop
279+
is running (e.g. service startup).
254280
"""
281+
self._codec = codec
255282
self.bind_address = bind_address
256283
self.loop = loop
257284
self.is_closed: bool = False
258285

259-
def publish(self, event_record: EventRecord) -> None:
260-
"""Publish the event record on the bound address.
261-
262-
Args:
263-
event_record: The event record to publish.
264-
"""
286+
def publish(self, item: T) -> None:
287+
"""Encode item via the codec and send."""
265288
if self.is_closed:
266289
return
267-
268-
topic, payload = encode_event_record(event_record)
290+
topic, payload = self._codec.encode(item)
269291
self.send(topic, payload)
270292

271293
@abstractmethod
272294
def send(self, topic: bytes, payload: bytes) -> None:
273-
"""Send the message via the implemented transport layer.
274-
275-
Args:
276-
topic: The topic of the message.
277-
payload: The payload of the message.
278-
"""
279-
raise NotImplementedError("Subclasses must implement this method.")
295+
"""Send raw frame via the implemented transport layer."""
296+
raise NotImplementedError
280297

281298
def flush(self) -> None: # noqa: B027 — intentionally non-abstract
282299
"""Force-send any buffered records.
283300
284301
Unbuffered implementations need no override. Buffered subclasses
285-
(e.g., ZmqEventRecordPublisher) override this to drain their buffer.
302+
(e.g. ZmqMessagePublisher) override this to drain their buffer.
286303
"""
287304

288305
@abstractmethod
@@ -291,34 +308,39 @@ def close(self) -> None:
291308
292309
Implementations must flush any buffered records before closing.
293310
"""
294-
raise NotImplementedError("Subclasses must implement this method.")
311+
raise NotImplementedError
295312

296313

297-
class EventRecordSubscriber(ABC):
298-
"""Abstract base class for subscribing to event records over a transport."""
314+
class MessageSubscriber(ABC, Generic[T]):
315+
"""Abstract base for subscribing to typed messages over a transport.
316+
317+
Subclasses implement receive() (raw bytes from socket) and process()
318+
(handle decoded items). _on_readable wires them together using the
319+
codec.
320+
"""
299321

300322
def __init__(
301323
self,
324+
codec: MessageCodec[T],
302325
connect_address: str,
303326
loop: asyncio.AbstractEventLoop,
304327
topics: list[str] | None = None,
305328
):
306-
"""Creates a new EventRecordSubscriber.
329+
"""Creates a new MessageSubscriber.
307330
308-
Initializing the subscriber does NOT start processing. The subscriber connects
309-
to the address and subscribes to topics, but the socket reader is only added
310-
when .start() is called. This allows bookkeeping or other setup before
311-
listening. Each subscriber should use its own event loop (e.g. from LoopManager),
312-
not shared with the publisher.
313-
314-
It is mandatory for subscriber implementations to set the `_fd` attribute to the file
315-
descriptor of the socket to add an asyncio reader to the event loop.
331+
Initializing does NOT start processing — call .start() to add the
332+
socket reader to the loop. Subclasses must set ``self._fd`` to the
333+
socket file descriptor before .start() is called.
316334
317335
Args:
318-
connect_address: The address to connect the subscriber to. This can be an IPC or TCP socket address.
319-
loop: The event loop to use for the subscriber (typically a dedicated loop per subscriber).
320-
topics: The topics to subscribe to. If not provided, it is assumed that the subscriber should subscribe to all topics.
336+
codec: Decode policy. Required for the same reason as in
337+
MessagePublisher.
338+
connect_address: IPC or TCP socket address to connect to.
339+
loop: Dedicated loop for this subscriber (typically from
340+
LoopManager — not shared with the publisher).
341+
topics: Topics to subscribe to. None means subscribe to all.
321342
"""
343+
self._codec = codec
322344
self.connect_address = connect_address
323345
self.topics = topics
324346
self.loop = loop
@@ -328,31 +350,22 @@ def __init__(
328350

329351
@abstractmethod
330352
def receive(self) -> bytes | None:
331-
"""Receive data from the transport.
332-
333-
Should receive data from the socket and return a bytes object that should be able
334-
to be decoded into an EventRecord.
353+
"""Receive a single payload (no topic prefix) from the transport.
335354
336-
If the received data is malformed, this method should return None.
337-
338-
For the specific case that the transport is not readable or the underlying socket is busy
339-
(such as when an EAGAIN error is raised), this method should raise a StopIteration exception.
355+
Returns None for malformed-but-recognized frames. Raises
356+
StopIteration when the transport has nothing more to deliver right
357+
now (EAGAIN).
340358
"""
341-
raise NotImplementedError("Subclasses must implement this method.")
359+
raise NotImplementedError
342360

343361
@abstractmethod
344-
async def process(self, records: list[EventRecord]) -> None:
345-
"""Process a list of EventRecords.
346-
347-
Called asynchronously (scheduled via create_task) so that heavy work does not
348-
block the socket read path. Implementations should be async.
349-
"""
350-
raise NotImplementedError("Subclasses must implement this method.")
362+
async def process(self, items: list[T]) -> None:
363+
"""Handle a batch of decoded items. Called as an asyncio task so
364+
heavy work does not block the socket read path."""
365+
raise NotImplementedError
351366

352367
def close(self) -> None:
353-
"""Close the subscriber and release resources (e.g. remove reader, close socket).
354-
Should be idempotent; safe to call multiple times. Call when the session has ended.
355-
"""
368+
"""Close the subscriber. Idempotent."""
356369
if self.loop is not None and self._fd is not None:
357370
try:
358371
self.loop.remove_reader(self._fd)
@@ -361,46 +374,37 @@ def close(self) -> None:
361374
pass
362375

363376
def _on_readable(self) -> None:
364-
"""Drain socket, decode records, and schedule process() as an async task."""
377+
"""Drain socket, decode via codec, and schedule process()."""
365378
if self.is_closed:
366379
return
367380

368-
records: list[EventRecord] = []
381+
items: list[T] = []
369382
try:
370383
while True:
371384
payload = self.receive()
372385
if payload is None:
373386
continue
374-
375-
# Attempt decode
376387
try:
377-
event_record = decode_event_record(payload)
378-
except msgspec.DecodeError as e:
379-
event_record = EventRecord(
380-
event_type=ErrorEventType.GENERIC,
381-
data=ErrorData(
382-
error_type="msgspec.DecodeError",
383-
error_message=str(e),
384-
),
385-
)
386-
records.append(event_record)
388+
items.append(self._codec.decode(payload))
389+
except Exception as e: # noqa: BLE001 — codec decides handling
390+
# The base class is codec-agnostic: different codec
391+
# implementations raise different exception types
392+
# (msgspec.DecodeError, json.JSONDecodeError, ValueError,
393+
# etc.). The codec's on_decode_error decides whether to
394+
# return a fallback item, drop the message, or re-raise.
395+
fallback = self._codec.on_decode_error(payload, e)
396+
if fallback is not None:
397+
items.append(fallback)
387398
except StopIteration:
388-
# No more messages to receive right now
389399
pass
390400
finally:
391-
if records:
392-
# Schedule process() so it does not block the socket read path
393-
self.loop.create_task(self.process(records))
401+
if items:
402+
self.loop.create_task(self.process(items))
394403

395404
def start(self) -> None:
396-
"""Start the subscriber: add the socket reader to the loop and begin processing.
397-
398-
Call this after any setup (e.g. when the session is about to start). Before
399-
start() is called, no messages are received.
400-
"""
405+
"""Add the socket reader to the loop and begin processing."""
401406
if self._fd is None:
402407
raise ValueError("Subscriber not initialized with a file descriptor")
403-
404408
self.loop.add_reader(self._fd, self._on_readable)
405409

406410

@@ -410,6 +414,7 @@ def start(self) -> None:
410414
"SenderTransport",
411415
"WorkerConnector",
412416
"WorkerPoolTransport",
413-
"EventRecordPublisher",
414-
"EventRecordSubscriber",
417+
"MessageCodec",
418+
"MessagePublisher",
419+
"MessageSubscriber",
415420
]

0 commit comments

Comments
 (0)