Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
03bd8bc
Adds dict method to IOController to publish events
chrisk314 Jul 24, 2025
b710772
Adds placeholder method to build producer graph
chrisk314 Aug 13, 2025
dde4196
Adds placeholder method to update producer graph
chrisk314 Aug 14, 2025
7297f84
Adds method to get process for component to state backend
chrisk314 Aug 19, 2025
43711fd
Implementations for building and updating producer graph
chrisk314 Aug 21, 2025
21e227e
Removes StopEvent from producer graph
chrisk314 Aug 21, 2025
4543cc9
Adds TODO note
chrisk314 Aug 21, 2025
13e7af4
Close channels in task group
chrisk314 Aug 25, 2025
16ba4e2
Adds grace period before closing IO controller in events scenarios
chrisk314 Aug 25, 2025
f0aed10
Only calls update producer graph if no field inputs
chrisk314 Aug 25, 2025
e7786a2
Apply suggestions from code review
chrisk314 Aug 25, 2025
28d2e69
Replaces exception with log warning
chrisk314 Aug 26, 2025
e2b782d
Adds contrived test for graceful shutdown of event driven models
chrisk314 Aug 26, 2025
d15bb0f
Covariant type annotation
chrisk314 Aug 27, 2025
92d0906
Fixup test component input events spec
chrisk314 Aug 30, 2025
bfced4f
Fixup IOController timeout logic for only output events
chrisk314 Aug 30, 2025
b82511a
Modifies can step logic for event based components
chrisk314 Aug 30, 2025
b038fd0
Adds a note on implementation issue
chrisk314 Aug 30, 2025
cea1898
Makes Component.step concrete and adds check for override
chrisk314 Sep 6, 2025
4b79bc7
Passes timeout for iocontroller from component
chrisk314 Sep 6, 2025
6d38d95
fixup! Passes timeout for iocontroller from component
chrisk314 Sep 6, 2025
f4cd8f9
Changes exception class for no more event producers
chrisk314 Sep 6, 2025
9eece5c
Fixup missing exports for test assertions
chrisk314 Sep 6, 2025
577347b
Only set status complete if not stopped or failed
chrisk314 Sep 9, 2025
576b9c9
Reduces test duplication
chrisk314 Sep 9, 2025
4c12ed0
Clears state backend caches on destroy
chrisk314 Sep 9, 2025
d2380d0
Merge branch 'main' into feat/producer-graph
chrisk314 Sep 10, 2025
4cd9325
Fixup formatting issues
chrisk314 Sep 10, 2025
a08077d
Updates _can_step docstring
chrisk314 Sep 10, 2025
c3136a7
Don't include self in producer graph
chrisk314 Sep 14, 2025
ae264b5
Makes warning message more helpful for user
chrisk314 Sep 14, 2025
c6b80c1
Updates deps
chrisk314 Sep 14, 2025
a5acd90
Gives more time for integration tests
chrisk314 Sep 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 96 additions & 12 deletions plugboard/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from abc import ABC
import asyncio
from collections import defaultdict, deque
from functools import wraps
from functools import cached_property, wraps
import typing as _t

from that_depends import ContextScopes, container_context

