2626from abc import ABC , abstractmethod
2727from collections .abc import AsyncIterator
2828from contextlib import asynccontextmanager
29- from typing import Any , Protocol , runtime_checkable
29+ from typing import Any , Generic , Protocol , TypeVar , runtime_checkable
3030
3131import msgspec
3232from pydantic import BaseModel , ConfigDict , Field
3333
34- from inference_endpoint .core .record import (
35- ErrorEventType ,
36- EventRecord ,
37- decode_event_record ,
38- encode_event_record ,
39- )
40- from inference_endpoint .core .types import ErrorData , Query , QueryResult , StreamChunk
34+ from inference_endpoint .core .types import Query , QueryResult , StreamChunk
35+
36+ T = TypeVar ("T" )
4137
4238
4339class TransportConfig (BaseModel , ABC ):
@@ -235,54 +231,76 @@ def cleanup(self) -> None:
235231 pass
236232
237233
238- class EventRecordPublisher (ABC ):
239- """Abstract base class for publishing event records over a transport."""
234+ class MessageCodec (Protocol [T ]):
235+ """Encode/decode policy for a single message type on the pub/sub layer.
236+
237+ The codec is the only type-specific surface in the pub/sub stack. All
238+ transport machinery (ZmqMessagePublisher / ZmqMessageSubscriber) operates
239+ on (topic_bytes, payload_bytes); the codec is what binds those bytes to
240+ a concrete Python type T.
241+ """
242+
243+ def encode (self , item : T ) -> tuple [bytes , bytes ]:
244+ """Return (topic, payload). topic must be exactly TOPIC_FRAME_SIZE bytes."""
245+ ...
246+
247+ def decode (self , payload : bytes ) -> T :
248+ """Decode payload back to T. May raise; the caller routes failures
249+ through on_decode_error."""
250+ ...
251+
252+ def on_decode_error (self , payload : bytes , exc : Exception ) -> T | None :
253+ """Fallback for malformed payloads. Return a sentinel item or None
254+ to drop the message."""
255+ ...
256+
257+
258+ class MessagePublisher (ABC , Generic [T ]):
259+ """Abstract base for publishing typed messages over a transport.
260+
261+ Subclasses implement send(topic, payload) and close(). publish() is
262+ generic over T via the codec.
263+ """
240264
241265 def __init__ (
242266 self ,
267+ codec : MessageCodec [T ],
243268 bind_address : str ,
244269 loop : asyncio .AbstractEventLoop | None = None ,
245270 ):
246- """Creates a new EventRecordPublisher .
271+ """Creates a new MessagePublisher .
247272
248273 Args:
249- bind_address: The address to bind the publisher to. This can be an IPC or TCP socket address.
250- loop: The event loop to use for the publisher. If not provided, it is assumed that the publisher
251- should always execute eagerly and will be blocking. This means that the call to `.publish()`
252- will always be called immediately and the current loop and thread will block until the message
253- is sent.
274+ codec: Encode policy. Required because turning T into wire bytes
275+ is the only type-specific operation; injecting it is the
276+ whole point of generalization.
277+ bind_address: IPC or TCP socket address to bind to.
278+ loop: Event loop to register async writes on. If None, send is
279+ eager/blocking — used by callers that publish before a loop
280+ is running (e.g. service startup).
254281 """
282+ self ._codec = codec
255283 self .bind_address = bind_address
256284 self .loop = loop
257285 self .is_closed : bool = False
258286
259- def publish (self , event_record : EventRecord ) -> None :
260- """Publish the event record on the bound address.
261-
262- Args:
263- event_record: The event record to publish.
264- """
287+ def publish (self , item : T ) -> None :
288+ """Encode item via the codec and send."""
265289 if self .is_closed :
266290 return
267-
268- topic , payload = encode_event_record (event_record )
291+ topic , payload = self ._codec .encode (item )
269292 self .send (topic , payload )
270293
271294 @abstractmethod
272295 def send (self , topic : bytes , payload : bytes ) -> None :
273- """Send the message via the implemented transport layer.
274-
275- Args:
276- topic: The topic of the message.
277- payload: The payload of the message.
278- """
279- raise NotImplementedError ("Subclasses must implement this method." )
296+ """Send raw frame via the implemented transport layer."""
297+ raise NotImplementedError
280298
281299 def flush (self ) -> None : # noqa: B027 — intentionally non-abstract
282300 """Force-send any buffered records.
283301
284302 Unbuffered implementations need no override. Buffered subclasses
285- (e.g., ZmqEventRecordPublisher ) override this to drain their buffer.
303+ (e.g. ZmqMessagePublisher ) override this to drain their buffer.
286304 """
287305
288306 @abstractmethod
@@ -291,34 +309,39 @@ def close(self) -> None:
291309
292310 Implementations must flush any buffered records before closing.
293311 """
294- raise NotImplementedError ( "Subclasses must implement this method." )
312+ raise NotImplementedError
295313
296314
297- class EventRecordSubscriber (ABC ):
298- """Abstract base class for subscribing to event records over a transport."""
315+ class MessageSubscriber (ABC , Generic [T ]):
316+ """Abstract base for subscribing to typed messages over a transport.
317+
318+ Subclasses implement receive() (raw bytes from socket) and process()
319+ (handle decoded items). _on_readable wires them together using the
320+ codec.
321+ """
299322
300323 def __init__ (
301324 self ,
325+ codec : MessageCodec [T ],
302326 connect_address : str ,
303327 loop : asyncio .AbstractEventLoop ,
304328 topics : list [str ] | None = None ,
305329 ):
306- """Creates a new EventRecordSubscriber .
330+ """Creates a new MessageSubscriber .
307331
308- Initializing the subscriber does NOT start processing. The subscriber connects
309- to the address and subscribes to topics, but the socket reader is only added
310- when .start() is called. This allows bookkeeping or other setup before
311- listening. Each subscriber should use its own event loop (e.g. from LoopManager),
312- not shared with the publisher.
313-
314- It is mandatory for subscriber implementations to set the `_fd` attribute to the file
315- descriptor of the socket to add an asyncio reader to the event loop.
332+ Initializing does NOT start processing — call .start() to add the
333+ socket reader to the loop. Subclasses must set ``self._fd`` to the
334+ socket file descriptor before .start() is called.
316335
317336 Args:
318- connect_address: The address to connect the subscriber to. This can be an IPC or TCP socket address.
319- loop: The event loop to use for the subscriber (typically a dedicated loop per subscriber).
320- topics: The topics to subscribe to. If not provided, it is assumed that the subscriber should subscribe to all topics.
337+ codec: Decode policy. Required for the same reason as in
338+ MessagePublisher.
339+ connect_address: IPC or TCP socket address to connect to.
340+ loop: Dedicated loop for this subscriber (typically from
341+ LoopManager — not shared with the publisher).
342+ topics: Topics to subscribe to. None means subscribe to all.
321343 """
344+ self ._codec = codec
322345 self .connect_address = connect_address
323346 self .topics = topics
324347 self .loop = loop
@@ -328,31 +351,22 @@ def __init__(
328351
329352 @abstractmethod
330353 def receive (self ) -> bytes | None :
331- """Receive data from the transport.
332-
333- Should receive data from the socket and return a bytes object that should be able
334- to be decoded into an EventRecord.
354+ """Receive a single payload (no topic prefix) from the transport.
335355
336- If the received data is malformed, this method should return None.
337-
338- For the specific case that the transport is not readable or the underlying socket is busy
339- (such as when an EAGAIN error is raised), this method should raise a StopIteration exception.
356+ Returns None for malformed-but-recognized frames. Raises
357+ StopIteration when the transport has nothing more to deliver right
358+ now (EAGAIN).
340359 """
341- raise NotImplementedError ( "Subclasses must implement this method." )
360+ raise NotImplementedError
342361
343362 @abstractmethod
344- async def process (self , records : list [EventRecord ]) -> None :
345- """Process a list of EventRecords.
346-
347- Called asynchronously (scheduled via create_task) so that heavy work does not
348- block the socket read path. Implementations should be async.
349- """
350- raise NotImplementedError ("Subclasses must implement this method." )
363+ async def process (self , items : list [T ]) -> None :
364+ """Handle a batch of decoded items. Called as an asyncio task so
365+ heavy work does not block the socket read path."""
366+ raise NotImplementedError
351367
352368 def close (self ) -> None :
353- """Close the subscriber and release resources (e.g. remove reader, close socket).
354- Should be idempotent; safe to call multiple times. Call when the session has ended.
355- """
369+ """Close the subscriber. Idempotent."""
356370 if self .loop is not None and self ._fd is not None :
357371 try :
358372 self .loop .remove_reader (self ._fd )
@@ -361,46 +375,32 @@ def close(self) -> None:
361375 pass
362376
363377 def _on_readable (self ) -> None :
364- """Drain socket, decode records , and schedule process() as an async task ."""
378+ """Drain socket, decode via codec , and schedule process()."""
365379 if self .is_closed :
366380 return
367381
368- records : list [EventRecord ] = []
382+ items : list [T ] = []
369383 try :
370384 while True :
371385 payload = self .receive ()
372386 if payload is None :
373387 continue
374-
375- # Attempt decode
376388 try :
377- event_record = decode_event_record ( payload )
389+ items . append ( self . _codec . decode ( payload ) )
378390 except msgspec .DecodeError as e :
379- event_record = EventRecord (
380- event_type = ErrorEventType .GENERIC ,
381- data = ErrorData (
382- error_type = "msgspec.DecodeError" ,
383- error_message = str (e ),
384- ),
385- )
386- records .append (event_record )
391+ fallback = self ._codec .on_decode_error (payload , e )
392+ if fallback is not None :
393+ items .append (fallback )
387394 except StopIteration :
388- # No more messages to receive right now
389395 pass
390396 finally :
391- if records :
392- # Schedule process() so it does not block the socket read path
393- self .loop .create_task (self .process (records ))
397+ if items :
398+ self .loop .create_task (self .process (items ))
394399
395400 def start (self ) -> None :
396- """Start the subscriber: add the socket reader to the loop and begin processing.
397-
398- Call this after any setup (e.g. when the session is about to start). Before
399- start() is called, no messages are received.
400- """
401+ """Add the socket reader to the loop and begin processing."""
401402 if self ._fd is None :
402403 raise ValueError ("Subscriber not initialized with a file descriptor" )
403-
404404 self .loop .add_reader (self ._fd , self ._on_readable )
405405
406406
@@ -410,6 +410,7 @@ def start(self) -> None:
410410 "SenderTransport" ,
411411 "WorkerConnector" ,
412412 "WorkerPoolTransport" ,
413- "EventRecordPublisher" ,
414- "EventRecordSubscriber" ,
413+ "MessageCodec" ,
414+ "MessagePublisher" ,
415+ "MessageSubscriber" ,
415416]
0 commit comments