from plugboard.component.io_controller import IOController as IO, IODirection
from plugboard.events import Event, EventHandlers, StopEvent
from plugboard.exceptions import (
EventStreamClosedError,
IOSetupError,
IOStreamClosedError,
ProcessStatusError,
Expand Down Expand Up @@ -45,6 +46,8 @@ class Component(ABC, ExportMixin):
io: IO = IO(input_events=[StopEvent], output_events=[StopEvent])
exports: _t.Optional[list[str]] = None

_implements_step: bool = False

def __init__(
self,
*,
Expand Down Expand Up @@ -79,6 +82,7 @@ def __init__(
namespace=self.name,
component=self,
)
self._event_producers: dict[str, set[str]] = defaultdict(set)
self._status = Status.CREATED
self._is_running = False
self._field_inputs: dict[str, _t.Any] = {}
Expand Down Expand Up @@ -144,6 +148,8 @@ def _configure_io(cls) -> None:
raise IOSetupError(
f"{cls.__name__} must extend Component abstract base class io arguments"
)
# Check if component implements step method
cls._implements_step = cls.step is not Component.step

@classmethod
def _get_component_bases(cls) -> list[_t.Type[Component]]:
Expand Down Expand Up @@ -226,29 +232,71 @@ def _handle_init_wrapper(self) -> _t.Callable:
@wraps(self.init)
async def _wrapper() -> None:
with self._job_id_ctx():
await self._build_producer_graph()
await self._init()
await self._set_status(Status.INIT)

return _wrapper

@abstractmethod
async def _build_producer_graph(self) -> None:
"""Builds the producer graph for the component."""
# TODO : How to handle the case of recursion, i.e., a component which is both a producer and
# : consumer of a given event?
if not (self._state and self._state_is_connected):
self._logger.warning("State backend not connected. Cannot build producer graph.")
return
Comment thread
chrisk314 marked this conversation as resolved.
process = await self._state.get_process_for_component(self.id)
input_event_set = {evt.safe_type() for evt in self.io.input_events}
input_event_set.remove(StopEvent.safe_type())
for comp_id, comp_data in process["components"].items():
for evt in input_event_set.intersection(comp_data["io"]["output_events"]):
self._event_producers[evt].add(comp_id)

async def _update_producer_graph(self) -> None:
"""Updates the producer graph for the component."""
if not (self._state and self._state_is_connected):
self._logger.warning("State backend not connected. Cannot update producer graph.")
return
if not self._event_producers:
return # Nothing to do
process = await self._state.get_process_for_component(self.id)
for evt in list(self._event_producers.keys()):
for comp_id in list(self._event_producers[evt]):
comp_status = process["components"][comp_id]["status"]
if comp_status not in (Status.RUNNING, Status.WAITING):
self._event_producers[evt].remove(comp_id)
if not self._event_producers[evt]:
self._event_producers.pop(evt)
Comment thread
chrisk314 marked this conversation as resolved.
if not self._event_producers:
raise EventStreamClosedError("No more events to process.")

async def step(self) -> None:
"""Executes component logic for a single step."""
pass
raise NotImplementedError("Component step method not implemented")

@cached_property
def _produces_no_output_events(self) -> bool:
output_events = set([evt.safe_type() for evt in self.io.output_events])
return len(output_events - {StopEvent.safe_type()}) == 0

@property
def _can_step(self) -> bool:
"""Checks if the component can step.

- if a component has no input or output fields, it cannot step (purely event-driven case);
The rules for whether a component can step are as follows:
- if a component does not implement the `step` method, it cannot step;
- if a component produces no outputs and consumes no input fields, it cannot step (purely
event-driven case);
- if a component requires inputs, it can only step if all the inputs are available;
- otherwise, a component which has outputs but does not require inputs can always step.
"""
consumes_no_inputs = len(self.io.inputs) == 0
produces_no_outputs = len(self.io.outputs) == 0
if consumes_no_inputs and produces_no_outputs:
if not self._implements_step:
return False
return consumes_no_inputs or self._field_inputs_ready
produces_no_outputs = self._produces_no_output_events and len(self.io.outputs) == 0
consumes_no_input_fields = len(self.io.inputs) == 0
if consumes_no_input_fields and produces_no_outputs:
return False
return consumes_no_input_fields or self._field_inputs_ready

def _handle_step_wrapper(self) -> _t.Callable:
self._step = self.step
Expand All @@ -274,32 +322,66 @@ async def _wrapper() -> None:

return _wrapper

@cached_property
def _has_field_inputs(self) -> bool:
return len(self.io.inputs) > 0

@cached_property
def _has_event_inputs(self) -> bool:
input_events = set([evt.safe_type() for evt in self.io.input_events])
return len(input_events - {StopEvent.safe_type()}) > 0

@cached_property
def _has_inputs(self) -> bool:
return self._has_field_inputs or self._has_event_inputs

@cached_property
def _has_field_outputs(self) -> bool:
return len(self.io.outputs) > 0

@cached_property
def _has_event_outputs(self) -> bool:
output_events = set([evt.safe_type() for evt in self.io.output_events])
return len(output_events - {StopEvent.safe_type()}) > 0

@cached_property
def _has_outputs(self) -> bool:
return self._has_field_outputs or self._has_event_outputs

async def _io_read_with_status_check(self) -> None:
"""Reads from IO controller with concurrent periodic status checks.

Status checks are performed periodically until the read completes. If the process is in a
failed state, the component status is set to `STOPPED` and a `ProcessStatusError` is raised;
otherwise another read attempt is made.
"""
read_timeout = 1e-3 if self._has_outputs and not self._has_inputs else None
done, pending = await asyncio.wait(
(
asyncio.create_task(self._periodic_status_check()),
asyncio.create_task(self.io.read()),
asyncio.create_task(self.io.read(timeout=read_timeout)),
),
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
for task in done:
exc = task.exception()
if exc is not None:
if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0:
await self.io.close() # Call close for final wait and flush event buffer
elif exc is not None:
raise exc

async def _periodic_status_check(self) -> None:
"""Periodically checks the status of the process and updates the component status."""
while True:
await asyncio.sleep(IO_READ_TIMEOUT_SECONDS)
await self._status_check()
# TODO : Eventually producer graph update will be event driven. For now,
# : the update is performed periodically, so it's called here along
# : with the status check.
if len(self.io.inputs) == 0:
await self._update_producer_graph()

async def _status_check(self) -> None:
"""Checks the status of the process and updates the component status."""
Expand Down Expand Up @@ -385,7 +467,8 @@ async def run(self) -> None:
await self.step()
except IOStreamClosedError:
break
await self._set_status(Status.COMPLETED)
if self.status not in {Status.STOPPED, Status.FAILED}:
await self._set_status(Status.COMPLETED)
finally:
self._is_running = False

Expand All @@ -406,6 +489,7 @@ def dict(self) -> dict[str, _t.Any]: # noqa: D102
"status": str(self.status),
**field_data,
"exports": {name: getattr(self, name, None) for name in self.exports or []},
"io": self.io.dict(),
}


Expand Down
39 changes: 27 additions & 12 deletions plugboard/component/io_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import asyncio
from collections import defaultdict, deque
from functools import cached_property
from functools import cache, cached_property
import typing as _t

from plugboard.connector import AsyncioChannel, Channel, Connector
from plugboard.events import Event
from plugboard.events import Event, StopEvent
from plugboard.exceptions import ChannelClosedError, IOStreamClosedError
from plugboard.schemas.io import IODirection
from plugboard.utils import DI
Expand All @@ -17,7 +17,8 @@
if _t.TYPE_CHECKING: # pragma: no cover
from plugboard.component import Component

IO_NS_UNSET = "__UNSET__"
IO_NS_UNSET: str = "__UNSET__"
IO_CLOSE_GRACE_PERIOD: float = 3.0

_t_field_key = tuple[str, str]
_io_key_in: str = str(IODirection.INPUT)
Expand Down Expand Up @@ -89,15 +90,15 @@ def is_closed(self) -> bool:
def _has_field_inputs(self) -> bool:
return len(self._input_channels) > 0

@cached_property
def _has_field_outputs(self) -> bool:
return len(self._output_channels) > 0

@cached_property
def _has_event_inputs(self) -> bool:
return len(self._input_event_channels) > 0

async def read(self) -> None:
@cached_property
def _has_inputs(self) -> bool:
return self._has_field_inputs or self._has_event_inputs

async def read(self, timeout: float | None = None) -> None:
"""Reads data and/or events from input channels.

Read behaviour is dependent on the specific combination of input fields, output fields,
Expand All @@ -116,8 +117,6 @@ async def read(self) -> None:
raise IOStreamClosedError("Attempted read on a closed io controller.")
if len(read_tasks := self._set_read_tasks()) == 0:
return
# If there are field outputs but not inputs, wait for a short time to receive input events
timeout = 1e-3 if self._has_field_outputs and not self._has_field_inputs else None
try:
try:
done, _ = await asyncio.wait(
Expand Down Expand Up @@ -310,10 +309,15 @@ def queue_event(self, event: Event) -> None:

async def close(self) -> None:
"""Closes all input/output channels."""
for chan in self._output_channels.values():
await chan.close()
async with asyncio.TaskGroup() as tg:
for chan in self._output_channels.values():
tg.create_task(chan.close())
for task in self._read_tasks.values():
task.cancel()
# If there are events to read wait some grace period before flushing event buffer
if self._input_event_types - {StopEvent.safe_type()}:
await asyncio.sleep(IO_CLOSE_GRACE_PERIOD)
await self._flush_internal_event_buffer()
self._is_closed = True
self._logger.info("IOController closed")

Expand Down Expand Up @@ -398,6 +402,17 @@ def _validate_connections(self) -> None:
if unconnected_outputs := set(self.outputs) - connected_outputs:
self._logger.warning("Output fields not connected", unconnected=unconnected_outputs)

@cache
def dict(self) -> dict[str, _t.Any]: # noqa: D102
return {
"namespace": self.namespace,
"inputs": self.inputs,
"outputs": self.outputs,
"input_events": [e.safe_type() for e in self.input_events],
"output_events": [e.safe_type() for e in self.output_events],
"initial_values": {k: list(v) for k, v in self._initial_values.items()},
}


class IOBuffer(_t.Protocol):
"""`IOBuffer` is a buffer for input/output data."""
Expand Down
2 changes: 1 addition & 1 deletion plugboard/events/event_connector_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class EventConnectorBuilder: # pragma: no cover
def __init__(self, connector_builder: ConnectorBuilder) -> None:
self._connector_builder = connector_builder

def build(self, components: list[Component]) -> dict[str, Connector]:
def build(self, components: _t.Iterable[Component]) -> dict[str, Connector]:
"""Returns mapping of connectors for events handled by components."""
evt_conn_map: dict[str, Connector] = {}
for component in components:
Expand Down
6 changes: 6 additions & 0 deletions plugboard/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class IOSetupError(IOControllerError):
pass


class EventStreamClosedError(Exception):
"""Raised when there are no more event producers running."""

pass


class NoMoreDataException(Exception):
"""Raised when there is no more data to fetch."""

Expand Down
12 changes: 12 additions & 0 deletions plugboard/state/sqlite_state_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ async def init(self) -> None:
await self._initialise_db()
await super().init()

async def destroy(self) -> None:
"""Destroys the `SqliteStateBackend`."""
await super().destroy()
self._get_db_id.cache_clear()
self._get_process_id_for_component.cache_clear()
self._get_process_id_for_connector.cache_clear()

async def _fetchone(
self, statement: str, params: _t.Tuple[_t.Any, ...]
) -> aiosqlite.Row | None:
Expand Down Expand Up @@ -159,6 +166,11 @@ async def get_process(self, process_id: str) -> dict:
process_data["connectors"] = process_connectors
return process_data

async def get_process_for_component(self, component_id: str) -> dict:
"""Gets the process that a component belongs to."""
process_id: str = await self._get_process_id_for_component(component_id)
return await self.get_process(process_id)

@alru_cache(maxsize=128)
async def _get_process_id_for_component(self, component_id: str) -> str:
"""Returns the database id of the process which a component belongs to.
Expand Down
15 changes: 12 additions & 3 deletions plugboard/state/state_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ async def get_process(self, process_id: str) -> dict:
"""Returns a process from the state."""
return await self._get(self._process_key(process_id))

async def _get_process_id_for_component(self, component_id: str) -> str:
process_id: str | None = await self._get(("_comp_proc_map", component_id))
if process_id is None:
raise NotFoundError(f"No process found for component with ID {component_id}")
return process_id

async def get_process_for_component(self, component_id: str) -> dict:
"""Gets the process that a component belongs to."""
process_id = await self._get_process_id_for_component(component_id)
return await self.get_process(process_id)

async def upsert_component(self, component: Component) -> None:
"""Upserts a component into the state."""
process_id = await self._get(("_comp_proc_map", component.id))
Expand Down Expand Up @@ -214,7 +225,5 @@ async def get_process_status(self, process_id: str) -> Status:

async def get_process_status_for_component(self, component_id: str) -> Status:
"""Gets the status of the process that a component belongs to."""
process_id: str | None = await self._get(("_comp_proc_map", component_id))
if process_id is None:
raise NotFoundError(f"No process found for component with ID {component_id}")
process_id = await self._get_process_id_for_component(component_id)
return await self.get_process_status(process_id)
Loading
Loading