diff --git a/setup.py b/setup.py index ed9511c..179c7f0 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,13 @@ from setuptools import setup, find_packages +# Read version without importing the package (safe for builds) +version_ns = {} +with open("summoner/_version.py", encoding="utf-8") as f: + exec(f.read(), version_ns) + setup( name="summoner", - version="0.1.0", + version=version_ns["__version__"], description="Summoner's core SDK", author="Remy Tuyeras", author_email="rtuyeras@summoner.org", diff --git a/simulations/simul_flow_2.py b/simulations/simul_flow_2.py index 5da4bb3..78786cb 100644 --- a/simulations/simul_flow_2.py +++ b/simulations/simul_flow_2.py @@ -1,5 +1,5 @@ from summoner.protocol.flow import Flow -from typing import Type, Optional, List, Dict, Union, Callable, Any, Awaitable +from typing import Coroutine, Tuple, Optional, Callable, Any from summoner.protocol.triggers import ( Signal, Action, @@ -123,7 +123,7 @@ def upload_states(flow: Flow, routes: list[str]) -> dict[str, list[Node]]: tape = StateTape(raw_states) activation_index: dict[tuple[int, ...], list[TapeActivation]] = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) -batches: dict[tuple[int, ...], list[Callable[[Any],Awaitable]]] = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} +batches: dict[tuple[int, ...], list[Callable[[Any],Coroutine[Any,Any,Any]]]] = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} print("\n\nactivation_index") pprint.pprint(activation_index) @@ -145,7 +145,7 @@ async def _safe_call(fn, payload): payload = {"content": "hello"} async def receiver_test(): - event_buffer: dict[tuple[int, ...], list[tuple[str, ParsedRoute, Event]]] = defaultdict(list) + event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Optional[Event]]]] = defaultdict(list) for priority, batch_fns in sorted(batches.items(), key=lambda kv: kv[0]): label = "default priority" if priority == () else f"priority {priority}" @@ -158,7 +158,7 @@ async def receiver_test(): activations = activation_index[priority] local_tape = tape.refresh() - to_extend: dict[str, list[Node]] = defaultdict(list) + to_extend: dict[Optional[str], list[Node]] = defaultdict(list) for act, event in zip(activations, events): to_extend[act.key].extend(act.route.activated_nodes(event)) to_extend = dict(to_extend) @@ -323,7 +323,7 @@ async def _start_send_workers(num_workers, send_queue): # --- start workers and test runner --- async def sender_test(): - send_queue: asyncio.Queue[Optional[Sender]] = asyncio.Queue(maxsize=50) + send_queue: asyncio.Queue[Optional[Tuple[str,Sender]]] = asyncio.Queue(maxsize=50) num_workers = 50 await _start_send_workers(num_workers, send_queue) @@ -341,5 +341,3 @@ async def sender_test(): # run the simulation asyncio.run(sender_test()) - - diff --git a/summoner/_version.py b/summoner/_version.py new file mode 100644 index 0000000..545d07d --- /dev/null +++ b/summoner/_version.py @@ -0,0 +1 @@ +__version__ = "1.1.1" \ No newline at end of file diff --git a/summoner/client/__init__.py b/summoner/client/__init__.py index 08587fb..fc94b26 100644 --- a/summoner/client/__init__.py +++ b/summoner/client/__init__.py @@ -1,2 +1,5 @@ +""" +TODO: doc client, ClientMerger and ClientTranslation summary +""" from .client import SummonerClient -from .merger import ClientMerger, ClientTranslation \ No newline at end of file +from .merger import ClientMerger, ClientTranslation diff --git a/summoner/client/client.py b/summoner/client/client.py index 06e9a9c..7d27c3f 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -1,14 +1,27 @@ +""" +SummonerClient +""" +#pylint:disable=line-too-long, wrong-import-position, too-many-lines +#pylint:disable=logging-fstring-interpolation + import os import sys import json +from types import FrameType from typing import ( - Optional, - Callable, - Union, - Awaitable, - Any, - Type, + Awaitable, + Dict, + Generator, + List, + Optional, + Callable, + Set, + Tuple, + Union, + Coroutine, + cast, ) +from typing import Any import asyncio import signal import inspect @@ -19,6 +32,27 @@ if target_path not in sys.path: sys.path.insert(0, target_path) +from summoner.client.client_types import ( + DOWNLOAD_TYPE, + HOOK_TYPE, + RECEIVE_DECORATED_TYPE, + SEND_DECORATED_TYPE, + SENDING_HOOKS_TYPE, + RECEIVING_HOOKS_TYPE, + UPLOAD_TYPE, +) +from summoner.client.dna import ( + DNA_DOWNLOAD, + DNA_UPLOAD, + DNAHook, + DNAReceiver, + DNASender, + hook_entry_contribution, + receiver_entry_contribution, + sender_entry_contribution, + upload_entry_contribution, + download_entry_contribution, +) from summoner.utils import ( load_config, is_jsonable, @@ -28,40 +62,45 @@ rebuild_expression_for, ) from summoner.logger import ( - get_logger, - configure_logger, + get_logger, + configure_logger, Logger, ) from summoner.protocol.triggers import ( - Signal, - Event, + Signal, + Event, Action ) from summoner.protocol.process import ( - StateTape, - ParsedRoute, - Node, - Sender, - Receiver, + StateTape, + ParsedRoute, + Node, + Sender, + Receiver, Direction, ClientIntent, ) from summoner.protocol.flow import Flow from summoner.protocol.validation import ( - hook_priority_order, + hook_priority_order, _check_param_and_return, ) from summoner.protocol.payload import ( - wrap_with_types, + wrap_with_types, recover_with_types, RelayedMessage ) +from summoner._version import __version__ as core_version + class ServerDisconnected(Exception): """Raised when the server closes the connection.""" - pass +#pylint:disable=too-many-instance-attributes class SummonerClient: + """ + TODO doc client + """ DEFAULT_MAX_BYTES_PER_LINE = 64 * 1024 # 64 KiB DEFAULT_READ_TIMEOUT_SECONDS = None # Wait for messages to arrive @@ -74,13 +113,13 @@ class SummonerClient: DEFAULT_EVENT_BRIDGE_SIZE = 1000 DEFAULT_MAX_CONSECUTIVE_ERRORS = 3 # Failed attempts to send before disconnecting - core_version = "1.1.1" + core_version = core_version def __init__(self, name: Optional[str] = None): - + # Give a name to the server self.name = name if isinstance(name, str) else "" - + # Create a bare logger (no handlers yet) self.logger: Logger = get_logger(self.name) @@ -91,7 +130,7 @@ def __init__(self, name: Optional[str] = None): asyncio.set_event_loop(self.loop) # Protect concurrent access to the set of active tasks - self.active_tasks: set[asyncio.Task] = set() + self.active_tasks: set[asyncio.Task[Any]] = set() self.tasks_lock = asyncio.Lock() # Protect route registration and access for receive/send functions @@ -107,7 +146,7 @@ def __init__(self, name: Optional[str] = None): self.connection_lock = asyncio.Lock() # Safe registration of decorators (hooks, receivers, senders) - self._registration_tasks: list[asyncio.Task] = [] + self._registration_tasks: list[asyncio.Task[Any]] = [] # One-time indexing of parsed routes self.receiver_parsed_routes: dict[str, ParsedRoute] = {} @@ -117,50 +156,65 @@ def __init__(self, name: Optional[str] = None): self._flow = Flow() # Functions to read and write the flow's active states in memory - self._upload_states: Optional[Callable[[Any], Awaitable]] = None - self._download_states: Optional[Callable[[Any], Awaitable]] = None - - self.event_bridge_maxsize = None - self.max_concurrent_workers = None # Limit the sending rate (will use 50 if None is given) - self.send_queue_maxsize = None - self.max_bytes_per_line = None - self.read_timeout_seconds = None # None is prefered - self.retry_delay_seconds = None - self.batch_drain = None + self._upload_states: Optional[UPLOAD_TYPE] = None + self._download_states: Optional[DOWNLOAD_TYPE] = None + + # Sender HyperParameters + self.event_bridge_maxsize : Optional[int] = None + self.max_concurrent_workers : Optional[int] = None # Limit the sending rate (will use 50 if None is given) + self.send_queue_maxsize : Optional[int] = None + self.batch_drain: Optional[bool] = None + # self.max_consecutive_worker_errors is unbound until _apply_config + + # Receiver HyperParameters + self.max_bytes_per_line : Optional[int] = None + self.read_timeout_seconds : Optional[float] = None # None is prefered + + # Reconnction HyperParameters + self.retry_delay_seconds : Optional[float] = None + # self.primary_retry_limit is unbound until _apply_config + # self.default_host is unbound until _apply_config + # self.default_port is unbound until _apply_config + # self.default_retry_limit is unbound until _apply_config # Pass Event information from the receiving end to the sending end - self.event_bridge: Optional[asyncio.Queue[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]]] = None + self.event_bridge: Optional[asyncio.Queue[tuple[tuple[int, ...], Optional[str], ParsedRoute, Optional[Event]]]] = None - self.send_queue: Optional[asyncio.Queue] = None + self.send_queue: Optional[asyncio.Queue[Optional[Tuple[str,Sender]]]] = None self.send_workers_started = False # To avoid double-starting workers self.worker_tasks: list[asyncio.Task] = [] self.writer_lock = asyncio.Lock() # Store validation hooks to be used before sending and after receiving - self.sending_hooks: dict[tuple[int,...], Callable[[Union[str, dict]], Union[str, dict]]] = {} - self.receiving_hooks: dict[tuple[int,...], Callable[[Union[str, dict]], Union[str, dict]]] = {} + self.sending_hooks: dict[tuple[int,...], SENDING_HOOKS_TYPE] = {} + self.receiving_hooks: dict[tuple[int,...], RECEIVING_HOOKS_TYPE] = {} self.hooks_lock = asyncio.Lock() # ─── DNA capture for merging ───────────────────────────────────────── # lists of dicts, each entry records one decorated handler - self._dna_receivers: list[dict] = [] - self._dna_senders: list[dict] = [] - self._dna_hooks: list[dict] = [] + self._dna_receivers: list[DNAReceiver] = [] + self._dna_senders: list[DNASender] = [] + self._dna_hooks: list[DNAHook] = [] - self._dna_upload_states: Optional[dict] = None - self._dna_download_states: Optional[dict] = None + self._dna_upload_states: Optional[DNA_UPLOAD] = None + self._dna_download_states: Optional[DNA_DOWNLOAD] = None # ==== VERSION SPECIFIC ==== + #pylint:disable=attribute-defined-outside-init def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): + """ + Given the config which says specific hyperparameters, set the corresponding + fields here. Making sure they meet the expected constraints. + If some are not provided in config, they may be given default values. + """ + self.host = config.get("host") # default is None # pyright: ignore[reportAttributeAccessIssue] + self.port = config.get("port") # default is None # pyright: ignore[reportAttributeAccessIssue] - self.host = config.get("host") # default is None - self.port = config.get("port") # default is None - - logger_cfg = config.get("logger", {}) + logger_cfg = cast(Dict[str,Any],config.get("logger", {})) configure_logger(self.logger, logger_cfg) - hp_config = config.get("hyper_parameters", {}) + hp_config = cast(Dict[str,Any],config.get("hyper_parameters", {})) reconn_cfg = hp_config.get("reconnection", {}) self.retry_delay_seconds = reconn_cfg.get("retry_delay_seconds", self.DEFAULT_RETRY_DELAY) @@ -172,42 +226,55 @@ def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): receiver_cfg = hp_config.get("receiver", {}) self.max_bytes_per_line = receiver_cfg.get("max_bytes_per_line", self.DEFAULT_MAX_BYTES_PER_LINE) self.read_timeout_seconds = receiver_cfg.get("read_timeout_seconds", self.DEFAULT_READ_TIMEOUT_SECONDS) - + sender_cfg = hp_config.get("sender", {}) self.max_concurrent_workers = sender_cfg.get("concurrency_limit", self.DEFAULT_CONCURRENCY_LIMIT) self.batch_drain = bool(sender_cfg.get("batch_drain", True)) self.send_queue_maxsize = sender_cfg.get("queue_maxsize", self.max_concurrent_workers) self.event_bridge_maxsize = sender_cfg.get("event_bridge_maxsize", self.DEFAULT_EVENT_BRIDGE_SIZE) self.max_consecutive_worker_errors = sender_cfg.get("max_worker_errors", self.DEFAULT_MAX_CONSECUTIVE_ERRORS) - + if (not isinstance(self.max_consecutive_worker_errors, int) or self.max_consecutive_worker_errors < 1): raise ValueError("sender.max_worker_errors must be an integer ≥ 1") if not isinstance(self.max_concurrent_workers, int) or self.max_concurrent_workers <= 0: raise ValueError("sender.concurrency_limit must be an integer ≥ 1") - + if not isinstance(self.send_queue_maxsize, int) or self.send_queue_maxsize <= 0: raise ValueError("sender.queue_maxsize must be an integer ≥ 1") - + if self.send_queue_maxsize < self.max_concurrent_workers: self.logger.warning(f"queue_maxsize < concurrency_limit; back-pressure will throttle producers at {self.send_queue_maxsize}") def initialize(self): + """ + Get the regex patterns for flow ready + """ self._flow.compile_arrow_patterns() def flow(self) -> Flow: + """ + TODO: doc flow + """ return self._flow - + async def travel_to(self, host, port): + """ + Change the host and port. + So there is now pending travel. + """ async with self.connection_lock: self.host = host self.port = port self._travel = True async def quit(self): + """ + Now a pending quit + """ async with self.connection_lock: self._quit = True - + async def _reset_client_intent(self): """Clear any pending quit or travel so we start fresh next time.""" async with self.connection_lock: @@ -219,18 +286,18 @@ def upload_states(self): Decorator to supply a function that returns the current state snapshot. Must be used before client.run(). """ - def decorator(fn: Callable[[], Awaitable]): - + def decorator(fn: UPLOAD_TYPE): + # ----[ Safety Checks ]---- - + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@upload_states handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@upload_states", - allow_param=(type(None), str, dict, Any), # the payload - allow_return=(type(None), str, Any, Node, list, dict, + allow_param=(type(None), str, dict, Any), # the payload # pyright: ignore[reportArgumentType] + allow_return=(type(None), str, Any, Node, list, dict, # pyright: ignore[reportArgumentType] list[str], dict[str, str], dict[str, list[str]], list[Node], dict[str, Node], dict[str, list[Node]], dict[str, Union[str, list[str]]], @@ -239,10 +306,10 @@ def decorator(fn: Callable[[], Awaitable]): ), # the payload-dependent tape logger=self.logger, ) - + # if self.loop.is_running(): # raise RuntimeError("@upload_states() must be registered before client.run()") - + if self._upload_states is not None: self.logger.warning("@upload_states handler overwritten") @@ -263,32 +330,32 @@ def download_states(self): Decorator to supply a function that receives a StateTape. Must be used before client.run(). """ - def decorator(fn: Callable[[Any], Awaitable]): - + def decorator(fn: DOWNLOAD_TYPE): + # ----[ Safety Checks ]---- if not inspect.iscoroutinefunction(fn): raise TypeError(f"@download_states handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@download_states", - allow_param=(type(None), Node, Any, list, dict, - list[Node], - dict[str, Node], - dict[str, list[Node]], - dict[str, Union[Node, list[Node]]], - dict[Optional[str], Node], + allow_param=(type(None), Node, Any, list, dict, # pyright: ignore[reportArgumentType] + list[Node], + dict[str, Node], + dict[str, list[Node]], + dict[str, Union[Node, list[Node]]], + dict[Optional[str], Node], dict[Optional[str], list[Node]], dict[Optional[str], Union[Node, list[Node]]], ), - allow_return=(type(None), Any), + allow_return=(type(None), Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) # if self.loop.is_running(): # raise RuntimeError("@download_states() must be registered before client.run()") - + if self._download_states is not None: self.logger.warning("@download_states handler overwritten") @@ -303,11 +370,11 @@ def decorator(fn: Callable[[Any], Awaitable]): return fn return decorator - + # ==== REGISTRATION HELPER ==== - def _schedule_registration(self, register_coro: Awaitable): + def _schedule_registration(self, register_coro: Coroutine[Any,Any,None]): """ Schedule `register_coro` onto self.loop. If the loop isn't running yet, create the task immediately. @@ -325,28 +392,35 @@ def _cb(): # ==== HOOK REGISTRATION ==== def hook( - self, - direction: Direction, + self, + direction: Direction, priority: Union[int, tuple[int, ...]] = () ): - def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]]): - - # ----[ Safety Checks ]---- - + """ + TODO: doc hook + """ + def decorator(fn : HOOK_TYPE): + """ + TODO: doc decorator + """ + + # ----[ Safety Checks ]---- + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@hook handler '{fn.__name__}' must be async") - + + _check_param_and_return( fn, decorator_name="@hook", - allow_param=(Any, str, dict), - allow_return=(type(None), str, dict, Any), + allow_param=(Any, str, dict), # pyright: ignore[reportArgumentType] + allow_return=(type(None), str, dict, Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) - + if not isinstance(direction, Direction): - raise TypeError(f"Direction for hook must be either Direction.SEND or Direction.RECEIVE") - + raise TypeError("Direction for hook must be either Direction.SEND or Direction.RECEIVE") + if isinstance(priority, int): tuple_priority = (priority,) elif isinstance(priority, tuple) and all(isinstance(p, int) for p in priority): @@ -355,20 +429,23 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic raise ValueError(f"Priority must be an integer or a tuple of integers (got type {type(priority).__name__}: {priority!r})") # ----[ DNA capture ]---- - self._dna_hooks.append({ - "fn": fn, - "direction": direction, - "priority": tuple_priority, - "source": inspect.getsource(fn), - }) + self._dna_hooks.append(DNAHook( + fn = fn, + direction = direction, + priority = tuple_priority, + source = inspect.getsource(fn), + )) - # ----[ Registration Code ]---- async def register(): + """ + ----[ Registration Code ]---- + TODO doc register + """ async with self.hooks_lock: if direction == Direction.RECEIVE: - self.receiving_hooks[tuple_priority] = fn + self.receiving_hooks[tuple_priority] = cast(RECEIVING_HOOKS_TYPE, fn) elif direction == Direction.SEND: - self.sending_hooks[tuple_priority] = fn + self.sending_hooks[tuple_priority] = cast(SENDING_HOOKS_TYPE, fn) # ----[ Safe Registration ]---- # NOTE: register() is run ASAP and _registration_tasks is used to wait all registrations before run_client() @@ -381,27 +458,30 @@ async def register(): # ==== RECEIVER REGISTRATION ==== def receive( - self, - route: str, + self, + route: str, priority: Union[int, tuple[int, ...]] = () ): + """ + TODO: doc receive + """ route = route.strip() - def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): - + def decorator(fn: RECEIVE_DECORATED_TYPE): + # ----[ Safety Checks ]---- - + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@receive handler '{fn.__name__}' must be async") - + sig = inspect.signature(fn) if len(sig.parameters) != 1: raise TypeError(f"@receive '{fn.__name__}' must accept exactly one argument (payload)") - + _check_param_and_return( fn, decorator_name="@receive", - allow_param=(Any, str, dict), - allow_return=(type(None), Event, Any), + allow_param=(Any, str, dict), # pyright: ignore[reportArgumentType] + allow_return=(type(None), Event, Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) @@ -416,17 +496,17 @@ def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): raise ValueError(f"Priority must be an integer or a tuple of integers (got type {type(priority).__name__}: {priority!r})") # ----[ DNA capture ]---- - self._dna_receivers.append({ - "fn": fn, - "route": route, - "priority": tuple_priority, - "source": inspect.getsource(fn), # for text serialization - }) + self._dna_receivers.append(DNAReceiver( + fn = fn, + route = route, + priority = tuple_priority, + source = inspect.getsource(fn), # for text serialization + )) # ----[ Registration Code ]---- async def register(): receiver = Receiver(fn=fn, priority=tuple_priority) - + parsed_route = None normalized_route = route @@ -434,7 +514,7 @@ async def register(): try: parsed_route = self._flow.parse_route(route) normalized_route = str(parsed_route) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"@receive: could not parse route {route!r} while flow is enabled; " f"registering raw route. Error: {type(e).__name__}: {e}" @@ -445,7 +525,7 @@ async def register(): async with self.routes_lock: if route in self.receiver_index: self.logger.warning(f"Route '{route}' already exists. Overwriting.") - + if self._flow.in_use and parsed_route is not None: self.receiver_parsed_routes[normalized_route] = parsed_route self.receiver_index[normalized_route] = receiver @@ -456,33 +536,36 @@ async def register(): return fn return decorator - + # ==== SENDER REGISTRATION ==== def send( - self, - route: str, - multi: bool = False, - on_triggers: Optional[set[Signal]] = None, - on_actions: Optional[set[Type]] = None, + self, + route: str, + multi: bool = False, + on_triggers: Optional[set[Any]] = None, + on_actions: Optional[set[Any]] = None, ): + """ + TODO: doc send + """ route = route.strip() - def decorator(fn: Callable[[], Awaitable]): - + def decorator(fn: SEND_DECORATED_TYPE): + # ----[ Safety Checks ]---- if not inspect.iscoroutinefunction(fn): raise TypeError(f"@send sender '{fn.__name__}' must be async") - + sig = inspect.signature(fn) if len(sig.parameters) != 0: raise TypeError(f"@send '{fn.__name__}' must accept no arguments") - + if not multi: _check_param_and_return( fn, decorator_name="@send", allow_param=(), # no args allowed - allow_return=(type(None), Any, str, dict), + allow_return=(type(None), Any, str, dict), # pyright: ignore[reportArgumentType] logger=self.logger, ) else: @@ -490,10 +573,10 @@ def decorator(fn: Callable[[], Awaitable]): fn, decorator_name="@send[multi=True]", allow_param=(), # no args allowed - allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), + allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), # pyright: ignore[reportArgumentType] logger=self.logger, ) - + if not isinstance(route, str): raise TypeError(f"Argument `route` must be string. Provided: {route}") @@ -511,24 +594,24 @@ def decorator(fn: Callable[[], Awaitable]): not all(isinstance(act, type) and issubclass(act, Event) and act in {Action.MOVE, Action.STAY, Action.TEST} for act in on_actions) ): raise TypeError(f"Argument `on_actions` must be `None` or a set of Action event classes: {{Action.MOVE, Action.STAY, Action.TEST}}. Provided: {on_actions!r}") - + # ----[ DNA capture ]---- - self._dna_senders.append({ - "fn": fn, - "route": route, - "multi": multi, - "on_triggers": on_triggers, - "on_actions": on_actions, - "source": inspect.getsource(fn), - }) + self._dna_senders.append(DNASender( + fn = fn, + route = route, + multi = multi, + on_triggers = on_triggers, + on_actions = on_actions, + source = inspect.getsource(fn), + )) # ----[ Registration Code ]---- async def register(): - + sender = Sender(fn=fn, multi=multi, actions=on_actions, triggers=on_triggers) actions_exist = isinstance(on_actions, set) and bool(on_actions) triggers_exist = isinstance(on_triggers, set) and bool(on_triggers) - + parsed_route = None normalized_route = route @@ -536,7 +619,7 @@ async def register(): try: parsed_route = self._flow.parse_route(route) normalized_route = str(parsed_route) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"@send: could not parse route {route!r} while flow is enabled; " f"registering raw route. Error: {type(e).__name__}: {e}" @@ -634,6 +717,7 @@ def _infer_client_binding_name(self) -> str: return "agent" + #pylint:disable=too-many-locals, too-many-branches, too-many-statements def dna(self, include_context: bool = False) -> str: """ Serialize this client's registered behavior into a JSON string ("DNA"). @@ -689,98 +773,41 @@ def dna(self, include_context: bool = False) -> str: - Flow construction and trigger enum creation are not captured unless they are referenced and executed within decorated handler sources. """ + #pylint:disable=import-outside-toplevel import builtins - # Collect handler functions once so we can reuse them for context analysis. + # TODO: Only used below, so could have not turned into a list right away + # Keep only as a generator being iterated over handler_fns = list(self._iter_registered_handler_functions()) # ---------------------------- # Handler DNA entries # ---------------------------- - entries: list[dict] = [] + entries: List[ + Dict[str, str] | \ + Dict[str, str | tuple[int,...]] |\ + Dict[str, str | bool | List[str]] + ] = [] # Upload state hook, if present if self._dna_upload_states is not None: - fn = self._dna_upload_states["fn"] - entries.append({ - "type": "upload_states", - "source": get_callable_source(fn, self._dna_upload_states.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(upload_entry_contribution(self._dna_upload_states)) # Download state hook, if present if self._dna_download_states is not None: - fn = self._dna_download_states["fn"] - entries.append({ - "type": "download_states", - "source": get_callable_source(fn, self._dna_download_states.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(download_entry_contribution(self._dna_download_states)) # All receivers for dna in self._dna_receivers: - fn = dna["fn"] - raw_route = dna["route"] - - try: - if self._flow.in_use: - route_key = str(self._flow.parse_route(raw_route)) - else: - route_key = raw_route - except Exception: - route_key = raw_route - route_key = "".join(str(route_key).split()) - - entries.append({ - "type": "receive", - "route": raw_route, # original route string - "route_key": route_key, # stable route representative - "priority": dna["priority"], - "source": get_callable_source(fn, dna.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(receiver_entry_contribution(dna, self._flow if self._flow.in_use else None)) # All senders for dna in self._dna_senders: - fn = dna["fn"] - raw_route = dna["route"] - - try: - if self._flow.in_use: - route_key = str(self._flow.parse_route(raw_route)) - else: - route_key = raw_route - except Exception: - route_key = raw_route - route_key = "".join(str(route_key).split()) - - entries.append({ - "type": "send", - "route": raw_route, # original route string - "route_key": route_key, # stable route representative - "multi": dna["multi"], - # Serialize triggers/actions by name so they can be re-resolved later. - "on_triggers": [t.name for t in (dna["on_triggers"] or [])], - "on_actions": [a.__name__ for a in (dna["on_actions"] or [])], - "source": get_callable_source(fn, dna.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(sender_entry_contribution(dna, self._flow if self._flow.in_use else None)) # All hooks for dna in self._dna_hooks: - fn = dna["fn"] - entries.append({ - "type": "hook", - "direction": dna["direction"].name, - "priority": dna["priority"], - "source": get_callable_source(fn, dna.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(hook_entry_contribution(dna)) # Fast path: return only handler entries. if not include_context: @@ -809,11 +836,12 @@ def dna(self, include_context: bool = False) -> str: # Identify the runtime type of asyncio.Lock() so we can detect it. try: lock_type = type(asyncio.Lock()) - except Exception: + except Exception:# pylint:disable=broad-exception-caught lock_type = None # Used by resolve_import_statement to avoid emitting repeated imports. known_modules: set[str] = set() + # Collect handler functions once so we can reuse them for context analysis. for fn in handler_fns: g = getattr(fn, "__globals__", None) @@ -821,7 +849,10 @@ def dna(self, include_context: bool = False) -> str: continue # Names referenced by the function body. - names_to_scan = set(getattr(fn, "__code__", None).co_names if hasattr(fn, "__code__") else ()) + if hasattr(fn, "__code__"): + names_to_scan = set(fn.__code__.co_names) + else: + names_to_scan: Set[str] = set() # Names referenced only via annotations. try: @@ -831,14 +862,14 @@ def dna(self, include_context: bool = False) -> str: nm = getattr(v, "__name__", None) if isinstance(nm, str) and nm: names_to_scan.add(nm) - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # Fallback: parse identifiers from annotation syntax in the source. try: src = get_callable_source(fn) names_to_scan |= extract_annotation_identifiers(src) - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass for name in names_to_scan: @@ -872,6 +903,7 @@ def dna(self, include_context: bool = False) -> str: recipes.setdefault(name, r) if "Path(" in r: path_needed = True + #pylint:disable=consider-using-f-string if "{}(".format(Node.__name__) in r: imports_out.add("from summoner.protocol.process import Node") continue @@ -893,7 +925,7 @@ def dna(self, include_context: bool = False) -> str: if path_needed: imports_out.add("from pathlib import Path") - context_entry = { + context_entry : Dict[str, str | list[str] | dict[str,object] | dict[str,str]] = { "type": "__context__", "var_name": inferred_var_name, "imports": sorted(imports_out), @@ -945,16 +977,20 @@ async def _read_line_safe( continue return data - + + # pylint:disable=too-many-locals, too-many-branches, too-many-statements async def message_receiver_loop( - self, - reader: asyncio.StreamReader, + self, + reader: asyncio.StreamReader, stop_event: asyncio.Event ): - + """ + TODO: doc message receiver + """ + # ----[ Wrapper: Interpret Protocol-Only Errors as None ]---- - async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: + async def _safe_call(fn: Callable[[Any],Awaitable[Any]], payload: Any) -> Optional[Any]: try: return await fn(payload) except BlockingIOError: @@ -963,21 +999,21 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: except Exception as e: self.logger.exception(f"Receiver function {fn.__name__} raised an unexpected error: {e}") raise - + try: - + # ----[ Constantly Listen ]---- while not stop_event.is_set(): - + async with self.connection_lock: if self._quit or self._travel: stop_event.set() break - + # ----[ Prepare Receiver Batches ]---- async with self.routes_lock: receiver_index: dict[str, Receiver] = self.receiver_index.copy() - + if self._flow.in_use: async with self.routes_lock: receiver_parsed_routes: dict[str, ParsedRoute] = self.receiver_parsed_routes.copy() @@ -986,29 +1022,30 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # ----[ Empty: Skip and Prevent Client Overwhelming ]---- if not receiver_index: data = await self._read_line_safe( - reader, - limit=self.max_bytes_per_line, + reader, + limit=self.max_bytes_per_line, # max_bytes_per_line is set to an integer by now # pyright: ignore[reportArgumentType] timeout=0.1, ) # if not data: # raise ServerDisconnected("EOF while dropping messages") continue - + # ----[ Build and Run Receiver Batches ]---- try: - + # ----[ Build: Get Messages ]---- - + data = await self._read_line_safe( - reader, - limit=self.max_bytes_per_line, + reader, + limit=self.max_bytes_per_line, # max_bytes_per_line is set to an integer by now # pyright: ignore[reportArgumentType] timeout=self.read_timeout_seconds, ) # data = await reader.readline() # if not data: # raise ServerDisconnected("Server closed the connection.") - payload: RelayedMessage = recover_with_types(data.decode()) + pre_payload: RelayedMessage = recover_with_types(data.decode()) + payload = cast(Union[RelayedMessage, None], pre_payload) # ----[ Build: Validation ]---- async with self.hooks_lock: @@ -1021,26 +1058,26 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: if new_payload is None: payload = None break - - except Exception as e: + payload = new_payload + + except Exception as e:# pylint:disable=broad-exception-caught self.logger.error( f"Receiving hook {receiving_hook.__name__} (priority={priority}) " f"failed on payload {payload!r}: {e}", exc_info=True ) - new_payload = payload - payload = new_payload - + # if *any* hook returned None, skip the rest of processing if payload is None: continue - + # ----[ Build: Organize Batches by Priority ]---- - batches: dict[tuple[int, ...], list[Callable[[Any], Awaitable]]] = {} + batches: dict[tuple[int, ...], list[Callable[[Any], Coroutine[Any,Any,Any]]]] = {} if self._flow.in_use: raw_states = (await self._upload_states(payload)) if self._upload_states is not None else None tape = StateTape(raw_states) - activation_index = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) + activation_index = tape.collect_activations( + receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) # pyright: ignore[reportPossiblyUnboundVariable] batches = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} else: for _, receiver in receiver_index.items(): @@ -1052,14 +1089,14 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: _ = await reader.readline() # Space await asyncio.sleep(0.1) # Time continue - + # ----[ Exec: Prepare Passage Receiver → Sender ]---- if self._flow.in_use: - event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Event]]] = defaultdict(list) + event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Optional[Event]]]] = defaultdict(list) # ----[ Exec: Run Batches in Order ]---- for priority, batch_fns in sorted(batches.items(), key=lambda kv: kv[0]): - + # ----[ Before: Run Batch ]---- # label = "default priority" if priority == () else f"priority {priority}" # self.logger.info(f"Running batch at {label}, {len(batch_fns)} receivers") @@ -1069,35 +1106,38 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # ----[ After: Handle Returns ]---- if self._flow.in_use: - activations = activation_index[priority] - - local_tape = tape.refresh() - to_extend: dict[str, list[Node]] = defaultdict(list) + activations = activation_index[priority] # pyright: ignore[reportPossiblyUnboundVariable] + + local_tape = tape.refresh() # pyright: ignore[reportPossiblyUnboundVariable] + to_extend: dict[Optional[str], list[Node]] = defaultdict(list) for act, event in zip(activations, events): to_extend[act.key].extend(act.route.activated_nodes(event)) local_tape.extend(to_extend) buffer_entries = [(act.key, act.route, event) for act, event in zip(activations, events)] - event_buffer[priority].extend(buffer_entries) - + # event buffer is bound because this is within if self._flow_in_use + # some of the event's in buffer_entries could have been None + event_buffer[priority].extend(buffer_entries) # pyright: ignore[reportPossiblyUnboundVariable] + if self._download_states is not None: await self._download_states(local_tape.revert()) # ----[ Final: Pass Data Over To Senders ]---- if self._flow.in_use: - - for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): + # event buffer is bound because this is within if self._flow_in_use + for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): # pyright: ignore[reportPossiblyUnboundVariable] for event_data in event_list: # this will block if the bridge is full, slowing down readers - await self.event_bridge.put((priority,) + event_data) + to_put = (priority,) + event_data + await self.event_bridge.put(to_put) # pyright: ignore[reportOptionalMemberAccess] event_buffer = {} - + except ServerDisconnected as e: # Intentionally propagate this so reconnection logic can trigger self.logger.info(f"Graceful disconnect from server: {e}") raise - + except (ConnectionResetError, BrokenPipeError) as e: self.logger.warning(f"Socket-level failure (client likely to blame): {e}") break @@ -1110,10 +1150,12 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: def _start_send_workers( self, - writer: asyncio.StreamWriter, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): if not self.send_workers_started: + if self.max_concurrent_workers is None: + raise ValueError("_apply_config will make sure that the maximum number of workers is set to an integer ≥ 1") for _ in range(self.max_concurrent_workers): worker_task = self.loop.create_task(self._send_worker(writer, stop_event)) self.worker_tasks.append(worker_task) @@ -1121,18 +1163,22 @@ def _start_send_workers( async def _send_worker( self, - writer: asyncio.StreamWriter, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): consecutive_errors = 0 while True: - + + if self.send_queue is None: + # pylint:disable=line-too-long + raise ValueError("send_queue is not initialized; this should not happen because this is protected method and called within handle_session") + item: Optional[tuple[str, Sender]] = await self.send_queue.get() if item is None: self.send_queue.task_done() break - + route, sender = item try: result = await sender.fn() @@ -1149,10 +1195,10 @@ async def _send_worker( # ----[ Unpack: Handle Multi Sends ]---- payloads = result if sender.multi else [result] for payload in payloads: - + if payload is None: continue - + # ----[ Unpack: Validation ]---- async with self.hooks_lock: sending_hooks = self.sending_hooks.copy() @@ -1160,19 +1206,18 @@ async def _send_worker( for priority, sending_hook in sorted(sending_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: new_payload = await sending_hook(payload) - + if new_payload is None: payload = None break - - except Exception as e: + payload = new_payload + + except Exception as e:# pylint:disable=broad-exception-caught self.logger.error( f"[route={route}] Sending hook {sending_hook.__name__} (priority={priority}) " f"failed on payload {payload!r}: {e}", exc_info=True ) - new_payload = payload - payload = new_payload # if *any* hook returned None, skip the rest of processing if payload is None: @@ -1191,13 +1236,13 @@ async def _send_worker( # ----[ Unpack: Post Messages ]---- async with self.writer_lock: writer.write(message) - + # No concurrency on batch_drain (initialized in run()) if not self.batch_drain: async with self.writer_lock: await writer.drain() - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught consecutive_errors += 1 self.logger.error( f"Worker for {sender.fn.__name__} crashed ({consecutive_errors} in a row): {e}", @@ -1213,6 +1258,12 @@ async def _send_worker( self.send_queue.task_done() async def _cleanup_workers(self): + """ + Cancel all worker taks + Gather raising any exceptions caused in the worker_tasks + Clear them + Set back to having not started send workers + """ for w in self.worker_tasks: w.cancel() if self.worker_tasks: @@ -1220,17 +1271,24 @@ async def _cleanup_workers(self): self.worker_tasks.clear() self.send_workers_started = False + #pylint:disable=too-many-nested-blocks, no-else-continue, no-else-break async def message_sender_loop( - self, - writer: asyncio.StreamWriter, + self, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): + """ + TODO: doc message_sender + """ # ----[ Helper: Matches Routes Between Senders and Receivers to Trigger Send ]---- def _route_accepts( - sender_pr: ParsedRoute, + sender_pr: ParsedRoute, receiver_pr: ParsedRoute ) -> bool: + """ + TODO: doc route_accepts + """ source_ok = all(any(n.accepts(m) for m in receiver_pr.source) for n in sender_pr.source) label_ok = all(any(n.accepts(m) for m in receiver_pr.label) for n in sender_pr.label) target_ok = all(any(n.accepts(m) for m in receiver_pr.target) for n in sender_pr.target) @@ -1241,28 +1299,29 @@ def _route_accepts( # ----[ Keep Sending While Actively Listening (No Travel) ]---- while not stop_event.is_set(): - + # ----[ Prepare Sender Batch ]---- - + async with self.routes_lock: sender_index: dict[str, list[Sender]] = self.sender_index.copy() # ----[ Fast upload of pending event data ]---- if self._flow.in_use: - + async with self.routes_lock: sender_parsed_routes: dict[str, ParsedRoute] = self.sender_parsed_routes.copy() - - pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]] = [] + + pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Optional[Event]]] = [] try: while True: - pending.append(self.event_bridge.get_nowait()) + # event_bridge exists, handle_session did so + pending.append(self.event_bridge.get_nowait()) # pyright: ignore[reportOptionalMemberAccess] except asyncio.QueueEmpty: pass - + pending.sort(key=lambda it: hook_priority_order(it[0])) - # ----[ Build Sender Batch ]---- + # ----[ Build Sender Batch ]---- senders: list[tuple[str, Sender]] = [] # De-dup set: at most one sender per (route, key-from-recv, recv-handler-name) this cycle. @@ -1270,48 +1329,55 @@ def _route_accepts( for route, routed_senders in sender_index.items(): for sender in routed_senders: - + # Non-reactive (no actions/triggers): preserve current behavior if (not self._flow.in_use) or (sender.actions is None and sender.triggers is None): senders.append((route, sender)) - + # Reactive: require matching a pending activation (existential) - elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or + elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or (sender.triggers and isinstance(sender.triggers, set))): - + + # self._flow_in_use so sender_parsed_routes is bound + sender_parsed_routes = cast(Dict[str,ParsedRoute],sender_parsed_routes) # pyright: ignore[reportPossiblyUnboundVariable] sender_parsed_route = sender_parsed_routes.get(route) if sender_parsed_route is None: continue - + # Iterate pending in queue order; first match "wins" for this (route,key,fn_name) - for (priority, key, parsed_route, event) in pending: + # _flow_in_use so pending is bound + for (_priority, key, parsed_route, event) in pending: # pyright: ignore[reportPossiblyUnboundVariable] if _route_accepts(sender_parsed_route, parsed_route) and sender.responds_to(event): dedup_key = (route, key, sender.fn.__name__) # key scopes to the activation thread/peer if dedup_key not in emitted: senders.append((route, sender)) emitted.add(dedup_key) break # do not enqueue multiple times for this sender this cycle - + # ----[ Empty: Skip and Prevent Client Overwhelming | Almost full: warning ]---- if not senders: await asyncio.sleep(0.1) # Time continue else: - queue_size = self.send_queue.qsize() + # send_queue exists, handle_session did so + queue_size = self.send_queue.qsize() # pyright: ignore[reportOptionalMemberAccess] expected_queue_size = queue_size + len(senders) - if expected_queue_size > self.send_queue_maxsize * 0.8: # 80% full + # self.send_queue_maxsize is set to an integer by now + if expected_queue_size > self.send_queue_maxsize * 0.8: # pyright: ignore[reportOptionalOperand] # 80% full self.logger.warning(f"Queue is about to exceed 80% its capacity; Attempted load size: {expected_queue_size} out of {self.send_queue_maxsize}") # ----[ Enqueue Sender Batch | Senders Are Run in Background ]---- try: for sender in senders: - await self.send_queue.put(sender) # Will block if full (i.e., back-pressure) + # self.send_queue exists, handle_session did so + await self.send_queue.put(sender) # pyright: ignore[reportOptionalMemberAccess] # Will block if full (i.e., back-pressure) except asyncio.CancelledError: self.logger.info("Sender enqueue loop cancelled mid-batch.") raise # ----[ Wait for Sender Batch to Finish]---- - await self.send_queue.join() + # self.send_queue exists, handle_session did so + await self.send_queue.join() # pyright: ignore[reportOptionalMemberAccess] if self.batch_drain: async with self.writer_lock: @@ -1330,6 +1396,8 @@ def _route_accepts( finally: # Best-effort signal to workers; never block on shutdown if self.send_queue is not None: + if self.max_concurrent_workers is None: + raise ValueError("_apply_config will make sure that the maximum number of workers is set to an integer ≥ 1") # This may result in redundant cancellation if shutdown() is also called, # but guarantees all workers get signaled even in abrupt exits. for _ in range(self.max_concurrent_workers): @@ -1351,15 +1419,16 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): Run listener and sender concurrently; whichever exits first (due to disconnect, /quit or /travel) triggers session termination. The remaining task is cancelled. """ - while True: + # Shared flag between the two tasks to signal coordinated session termination stop_event = asyncio.Event() # Always clean up old worker tasks and queues before starting new session; guarantees a fresh worker batch and prevents zombie tasks. await self._cleanup_workers() - self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) - self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) + # The maxsize's are set to integers by now by _apply_config, so these queues will not raise TypeError for maxsize=None + self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) # pyright: ignore[reportArgumentType] + self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) # pyright: ignore[reportArgumentType] # reset any previous travel/quit intent so each session starts fresh; # travel is only honored if set after this point, quit likewise @@ -1372,7 +1441,9 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Register this session's task so it can be cancelled during shutdown current_task = asyncio.current_task() async with self.tasks_lock: - self.active_tasks.add(current_task) + # current_task assumed not None + # but we are within the try block + self.active_tasks.add(current_task) # pyright: ignore[reportArgumentType] # Use lock when accessing dynamic routing information async with self.connection_lock: @@ -1410,32 +1481,36 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): await task except ServerDisconnected as e: # Propagate server-side disconnection to the reconnection handler - raise ServerDisconnected(e) + raise ServerDisconnected(e) #pylint:disable=raise-missing-from except asyncio.CancelledError: # Normal during shutdown; ignore pass - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.exception(f"Unexpected error during session task: {e}") # Cleanly close the connection writer.close() await writer.wait_closed() self.logger.info("Disconnected from server.") - + finally: - + # Ensure both child tasks are cancelled & awaited even if we were cancelled mid-wait for task in (listen_task, sender_task): if task is not None and not task.done(): task.cancel() - + # Clean up worker used in the sender loop await self._cleanup_workers() # Deregister this session and its children from active tasks async with self.tasks_lock: - if task is not None: - self.active_tasks.discard(task) + try: + if current_task is not None: + self.active_tasks.discard(current_task) + except NameError: + pass + # Check whether we should quit or loop back to travel to the next server (agent migration) async with self.connection_lock: @@ -1445,10 +1520,13 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # ==== CLIENT LIFE CYCLE ==== def shutdown(self): + """ + Cancel all the tasks for this event loop + """ self.logger.info("Client is shutting down...") for task in asyncio.all_tasks(self.loop): task.cancel() - + def set_termination_signals(self): """ Install SIGINT/SIGTERM handlers onto the loop: @@ -1457,7 +1535,22 @@ def set_termination_signals(self): """ if platform.system() != "Windows": for sig in (signal.SIGINT, signal.SIGTERM): - self.loop.add_signal_handler(sig, lambda: self.shutdown()) + self.loop.add_signal_handler(sig, self.shutdown) + else: + def _handler(_sig: int, _frame: Optional[FrameType]): + # thread-safe: schedule shutdown on the event loop + try: + self.loop.call_soon_threadsafe(self.shutdown) + except RuntimeError: + pass + signal.signal(signal.SIGINT, _handler) + # SIGTERM exists on Windows in Python, but behavior varies by launcher + if hasattr(signal, "SIGTERM"): + try: + signal.signal(signal.SIGTERM, _handler) + except Exception:# pylint:disable=broad-exception-caught + pass + async def _wait_for_registration(self): """ @@ -1476,7 +1569,7 @@ async def _wait_for_tasks_to_finish(self): # tasks = list(self.active_tasks) # if tasks: # await asyncio.gather(*tasks, return_exceptions=True) - + async with self.tasks_lock: tasks = list(self.active_tasks) if tasks: @@ -1493,16 +1586,17 @@ async def _retry_loop(self, host, port, limit, stage = "Primary"): attempts = 0 # clean disconnect (/quit or /travel) return True - + except (ConnectionRefusedError, ServerDisconnected, OSError) as e: attempts += 1 + sleep_time = self.retry_delay_seconds or self.DEFAULT_RETRY_DELAY self.logger.error( f"[{type(e).__name__}: {e}] " f"({stage}) retry {attempts} of " f"{limit if limit is not None else '∞'}; " - f"sleeping {self.retry_delay_seconds}s", + f"sleeping {sleep_time}s", ) - await asyncio.sleep(self.retry_delay_seconds) + await asyncio.sleep(sleep_time) # Check retry limit if (limit is not None and attempts >= limit): @@ -1516,68 +1610,77 @@ async def _get_client_intent(self) -> ClientIntent: if self._travel: return ClientIntent.TRAVEL return ClientIntent.ABORT - + async def _fallback(self): + """ + use the default host and port + """ async with self.connection_lock: self.host = self.default_host self.port = self.default_port async def run_client(self, host: str = '127.0.0.1', port: int = 8888): + """ + TODO: doc run_client + """ primary_stage = True while True: - + stage = "Primary" if primary_stage else "Default" limit = self.primary_retry_limit if primary_stage else self.default_retry_limit - + succeeded = await self._retry_loop(host, port, limit, stage) - + if succeeded: - + intent = await self._get_client_intent() - + if intent is ClientIntent.QUIT: - break - + break + elif intent is ClientIntent.TRAVEL: primary_stage = True continue - + else: break else: - + if primary_stage: primary_stage = False await self._fallback() self.logger.warning(f"Falling back to default server at {self.default_host}:{self.default_port}") continue - + else: self.logger.critical( f"Cannot connect to fallback {self.default_host}:{self.default_port} after " f"{self.default_retry_limit or '∞'} attempts; exiting" ) break - + def run( - self, - host: str = '127.0.0.1', - port: int = 8888, + self, + host: str = '127.0.0.1', + port: int = 8888, config_path: Optional[str] = None, config_dict: Optional[dict[str, Any]] = None, ): + """ + TODO: doc run + """ try: - + if config_dict is None: # Load config parameters client_config = load_config(config_path=config_path, debug=True) elif isinstance(config_dict, dict): # Shallow copy to avoid external mutation - client_config = dict(config_dict) + client_config = dict(config_dict) else: raise TypeError(f"SummonerClient.run: config_dict must be a dict or None, got {type(config_dict).__name__}") - + # client_config = load_config(config_path=config_path, debug=True) self._apply_config(client_config) @@ -1608,6 +1711,18 @@ def run( self.loop.run_until_complete(asyncio.gather(*self.worker_tasks, return_exceptions=True)) except (asyncio.CancelledError, KeyboardInterrupt): pass - + self.loop.close() - self.logger.info("Client exited cleanly.") \ No newline at end of file + self.logger.info("Client exited cleanly.") + + def _view_candidates(self) -> Generator[Optional[Callable[..., Coroutine[Any,Any,Any] | Awaitable[Any]]],None,None]: + if self._upload_states is not None: + yield self._upload_states + if self._download_states is not None: + yield self._download_states + for d in self._dna_receivers: + yield d.get("fn") + for d in self._dna_senders: + yield d.get("fn") + for d in self._dna_hooks: + yield d.get("fn") diff --git a/summoner/client/client_types.py b/summoner/client/client_types.py new file mode 100644 index 0000000..850a76e --- /dev/null +++ b/summoner/client/client_types.py @@ -0,0 +1,67 @@ +""" +Types used for client and client DNA +""" +#pylint:disable=wrong-import-position, invalid-name +from typing import ( +Dict, +List, +Optional, +Callable, +TypedDict, +Union, +Coroutine, +) +from typing import Any +import os +import sys + +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) +from summoner.protocol.process import Node +from summoner.protocol.triggers import Event +from summoner.protocol.payload import RelayedMessage + +#pylint:disable=pointless-string-statement +""" +See the _check_param_and_return_types function in client.py +inside the decorator internal functions for where these types are actually enforced. +This tells us what the expected signatures of client hooks, upload functions, +and decorated send/receive handlers are. +That way we can put those expected signatures in one place here. Though +the actual enforcement of these types is in client.py, not here. +""" + +HOOK_ARGUMENT = Any | str | Dict[Any,Any] +HOOK_RETURN = None | str | Dict[Any,Any] | Any +HOOK_TYPE = Callable[[HOOK_ARGUMENT], Coroutine[Any,Any,HOOK_RETURN]] + +UPLOAD_RETURN = Union[None | str | Any | Node | List[Any] | Dict[Any,Any]\ + | List[str] | Dict[str, str] | Dict[str, List[str]]\ + | List[Node] | Dict[str, Node] | Dict[str, List[Node]]\ + | Dict[str, Union[str, list[str]]]\ + | Dict[str, Union[Node, list[Node]]]\ + | Dict[str, Union[str, list[str], Node, list[Node]]] +] +UPLOAD_ARGUMENT = None | Any | str | Dict[Any,Any] | TypedDict +UPLOAD_TYPE = Callable[[UPLOAD_ARGUMENT], Coroutine[Any,Any,UPLOAD_RETURN]] + +DOWNLOAD_TYPE = Callable[[Any], Coroutine[Any,Any,Optional[Any]]] + +SENDING_HOOKS_TYPE = Callable[ + [Optional[Union[str, dict]]], + Coroutine[Any,Any,Optional[Union[str, dict]]]] +RECEIVING_HOOKS_TYPE = Callable[ + [Optional[Union[str, dict, RelayedMessage]]], + Coroutine[Any,Any,Optional[Union[str, dict]]]] + +SEND_RETURN_SINGLE_TYPE = None | Any | str | Dict[Any,Any] +SEND_RETURN_MULTI_TYPE = Any | List[Any] | List[str] | \ + List[Dict[Any,Any]] | List[Union[str, Dict[Any,Any]]] +SEND_DECORATED_TYPE = Callable[[], Coroutine[Any, Any, SEND_RETURN_SINGLE_TYPE]] |\ + Callable[[], Coroutine[Any, Any, SEND_RETURN_MULTI_TYPE]] + +RECEIVE_DECORATED_TYPE = Callable[ + [Any | str | Dict[Any,Any]], + Coroutine[Any, Any, Any | Event | None] +] diff --git a/summoner/client/dna.py b/summoner/client/dna.py new file mode 100644 index 0000000..91d2fd3 --- /dev/null +++ b/summoner/client/dna.py @@ -0,0 +1,173 @@ +""" +A client's registered behavior can be serialized into a JSON string ("DNA"). +The types and functions that support this are defined here. + +Purpose +------- +DNA is intended to support: + 1) cloning: rehydrate a client from data by re-evaluating handler sources + 2) merging: combine multiple clients into a composite client (ClientMerger) + +Portability rule +---------------- +DNA is meant to be replayable across environments. Unstable runtime bindings +(for example '__main__' imports or live objects that cannot be rebuilt) should +end up in "missing", not embedded implicitly. +""" +#pylint:disable=wrong-import-position, invalid-name, duplicate-code + +from typing import List, Optional, Set, Type, TypedDict +from typing import Any + +import os +import sys + +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.protocol.flow import Flow +from summoner.utils.code_handlers import get_callable_source +from summoner.protocol.triggers import Signal +from summoner.client.client_types import DOWNLOAD_TYPE, RECEIVE_DECORATED_TYPE, \ + SEND_DECORATED_TYPE, HOOK_TYPE, UPLOAD_TYPE +from summoner.protocol.process import Direction + + +class DNAHook(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_hooks list + """ + fn: HOOK_TYPE + direction: Direction + priority: tuple[int, ...] + source: Optional[str] + +def hook_entry_contribution(hook_entry: DNAHook) -> dict[str, str | tuple[int,...]]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = hook_entry["fn"] + return { + "type": "hook", + "direction": hook_entry["direction"].name, + "priority": hook_entry["priority"], + "source": get_callable_source(fn, hook_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNAReceiver(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_receivers list + """ + fn: RECEIVE_DECORATED_TYPE + route: str + priority: tuple[int, ...] + source: Optional[str] + +def receiver_entry_contribution( + receiver_entry: DNAReceiver, + flow_in_use: Optional[Flow]) -> dict[str, str | tuple[int,...]]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = receiver_entry["fn"] + raw_route = receiver_entry["route"] + + try: + if flow_in_use is not None: + route_key = str(flow_in_use.parse_route(raw_route)) + else: + route_key = raw_route + except Exception: #pylint:disable=broad-exception-caught + route_key = raw_route + route_key = "".join(str(route_key).split()) + + return { + "type": "receive", + "route": raw_route, # original route string + "route_key": route_key, # stable route representative + "priority": receiver_entry["priority"], + "source": get_callable_source(fn, receiver_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNASender(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_senders list + """ + fn: SEND_DECORATED_TYPE + route: str + multi: bool + on_triggers: Optional[Set[Any] | Set[Signal]] + on_actions: Optional[Set[Any] | Set[Type]] + source: Optional[str] + +def sender_entry_contribution( + sender_entry: DNASender, + flow_in_use: Optional[Flow]) -> dict[str, str | bool | List[str]]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = sender_entry["fn"] + raw_route = sender_entry["route"] + + try: + if flow_in_use is not None: + route_key = str(flow_in_use.parse_route(raw_route)) + else: + route_key = raw_route + except Exception: #pylint:disable=broad-exception-caught + route_key = raw_route + route_key = "".join(str(route_key).split()) + + return { + "type": "send", + "route": raw_route, # original route string + "route_key": route_key, # stable route representative + "multi": sender_entry["multi"], + # Serialize triggers/actions by name so they can be re-resolved later. + "on_triggers": [t.name for t in (sender_entry["on_triggers"] or [])], + "on_actions": [a.__name__ for a in (sender_entry["on_actions"] or [])], + "source": get_callable_source(fn, sender_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNA_UPLOAD(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_upload_states entry + """ + fn: UPLOAD_TYPE + source: str + +def upload_entry_contribution(upload_entry: DNA_UPLOAD) -> dict[str, str]: + """ + The contribution of this entry to the overall DNA dict. + """ + return { + "type": "upload_states", + "source": get_callable_source(upload_entry["fn"], upload_entry["source"]), + "module": upload_entry["fn"].__module__, + "fn_name": upload_entry["fn"].__name__, + } + +class DNA_DOWNLOAD(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_download_states entry + """ + fn: DOWNLOAD_TYPE + source: str + +def download_entry_contribution(download_entry: DNA_DOWNLOAD) -> dict[str, str]: + """ + The contribution of this entry to the overall DNA dict. + """ + return { + "type": "download_states", + "source": get_callable_source(download_entry["fn"], download_entry["source"]), + "module": download_entry["fn"].__module__, + "fn_name": download_entry["fn"].__name__, + } diff --git a/summoner/client/just_merger.py b/summoner/client/just_merger.py new file mode 100644 index 0000000..2a048e2 --- /dev/null +++ b/summoner/client/just_merger.py @@ -0,0 +1,894 @@ +""" +merger.py + +This module provides two related utilities built on top of SummonerClient: + +1) ClientMerger + Build a single composite SummonerClient by replaying handlers from multiple sources. + + A "source" can be: + - an imported SummonerClient instance (live Python object), or + - a DNA list (already loaded JSON list[dict]), or + - a DNA JSON file path. + + Imported-client sources: + - handlers keep their original module globals (module-backed execution), + - the original client binding (for example the name "agent") is rebound to the merged client, + - optional rebind_globals are injected into handler globals. + + DNA sources: + - handlers are reconstructed by compiling their recorded source text into an isolated + sandbox module (one sandbox per DNA source), + - the sandbox binds var_name (for example "agent") to the merged client instance, + so handler code that references `agent` executes against the composite client, + - optional context (imports, globals, recipes) is applied into the sandbox. + + Usage pattern: + - instantiate ClientMerger(...) + - configure flow / styles as usual on the merged client if desired + - call agent.initiate_all() to replay handlers onto the merged client + - call agent.run(...) + +2) ClientTranslation + Reconstruct a fresh SummonerClient from a DNA list. + + Translation compiles handler functions from their recorded source into a fresh sandbox module, + binds var_name (for example "agent") to the translated client, then registers the handlers + using the normal decorators. + +Security and trust model +------------------------ +Both classes execute code from DNA via exec() and eval(): + +- context imports (ctx["imports"]) +- recipes (ctx["recipes"]) +- handler bodies (entry["source"]) + +This is intended for trusted DNA (typically produced by your own agents). +Do not run untrusted DNA. +""" +#pylint:disable=line-too-long, wrong-import-position +#pylint:disable=invalid-name, logging-fstring-interpolation + +from typing import Dict, List, Literal, Optional, TypeGuard, TypedDict +from typing import Any +from contextlib import suppress +from pathlib import Path +import inspect +import asyncio +import types +import re +import json +import uuid + +import os +import sys +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.client.client import SummonerClient +from summoner.protocol.triggers import Action, load_triggers +from summoner.protocol.process import Direction + + +def _resolve_trigger(TriggerCls, name: str) -> Any: + """ + Resolve a trigger name into a trigger instance from TriggerCls. + + DNA stores triggers as strings. This helper supports the two common access patterns: + - Enum-style indexing: TriggerCls["ok"] + - Attribute access: TriggerCls.ok + + Parameters + ---------- + TriggerCls: + Trigger class or enum-like object returned by flow.triggers() or load_triggers(). + name: + Trigger name as stored in DNA. + + Returns + ------- + Any + The resolved trigger value. + + Raises + ------ + KeyError + If the trigger cannot be resolved. + """ + # Enum-style: TriggerCls["ok"] + try: + return TriggerCls[name] + except Exception:# pylint:disable=broad-exception-caught + pass + # Attribute-style: TriggerCls.ok + try: + return getattr(TriggerCls, name) + except Exception:# pylint:disable=broad-exception-caught + pass + raise KeyError(f"Unknown trigger '{name}' for {TriggerCls}") + + +def _resolve_action(ActionCls, name: str): + """ + Resolve an action name into the corresponding Action entry. + + DNA stores actions as strings. Depending on how a sender was serialized, the name can be: + - the enum attribute name ("MOVE") + - a mixed-case name ("Move") + - the underlying class name (Move.__name__ == "Move") + + Parameters + ---------- + ActionCls: + The Action container used by the protocol layer (typically summoner.protocol.triggers.Action). + name: + Action name as stored in DNA. + + Returns + ------- + Any + The resolved Action entry. + + Raises + ------ + KeyError + If the action cannot be resolved. + """ + # 1) Try direct attribute match: "MOVE" + if hasattr(ActionCls, name): + return getattr(ActionCls, name) + + # 2) Try uppercased: "Move" -> "MOVE" + up = name.upper() + if hasattr(ActionCls, up): + return getattr(ActionCls, up) + + # 3) Try matching the underlying class/function name: + # Action.MOVE == Move, where Move.__name__ == "Move" + for v in ActionCls.__dict__.values(): + if getattr(v, "__name__", None) == name: + return v + + raise KeyError(f"Unknown action '{name}' for {ActionCls}") + +class StructuredReport(TypedDict): + """ + This can be used as a dict[str, Any] + but fixed with keys label, succeeded, failed, skipped + This is the information returned by _apply_context + which says what happened for each + line in the ctx["imports"] + as far as whether it succeeded, failed with an error, or was skipped due to settings. + If ctx was None or there was no key for imports, then this report will have empty lists. + """ + label: str + succeeded: list[str] + failed: list[tuple[str, str]] # list of (item, error) + skipped: list[str] + +class NormalizedClientSource(TypedDict): + """ + See _normalize_source for the meaning of these fields. + """ + kind: Literal["client"] + var_name: str + client: SummonerClient + +class NormalizedDNASource(TypedDict): + """ + See _normalize_source for the meaning of these fields. + """ + kind: Literal["dna"] + var_name: str + dna_entries: List[Dict[Any,Any]] + context: Optional[Dict[Any,Any]] + sandbox_name: str + globals: Dict[Any,Any] + import_report: StructuredReport + +def just_client_source(arbitrary_source: NormalizedDNASource | NormalizedClientSource) -> TypeGuard[NormalizedClientSource]: + """ + A type guard to help with the fact that self.sources is a list of two different dict types. + """ + return arbitrary_source["kind"] == "client" + +class ClientMerger(SummonerClient): + """ + Merge multiple sources into one client. + + Each input source can be: + - an imported SummonerClient instance (module-backed execution), + - a DNA list (list[dict]), + - a DNA JSON file path. + + Imported-client sources + ----------------------- + - Handlers keep their original module globals. + - The original client binding (var_name such as "agent") is rebound to the merged client. + - Optional rebind_globals are injected into handler globals. + - Note: rebinding mutates handler globals. This is intentional. + + DNA sources + ----------- + - Each DNA source gets its own sandbox module (isolated globals dict). + - var_name is bound to the merged client in that sandbox, so handler code referencing + `agent` executes against the composite client. + - Optional context imports/globals/recipes are executed in the sandbox. + + Execution model + --------------- + ClientMerger does not automatically register handlers during __init__. + You must call initiate_all() (or initiate_* individually) before run(). + + Safety + ------ + This class executes trusted code from DNA via exec()/eval(). Do not run untrusted DNA. + """ + + # pylint:disable=too-many-arguments, too-many-positional-arguments + def __init__( + self, + named_clients: list[Any], # backward compatible: list[dict] or list[SummonerClient] or list[dna_list] + name: Optional[str] = None, + rebind_globals: Optional[dict[str, Any]] = None, + allow_context_imports: bool = True, + verbose_context_imports: bool = False, + close_subclients: bool = True, + ): + super().__init__(name=name) + + # Globals injected into: + # - sandbox globals dicts (DNA sources), and + # - imported handler globals (imported-client sources). + # This is how "missing" symbols like Trigger or shared objects are supplied. + self._rebind_globals = dict(rebind_globals or {}) + + # Context controls: + # - allow_context_imports: execute import lines found in DNA context + # - verbose_context_imports: log successes as well as failures + self._allow_context_imports = allow_context_imports + self._verbose_context_imports = verbose_context_imports + + # If True, imported template clients are cleaned up after extraction: + # cancel/drain registration tasks and close their event loops when possible. + # This reduces warnings when importing agent scripts as templates. + self._close_subclients = close_subclients + + # Normalized sources used later by initiate_* replay methods. + self.sources: list[NormalizedClientSource | NormalizedDNASource] = [] + self._import_reports: list[StructuredReport] = [] + + for idx, entry in enumerate(named_clients): + src = self._normalize_source(entry, idx) + self.sources.append(src) + + if self._close_subclients: + self._shutdown_imported_clients() + + # ---------------------------- + # Source normalization + # ---------------------------- + + #pylint:disable=too-many-branches + def _normalize_source(self, + entry: SummonerClient | List[Dict[Any,Any]] | dict[str, Any], + idx: int) -> NormalizedClientSource | NormalizedDNASource: + """ + Normalize a user-provided source specification into a canonical dict. + + Accepted inputs + --------------- + - SummonerClient instance + - DNA list (list[dict]) + - dict containing one of: {"client"}, {"dna_list"}, {"dna_path"} + + Normalized output + ----------------- + Returns a dict with at least: + - kind: "client" or "dna" + - var_name: global name used by handler sources to refer to the client ("agent" by default) + + For kind="client": + - client: the imported SummonerClient instance + + For kind="dna": + - dna_entries: handler entries (context removed if present) + - context: optional __context__ entry + - sandbox_name: unique module name + - globals: sandbox globals dict where code is compiled and executed + - import_report: best-effort report for context imports + """ + # Allow passing a SummonerClient directly + if isinstance(entry, SummonerClient): + entry = {"client": entry} + + # Allow passing a dna_list directly + if isinstance(entry, list): + entry = {"dna_list": entry} + + if not isinstance(entry, dict): + raise TypeError(f"Entry #{idx} must be dict | SummonerClient | dna_list, got {type(entry).__name__}") + + if "client" in entry: + client = entry["client"] + if not isinstance(client, SummonerClient): + raise TypeError(f"Entry #{idx} 'client' must be SummonerClient, got {type(client).__name__}") + + # var_name controls which global name will be rebound to the merged client. + var_name = entry.get("var_name") + if var_name is None: + var_name = self._infer_client_var_name(client) or "agent" + if not isinstance(var_name, str): + raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") + + return { + "kind": "client", + "client": client, + "var_name": var_name, + } + + # DNA sources + dna_list = None + if "dna_list" in entry: + dna_list = entry["dna_list"] + elif "dna_path" in entry: + p = Path(entry["dna_path"]) + dna_list = json.loads(p.read_text(encoding="utf-8")) + else: + raise KeyError(f"Entry #{idx} must contain 'client' or 'dna_list' or 'dna_path'") + + if not isinstance(dna_list, list): + raise TypeError(f"Entry #{idx} DNA must be a list, got {type(dna_list).__name__}") + + # Optional context header is stored in DNA as the first entry. + ctx = None + entries = dna_list + if entries and isinstance(entries[0], dict) and entries[0].get("type") == "__context__": + ctx = entries[0] + entries = entries[1:] + + # Determine var_name binding: + # - explicit var_name in entry wins + # - else use ctx["var_name"] + # - else default to "agent" + var_name = entry.get("var_name") + if var_name is None: + ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None + var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" + if not isinstance(var_name, str): + raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") + + # Each DNA source gets its own sandbox module (isolated globals dict). + sandbox_module_name = f"summoner_merge_{uuid.uuid4().hex}" + sandbox_module = types.ModuleType(sandbox_module_name) + sys.modules[sandbox_module_name] = sandbox_module + g = sandbox_module.__dict__ + + # Bind the client name used inside handler source to the merged client. + # This makes `await agent.travel_to(...)` act on the composite client. + g[var_name] = self + + # Apply context (imports/globals/recipes) into that sandbox. + report = self._apply_context(ctx, g, label=f"source#{idx}") + self._import_reports.append(report) + + return { + "kind": "dna", + "dna_entries": entries, + "context": ctx, + "var_name": var_name, + "sandbox_name": sandbox_module_name, + "globals": g, + "import_report": report, + } + + def _infer_client_var_name(self, client: SummonerClient) -> Optional[str]: + """ + Infer the module-global variable name used by handlers to refer to `client`. + + This is used for imported-client sources so that we can rebind that name + (for example "agent") to the merged client in handler globals. + + Returns + ------- + Optional[str] + The inferred binding name, or None if not found. + """ + # Look for a module-global name whose value is exactly `client` + # TODO: can use _view_candidates instead of building a list + candidates = [] + #pylint:disable=protected-access + if client._upload_states is not None: + candidates.append(client._upload_states) + if client._download_states is not None: + candidates.append(client._download_states) + for d in client._dna_receivers: + candidates.append(d.get("fn")) + for d in client._dna_senders: + candidates.append(d.get("fn")) + for d in client._dna_hooks: + candidates.append(d.get("fn")) + + for fn in candidates: + if fn is None: + continue + g = getattr(fn, "__globals__", None) + if not isinstance(g, dict): + continue + for k, v in g.items(): + if v is client and isinstance(k, str): + return k + return None + + def _shutdown_imported_clients(self) -> None: + """ + Best-effort cleanup for imported template clients. + + Why this exists + -------------- + Many agent scripts create a SummonerClient at import-time. That client creates + an event loop and schedules registration tasks. + + When agent scripts are imported only as templates for merging, those template + clients should not be left alive, otherwise you often see: + - "coroutine was never awaited" + - "Task was destroyed but it is pending" + + Cleanup approach + ---------------- + For each imported client: + 1) cancel pending registration tasks + 2) if its loop is not running, drive the loop to await cancellations + 3) close the loop + 4) clear the template's registration list + """ + for src in self.sources: + if not just_client_source(src): + continue + + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + + #pylint:disable=protected-access + tasks = list(client._registration_tasks or []) + loop = client.loop + + # Nothing to do. + if not tasks and loop.is_closed(): + continue + + # 1) cancel tasks + for t in tasks: + with suppress(Exception): + t.cancel() + + # 2) drain tasks on that loop so they are actually awaited + if loop is not None and not loop.is_closed(): + if loop.is_running(): + # Can't safely run_until_complete; also shouldn't close a running loop. + self.logger.warning( + f"[{var_name}] Imported client loop is running; " + f"cannot drain/close registration tasks cleanly." + ) + else: + old_loop = None + try: + # Set context so asyncio.gather/futures bind to the right loop. + with suppress(Exception):# pylint:disable=broad-exception-caught + old_loop = asyncio.get_event_loop_policy().get_event_loop() + asyncio.set_event_loop(loop) + + # Await cancellation. This is what prevents warnings. + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[{var_name}] Error draining registration tasks: {e}") + finally: + # Restore previous loop context (or clear it). + with suppress(Exception): + asyncio.set_event_loop(old_loop) + + # 3) close loop after drain + try: + loop.close() + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[{var_name}] Error closing event loop: {e}") + + # 4) clear list + with suppress(Exception): + #pylint:disable=protected-access + client._registration_tasks.clear() + + # ---------------------------- + # Context application (DNA) + # ---------------------------- + + # pylint:disable=too-many-branches + def _apply_context(self, ctx: Optional[dict[Any,Any]], g: dict[Any,Any], *, label: str) -> StructuredReport: + """ + Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. + + This is best-effort: + - import failures are recorded and logged + - recipe failures are logged but do not abort merge + + Parameters + ---------- + ctx: + The optional "__context__" DNA header entry. + g: + The sandbox globals dict (module.__dict__) where code will be executed. + label: + Used for logging and for the returned report. + + Returns + ------- + StructurcedReport + can be used as a dict[str, Any] and has keys: label, succeeded, failed, skipped. + so how it is used remains the same as a dict, but it is more explicit about the expected structure. + + Security note + ------------- + This executes code from ctx (imports and recipes). Use only with trusted DNA. + """ + report : StructuredReport = {"label": label, "succeeded": [], "failed": [], "skipped": []} + + if not isinstance(ctx, dict): + return report + + # imports (executed inside sandbox namespace) + for line in (ctx.get("imports") or []): + if not isinstance(line, str) or not line.strip(): + continue + + if not self._allow_context_imports: + report["skipped"].append(line) + continue + + try: + # pylint:disable=exec-used + exec(line, g) + report["succeeded"].append(line) + if self._verbose_context_imports: + self.logger.info(f"[merge ctx:{label}] import ok: {line}") + except Exception as e:# pylint:disable=broad-exception-caught + report["failed"].append((line, f"{type(e).__name__}: {e}")) + self.logger.warning(f"[merge ctx:{label}] import failed: {line!r} ({type(e).__name__}: {e})") + + # plain globals (already JSON-friendly) + globs = ctx.get("globals") or {} + if isinstance(globs, dict): + for k, v in globs.items(): + if isinstance(k, str): + g.setdefault(k, v) + + # recipes (evaluated inside the sandbox namespace) + recipes = ctx.get("recipes") or {} + if isinstance(recipes, dict): + for k, expr in recipes.items(): + if not (isinstance(k, str) and isinstance(expr, str)): + continue + try: + # pylint:disable=eval-used + g.setdefault(k, eval(expr, g, {})) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") + + return report + + # ---------------------------- + # Utility: source patch for getsource capture + # ---------------------------- + + def _apply_with_source_patch(self, decorator, fn, source: str): + """ + Apply a SummonerClient decorator while preserving DNA source text. + + Why patch inspect.getsource + --------------------------- + The base decorators record handler source using inspect.getsource(fn). + When functions are constructed from DNA via exec(), inspect.getsource(fn) + will typically fail because there is no real source file. + + This helper temporarily patches inspect.getsource so the decorator stores + the original DNA text. This patch is process-global, so it should be used + only in controlled, single-threaded contexts (typical CLI usage). + """ + orig = inspect.getsource + inspect.getsource = lambda o: source + try: + decorator(fn) + finally: + inspect.getsource = orig + + # ---------------------------- + # Imported-client handler cloning + # ---------------------------- + + def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.FunctionType: + """ + Clone a function object for imported-client sources while preserving module globals. + + Behavior + -------- + - Mutates fn.__globals__ in-place: + - rebind original_name (for example "agent") to the merged client + - inject rebind_globals into the same globals dict + - Creates a new function object using the original code object and closure. + + This preserves the module-backed environment of imported clients. The cost is that + imported handler globals are modified, which is intentional for merge semantics. + """ + g = fn.__globals__ + + # rebind the client variable name (agent/client/etc) + try: + g[original_name] = self + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"Could not bind '{original_name}' to merged client: {e}") + + # rebind any shared globals (viz, Trigger, etc.) + for k, v in self._rebind_globals.items(): + try: + g[k] = v + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"Could not bind global '{k}' in '{fn.__name__}': {e}") + + new_fn = types.FunctionType( + fn.__code__, + g, + name=fn.__name__, + argdefs=fn.__defaults__, + closure=fn.__closure__, + ) + new_fn.__annotations__ = getattr(fn, "__annotations__", {}) + new_fn.__doc__ = getattr(fn, "__doc__", None) + + # if your dna() uses a __dna_source__ fallback, keep it + if hasattr(fn, "__dna_source__"): + new_fn.__dna_source__ = fn.__dna_source__ # pyright: ignore[reportFunctionMemberAccess] + + return new_fn + + + # ---------------------------- + # DNA compilation (per-source sandbox) + # ---------------------------- + + def _make_from_source(self, entry: dict[str, Any], g: dict, sandbox_name: str) -> types.FunctionType: + """ + Build a function object from a DNA entry by executing its function body in a sandbox. + + Key rule: strip decorators + -------------------------- + DNA sources typically include decorator lines (for example "@agent.receive(...)"). + We skip all decorators and only exec the "def ..." block. Otherwise, compiling the + function would also register it immediately, which would duplicate handlers. + + Globals + ------- + - rebind_globals is injected into g before exec so required runtime symbols + (for example Trigger, viz, shared objects) are visible to the compiled function. + - var_name is already bound to the merged client in g during source normalization. + """ + fn_name = entry["fn_name"] + raw = entry["source"] + lines = raw.splitlines() + + # --------------------------------------------------------------------- + # 1) Find the def line for this function, skipping decorators. + # --------------------------------------------------------------------- + def_idx = None + for idx, line in enumerate(lines): + pat = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" + if re.match(pat, line): + def_idx = idx + break + if def_idx is None: + raise RuntimeError(f"Could not find def for '{fn_name}'") + + func_body = "\n".join(lines[def_idx:]) + + # --------------------------------------------------------------------- + # 2) Ensure rebinding happens in the same globals dict used by exec(). + # --------------------------------------------------------------------- + rebind = self._rebind_globals + if isinstance(rebind, dict) and rebind: + g.update(rebind) + + # --------------------------------------------------------------------- + # 3) Execute and retrieve the resulting function object. + # --------------------------------------------------------------------- + if "__builtins__" not in g: + g["__builtins__"] = __builtins__ + + # pylint:disable=exec-used + exec(compile(func_body, filename=f"<{sandbox_name}>", mode="exec"), g) + + fn = g.get(fn_name) + if not isinstance(fn, types.FunctionType): + raise RuntimeError(f"Failed to construct function '{fn_name}'") + + return fn + + + # ---------------------------- + # Public replay API + # ---------------------------- + + def initiate_upload_states(self): + """Replay @upload_states from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + fn = client._upload_states + if fn is None: + continue + fn_clone = self._clone_handler(fn, var_name) + try: + self.upload_states()(fn_clone) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[{var_name}] Failed to replay upload_states '{fn.__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "upload_states": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.upload_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_download_states(self): + """Replay @download_states from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + fn = client._download_states + if fn is None: + continue + fn_clone = self._clone_handler(fn, var_name) + try: + self.download_states()(fn_clone) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[{var_name}] Failed to replay download_states '{fn.__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "download_states": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.download_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_hooks(self): + """Replay @hook(Direction, priority=...) from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_hooks: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.hook(dna["direction"], priority=dna["priority"])(fn_clone) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[{var_name}] Failed to replay hook '{dna['fn'].__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "hook": + continue + fn = self._make_from_source(entry, g, sandbox) + direction = Direction[entry["direction"]] if isinstance(entry.get("direction"), str) else entry["direction"] + dec = self.hook(direction, priority=tuple(entry.get("priority", ()))) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_receivers(self): + """Replay @receive(route, priority=...) from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_receivers: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.receive(dna["route"], priority=dna["priority"])(fn_clone) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning( + f"[{var_name}] Failed to replay receiver '{dna['fn'].__name__}' on route '{dna['route']}': {e}" + ) + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "receive": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.receive(entry["route"], priority=tuple(entry.get("priority", ()))) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_senders(self): + """ + Replay @send(route, multi, on_triggers, on_actions) from every source onto the merged client. + + Imported-client sources: + - carry actual trigger/action objects in _dna_senders. + + DNA sources: + - store trigger/action names as strings. + - triggers are resolved using TriggerCls: + - prefer Trigger class provided by sandbox context ("Trigger" in sandbox globals) + - else fall back to load_triggers() + - actions are resolved from Action by name via _resolve_action. + """ + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_senders: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.send( + dna["route"], + multi=dna["multi"], + on_triggers=dna["on_triggers"], + on_actions=dna["on_actions"], + )(fn_clone) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning( + f"[{var_name}] Failed to replay sender '{dna['fn'].__name__}' on route '{dna['route']}': {e}" + ) + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + + # Triggers: prefer a Trigger class provided by sandbox context; otherwise load defaults. + TriggerCls = g.get("Trigger") + if TriggerCls is None: + TriggerCls = load_triggers() + + for entry in src["dna_entries"]: + if entry.get("type") != "send": + continue + fn = self._make_from_source(entry, g, sandbox) + on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None + on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None + dec = self.send( + entry["route"], + multi=entry.get("multi", False), + on_triggers=on_triggers, + on_actions=on_actions, + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_all(self): + """ + Replay all supported handler types in a standard order. + + This should be called before run(). The order matches SummonerClient.dna(): + 1) upload_states + 2) download_states + 3) hooks + 4) receivers + 5) senders + """ + self.initiate_upload_states() + self.initiate_download_states() + self.initiate_hooks() + self.initiate_receivers() + self.initiate_senders() diff --git a/summoner/client/merger.py b/summoner/client/merger.py index 3a3bad4..0c6afca 100644 --- a/summoner/client/merger.py +++ b/summoner/client/merger.py @@ -1,1234 +1,7 @@ """ -merger.py - -This module provides two related utilities built on top of SummonerClient: - -1) ClientMerger - Build a single composite SummonerClient by replaying handlers from multiple sources. - - A "source" can be: - - an imported SummonerClient instance (live Python object), or - - a DNA list (already loaded JSON list[dict]), or - - a DNA JSON file path. - - Imported-client sources: - - handlers keep their original module globals (module-backed execution), - - the original client binding (for example the name "agent") is rebound to the merged client, - - optional rebind_globals are injected into handler globals. - - DNA sources: - - handlers are reconstructed by compiling their recorded source text into an isolated - sandbox module (one sandbox per DNA source), - - the sandbox binds var_name (for example "agent") to the merged client instance, - so handler code that references `agent` executes against the composite client, - - optional context (imports, globals, recipes) is applied into the sandbox. - - Usage pattern: - - instantiate ClientMerger(...) - - configure flow / styles as usual on the merged client if desired - - call agent.initiate_all() to replay handlers onto the merged client - - call agent.run(...) - -2) ClientTranslation - Reconstruct a fresh SummonerClient from a DNA list. - - Translation compiles handler functions from their recorded source into a fresh sandbox module, - binds var_name (for example "agent") to the translated client, then registers the handlers - using the normal decorators. - -Security and trust model ------------------------- -Both classes execute code from DNA via exec() and eval(): - -- context imports (ctx["imports"]) -- recipes (ctx["recipes"]) -- handler bodies (entry["source"]) - -This is intended for trusted DNA (typically produced by your own agents). -Do not run untrusted DNA. +merger.py had both ClientMerger and ClientTranslation. +Make sure that still holds so that any from summoner.client.merger ... still hold """ - -from importlib import import_module -from typing import Optional, Any -from contextlib import suppress -from pathlib import Path -import inspect -import asyncio -import types -import re -import json -import uuid - -import os, sys -target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) -if target_path not in sys.path: - sys.path.insert(0, target_path) - -from summoner.client.client import SummonerClient -from summoner.protocol.triggers import Action, load_triggers -from summoner.protocol.process import Direction - - -def _resolve_trigger(TriggerCls, name: str): - """ - Resolve a trigger name into a trigger instance from TriggerCls. - - DNA stores triggers as strings. This helper supports the two common access patterns: - - Enum-style indexing: TriggerCls["ok"] - - Attribute access: TriggerCls.ok - - Parameters - ---------- - TriggerCls: - Trigger class or enum-like object returned by flow.triggers() or load_triggers(). - name: - Trigger name as stored in DNA. - - Returns - ------- - Any - The resolved trigger value. - - Raises - ------ - KeyError - If the trigger cannot be resolved. - """ - # Enum-style: TriggerCls["ok"] - try: - return TriggerCls[name] - except Exception: - pass - # Attribute-style: TriggerCls.ok - try: - return getattr(TriggerCls, name) - except Exception: - pass - raise KeyError(f"Unknown trigger '{name}' for {TriggerCls}") - - -def _resolve_action(ActionCls, name: str): - """ - Resolve an action name into the corresponding Action entry. - - DNA stores actions as strings. Depending on how a sender was serialized, the name can be: - - the enum attribute name ("MOVE") - - a mixed-case name ("Move") - - the underlying class name (Move.__name__ == "Move") - - Parameters - ---------- - ActionCls: - The Action container used by the protocol layer (typically summoner.protocol.triggers.Action). - name: - Action name as stored in DNA. - - Returns - ------- - Any - The resolved Action entry. - - Raises - ------ - KeyError - If the action cannot be resolved. - """ - # 1) Try direct attribute match: "MOVE" - if hasattr(ActionCls, name): - return getattr(ActionCls, name) - - # 2) Try uppercased: "Move" -> "MOVE" - up = name.upper() - if hasattr(ActionCls, up): - return getattr(ActionCls, up) - - # 3) Try matching the underlying class/function name: - # Action.MOVE == Move, where Move.__name__ == "Move" - for v in ActionCls.__dict__.values(): - if getattr(v, "__name__", None) == name: - return v - - raise KeyError(f"Unknown action '{name}' for {ActionCls}") - - -class ClientMerger(SummonerClient): - """ - Merge multiple sources into one client. - - Each input source can be: - - an imported SummonerClient instance (module-backed execution), - - a DNA list (list[dict]), - - a DNA JSON file path. - - Imported-client sources - ----------------------- - - Handlers keep their original module globals. - - The original client binding (var_name such as "agent") is rebound to the merged client. - - Optional rebind_globals are injected into handler globals. - - Note: rebinding mutates handler globals. This is intentional. - - DNA sources - ----------- - - Each DNA source gets its own sandbox module (isolated globals dict). - - var_name is bound to the merged client in that sandbox, so handler code referencing - `agent` executes against the composite client. - - Optional context imports/globals/recipes are executed in the sandbox. - - Execution model - --------------- - ClientMerger does not automatically register handlers during __init__. - You must call initiate_all() (or initiate_* individually) before run(). - - Safety - ------ - This class executes trusted code from DNA via exec()/eval(). Do not run untrusted DNA. - """ - - def __init__( - self, - named_clients: list[Any], # backward compatible: list[dict] or list[SummonerClient] or list[dna_list] - name: Optional[str] = None, - rebind_globals: Optional[dict[str, Any]] = None, - allow_context_imports: bool = True, - verbose_context_imports: bool = False, - close_subclients: bool = True, - ): - super().__init__(name=name) - - # Globals injected into: - # - sandbox globals dicts (DNA sources), and - # - imported handler globals (imported-client sources). - # This is how "missing" symbols like Trigger or shared objects are supplied. - self._rebind_globals = dict(rebind_globals or {}) - - # Context controls: - # - allow_context_imports: execute import lines found in DNA context - # - verbose_context_imports: log successes as well as failures - self._allow_context_imports = allow_context_imports - self._verbose_context_imports = verbose_context_imports - - # If True, imported template clients are cleaned up after extraction: - # cancel/drain registration tasks and close their event loops when possible. - # This reduces warnings when importing agent scripts as templates. - self._close_subclients = close_subclients - - # Normalized sources used later by initiate_* replay methods. - self.sources: list[dict[str, Any]] = [] - self._import_reports: list[dict[str, Any]] = [] - - for idx, entry in enumerate(named_clients): - src = self._normalize_source(entry, idx) - self.sources.append(src) - - if self._close_subclients: - self._shutdown_imported_clients() - - # ---------------------------- - # Source normalization - # ---------------------------- - - def _normalize_source(self, entry: Any, idx: int) -> dict[str, Any]: - """ - Normalize a user-provided source specification into a canonical dict. - - Accepted inputs - --------------- - - SummonerClient instance - - DNA list (list[dict]) - - dict containing one of: {"client"}, {"dna_list"}, {"dna_path"} - - Normalized output - ----------------- - Returns a dict with at least: - - kind: "client" or "dna" - - var_name: global name used by handler sources to refer to the client ("agent" by default) - - For kind="client": - - client: the imported SummonerClient instance - - For kind="dna": - - dna_entries: handler entries (context removed if present) - - context: optional __context__ entry - - sandbox_name: unique module name - - globals: sandbox globals dict where code is compiled and executed - - import_report: best-effort report for context imports - """ - # Allow passing a SummonerClient directly - if isinstance(entry, SummonerClient): - entry = {"client": entry} - - # Allow passing a dna_list directly - if isinstance(entry, list): - entry = {"dna_list": entry} - - if not isinstance(entry, dict): - raise TypeError(f"Entry #{idx} must be dict | SummonerClient | dna_list, got {type(entry).__name__}") - - if "client" in entry: - client = entry["client"] - if not isinstance(client, SummonerClient): - raise TypeError(f"Entry #{idx} 'client' must be SummonerClient, got {type(client).__name__}") - - # var_name controls which global name will be rebound to the merged client. - var_name = entry.get("var_name") - if var_name is None: - var_name = self._infer_client_var_name(client) or "agent" - if not isinstance(var_name, str): - raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") - - return { - "kind": "client", - "client": client, - "var_name": var_name, - } - - # DNA sources - dna_list = None - if "dna_list" in entry: - dna_list = entry["dna_list"] - elif "dna_path" in entry: - p = Path(entry["dna_path"]) - dna_list = json.loads(p.read_text(encoding="utf-8")) - else: - raise KeyError(f"Entry #{idx} must contain 'client' or 'dna_list' or 'dna_path'") - - if not isinstance(dna_list, list): - raise TypeError(f"Entry #{idx} DNA must be a list, got {type(dna_list).__name__}") - - # Optional context header is stored in DNA as the first entry. - ctx = None - entries = dna_list - if entries and isinstance(entries[0], dict) and entries[0].get("type") == "__context__": - ctx = entries[0] - entries = entries[1:] - - # Determine var_name binding: - # - explicit var_name in entry wins - # - else use ctx["var_name"] - # - else default to "agent" - var_name = entry.get("var_name") - if var_name is None: - ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None - var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" - if not isinstance(var_name, str): - raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") - - # Each DNA source gets its own sandbox module (isolated globals dict). - sandbox_module_name = f"summoner_merge_{uuid.uuid4().hex}" - sandbox_module = types.ModuleType(sandbox_module_name) - sys.modules[sandbox_module_name] = sandbox_module - g = sandbox_module.__dict__ - - # Bind the client name used inside handler source to the merged client. - # This makes `await agent.travel_to(...)` act on the composite client. - g[var_name] = self - - # Apply context (imports/globals/recipes) into that sandbox. - report = self._apply_context(ctx, g, label=f"source#{idx}") - self._import_reports.append(report) - - return { - "kind": "dna", - "dna_entries": entries, - "context": ctx, - "var_name": var_name, - "sandbox_name": sandbox_module_name, - "globals": g, - "import_report": report, - } - - def _infer_client_var_name(self, client: SummonerClient) -> Optional[str]: - """ - Infer the module-global variable name used by handlers to refer to `client`. - - This is used for imported-client sources so that we can rebind that name - (for example "agent") to the merged client in handler globals. - - Returns - ------- - Optional[str] - The inferred binding name, or None if not found. - """ - # Look for a module-global name whose value is exactly `client` - candidates = [] - if client._upload_states is not None: - candidates.append(client._upload_states) - if client._download_states is not None: - candidates.append(client._download_states) - for d in client._dna_receivers: - candidates.append(d.get("fn")) - for d in client._dna_senders: - candidates.append(d.get("fn")) - for d in client._dna_hooks: - candidates.append(d.get("fn")) - - for fn in candidates: - if fn is None: - continue - g = getattr(fn, "__globals__", None) - if not isinstance(g, dict): - continue - for k, v in g.items(): - if v is client and isinstance(k, str): - return k - return None - - def _shutdown_imported_clients(self) -> None: - """ - Best-effort cleanup for imported template clients. - - Why this exists - -------------- - Many agent scripts create a SummonerClient at import-time. That client creates - an event loop and schedules registration tasks. - - When agent scripts are imported only as templates for merging, those template - clients should not be left alive, otherwise you often see: - - "coroutine was never awaited" - - "Task was destroyed but it is pending" - - Cleanup approach - ---------------- - For each imported client: - 1) cancel pending registration tasks - 2) if its loop is not running, drive the loop to await cancellations - 3) close the loop - 4) clear the template's registration list - """ - for src in self.sources: - if src.get("kind") != "client": - continue - - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - - tasks = list(client._registration_tasks or []) - loop = client.loop - - # Nothing to do. - if not tasks and loop.is_closed(): - continue - - # 1) cancel tasks - for t in tasks: - with suppress(Exception): - t.cancel() - - # 2) drain tasks on that loop so they are actually awaited - if loop is not None and not loop.is_closed(): - if loop.is_running(): - # Can't safely run_until_complete; also shouldn't close a running loop. - self.logger.warning( - f"[{var_name}] Imported client loop is running; " - f"cannot drain/close registration tasks cleanly." - ) - else: - old_loop = None - try: - # Set context so asyncio.gather/futures bind to the right loop. - with suppress(Exception): - old_loop = asyncio.get_event_loop_policy().get_event_loop() - asyncio.set_event_loop(loop) - - # Await cancellation. This is what prevents warnings. - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - except Exception as e: - self.logger.warning(f"[{var_name}] Error draining registration tasks: {e}") - finally: - # Restore previous loop context (or clear it). - with suppress(Exception): - asyncio.set_event_loop(old_loop) - - # 3) close loop after drain - try: - loop.close() - except Exception as e: - self.logger.warning(f"[{var_name}] Error closing event loop: {e}") - - # 4) clear list - with suppress(Exception): - client._registration_tasks.clear() - - # ---------------------------- - # Context application (DNA) - # ---------------------------- - - def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[str, Any]: - """ - Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. - - This is best-effort: - - import failures are recorded and logged - - recipe failures are logged but do not abort merge - - Parameters - ---------- - ctx: - The optional "__context__" DNA header entry. - g: - The sandbox globals dict (module.__dict__) where code will be executed. - label: - Used for logging and for the returned report. - - Returns - ------- - dict[str, Any] - A structured report with keys: label, succeeded, failed, skipped. - - Security note - ------------- - This executes code from ctx (imports and recipes). Use only with trusted DNA. - """ - report = {"label": label, "succeeded": [], "failed": [], "skipped": []} - - if not isinstance(ctx, dict): - return report - - # imports (executed inside sandbox namespace) - for line in (ctx.get("imports") or []): - if not isinstance(line, str) or not line.strip(): - continue - - if not self._allow_context_imports: - report["skipped"].append(line) - continue - - try: - exec(line, g) - report["succeeded"].append(line) - if self._verbose_context_imports: - self.logger.info(f"[merge ctx:{label}] import ok: {line}") - except Exception as e: - report["failed"].append((line, f"{type(e).__name__}: {e}")) - self.logger.warning(f"[merge ctx:{label}] import failed: {line!r} ({type(e).__name__}: {e})") - - # plain globals (already JSON-friendly) - globs = ctx.get("globals") or {} - if isinstance(globs, dict): - for k, v in globs.items(): - if isinstance(k, str): - g.setdefault(k, v) - - # recipes (evaluated inside the sandbox namespace) - recipes = ctx.get("recipes") or {} - if isinstance(recipes, dict): - for k, expr in recipes.items(): - if not (isinstance(k, str) and isinstance(expr, str)): - continue - try: - g.setdefault(k, eval(expr, g, {})) - except Exception as e: - self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") - - return report - - # ---------------------------- - # Utility: source patch for getsource capture - # ---------------------------- - - def _apply_with_source_patch(self, decorator, fn, source: str): - """ - Apply a SummonerClient decorator while preserving DNA source text. - - Why patch inspect.getsource - --------------------------- - The base decorators record handler source using inspect.getsource(fn). - When functions are constructed from DNA via exec(), inspect.getsource(fn) - will typically fail because there is no real source file. - - This helper temporarily patches inspect.getsource so the decorator stores - the original DNA text. This patch is process-global, so it should be used - only in controlled, single-threaded contexts (typical CLI usage). - """ - orig = inspect.getsource - inspect.getsource = lambda o: source - try: - decorator(fn) - finally: - inspect.getsource = orig - - # ---------------------------- - # Imported-client handler cloning - # ---------------------------- - - def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.FunctionType: - """ - Clone a function object for imported-client sources while preserving module globals. - - Behavior - -------- - - Mutates fn.__globals__ in-place: - - rebind original_name (for example "agent") to the merged client - - inject rebind_globals into the same globals dict - - Creates a new function object using the original code object and closure. - - This preserves the module-backed environment of imported clients. The cost is that - imported handler globals are modified, which is intentional for merge semantics. - """ - g = fn.__globals__ - - # rebind the client variable name (agent/client/etc) - try: - g[original_name] = self - except Exception as e: - self.logger.warning(f"Could not bind '{original_name}' to merged client: {e}") - - # rebind any shared globals (viz, Trigger, etc.) - for k, v in self._rebind_globals.items(): - try: - g[k] = v - except Exception as e: - self.logger.warning(f"Could not bind global '{k}' in '{fn.__name__}': {e}") - - new_fn = types.FunctionType( - fn.__code__, - g, - name=fn.__name__, - argdefs=fn.__defaults__, - closure=fn.__closure__, - ) - new_fn.__annotations__ = getattr(fn, "__annotations__", {}) - new_fn.__doc__ = getattr(fn, "__doc__", None) - - # if your dna() uses a __dna_source__ fallback, keep it - if hasattr(fn, "__dna_source__"): - new_fn.__dna_source__ = fn.__dna_source__ - - return new_fn - - - # ---------------------------- - # DNA compilation (per-source sandbox) - # ---------------------------- - - def _make_from_source(self, entry: dict[str, Any], g: dict, sandbox_name: str) -> types.FunctionType: - """ - Build a function object from a DNA entry by executing its function body in a sandbox. - - Key rule: strip decorators - -------------------------- - DNA sources typically include decorator lines (for example "@agent.receive(...)"). - We skip all decorators and only exec the "def ..." block. Otherwise, compiling the - function would also register it immediately, which would duplicate handlers. - - Globals - ------- - - rebind_globals is injected into g before exec so required runtime symbols - (for example Trigger, viz, shared objects) are visible to the compiled function. - - var_name is already bound to the merged client in g during source normalization. - """ - fn_name = entry["fn_name"] - raw = entry["source"] - lines = raw.splitlines() - - # --------------------------------------------------------------------- - # 1) Find the def line for this function, skipping decorators. - # --------------------------------------------------------------------- - def_idx = None - for idx, line in enumerate(lines): - pat = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" - if re.match(pat, line): - def_idx = idx - break - if def_idx is None: - raise RuntimeError(f"Could not find def for '{fn_name}'") - - func_body = "\n".join(lines[def_idx:]) - - # --------------------------------------------------------------------- - # 2) Ensure rebinding happens in the same globals dict used by exec(). - # --------------------------------------------------------------------- - rebind = self._rebind_globals - if isinstance(rebind, dict) and rebind: - g.update(rebind) - - # --------------------------------------------------------------------- - # 3) Execute and retrieve the resulting function object. - # --------------------------------------------------------------------- - if "__builtins__" not in g: - g["__builtins__"] = __builtins__ - - exec(compile(func_body, filename=f"<{sandbox_name}>", mode="exec"), g) - - fn = g.get(fn_name) - if not isinstance(fn, types.FunctionType): - raise RuntimeError(f"Failed to construct function '{fn_name}'") - - return fn - - - # ---------------------------- - # Public replay API - # ---------------------------- - - def initiate_upload_states(self): - """Replay @upload_states from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - fn = client._upload_states - if fn is None: - continue - fn_clone = self._clone_handler(fn, var_name) - try: - self.upload_states()(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay upload_states '{fn.__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "upload_states": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.upload_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_download_states(self): - """Replay @download_states from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - fn = client._download_states - if fn is None: - continue - fn_clone = self._clone_handler(fn, var_name) - try: - self.download_states()(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay download_states '{fn.__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "download_states": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.download_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_hooks(self): - """Replay @hook(Direction, priority=...) from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - for dna in client._dna_hooks: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.hook(dna["direction"], priority=dna["priority"])(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay hook '{dna['fn'].__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "hook": - continue - fn = self._make_from_source(entry, g, sandbox) - direction = Direction[entry["direction"]] if isinstance(entry.get("direction"), str) else entry["direction"] - dec = self.hook(direction, priority=tuple(entry.get("priority", ()))) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_receivers(self): - """Replay @receive(route, priority=...) from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - for dna in client._dna_receivers: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.receive(dna["route"], priority=dna["priority"])(fn_clone) - except Exception as e: - self.logger.warning( - f"[{var_name}] Failed to replay receiver '{dna['fn'].__name__}' on route '{dna['route']}': {e}" - ) - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "receive": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.receive(entry["route"], priority=tuple(entry.get("priority", ()))) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_senders(self): - """ - Replay @send(route, multi, on_triggers, on_actions) from every source onto the merged client. - - Imported-client sources: - - carry actual trigger/action objects in _dna_senders. - - DNA sources: - - store trigger/action names as strings. - - triggers are resolved using TriggerCls: - - prefer Trigger class provided by sandbox context ("Trigger" in sandbox globals) - - else fall back to load_triggers() - - actions are resolved from Action by name via _resolve_action. - """ - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - for dna in client._dna_senders: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.send( - dna["route"], - multi=dna["multi"], - on_triggers=dna["on_triggers"], - on_actions=dna["on_actions"], - )(fn_clone) - except Exception as e: - self.logger.warning( - f"[{var_name}] Failed to replay sender '{dna['fn'].__name__}' on route '{dna['route']}': {e}" - ) - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - - # Triggers: prefer a Trigger class provided by sandbox context; otherwise load defaults. - TriggerCls = g.get("Trigger") - if TriggerCls is None: - TriggerCls = load_triggers() - - for entry in src["dna_entries"]: - if entry.get("type") != "send": - continue - fn = self._make_from_source(entry, g, sandbox) - on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None - on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None - dec = self.send( - entry["route"], - multi=entry.get("multi", False), - on_triggers=on_triggers, - on_actions=on_actions, - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_all(self): - """ - Replay all supported handler types in a standard order. - - This should be called before run(). The order matches SummonerClient.dna(): - 1) upload_states - 2) download_states - 3) hooks - 4) receivers - 5) senders - """ - self.initiate_upload_states() - self.initiate_download_states() - self.initiate_hooks() - self.initiate_receivers() - self.initiate_senders() - - -class ClientTranslation(SummonerClient): - """ - Reconstruct a SummonerClient from its DNA list. - - Translation compiles handlers from their recorded source into a fresh sandbox module, - then registers them on this client via normal decorators. - - Execution environment - --------------------- - - Translated handlers do not run inside the original agent modules. - - They run inside the translation sandbox, with only explicitly imported or rebound - globals available (for example: Trigger, shared objects, viz, etc.). - - Context behavior - ---------------- - If the DNA begins with a "__context__" entry, translation may: - - exec() ctx["imports"] (if allow_context_imports=True) - - bind ctx["globals"] into sandbox - - eval() ctx["recipes"] into sandbox - - Template-client cleanup - ----------------------- - DNA entries carry a "module" field from the original capture. - Importing those modules can create template SummonerClient instances at import-time. - - This class attempts to find and clean up such template clients so they do not leave - pending registration tasks or open loops behind. - """ - def __init__( - self, - dna_list: list[dict[str, Any]], - name: Optional[str] = None, - var_name: Optional[str] = None, - rebind_globals: Optional[dict[str, Any]] = None, - allow_context_imports: bool = True, - verbose_context_imports: bool = False - ): - super().__init__(name=name) - - if not isinstance(dna_list, list): - raise TypeError("dna_list must be a list of DNA entries") - - self._rebind_globals = dict(rebind_globals or {}) - self._allow_context_imports = allow_context_imports - self._verbose_context_imports = verbose_context_imports - - # Extract optional context entry - ctx = None - if dna_list and isinstance(dna_list[0], dict) and dna_list[0].get("type") == "__context__": - ctx = dna_list[0] - dna_list = dna_list[1:] - - # Decide binding name: - # - explicit var_name wins (admin override) - # - else use context var_name if present - # - else default to "agent" - if var_name is None: - ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None - var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" - - self._dna_list = dna_list - if not isinstance(var_name, str): - raise TypeError("var_name must be a string") - self._var_name = var_name - self._context = ctx - - # Create a sandbox module for translated code (independent from user imports) - self._sandbox_module_name = f"summoner_translation_{uuid.uuid4().hex}" - self._sandbox_module = types.ModuleType(self._sandbox_module_name) - sys.modules[self._sandbox_module_name] = self._sandbox_module - self._sandbox_globals = self._sandbox_module.__dict__ - - # Bind the client name used in handler source - self._sandbox_globals[self._var_name] = self - - # Apply context if present - self._apply_context() - - # Best-effort cleanup of template clients imported indirectly via DNA modules - self._cleanup_template_clients_from_modules() - - def _apply_context(self): - """ - Apply optional "__context__" entry into the translation sandbox. - - This is best-effort and intended for trusted DNA. - It may execute code via exec()/eval() if imports/recipes are present. - """ - if not isinstance(self._context, dict): - return - - g = self._sandbox_globals - - # Imports - for line in self._context.get("imports", []) or []: - if not isinstance(line, str) or not line.strip(): - continue - if not self._allow_context_imports: - continue - try: - exec(line, g) - if self._verbose_context_imports: - self.logger.info(f"[translation ctx] import ok: {line}") - except Exception as e: - self.logger.warning(f"[translation ctx] import failed: {line!r} ({type(e).__name__}: {e})") - - # Plain globals - globs = self._context.get("globals", {}) or {} - if isinstance(globs, dict): - for k, v in globs.items(): - if isinstance(k, str): - g.setdefault(k, v) - - # Recipes - rec = self._context.get("recipes", {}) or {} - if isinstance(rec, dict): - for k, expr in rec.items(): - if not (isinstance(k, str) and isinstance(expr, str)): - continue - # eval in the sandbox global namespace - try: - g.setdefault(k, eval(expr, g, {})) - except Exception as e: - self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") - - def _cleanup_one_template_client(self, client: SummonerClient, label: str): - """ - Best-effort cleanup for a template client discovered in an imported module. - - If a DNA entry references a module, importing it may construct a SummonerClient - at module import-time. That client is not meant to run in translation mode. - - Cleanup steps: - - cancel pending registration tasks - - if possible, await cancellations by driving the template's loop - - close the loop when it is not running - """ - regs = list(client._registration_tasks or []) - loop = client.loop - - # cancel registrations - for t in regs: - try: - t.cancel() - except Exception: - pass - - # drain cancellations if we can drive that loop - try: - if regs and loop is not None and (not loop.is_running()) and (not loop.is_closed()): - loop.run_until_complete(asyncio.gather(*regs, return_exceptions=True)) - except Exception: - # best-effort only - pass - - # clear list - try: - client._registration_tasks.clear() - except Exception: - pass - - # close the loop (template clients are not meant to run) - try: - if loop is not None and (not loop.is_running()) and (not loop.is_closed()): - loop.close() - except Exception: - pass - - def _cleanup_template_clients_from_modules(self): - """ - For every module referenced in the DNA, attempt to find a template client: - - If the module defines a global variable named self._var_name (for example "agent") - and it points to a SummonerClient instance other than this translated client, - then treat it as a template and clean it up. - - This reduces warnings about pending registration tasks and open event loops. - """ - modules = {entry.get("module") for entry in self._dna_list if isinstance(entry, dict)} - modules.discard(None) - - seen_ids: set[int] = set() - - for module_name in modules: - try: - module = sys.modules.get(module_name) or import_module(module_name) - except Exception: - continue - - g = getattr(module, "__dict__", {}) - template = g.get(self._var_name) - - if isinstance(template, SummonerClient) and template is not self: - if id(template) in seen_ids: - continue - seen_ids.add(id(template)) - self._cleanup_one_template_client(template, label=f"{module_name}.{self._var_name}") - - def _make_from_source(self, entry: dict[str, Any]) -> types.FunctionType: - """ - Compile a handler function from its DNA 'source' into the translation sandbox globals. - - Decorator lines are stripped so compilation does not implicitly register handlers. - """ - fn_name = entry["fn_name"] - - module_globals = self._sandbox_globals - module_globals[self._var_name] = self - - # inject runtime globals (Trigger, viz, shared objects, etc.) - if self._rebind_globals: - module_globals.update(self._rebind_globals) - - if "__builtins__" not in module_globals: - module_globals["__builtins__"] = __builtins__ - - # strip off decorator lines so we only exec the `def` block - raw = entry["source"] - lines = raw.splitlines() - for idx, line in enumerate(lines): - pattern = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" - if re.match(pattern, line): - func_body = "\n".join(lines[idx:]) - break - else: - raise RuntimeError(f"Could not find definition for '{fn_name}'") - - exec(compile(func_body, filename=f"<{self._sandbox_module_name}>", mode="exec"), module_globals) - - fn = module_globals.get(fn_name) - if not isinstance(fn, types.FunctionType): - raise RuntimeError(f"Failed to construct function '{fn_name}'") - return fn - - def _apply_with_source_patch(self, decorator, fn, source: str): - """ - Temporarily override inspect.getsource so SummonerClient decorators record DNA text. - - This is process-global and is intended for single-threaded translation runs. - """ - orig = inspect.getsource - inspect.getsource = lambda o: source - try: - decorator(fn) - finally: - inspect.getsource = orig - - def initiate_upload_states(self): - """Replay @upload_states from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "upload_states": - continue - fn = self._make_from_source(entry) - dec = self.upload_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_download_states(self): - """Replay @download_states from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "download_states": - continue - fn = self._make_from_source(entry) - dec = self.download_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_hooks(self): - """Replay @hook from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "hook": - continue - fn = self._make_from_source(entry) - dec = self.hook( - Direction[entry["direction"]], - priority=tuple(entry.get("priority", ())) - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_receivers(self): - """Replay @receive from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "receive": - continue - fn = self._make_from_source(entry) - dec = self.receive( - entry["route"], - priority=tuple(entry.get("priority", ())) - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_senders(self): - """ - Replay @send from DNA onto this translated client. - - Triggers and actions are stored by name in DNA, then resolved here: - - Trigger is resolved using a Trigger class found in sandbox globals, else load_triggers() - - Action is resolved from the Action container by name - """ - g = self._sandbox_globals - - # Ensure rebind globals are visible before resolving triggers/actions. - if self._rebind_globals: - g.update(self._rebind_globals) - - TriggerCls = g.get("Trigger") - if TriggerCls is None: - TriggerCls = load_triggers() - - for entry in self._dna_list: - if entry.get("type") != "send": - continue - fn = self._make_from_source(entry) - on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None - on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None - dec = self.send( - entry["route"], - multi=entry.get("multi", False), - on_triggers=on_triggers, - on_actions=on_actions, - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_all(self): - """ - Replay all handler types from DNA in a standard order. - - This should be called before run(). - """ - self.initiate_upload_states() - self.initiate_download_states() - self.initiate_hooks() - self.initiate_receivers() - self.initiate_senders() - - async def _async_shutdown(self): - """ - Async shutdown path used by SIGINT/SIGTERM and quit(). - - Steps - ----- - 1) cancel everything (base shutdown) - 2) cancel and await pending decorator-registration coroutines - 3) await in-flight handler/worker tasks - 4) stop the loop so run_until_complete can return - """ - # 1) cancel everything - super().shutdown() - - # 2) wait for decorator register() tasks - regs = self._registration_tasks or [] - if regs: - for t in regs: - t.cancel() - await asyncio.gather(*regs, return_exceptions=True) - regs.clear() - - # 3) wait for in-flight handlers and workers - await self._wait_for_tasks_to_finish() - - # 4) stop the loop so run_client's run_until_complete can finish - self.loop.stop() - - def shutdown(self): - """ - Override the base-class SIGINT/SIGTERM handler. - - The base SummonerClient.shutdown() cancels tasks but does not await them. - In translation mode, a cleaner exit is preferred, so we schedule an async - shutdown coroutine instead. - """ - self.logger.info("Client is shutting down…") - try: - asyncio.create_task(self._async_shutdown()) - except RuntimeError: - # If the loop isn't running, ignore. - pass - - async def quit(self): - """ - Quit the translated client: - - set _quit so run_client exits - - then run the same cleanup as Ctrl+C - """ - await super().quit() - await self._async_shutdown() - - def run(self, *args, **kwargs): - """ - Wrap run() so that Ctrl+C cancels leftover registration tasks and exits cleanly. - - Note: this wrapper is intentionally conservative and does not change the base - reconnection and session logic. - """ - try: - super().run(*args, **kwargs) - except KeyboardInterrupt: - self.logger.info("KeyboardInterrupt caught-cancelling registration tasks…") - for task in list(self._registration_tasks or []): - task.cancel() - return +#pylint:disable=unused-import +from .just_merger import ClientMerger +from .translation import ClientTranslation diff --git a/summoner/client/translation.py b/summoner/client/translation.py new file mode 100644 index 0000000..49f3534 --- /dev/null +++ b/summoner/client/translation.py @@ -0,0 +1,480 @@ +""" +merger.py + +This module provides two related utilities built on top of SummonerClient: + +1) ClientMerger + Build a single composite SummonerClient by replaying handlers from multiple sources. + + A "source" can be: + - an imported SummonerClient instance (live Python object), or + - a DNA list (already loaded JSON list[dict]), or + - a DNA JSON file path. + + Imported-client sources: + - handlers keep their original module globals (module-backed execution), + - the original client binding (for example the name "agent") is rebound to the merged client, + - optional rebind_globals are injected into handler globals. + + DNA sources: + - handlers are reconstructed by compiling their recorded source text into an isolated + sandbox module (one sandbox per DNA source), + - the sandbox binds var_name (for example "agent") to the merged client instance, + so handler code that references `agent` executes against the composite client, + - optional context (imports, globals, recipes) is applied into the sandbox. + + Usage pattern: + - instantiate ClientMerger(...) + - configure flow / styles as usual on the merged client if desired + - call agent.initiate_all() to replay handlers onto the merged client + - call agent.run(...) + +2) ClientTranslation + Reconstruct a fresh SummonerClient from a DNA list. + + Translation compiles handler functions from their recorded source into a fresh sandbox module, + binds var_name (for example "agent") to the translated client, then registers the handlers + using the normal decorators. + +Security and trust model +------------------------ +Both classes execute code from DNA via exec() and eval(): + +- context imports (ctx["imports"]) +- recipes (ctx["recipes"]) +- handler bodies (entry["source"]) + +This is intended for trusted DNA (typically produced by your own agents). +Do not run untrusted DNA. +""" +#pylint:disable=line-too-long, wrong-import-position, duplicate-code +#pylint:disable=invalid-name, logging-fstring-interpolation + +from importlib import import_module +from typing import Optional +from typing import Any +import inspect +import asyncio +import types +import re +import uuid + +import os +import sys + +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.client.just_merger import _resolve_action, _resolve_trigger +from summoner.client.client import SummonerClient +from summoner.protocol.triggers import Action, load_triggers +from summoner.protocol.process import Direction + +# pylint:disable=too-many-instance-attributes +class ClientTranslation(SummonerClient): + """ + Reconstruct a SummonerClient from its DNA list. + + Translation compiles handlers from their recorded source into a fresh sandbox module, + then registers them on this client via normal decorators. + + Execution environment + --------------------- + - Translated handlers do not run inside the original agent modules. + - They run inside the translation sandbox, with only explicitly imported or rebound + globals available (for example: Trigger, shared objects, viz, etc.). + + Context behavior + ---------------- + If the DNA begins with a "__context__" entry, translation may: + - exec() ctx["imports"] (if allow_context_imports=True) + - bind ctx["globals"] into sandbox + - eval() ctx["recipes"] into sandbox + + Template-client cleanup + ----------------------- + DNA entries carry a "module" field from the original capture. + Importing those modules can create template SummonerClient instances at import-time. + + This class attempts to find and clean up such template clients so they do not leave + pending registration tasks or open loops behind. + """ + # pylint:disable=too-many-arguments, too-many-positional-arguments + def __init__( + self, + dna_list: list[dict[str, Any]], + name: Optional[str] = None, + var_name: Optional[str] = None, + rebind_globals: Optional[dict[str, Any]] = None, + allow_context_imports: bool = True, + verbose_context_imports: bool = False + ): + super().__init__(name=name) + + if not isinstance(dna_list, list): + raise TypeError("dna_list must be a list of DNA entries") + + self._rebind_globals = dict(rebind_globals or {}) + self._allow_context_imports = allow_context_imports + self._verbose_context_imports = verbose_context_imports + + # Extract optional context entry + ctx = None + if dna_list and isinstance(dna_list[0], dict) and dna_list[0].get("type") == "__context__": + ctx = dna_list[0] + dna_list = dna_list[1:] + + # Decide binding name: + # - explicit var_name wins (admin override) + # - else use context var_name if present + # - else default to "agent" + if var_name is None: + ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None + var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" + + self._dna_list = dna_list + if not isinstance(var_name, str): + raise TypeError("var_name must be a string") + self._var_name = var_name + self._context = ctx + + # Create a sandbox module for translated code (independent from user imports) + self._sandbox_module_name = f"summoner_translation_{uuid.uuid4().hex}" + self._sandbox_module = types.ModuleType(self._sandbox_module_name) + sys.modules[self._sandbox_module_name] = self._sandbox_module + self._sandbox_globals = self._sandbox_module.__dict__ + + # Bind the client name used in handler source + self._sandbox_globals[self._var_name] = self + + # Apply context if present + self._apply_context() + + # Best-effort cleanup of template clients imported indirectly via DNA modules + self._cleanup_template_clients_from_modules() + + #pylint:disable=too-many-branches + def _apply_context(self): + """ + Apply optional "__context__" entry into the translation sandbox. + + This is best-effort and intended for trusted DNA. + It may execute code via exec()/eval() if imports/recipes are present. + """ + if not isinstance(self._context, dict): + return + + g = self._sandbox_globals + + # Imports + for line in self._context.get("imports", []) or []: + if not isinstance(line, str) or not line.strip(): + continue + if not self._allow_context_imports: + continue + try: + # pylint:disable=exec-used + exec(line, g) + if self._verbose_context_imports: + self.logger.info(f"[translation ctx] import ok: {line}") + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"[translation ctx] import failed: {line!r} ({type(e).__name__}: {e})") + + # Plain globals + globs = self._context.get("globals", {}) or {} + if isinstance(globs, dict): + for k, v in globs.items(): + if isinstance(k, str): + g.setdefault(k, v) + + # Recipes + rec = self._context.get("recipes", {}) or {} + if isinstance(rec, dict): + for k, expr in rec.items(): + if not (isinstance(k, str) and isinstance(expr, str)): + continue + # eval in the sandbox global namespace + # pylint:disable=eval-used + try: + g.setdefault(k, eval(expr, g, {})) + except Exception as e:# pylint:disable=broad-exception-caught + self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") + + #pylint:disable=unused-argument + def _cleanup_one_template_client(self, client: SummonerClient, label: str): + """ + Best-effort cleanup for a template client discovered in an imported module. + + If a DNA entry references a module, importing it may construct a SummonerClient + at module import-time. That client is not meant to run in translation mode. + + Cleanup steps: + - cancel pending registration tasks + - if possible, await cancellations by driving the template's loop + - close the loop when it is not running + """ + # pylint:disable=protected-access + regs = list(client._registration_tasks or []) + loop = client.loop + + # cancel registrations + for t in regs: + try: + t.cancel() + except Exception:# pylint:disable=broad-exception-caught + pass + + # drain cancellations if we can drive that loop + try: + if regs and loop is not None and (not loop.is_running()) and (not loop.is_closed()): + loop.run_until_complete(asyncio.gather(*regs, return_exceptions=True)) + except Exception:# pylint:disable=broad-exception-caught + # best-effort only + pass + + # clear list + try: + # pylint:disable=protected-access + client._registration_tasks.clear() + except Exception:# pylint:disable=broad-exception-caught + pass + + # close the loop (template clients are not meant to run) + try: + if loop is not None and (not loop.is_running()) and (not loop.is_closed()): + loop.close() + except Exception:# pylint:disable=broad-exception-caught + pass + + def _cleanup_template_clients_from_modules(self): + """ + For every module referenced in the DNA, attempt to find a template client: + + If the module defines a global variable named self._var_name (for example "agent") + and it points to a SummonerClient instance other than this translated client, + then treat it as a template and clean it up. + + This reduces warnings about pending registration tasks and open event loops. + """ + modules = {entry.get("module") for entry in self._dna_list if isinstance(entry, dict)} + modules.discard(None) + + seen_ids: set[int] = set() + + for module_name in modules: + try: + module = sys.modules.get(module_name) or import_module(module_name) # pyright: ignore[reportArgumentType] + except Exception:# pylint:disable=broad-exception-caught + continue + + g = getattr(module, "__dict__", {}) + template = g.get(self._var_name) + + if isinstance(template, SummonerClient) and template is not self: + if id(template) in seen_ids: + continue + seen_ids.add(id(template)) + self._cleanup_one_template_client(template, label=f"{module_name}.{self._var_name}") + + def _make_from_source(self, entry: dict[str, Any]) -> types.FunctionType: + """ + Compile a handler function from its DNA 'source' into the translation sandbox globals. + + Decorator lines are stripped so compilation does not implicitly register handlers. + """ + fn_name = entry["fn_name"] + + module_globals = self._sandbox_globals + module_globals[self._var_name] = self + + # inject runtime globals (Trigger, viz, shared objects, etc.) + if self._rebind_globals: + module_globals.update(self._rebind_globals) + + if "__builtins__" not in module_globals: + module_globals["__builtins__"] = __builtins__ + + # strip off decorator lines so we only exec the `def` block + raw = entry["source"] + lines = raw.splitlines() + for idx, line in enumerate(lines): + pattern = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" + if re.match(pattern, line): + func_body = "\n".join(lines[idx:]) + break + else: + raise RuntimeError(f"Could not find definition for '{fn_name}'") + + #pylint:disable=exec-used + exec(compile(func_body, filename=f"<{self._sandbox_module_name}>", mode="exec"), module_globals) + + fn = module_globals.get(fn_name) + if not isinstance(fn, types.FunctionType): + raise RuntimeError(f"Failed to construct function '{fn_name}'") + return fn + + def _apply_with_source_patch(self, decorator, fn, source: str): + """ + Temporarily override inspect.getsource so SummonerClient decorators record DNA text. + + This is process-global and is intended for single-threaded translation runs. + """ + orig = inspect.getsource + inspect.getsource = lambda o: source + try: + decorator(fn) + finally: + inspect.getsource = orig + + def initiate_upload_states(self): + """Replay @upload_states from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "upload_states": + continue + fn = self._make_from_source(entry) + dec = self.upload_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_download_states(self): + """Replay @download_states from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "download_states": + continue + fn = self._make_from_source(entry) + dec = self.download_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_hooks(self): + """Replay @hook from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "hook": + continue + fn = self._make_from_source(entry) + dec = self.hook( + Direction[entry["direction"]], + priority=tuple(entry.get("priority", ())) + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_receivers(self): + """Replay @receive from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "receive": + continue + fn = self._make_from_source(entry) + dec = self.receive( + entry["route"], + priority=tuple(entry.get("priority", ())) + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_senders(self): + """ + Replay @send from DNA onto this translated client. + + Triggers and actions are stored by name in DNA, then resolved here: + - Trigger is resolved using a Trigger class found in sandbox globals, else load_triggers() + - Action is resolved from the Action container by name + """ + g = self._sandbox_globals + + # Ensure rebind globals are visible before resolving triggers/actions. + if self._rebind_globals: + g.update(self._rebind_globals) + + TriggerCls = g.get("Trigger") + if TriggerCls is None: + TriggerCls = load_triggers() + + for entry in self._dna_list: + if entry.get("type") != "send": + continue + fn = self._make_from_source(entry) + on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None + on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None + dec = self.send( + entry["route"], + multi=entry.get("multi", False), + on_triggers=on_triggers, + on_actions=on_actions, + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_all(self): + """ + Replay all handler types from DNA in a standard order. + + This should be called before run(). + """ + self.initiate_upload_states() + self.initiate_download_states() + self.initiate_hooks() + self.initiate_receivers() + self.initiate_senders() + + async def _async_shutdown(self): + """ + Async shutdown path used by SIGINT/SIGTERM and quit(). + + Steps + ----- + 1) cancel everything (base shutdown) + 2) cancel and await pending decorator-registration coroutines + 3) await in-flight handler/worker tasks + 4) stop the loop so run_until_complete can return + """ + # 1) cancel everything + super().shutdown() + + # 2) wait for decorator register() tasks + regs = self._registration_tasks or [] + if regs: + for t in regs: + t.cancel() + await asyncio.gather(*regs, return_exceptions=True) + regs.clear() + + # 3) wait for in-flight handlers and workers + await self._wait_for_tasks_to_finish() + + # 4) stop the loop so run_client's run_until_complete can finish + self.loop.stop() + + def shutdown(self): + """ + Override the base-class SIGINT/SIGTERM handler. + + The base SummonerClient.shutdown() cancels tasks but does not await them. + In translation mode, a cleaner exit is preferred, so we schedule an async + shutdown coroutine instead. + """ + self.logger.info("Client is shutting down…") + try: + asyncio.create_task(self._async_shutdown()) + except RuntimeError: + # If the loop isn't running, ignore. + pass + + async def quit(self): + """ + Quit the translated client: + - set _quit so run_client exits + - then run the same cleanup as Ctrl+C + """ + await super().quit() + await self._async_shutdown() + + def run(self, *args, **kwargs): + """ + Wrap run() so that Ctrl+C cancels leftover registration tasks and exits cleanly. + + Note: this wrapper is intentionally conservative and does not change the base + reconnection and session logic. + """ + try: + super().run(*args, **kwargs) + except KeyboardInterrupt: + self.logger.info("KeyboardInterrupt caught-cancelling registration tasks…") + for task in list(self._registration_tasks or []): + task.cancel() diff --git a/summoner/logger.py b/summoner/logger.py index 6cf594c..8da29c6 100644 --- a/summoner/logger.py +++ b/summoner/logger.py @@ -1,3 +1,8 @@ +""" +Handle logging specifics +to this context +like details of how things are consistently formatted +""" import sys import os import json @@ -5,14 +10,19 @@ import datetime import re -from typing import Optional, Any +from typing import Dict, Optional +from typing import Any from logging.handlers import RotatingFileHandler -from logging import Logger # This makes Logger importable from logger.py + +# This makes Logger importable from logger.py +#pylint:disable=unused-import +from logging import Logger # Log formatting style LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -LOG_FORMAT_CONSOLE = "\033[92m%(asctime)s\033[0m - \033[94m%(name)s\033[0m - %(levelname)s - %(message)s" +LOG_FORMAT_CONSOLE = \ + "\033[92m%(asctime)s\033[0m - \033[94m%(name)s\033[0m - %(levelname)s - %(message)s" class SafeStreamHandler(logging.StreamHandler): """ @@ -131,7 +141,7 @@ def __init__(self, fmt: str, datefmt: Optional[str], log_keys: Optional[list[str self.log_keys = log_keys def format(self, record: logging.LogRecord) -> str: - base = { + base : Dict[str, str | Dict[Any,Any]] = { "timestamp": self.formatTime(record, self._raw_datefmt), "name": record.name, "level": record.levelname, @@ -190,7 +200,8 @@ def configure_logger(logger: logging.Logger, logger_cfg: dict[str, Any]) -> None date_format = logger_cfg.get("date_format") console_handler = SafeStreamHandler(sys.stdout) - console_handler.setFormatter(TextFormatter(console_log_format, date_format, logger_cfg.get("log_keys"))) + console_handler.setFormatter( + TextFormatter(console_log_format, date_format, logger_cfg.get("log_keys"))) logger.addHandler(console_handler) # file @@ -208,9 +219,13 @@ def configure_logger(logger: logging.Logger, logger_cfg: dict[str, Any]) -> None log_format = logger_cfg.get("log_format", LOG_FORMAT) date_format = logger_cfg.get("date_format") - file_handler = RotatingFileHandler(path, maxBytes=max_file_size, backupCount=backup_count) + file_handler = RotatingFileHandler(path, + maxBytes=max_file_size, + backupCount=backup_count) if logger_cfg.get("enable_json_log", False): - file_handler.setFormatter(JsonFormatter(log_format, date_format, logger_cfg.get("log_keys"))) + file_handler.setFormatter( + JsonFormatter(log_format, date_format, logger_cfg.get("log_keys"))) else: - file_handler.setFormatter(TextFormatter(log_format, date_format, logger_cfg.get("log_keys"))) + file_handler.setFormatter( + TextFormatter(log_format, date_format, logger_cfg.get("log_keys"))) logger.addHandler(file_handler) diff --git a/summoner/protocol/__init__.py b/summoner/protocol/__init__.py index edb0303..4189d96 100644 --- a/summoner/protocol/__init__.py +++ b/summoner/protocol/__init__.py @@ -1,18 +1,21 @@ +""" +TODO: summary of triggers, process and flow +""" from .triggers import ( - Signal, - Event, + Signal, + Event, Action, Move, Stay, Test, ) from .process import ( - StateTape, - ParsedRoute, - Node, - Sender, - Receiver, + StateTape, + ParsedRoute, + Node, + Sender, + Receiver, Direction, ) from .flow import Flow -# from .validation import hook_priority_order, _check_param_and_return \ No newline at end of file +# from .validation import hook_priority_order, _check_param_and_return diff --git a/summoner/protocol/_deprecation.py b/summoner/protocol/_deprecation.py index 0ee7e66..c56f128 100644 --- a/summoner/protocol/_deprecation.py +++ b/summoner/protocol/_deprecation.py @@ -1,6 +1,9 @@ +""" +Warnings about using deprecated methods +""" try: - from warnings import deprecated # Python 3.13+ + from warnings import deprecated # pyright: ignore[reportAttributeAccessIssue] # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python <= 3.12 -__all__ = ["deprecated"] \ No newline at end of file +__all__ = ["deprecated"] diff --git a/summoner/protocol/flow.py b/summoner/protocol/flow.py index 59e789a..30fe5e9 100644 --- a/summoner/protocol/flow.py +++ b/summoner/protocol/flow.py @@ -1,11 +1,17 @@ +""" +Handles many regexes for different ArrowStyle +""" from __future__ import annotations import re from collections.abc import Callable -from typing import Optional, Any +from typing import Iterable, Optional +from typing import Any +import warnings from .triggers import load_triggers from .process import Node, ArrowStyle, ParsedRoute -from ._deprecation import deprecated -import warnings +from ._deprecation import deprecated # pyright: ignore[reportAttributeAccessIssue] + +# pylint:disable=line-too-long # variable names or commands used in flow transitions _TOKEN_RE = re.compile(r""" @@ -34,7 +40,7 @@ def get_token_list(input_string: str, separator: str) -> list[str]: parenthesis_depth: int = 0 for character in input_string: - + if character == "(": parenthesis_depth += 1 elif character == ")": @@ -56,20 +62,34 @@ def get_token_list(input_string: str, separator: str) -> list[str]: TriggerType = type class Flow: - + """ + Handles many regexes for different ArrowStyle + """ + def __init__(self, triggers_file: Optional[str] = None) -> None: + """ + Handles many regexes for different ArrowStyle + """ self.triggers_file = triggers_file self.in_use: bool = False self.arrows: set[ArrowStyle] = set() - + self._regex_ready: bool = False self._regex_patterns: list[tuple[re.Pattern[str], ArrowStyle, Unpack]] = [] - + def activate(self) -> Flow: + """ + In the client there is logic + that depends on how this is toggled. + """ self.in_use = True return self def deactivate(self) -> Flow: + """ + In the client there is logic + that depends on how this is toggled. + """ self.in_use = False return self @@ -80,6 +100,12 @@ def add_arrow_style( separator: str, tip: str ) -> None: + """ + Add an arrow, that is another possibility + for parsing routes. + This means the regex patterns will need + to be recompiled as the set of ArrowStyle has changed. + """ style = ArrowStyle( stem=stem, brackets=brackets, @@ -91,14 +117,15 @@ def add_arrow_style( self._regex_patterns.clear() def triggers(self, json_dict: Optional[dict[str, Any]] = None) -> TriggerType: - if json_dict is None: + """ + Load and build the TRIG class from the nested_dict or TRIGGERS file + """ + if json_dict is None: if self.triggers_file is None: # use the file TRIGGERS return load_triggers() - else: - return load_triggers(triggers_file=self.triggers_file) - else: - return load_triggers(json_dict=json_dict) + return load_triggers(triggers_file=self.triggers_file) + return load_triggers(json_dict=json_dict) def _build_labeled_complete( self, @@ -107,6 +134,15 @@ def _build_labeled_complete( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a complete labelled arrow. + For example, + As --[Cs]--> Bs + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -118,6 +154,14 @@ def _build_labeled_complete( return re.compile(regex, re.VERBOSE) def _build_unlabeled_complete(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a complete unlabelled arrow. + For example, + As --> Bs + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* @@ -125,7 +169,7 @@ def _build_unlabeled_complete(self, base: str, tip: str) -> re.Pattern[str]: (?P.+?)\s*$ """ return re.compile(regex, re.VERBOSE) - + def _build_labeled_dangling_right( self, base: str, @@ -133,6 +177,15 @@ def _build_labeled_dangling_right( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the right labelled arrow. + For example, + As --[Cs]--> + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -143,13 +196,21 @@ def _build_labeled_dangling_right( return re.compile(regex, re.VERBOSE) def _build_unlabeled_dangling_right(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the right unlabelled arrow. + For example, + As --> + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* (?P.+?)\s*{arrow}\s*$ """ return re.compile(regex, re.VERBOSE) - + def _build_labeled_dangling_left( self, base: str, @@ -157,6 +218,15 @@ def _build_labeled_dangling_left( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the left labelled arrow. + For example, + --[Cs]--> Bs + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -168,6 +238,14 @@ def _build_labeled_dangling_left( return re.compile(regex, re.VERBOSE) def _build_unlabeled_dangling_left(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the left unlabelled arrow. + For example, + --> Bs + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* @@ -177,6 +255,10 @@ def _build_unlabeled_dangling_left(self, base: str, tip: str) -> re.Pattern[str] return re.compile(regex, re.VERBOSE) def _unpack_labeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was complete and labelled, the produced regex will have 3 groups + for source, label and target. A match with this will give strings for each of these. + """ return ( matched_pattern.group("source"), matched_pattern.group("label"), @@ -184,21 +266,48 @@ def _unpack_labeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, ) def _unpack_unlabeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was complete and unlabelled, the produced regex will have 2 groups + for source and target. A match with this will give strings for each of these and the label is empty. + """ return matched_pattern.group("source"), "", matched_pattern.group("target") - + def _unpack_labeled_dangling_right(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the right and labelled, the produced regex will have 2 groups + for source and label. A match with this will give strings for each of these and the target is empty. + """ return matched_pattern.group("source"), matched_pattern.group("label"), "" - + def _unpack_unlabeled_dangling_right(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the right and unlabelled, the produced regex will have 1 group + for source. A match with this will give strings for each of these and the label and target are empty. + """ return matched_pattern.group("source"), "", "" - + def _unpack_labeled_dangling_left(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the left and labelled, the produced regex will have 2 groups + for label and target. A match with this will give strings for each of these and the source is empty. + """ return "", matched_pattern.group("label"), matched_pattern.group("target") def _unpack_unlabeled_dangling_left(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the left and unlabelled, the produced regex will have 1 group + for target. A match with this will give strings for each of these and the source and label are empty. + """ return "", "", matched_pattern.group("target") def _prepare_regex(self) -> None: + """ + If the regex for all the ArrowStyle(s) have not been compiled, compile them all. + This takes care of all variants of + - labelled, unlabelled + - dangling left, dangling right, complete + on each of the styles. + """ if self._regex_ready: return @@ -230,11 +339,21 @@ def _prepare_regex(self) -> None: self._regex_ready = True def _validate_tokens(self, tokens: list[str], text: str) -> None: + """ + Each source, label and target must conform to the naming style + similar to a Python identifier but also allowing /all /oneof /not + which are also valid as sources and targets. + """ for token in tokens: if not _TOKEN_RE.match(token): raise ValueError(f"Invalid token {token!r} in route {text!r}") def _parse_standalone(self, text: str) -> ParsedRoute: + """ + Similar to parsing the dangling right unlabelled case + which is like `As ->` + but this is just the As part without the -> + """ # split on commas, strip whitespace, drop empties source_list = get_token_list(text, separator=',') # validate each token @@ -259,6 +378,9 @@ def compile_arrow_patterns(self) -> None: "When using SummonerClient, patterns are compiled automatically during registration." ) def ready(self) -> None: + """ + Same as compile_arrow_patterns but the deprecated version. + """ warnings.warn( "Flow.ready() is deprecated; use Flow.compile_arrow_patterns() instead. " "In SummonerClient, you generally don't need to call this.", @@ -269,10 +391,22 @@ def ready(self) -> None: self._prepare_regex() def parse_route(self, route: str) -> ParsedRoute: + """ + route is something like As --[Bs]-> Cs + or As --> etc among all the possibilities of + having labels or not, being dangling or not, + having an arrow or not. + Give the ParsedRoute which is storing the + sources, labels and targets with empty tuples for any + of them that were not present. + If present, the As, Bs, and Cs are split into tuples according + to the separator of the relevant ArrowStyle in order + to handle multiple sources, labels and/or targets + """ route = route.strip() if not self._regex_ready: self._prepare_regex() - + for pattern, style, unpack in self._regex_patterns: matched_pattern = pattern.match(route) if not matched_pattern: @@ -290,8 +424,11 @@ def parse_route(self, route: str) -> ParsedRoute: target=tuple(Node(tok) for tok in target_list), style=style, ) - + return self._parse_standalone(route) - def parse_routes(self, routes: list[str]) -> list[ParsedRoute]: + def parse_routes(self, routes: Iterable[str]) -> list[ParsedRoute]: + """ + Do parse_route for each of the routes + """ return [self.parse_route(route=route) for route in routes] diff --git a/summoner/protocol/payload.py b/summoner/protocol/payload.py index 09b7b0c..f831634 100644 --- a/summoner/protocol/payload.py +++ b/summoner/protocol/payload.py @@ -1,15 +1,22 @@ +""" +TODO: doc payload +""" +#pylint:disable=line-too-long import json from json import JSONDecodeError -from typing import Any, Tuple, Dict, List, Union, TypedDict +from typing import Tuple, Dict, List, Union, TypedDict +from typing import Any from summoner.utils import ( - fully_recover_json, + fully_recover_json, remove_last_newline, ensure_trailing_newline, ) -# Current envelope version -WRAPPER_VERSION = "0.0.1" +from summoner._version import __version__ as core_version + +# Default envelope version +DEFAULT_VERSION = "0.0.1" # Registries for versioned parsers and casters envelope_parsers: Dict[str, Any] = {} @@ -29,9 +36,6 @@ def register_envelope_version( envelope_parsers[version] = parser envelope_casters[version] = caster - -from typing import Any, Tuple, Dict, List - STR_TYPE = "str" BOOL_TYPE = "bool" INT_TYPE = "int" @@ -39,7 +43,8 @@ def register_envelope_version( NULL_TYPE = "null" -def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: +#pylint:disable=too-many-return-statements +def parse_v0_0_1(obj: Any) -> Tuple[Any,Any]: """ Walk `obj` and build a parallel `type_tree`. Every leaf in `obj` becomes a simple type name; every list/dict is recursed. @@ -65,13 +70,13 @@ def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: # Lists → walk each element if isinstance(obj, list): - payloads: List[Any] = [] - types: List[Any] = [] + payloads_list: List[Any] = [] + types_list: List[Any] = [] for v in obj: p, t = parse_v0_0_1(v) - payloads.append(p) - types.append(t) - return payloads, types + payloads_list.append(p) + types_list.append(t) + return payloads_list, types_list # Dicts → walk each value if isinstance(obj, dict): @@ -88,23 +93,11 @@ def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: s = str(obj) return s, STR_TYPE - -def cast_v0_0_1(val: Any, expected: Any) -> Any: +def _cast_primitive(val: Any, expected: Any) -> str | bool | int | float | None | Any: """ Coerce `val` according to `expected`, but never fail on unknown types. Instead, pass the value through unchanged if we don't understand. - - Supported expected descriptors: - • "string", "boolean", "integer", "number", "null" - • a list of descriptors → cast element-wise up to len(expected) - • a dict of descriptors → cast known keys, keep extras unchanged - • None or unknown → identity (return val) """ - # 1) Identity for missing/unknown descriptors - if expected is None: - return val - - # 2) Primitives if expected == STR_TYPE: return str(val) if expected == BOOL_TYPE: @@ -123,6 +116,27 @@ def cast_v0_0_1(val: Any, expected: Any) -> Any: if expected == NULL_TYPE: # always map to Python None return None + # Unknown expected type → return the raw value + raise ValueError(f"Unreachable code in _cast_primitive: expected={expected!r}") + +def cast_v0_0_1(val: Any, expected: Any) -> Any: + """ + Coerce `val` according to `expected`, but never fail on unknown types. + Instead, pass the value through unchanged if we don't understand. + + Supported expected descriptors: + • "string", "boolean", "integer", "number", "null" + • a list of descriptors → cast element-wise up to len(expected) + • a dict of descriptors → cast known keys, keep extras unchanged + • None or unknown → identity (return val) + """ + # 1) Identity for missing/unknown descriptors + if expected is None: + return val + + # 2) Primitives + if expected in {STR_TYPE, BOOL_TYPE, INT_TYPE, NUMB_TYPE, NULL_TYPE}: + return _cast_primitive(val, expected) # 3) Lists: cast up to the expected length, keep extras untouched if isinstance(expected, list) and isinstance(val, list): @@ -156,12 +170,12 @@ def cast_v0_0_1(val: Any, expected: Any) -> Any: register_envelope_version("1.0.0", parse_v0_0_1, cast_v0_0_1) register_envelope_version("1.0.1", parse_v0_0_1, cast_v0_0_1) register_envelope_version("1.1.0", parse_v0_0_1, cast_v0_0_1) -register_envelope_version("1.1.1", parse_v0_0_1, cast_v0_0_1) - +# register_envelope_version("1.1.1", parse_v0_0_1, cast_v0_0_1) +register_envelope_version(core_version, parse_v0_0_1, cast_v0_0_1) def wrap_with_types( payload: Any, - version: str = WRAPPER_VERSION + version: str = DEFAULT_VERSION ) -> str: """ Wrap `payload` in a self-describing JSON envelope with type metadata. @@ -173,7 +187,7 @@ def wrap_with_types( Args: payload: Any JSON-serializable object (dict, list, primitive). - version: Version string for the envelope format (defaults to WRAPPER_VERSION). + version: Version string for the envelope format (defaults to DEFAULT_VERSION). Returns: A JSON string representing the typed envelope. @@ -209,7 +223,6 @@ class RelayedMessage(TypedDict): remote_addr: str content: Union[str, dict] - def recover_with_types(text: str) -> RelayedMessage: """ Recover and validate a typed payload from a server message. @@ -227,6 +240,11 @@ def recover_with_types(text: str) -> RelayedMessage: - Invalid JSON or non-JSON warning strings: strip trailing newline and return raw string. - Missing outer keys ("remote_addr"/"content"): strip newline and return raw string. - Missing envelope keys inside content: return the parsed `{"remote_addr":…, "content":…}` as-is. + + Strictly the return type annotation is not accurate. + It may end up in the fallback, but that pollutes that possibility + for the caller even when they are known to be passing in the case for text + which does not end up in the fallbacks. Args: text: Raw text received from the server (may include newline). @@ -243,13 +261,13 @@ def recover_with_types(text: str) -> RelayedMessage: obj = fully_recover_json(text) except (ValueError, JSONDecodeError): # If that fails (e.g. pure warning string), strip newline and relay the raw text - return remove_last_newline(text) + return remove_last_newline(text) # pyright: ignore[reportReturnType] # 2) Ensure we have the outer {"remote_addr":…, "content":…} wrapper if not (isinstance(obj, dict) and "remote_addr" in obj and "content" in obj): # Malformed protocol message; upstream code can catch this if needed # raise ValueError("Unsupported message format from server.") - return obj + return obj # pyright: ignore[reportReturnType] addr = obj["remote_addr"] content = obj["content"] @@ -262,11 +280,11 @@ def recover_with_types(text: str) -> RelayedMessage: and "_payload" in content and "_type" in content ): - return obj + return obj # pyright: ignore[reportReturnType] # 4) We have the versioned envelope—now look up the correct caster version = content["_version"] - caster = envelope_casters.get(version, envelope_casters[WRAPPER_VERSION]) + caster = envelope_casters.get(version, envelope_casters[DEFAULT_VERSION]) if caster is None: # Unknown version: hard error so we don't silently mis-interpret data raise ValueError(f"Unsupported wrapper version: {version}") diff --git a/summoner/protocol/process.py b/summoner/protocol/process.py index c432e23..bf66de1 100644 --- a/summoner/protocol/process.py +++ b/summoner/protocol/process.py @@ -1,7 +1,13 @@ +""" +TODO: doc process +""" + from __future__ import annotations import re from collections import defaultdict -from typing import Type, Any, Optional, Union, Callable, Awaitable +from typing import Coroutine, Dict, List, Literal, Mapping, Tuple, Type, \ + Optional, TypeGuard, Union, Callable, Awaitable +from typing import Any from enum import Enum, auto from dataclasses import dataclass from .triggers import Signal, Event, Action, extract_signal @@ -18,38 +24,43 @@ # Wildcard sentinel for dispatch _WILDCARD = object() +KindType = Literal["all", "not", "oneof", "plain"] + class Node: + """ + TODO: doc node + """ __slots__ = ('expr', 'kind', 'values') def __init__(self, expr: str) -> None: _expr: str = expr.strip() - self.kind: str - self.values: Optional[tuple[str]] + self.kind: KindType + self.values: Optional[tuple[str,...]] if _ALL_RE.fullmatch(_expr): self.kind = 'all' self.values = None - + elif (found_match := _NOT_RE.fullmatch(_expr)): self.kind = 'not' items = [item.strip() for item in found_match.group(1).split(',') if item.strip()] self.values = tuple(items) - + elif (found_match := _ONEOF_RE.fullmatch(_expr)): self.kind = 'oneof' items = [item.strip() for item in found_match.group(1).split(',') if item.strip()] self.values = tuple(items) - + elif _PLAIN_TOKEN_RE.fullmatch(_expr): self.kind = 'plain' - self.values = (_expr,) - + self.values = (_expr,) + else: raise ValueError(f"Invalid syntax for token: {_expr!r}") def __eq__(self, other: Any) -> bool: return ( - isinstance(other, Node) and + isinstance(other, Node) and self.kind == other.kind and self.values == other.values ) @@ -65,43 +76,57 @@ def __repr__(self) -> str: # f"\033[94mvalues\033[0m=\033[90m{self.values!r}\033[0m)" ) + #pylint:disable=too-many-return-statements def __str__(self) -> str: try: if self.kind == 'all': return '/all' - elif self.kind == 'plain': + if self.kind == 'plain': + if self.values is None: + return "" return self.values[0] - elif self.kind == 'not': + if self.kind == 'not': + if self.values is None: + return "" return f"/not({','.join(self.values)})" - elif self.kind == 'oneof': + if self.kind == 'oneof': + if self.values is None: + return "" return f"/oneof({','.join(self.values)})" - else: - return f"" - except Exception as e: + return f"" + except Exception as e: # pylint:disable=broad-exception-caught return f"" def accepts(self, state: Node) -> bool: - if not isinstance(state, Node): + """ + Handle the logic of whether this Node + accepts state or not. + In the plain case it is matching, + but there is also the logic of oneof, all, not kinds on either + operand. + """ + if not isinstance(state, Node): raise TypeError(f"Argument `state` must be Node; {state} provided") - - table = { + + # pylint:disable=line-too-long + table : Dict[Tuple[KindType | object, KindType | object], Callable[[Node,Node],bool]] = { ('all', 'all'): lambda g, s: True, ('all', _WILDCARD): lambda g, s: True, (_WILDCARD, 'all'): lambda g, s: True, - ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], - ('plain', 'not'): lambda g, s: g.values[0] not in s.values, - ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, - ('not', 'plain'): lambda g, s: s.values[0] not in g.values, + ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], # pyright: ignore[reportOptionalSubscript] + ('plain', 'not'): lambda g, s: g.values[0] not in s.values, # pyright: ignore[reportOperatorIssue,reportOptionalSubscript] + ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] + ('not', 'plain'): lambda g, s: s.values[0] not in g.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] ('not', _WILDCARD): lambda g, s: True, - ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, - ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), - ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), + ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] + ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), # pyright: ignore[reportArgumentType] + ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), # pyright: ignore[reportArgumentType] } for (gk, sk), fn in table.items(): if (gk == self.kind or gk is _WILDCARD) and (sk == state.kind or sk is _WILDCARD): return fn(self, state) - + raise RuntimeError("Unhandled combination in Node.is_compatible_with") # ======= ARROW STYLE ======= @@ -205,6 +230,7 @@ def _check_regex_safe(self) -> None: try: re.escape(part) except re.error as e: + #pylint:disable=raise-missing-from raise ValueError( f"Part {part!r} invalid for regex: {e}" ) @@ -212,6 +238,17 @@ def _check_regex_safe(self) -> None: # ======= PARSED ROUTE ======= class ParsedRoute: + """ + A parsed route holds + all the sources, labels and targets for an arrow that has been parsed + As --[Bs,Cs]--> /all has + Node for As in source + Node for Bs and Node for Cs in label + Node for /all in target + + This is also for when it is not an arrow as in standalone when there + are only sources + """ __slots__ = ('source', 'label', 'target', 'style', '_string') def __init__( @@ -227,6 +264,7 @@ def __init__( self.style = style if self.is_arrow: + assert self.style is not None, "If an arrow it has a style" src = self.style.separator.join(str(n) for n in self.source) lab = self.style.separator.join(str(n) for n in self.label) tgt = self.style.separator.join(str(n) for n in self.target) @@ -262,18 +300,30 @@ def __str__(self) -> str: @property def has_label(self) -> bool: + """ + There are labels + """ return bool(self.label) @property def is_arrow(self) -> bool: + """ + This is actually an arrow, unlike the standalone only sources + """ return bool(self.target) or self.has_label @property def is_object(self) -> bool: + """ + It is standalone + """ return not self.is_arrow - + @property def is_initial(self) -> bool: + """ + It is dangling on the left + """ return self.is_arrow and not self.source def activated_nodes( @@ -287,9 +337,8 @@ def activated_nodes( if isinstance(event, Event) and event is not None and not self.is_arrow: if isinstance(event, Action.TEST): return () - else: - # standalone → only the source nodes - return self.source + # standalone → only the source nodes + return self.source # arrow route → pick based on the Action subtype if isinstance(event, Action.MOVE): @@ -306,79 +355,206 @@ def activated_nodes( @dataclass(frozen=True) class Sender: + """ + TODO: doc sender + """ __slots__ = ('fn', 'multi', 'actions', 'triggers') - fn: Callable[[], Awaitable] + fn: Callable[[], Awaitable[Any]] multi: bool actions: Optional[set[Type]] triggers: Optional[set[Signal]] def responds_to(self, event: Any) -> bool: + """ + TODO: doc sender + """ action_check = True if self.actions is not None: if not any(isinstance(event, action) for action in self.actions): action_check = False - + trigger_check = True if self.triggers is not None: if not any(extract_signal(event) == trig for trig in self.triggers): trigger_check = False - + return action_check and trigger_check @dataclass(frozen=True) class Receiver: + """ + TODO: doc receiver + """ __slots__ = ('fn', 'priority') - fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]] + fn: Callable[[Union[str, dict]], Coroutine[Any,Any,Optional[Event]]] priority: tuple[int, ...] class Direction(Enum): + """ + Only two directions + """ SEND = auto() RECEIVE = auto() @dataclass(frozen=True) class TapeActivation: + """ + TODO: doc tape activation + """ __slots__ = ('key', 'state', 'route', 'fn') key: Optional[str] state: Optional[Node] route: ParsedRoute - fn: Callable[[Any], Awaitable] + fn: Callable[[Any], Coroutine[Any,Any,Any]] # ======= STATE TAPE ======= +TupleStrNode = Tuple[str | Node, ...] | Tuple[Node, ...] | Tuple[str, ...] +ListStrNode = List[str | Node] | List[Node] | List[str] +DictSingleStrNode = Mapping[Optional[str], + str | Node + ] | Mapping[str, + str | Node + ] +DictManyStrNode = Mapping[ + Optional[str], + str | Node | ListStrNode | TupleStrNode] | \ + Mapping[str, + str | Node | ListStrNode | TupleStrNode] + +StatesType = DictSingleStrNode | DictManyStrNode | str | Node | \ + ListStrNode | TupleStrNode | None + class TapeType(Enum): + """ + TODO: doc tapetype + """ SINGLE = auto() MANY = auto() INDEX_SINGLE = auto() INDEX_MANY = auto() + @staticmethod + def single_type_guard(states: Any) -> \ + TypeGuard[str | Node]: + """ + This input as states would get + interpreted as SINGLE type, so it must be a single str or Node + """ + return isinstance(states, (str, Node)) + + @staticmethod + def many_type_guard(states: Any) -> \ + TypeGuard[TupleStrNode | ListStrNode]: + """ + This input as states would get + interpreted as MANY type, so it must be a list or tuple of str or Node + """ + return isinstance(states, (list, tuple)) and \ + all(isinstance(s, (str, Node)) for s in states) + + @staticmethod + def index_single_guard(states: Any) -> \ + TypeGuard[DictSingleStrNode]: + """ + This input as states would get + interpreted as INDEX_SINGLE type, so it must be a dict from Optional[str] to str or Node + """ + return isinstance(states, dict) and \ + all( + isinstance(k, (str, type(None))) and \ + isinstance(v, (str, Node)) for k, v in states.items() + ) + + @staticmethod + def index_many_guard(states: Any) -> \ + TypeGuard[DictManyStrNode]: + """ + This input as states would get + interpreted as INDEX_MANY type, + so it must be a dict from Optional[str] to list or tuple of str or Node + """ + return isinstance(states, dict) and \ + all( + isinstance(k, (str, type(None))) + and ( + isinstance(v, (str, Node)) + or ( + isinstance(v, (list, tuple)) + and all(isinstance(x, (str, Node)) for x in v) + ) + ) + for k, v in states.items()) + + @staticmethod + def _assess_type(states: StatesType) -> Optional[TapeType]: + """ + If there is only one state, SINGLE + If there is a list or tuple of many states, MANY + If there is a dictionary sending each Optional[str] key to a single state, INDEX_SINGLE + If there is a dictionary sending each Optional[str] key to possibly many states, INDEX_MANY + """ + # Scalar → SINGLE + if TapeType.single_type_guard(states): + return TapeType.SINGLE + + # Sequence of scalars → MANY + if TapeType.many_type_guard(states): + return TapeType.MANY + + # Mapping → either INDEX_SINGLE or INDEX_MANY + if isinstance(states, dict): + # all values are scalar → INDEX_SINGLE + if TapeType.index_single_guard(states): + return TapeType.INDEX_SINGLE + + # all values are either scalar or sequence of scalars → INDEX_MANY + # but at least one of the values was actually a sequence + if TapeType.index_many_guard(states): + return TapeType.INDEX_MANY + + return None + class StateTape: + """ + TODO: doc state tape + """ __slots__ = ('states', '_type') prefix: str = "tape" - def __init__(self, states: Any = None, with_prefix: bool = True): + def __init__(self, states: StatesType = None, with_prefix: bool = True): # Figure out what kind of input we have - tp = StateTape._assess_type(states) + tp = TapeType._assess_type(states) # Default: empty index-many if tp is None: - self.states = {} + self.states : Dict[str,List[Node]] = {} self._type = TapeType.INDEX_MANY # Exactly SINGLE elif tp is TapeType.SINGLE: + assert TapeType.single_type_guard(states) node = self._nodeify(states) # wrap str→Node if needed + if not with_prefix: + #pylint:disable=line-too-long + raise ValueError("StateTape constructor with with_prefix=False only gets called internally and using a dict input.") self.states = {self.prefix: [node]} self._type = tp # Exactly MANY elif tp is TapeType.MANY: + assert TapeType.many_type_guard(states) nodes = [self._nodeify(s) for s in states] + if not with_prefix: + #pylint:disable=line-too-long + raise ValueError("StateTape constructor with with_prefix=False only gets called internally and using a dict input.") self.states = {self.prefix: nodes} self._type = tp # Exactly INDEX_SINGLE elif tp is TapeType.INDEX_SINGLE: + assert TapeType.index_single_guard(states) self.states = { self._add_prefix(k, with_prefix): [self._nodeify(v)] for k, v in states.items() @@ -387,8 +563,15 @@ def __init__(self, states: Any = None, with_prefix: bool = True): # Exactly INDEX_MANY elif tp is TapeType.INDEX_MANY: + assert TapeType.index_many_guard(states) + def v_to_node_list(v: Union[str, Node, ListStrNode, TupleStrNode]) -> List[Node]: + if isinstance(v, (str, Node)): + return [self._nodeify(v)] + if isinstance(v, (list, tuple)): + return [self._nodeify(s) for s in v] + raise TypeError(f"Invalid value type in INDEX_MANY: {v!r}") self.states = { - self._add_prefix(k, with_prefix): [self._nodeify(s) for s in v] + self._add_prefix(k, with_prefix): v_to_node_list(v) for k, v in states.items() } self._type = tp @@ -398,54 +581,49 @@ def __init__(self, states: Any = None, with_prefix: bool = True): raise RuntimeError(f"Unhandled TapeType {tp!r}") def set_type(self, value: TapeType) -> StateTape: + """ + Change the type + """ self._type = value return self - @staticmethod - def _assess_type(states: Any) -> Optional[TapeType]: - # Scalar → SINGLE - if isinstance(states, (str, Node)): - return TapeType.SINGLE - - # Sequence of scalars → MANY - if isinstance(states, (list, tuple)): - if all(isinstance(s, (str, Node)) for s in states): - return TapeType.MANY - - # Mapping → either INDEX_SINGLE or INDEX_MANY - if isinstance(states, dict): - # all values are scalar → INDEX_SINGLE - if all( - isinstance(k, (str, type(None))) - and isinstance(v, (str, Node)) - for k, v in states.items() - ): - return TapeType.INDEX_SINGLE - - # all values are either scalar or sequence of scalars → INDEX_MANY - if all( - isinstance(k, (str, type(None))) - and ( - isinstance(v, (str, Node)) - or ( - isinstance(v, (list, tuple)) - and all(isinstance(x, (str, Node)) for x in v) - ) - ) - for k, v in states.items() - ): - return TapeType.INDEX_MANY - - return None - - def _add_prefix(self, key: str, with_prefix: bool = True) -> str: + def _add_prefix(self, key: str | None, with_prefix: bool = True) -> str: + """ + TODO: doc prefix + """ + if key is None and with_prefix: + return f"{self.prefix}:{key}" + if key is None: + return key # pyright: ignore[reportReturnType] return f"{self.prefix}:{key}" if with_prefix else key + def _remove_str_prefix(self, key: str) -> str: + """ + TODO: doc prefix + """ + p = f"{self.prefix}:" + return key[len(p):] if key.startswith(p) else key + + # The type annotation here is incorrect, key=None gives None + # but making this Optional[str] in return + # is not the intended behavior and pollutes that possibility + # for the caller even when they are not passing key=None def _remove_prefix(self, key: Optional[str]) -> str: + """ + TODO: doc prefix + """ p = f"{self.prefix}:" - return key[len(p):] if isinstance(key, str) and key.startswith(p) else key + if isinstance(key, str) and key.startswith(p): + return key[len(p):] + return key # pyright: ignore[reportReturnType] - def extend(self, states: Any): + def extend(self, states: StatesType): + """ + Merge in these new states with the current self.states + Creating a temporary StateTape handles the different ways + the new states can be presented as in __init__ rather than always being + a dictionary from strings to List[Node] + """ # Delegate to a local StateTape then merge local_tape = StateTape(states, with_prefix=False) for k, nodes in local_tape.states.items(): @@ -453,10 +631,21 @@ def extend(self, states: Any): self.states[k].extend(nodes) def refresh(self): - # Delegate to a fresh StateTape - return StateTape({key: [] for key in self.states.keys()}, with_prefix=False).set_type(self._type) + """ + Delegate to a fresh StateTape + This has the same keys, but no more Nodes in the accompanying values + """ + return StateTape( + {key: [] for key in self.states.keys()}, + with_prefix=False + ).set_type(self._type) def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: + """ + The states as it was provided as the input to __init__ + StateTape(revert(StateTape(states,remove_prefix=b)),remove_prefix=b) + goes back and forth + """ # SINGLE or MANY → flatten to a single list if self._type in (TapeType.SINGLE, TapeType.MANY): out_list: list[Node] = [] @@ -468,7 +657,7 @@ def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: if self._type in (TapeType.INDEX_SINGLE, TapeType.INDEX_MANY): out_dict: dict[str, list[Node]] = {} for pk, seq in self.states.items(): - key = self._remove_prefix(pk) + key = self._remove_str_prefix(pk) out_dict.setdefault(key, []) out_dict[key].extend(seq) return out_dict @@ -476,9 +665,11 @@ def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: return None def _nodeify(self, x: Union[str, Node]) -> Node: - # wrap raw strings into Node + """ + wrap raw strings into Node + """ return x if isinstance(x, Node) else Node(x) - + def collect_activations( self, receiver_index: dict[str, Receiver], @@ -487,7 +678,8 @@ def collect_activations( """ For each receiver, and each (key, state) in self.states, if the parsed route matches that state (or has no source), - record a TapeActivation(priority, key, state, route, fn). + record a priority: TapeActivation(key, state, route, fn) + key-value pair """ activation_index: dict[tuple[int, ...], list[TapeActivation]] = defaultdict(list) @@ -516,10 +708,10 @@ def collect_activations( # ======= LIFE CYCLES ======= -from enum import Enum, auto - class ClientIntent(Enum): + """ + Life Cycles + """ QUIT = auto() # brutal, immediate exit TRAVEL = auto() # switch to a new host/port ABORT = auto() # abort due to error - diff --git a/summoner/protocol/triggers.py b/summoner/protocol/triggers.py index d2f55f5..1449d17 100644 --- a/summoner/protocol/triggers.py +++ b/summoner/protocol/triggers.py @@ -29,13 +29,15 @@ import re import sys -from typing import Any, Optional +from typing import Optional +from typing import Any from pathlib import Path import keyword _VARNAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +type TreeType = dict[str, Optional["TreeType"]] def is_valid_varname(name: str) -> bool: """Return True if name matches Python variable naming rules.""" @@ -83,25 +85,27 @@ def update_hierarchy(indent: int, indent_levels: list[int]) -> int: return depth -def simplify_leaves(tree: dict[str, Any]) -> None: +def simplify_leaves(tree: TreeType) -> None: """ Recursively convert empty dicts to None to mark leaves. (4) No return, operates in-place on 'tree'. """ for key, subtree in list(tree.items()): + if subtree is None: + continue if subtree == {}: tree[key] = None elif isinstance(subtree, dict): simplify_leaves(subtree) -def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any]: +def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> TreeType: """ Parse a list of lines (strings) into a nested dict tree. This is one entry point, taking raw lines (great for testing). """ - root: dict[str, Any] = {} - nodes_by_depth: dict[int, dict[str, Any]] = {0: root} + root: TreeType = {} + nodes_by_depth: dict[int, TreeType] = {0: root} indent_levels: list[int] = [0] for lineno, raw in enumerate(lines, 1): @@ -117,20 +121,22 @@ def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any parent = nodes_by_depth[depth] if name in parent: - raise ValueError(f"Line {lineno}: duplicate signal name {name!r} at indent level {indent}") + raise ValueError( + f"Line {lineno}: duplicate signal name {name!r} at indent level {indent}" + ) parent[name] = {} # 3) Trim nodes_by_depth so we don't keep stale deeper dicts for d in list(nodes_by_depth): if d > depth: del nodes_by_depth[d] - nodes_by_depth[depth + 1] = parent[name] + nodes_by_depth[depth + 1] = parent[name] # pyright: ignore[reportArgumentType] simplify_leaves(root) return root -def parse_signal_tree(filepath: str, tabsize: int = 8) -> dict[str, Any]: +def parse_signal_tree(filepath: Path | str, tabsize: int = 8) -> TreeType: """ Read a file and parse it into a nested dict tree. This is the second entry point, for file-based input. @@ -141,6 +147,11 @@ def parse_signal_tree(filepath: str, tabsize: int = 8) -> dict[str, Any]: class Signal: + """ + Keep track of position in a tree + via the path followed. + Allowing comparison by the ancestor relationship. + """ __slots__ = ("_path", "_name") def __init__(self, path: tuple[int, ...], name: str): self._path = path @@ -148,10 +159,16 @@ def __init__(self, path: tuple[int, ...], name: str): @property def path(self) -> tuple[int, ...]: + """ + The path in the tree from the root + """ return self._path @property def name(self) -> str: + """ + The name of this Signal + """ return self._name def __gt__(self, other): @@ -159,7 +176,7 @@ def __gt__(self, other): return NotImplemented a, b = self._path, other._path return len(a) < len(b) and b[:len(a)] == a - + def __lt__(self, other): return other > self @@ -167,7 +184,7 @@ def __ge__(self, other): if not isinstance(other, Signal): return NotImplemented return self._path == other._path or self > other - + def __le__(self, other): return other >= self @@ -183,17 +200,21 @@ def __repr__(self): return f"" -def build_triggers(tree: dict[str, Any]): +def build_triggers(tree: TreeType): + """ + TODO: doc Trigger + """ name_to_path: dict[str, tuple[int, ...]] = {} path_to_name: dict[tuple[int, ...], str] = {} - def recurse(subtree: dict[str, Any], prefix: tuple[int, ...] =()): + def recurse(subtree: TreeType, prefix: tuple[int, ...] =()): for idx, (name, child) in enumerate(subtree.items()): path = prefix + (idx,) name_to_path[name] = path path_to_name[path] = name - if isinstance(child, dict) and child: - recurse(child, path) + if child is not None: + if isinstance(child, dict) and child: + recurse(child, path) recurse(tree) @@ -215,13 +236,17 @@ def recurse(subtree: dict[str, Any], prefix: tuple[int, ...] =()): def name_of(*args): """Get signal name from tuple (or *args).""" return path_to_name.get(tuple(args)) - + attrs["name_of"] = staticmethod(name_of) return type("Trigger", (), attrs) +#pylint:disable=too-few-public-methods class Event: + """ + TODO: doc event + """ __slots__ = ("signal",) def __init__(self, signal: Signal) -> None: self.signal = signal @@ -229,36 +254,54 @@ def __repr__(self) -> str: return f"{type(self).__name__}({self.signal!r})" -class Move(Event): pass -class Stay(Event): pass +class Move(Event): + """ + TODO: doc move event + activated_nodes + """ +class Stay(Event): + """ + TODO: doc stay event + activated_nodes + """ class Test(Event): + """ + TODO: test event + activated_nodes + """ __test__ = False - pass + class Action: + """ + TODO: doc in activation_nodes + """ MOVE = Move STAY = Stay TEST = Test def extract_signal(trigger): + """ + Just the signal part. + Handling if it was wrapped up in an event or not + """ if isinstance(trigger, Event): return trigger.signal - elif isinstance(trigger, Signal): + if isinstance(trigger, Signal): return trigger - elif trigger is None: + if trigger is None: return None - else: - raise TypeError(f"Cannot extract signal from object of type {type(trigger).__name__!r}") - + raise TypeError(f"Cannot extract signal from object of type {type(trigger).__name__!r}") + # the file TRIGGERS needs to be next to the code for the client, hence sys.argv[0] WORKING_DIR = Path(sys.argv[0]).resolve().parent def load_triggers( triggers_file: Optional[str] = "TRIGGERS", - text: Optional[str] = None, + text: Optional[str] = None, json_dict: Optional[dict[str, Any]] = None ): """ @@ -272,19 +315,33 @@ def load_triggers( `json_dict` must match the nested structure output by `parse_signal_tree_lines`. Raises FileNotFoundError with clear message if file is missing. + Raises ValueError with message as in `parse_signal_tree_lines` and `build_triggers` + if any of the lines in the text or the file are malformed. """ - try: - if text is not None: - lines = text.splitlines() - tree = parse_signal_tree_lines(lines) - elif json_dict is not None: - tree = dict(json_dict) # shallow copy to avoid mutation - else: - path = WORKING_DIR / triggers_file - tree = parse_signal_tree(path) - except FileNotFoundError as e: - raise FileNotFoundError( - f"Could not find triggers file at {path if 'path' in locals() else ''}" - ) from e + if text is not None: + lines = text.splitlines() + # This below can raise ValueError + tree = parse_signal_tree_lines(lines) + elif json_dict is not None: + tree = dict(json_dict) + else: + path = "" + try: + if triggers_file is not None: + # In this case triggers_file was either provided as str + # or left out completely and the default is being used. + path = WORKING_DIR / triggers_file + else: + # In this case triggers_file was explicitly provided as None + path = WORKING_DIR / "TRIGGERS" + raise FileNotFoundError( + "no triggers file and weren't using the default either" + ) + # This below can raise ValueError or FileNotFoundError + tree =parse_signal_tree(path) + except FileNotFoundError as e: + #pylint:disable=line-too-long + raise FileNotFoundError( + f"Could not find triggers file at {path if 'path' in locals() else ''}" + ) from e return build_triggers(tree) - diff --git a/summoner/protocol/validation.py b/summoner/protocol/validation.py index f2659dc..078cb61 100644 --- a/summoner/protocol/validation.py +++ b/summoner/protocol/validation.py @@ -1,27 +1,33 @@ +""" +TODO: doc validation +""" import os import sys from typing import ( - Union, - get_type_hints, - get_origin, + Union, + get_type_hints, + get_origin, get_args, ) import inspect target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) +#pylint:disable=wrong-import-position from summoner.logger import Logger def hook_priority_order(priority: tuple) -> tuple: - + """ + TODO: doc hook_priority + """ + if priority != () and isinstance(priority[0], str): # SDK priority: sort first, e.g., ('aurora', 0) → (0, 'aurora', 0) return (0, priority) - else: - # User priority: sort after SDK, e.g., (0, 1) → (1, 0, 1) - return (1, priority) + # User priority: sort after SDK, e.g., (0, 1) → (1, 0, 1) + return (1, priority) def _normalize_annotation(raw): """ @@ -58,16 +64,19 @@ def _check_param_and_return(fn, # parameter check params = list(sig.parameters.values()) - expected_params = 1 if decorator_name in ("@hook", "@receive", "@upload_states", "@download_states") else 0 - + expected_params = 1 if decorator_name in \ + ("@hook", "@receive", "@upload_states", "@download_states") else 0 + if len(params) != expected_params: raise TypeError(f"{decorator_name} '{fn.__name__}' must have " f"{expected_params} parameter(s), not {len(params)}") if expected_params == 1: raw_param = params[0].annotation - param_hint = _normalize_annotation(raw_param) or hints.get(params[0].name, None) + param_hint = _normalize_annotation(raw_param) or \ + hints.get(params[0].name, None) if param_hint is None: + #pylint:disable=line-too-long logger.warning( f"{decorator_name} '{fn.__name__}' missing parameter annotation; skipping type check" ) @@ -87,4 +96,4 @@ def _check_param_and_return(fn, elif not _valid_type_hint(ret_hint, allow_return): raise TypeError( f"{decorator_name} '{fn.__name__}' must return one of {allow_return}, not {ret_hint!r}" - ) \ No newline at end of file + ) diff --git a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs index 5edc954..b1ec93e 100644 --- a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs @@ -31,6 +31,23 @@ pub struct BackpressurePolicy { pub disconnect_threshold: usize, } +impl BackpressurePolicy { + #[must_use = "The decision of whether to throttle"] + pub fn do_throttle(&self, queue_size: usize) -> bool { + self.enable_throttle && queue_size > self.throttle_threshold + } + + #[must_use = "The decision of whether to pause reading for chatty clients"] + pub fn do_flow_control(&self, queue_size: usize) -> bool { + self.enable_flow_control && queue_size > self.flow_control_threshold + } + + #[must_use = "The decision of whether to do a forced disconnect"] + pub fn do_disconnect(&self, queue_size: usize) -> bool { + self.enable_disconnect && queue_size > self.disconnect_threshold + } +} + ///////////////////////// // LoggerConfig // ///////////////////////// @@ -58,20 +75,19 @@ pub struct LoggerConfig { pub date_format: String, // Optional list of keys to include when logging dictionary messages in JSON pub log_keys: Option>, - // Maximum number of bytes per log file before rotation occurs // pub max_file_size: usize, // Number of rotated log files to keep before older ones are deleted // pub backup_count: usize, } - ////////////////////// // ServerConfig // ////////////////////// /// All the settings our server needs, driven by the Python-side dict #[derive(Debug, Clone)] +#[rustfmt::skip] pub struct ServerConfig { /// IP or hostname to listen on (e.g. `"127.0.0.1"`) pub host: String, @@ -81,10 +97,10 @@ pub struct ServerConfig { /// Logger configuration (text/JSON, file rotation, key filtering, etc.) pub logger: LoggerConfig, - + /// How many incoming connections we'll buffer before backpressure pub connection_buffer_size: usize, - + /// Seconds before we drop an idle client pub client_timeout: Option, @@ -125,6 +141,14 @@ pub struct ServerConfig { pub worker_threads: usize, } +impl ServerConfig { + #[must_use = "The decision of whether to throttle"] + #[inline] + pub fn addr(&self) -> String { + format!("{}:{}", self.host, self.port) + } +} + ///////////////////////////////////////////// // Converting from a Python dict into Rust // ///////////////////////////////////////////// @@ -133,8 +157,10 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { // We return a PyErr if something goes wrong parsing type Error = PyErr; + #[allow(clippy::too_many_lines)] + #[rustfmt::skip] fn try_from(dict: &Bound<'py, PyDict>) -> Result { - + // Helper: look up `key` in the dict, or fall back to `default`. // If the Python value has the wrong type, we warn but still use `default`. fn extract_or<'py, T: FromPyObject<'py>>( @@ -160,7 +186,7 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { let host = extract_or(dict, "host", "127.0.0.1".to_string()); let port = extract_or(dict, "port", 8888); let connection_buffer_size = extract_or(dict, "connection_buffer_size", 128); - + let rate_limit = extract_or(dict, "rate_limit_msgs_per_minute", 300); let command_buffer_size = extract_or(dict, "command_buffer_size", 32); let quarantine_cooldown_secs = extract_or(dict, "quarantine_cooldown_secs", 300); @@ -240,6 +266,7 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { // let max_file_size = extract_or(&logger_dict, "max_file_size", 1_000_000); // let backup_count = extract_or(&logger_dict, "backup_count", 3); + #[allow(clippy::inconsistent_struct_constructor)] let logger = LoggerConfig { log_level, enable_console_log, diff --git a/summoner/rust/rust_server_v1_1_0/src/lib.rs b/summoner/rust/rust_server_v1_1_0/src/lib.rs index a4e174f..6a43244 100644 --- a/summoner/rust/rust_server_v1_1_0/src/lib.rs +++ b/summoner/rust/rust_server_v1_1_0/src/lib.rs @@ -1,8 +1,13 @@ +#![allow(clippy::uninlined_format_args, clippy::manual_string_new)] // Import everything we need from PyO3 so we can expose Rust functions to Python. // - `prelude::*` gives us the common traits and types. // - `PyModule` and `PyDict` let us work with Python modules and dicts. // - `Bound` helps us hold a reference into Python data safely. -use pyo3::{prelude::*, types::{PyModule, PyDict}, Bound}; +use pyo3::{ + Bound, + prelude::*, + types::{PyDict, PyModule}, +}; // Public module for parsing and validating server configuration provided by Python. pub mod config; @@ -14,9 +19,9 @@ pub mod logger; mod server; // Pull in the specific items we need so we don't have to write full paths later. +use config::ServerConfig; use logger::init_logger; use server::run_server; -use config::ServerConfig; /// Expose this Rust function to Python as `start_tokio_server(name, config)`. /// Responsibilities: @@ -33,6 +38,11 @@ use config::ServerConfig; /// Returns: /// - `Ok(())` if everything starts fine. /// - A Python `RuntimeError` if something goes wrong. +/// +/// # Errors +/// +/// A Python `RuntimeError` if something goes wrong +#[allow(clippy::needless_pass_by_value, clippy::used_underscore_binding)] #[pyfunction] pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) -> PyResult<()> { // Convert the Python dict into our `ServerConfig` Rust struct. @@ -41,6 +51,7 @@ pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) // Build a multi-threaded Tokio runtime based on the `worker_threads` value. // This runtime drives all our async I/O and timers. + #[rustfmt::skip] let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(server_config.worker_threads) // how many threads to use .thread_name("rust-server-worker") // helpful for debugging @@ -72,6 +83,7 @@ pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) }); // If the server returned an error, convert it into a Python RuntimeError. + #[rustfmt::skip] server_result.map_err(|e| { PyErr::new::( format!("Server execution failed: {}", e) diff --git a/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs b/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs index 9667b91..9e74a41 100644 --- a/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs @@ -21,8 +21,10 @@ use crate::config::LoggerConfig; /// Given the *parsed* content (a JSON value) and an optional /// whitelist of keys, prune its `_payload` and `_type` sub-objects -/// if `keys` is Some, then return the resulting JsonValue. +/// if `keys` is Some, then return the resulting `JsonValue`. /// This consumes `content`, so no cloning is needed by the caller. +#[allow(clippy::must_use_candidate)] +#[rustfmt::skip] pub fn prune_content_value( mut content: JsonValue, keys: &Option>, @@ -32,7 +34,10 @@ pub fn prune_content_value( Some(m) if keys.is_some() => m, _ => return content, }; - let keys = keys.as_ref().unwrap(); + #[allow(clippy::missing_panics_doc)] + let keys = keys + .as_ref() + .expect("If keys was None, then it would have gone through early return"); // 2) Build a fresh map with version + filtered sub-objects let mut out = serde_json::Map::new(); @@ -46,14 +51,13 @@ pub fn prune_content_value( .filter(|(k, _)| keys.contains(k)) .map(|(k, v)| (k.clone(), v.clone())) .collect(); - out.insert(field.to_string(), JsonValue::Object(filtered)); + out.insert((*field).to_string(), JsonValue::Object(filtered)); } } JsonValue::Object(out) } - /// A simple Logger struct that wraps logging functions. /// Clonable to allow use across multiple threads/tasks. #[derive(Clone)] @@ -87,8 +91,8 @@ static LOGGER: OnceLock = OnceLock::new(); /// Initialize the global logger exactly once, according to the provided settings. /// After this call, all calls to `log::debug!(), info!(), warn!(), error!()` (and your /// `Logger` methods) will go through the configured fern dispatcher. +#[rustfmt::skip] pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { - LOGGER.get_or_init(|| { // ──────────────────────────────────────────────────────────────── // 1) Parse the configured level string into a log::LevelFilter @@ -128,7 +132,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { nm, record.level(), message - )) + )); }; base = base.chain( @@ -157,6 +161,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { let enable_json = cfg.enable_json_log; // Compute the logfile path + #[rustfmt::skip] let filepath = if cfg.log_file_path.is_empty() { format!("{}.log", nm.replace('.', "_")) } else { @@ -172,6 +177,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { if enable_json { // 1) Raw payload text let raw = message.to_string(); + #[rustfmt::skip] let message_json: JsonValue = serde_json::from_str(&raw).unwrap_or(JsonValue::String(raw.clone())); // 2) Build a real JSON envelope with "message" as an object @@ -183,7 +189,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { }); // 4) Emit the envelope unescaped - out.finish(format_args!("{}", envelope)) + out.finish(format_args!("{}", envelope)); } else { // Plain-text path out.finish(format_args!( @@ -192,7 +198,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { nm, record.level(), message - )) + )); } }; @@ -216,6 +222,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { // 5) Apply the composed dispatcher as the global logger // Any subsequent log:: calls will use this configuration. // ──────────────────────────────────────────────────────────────── + #[allow(clippy::missing_panics_doc)] base.apply().unwrap(); // Return our zero-sized Logger handle diff --git a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs index e2a8edb..b8a5cf5 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use std::net::SocketAddr; use tokio::sync::mpsc; -use crate::logger::Logger; use crate::config::BackpressurePolicy; +use crate::logger::Logger; /// Commands that the backpressure monitor can issue to the main server loop. #[derive(Debug)] @@ -19,8 +19,9 @@ pub enum ClientCommand { FlowControl, } - /// Spawn a task to monitor backpressure from clients +#[allow(clippy::needless_pass_by_value)] +#[rustfmt::skip] pub fn spawn_backpressure_monitor( mut rx: mpsc::Receiver<(SocketAddr, usize)>, command_tx: mpsc::Sender, @@ -28,42 +29,37 @@ pub fn spawn_backpressure_monitor( policy: BackpressurePolicy, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut client_queues: HashMap = HashMap::new(); while let Some((addr, queue_size)) = rx.recv().await { - client_queues.insert(addr, queue_size); - + let _old_queue_size = client_queues.insert(addr, queue_size); + // Log if queue size is getting large - if policy.enable_throttle && queue_size >= policy.throttle_threshold { + if policy.do_throttle(queue_size) { logger.warn(&format!("⚠️ Throttling client {}: {} messages queued", addr, queue_size)); if let Err(e) = command_tx.send(BackpressureCommand::Throttle(addr)).await { logger.error(&format!("Failed to send throttle command: {}", e)); } - } // Log if queue size is getting large - if policy.enable_flow_control && queue_size >= policy.flow_control_threshold { + if policy.do_flow_control(queue_size) { logger.warn(&format!("⏸️ Applying flow control to client {}: {} messages queued", addr, queue_size)); - + if let Err(e) = command_tx.send(BackpressureCommand::FlowControl(addr)).await { logger.error(&format!("Failed to send flow control command: {}", e)); } - } // Log if queue size is getting large - if policy.enable_disconnect && queue_size >= policy.disconnect_threshold { + if policy.do_disconnect(queue_size) { logger.warn(&format!("🚨 Disconnecting client {} due to extreme backpressure: {} messages queued", addr, queue_size)); - + if let Err(e) = command_tx.send(BackpressureCommand::Disconnect(addr)).await { logger.error(&format!("Failed to send disconnect command: {}", e)); } - } - } }) -} \ No newline at end of file +} diff --git a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs index daecff5..fd3ba37 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs @@ -1,5 +1,5 @@ +#![allow(clippy::empty_line_after_doc_comments)] /// === IMPORTS === - // Standard library type for holding an IP address and port together. use std::net::SocketAddr; @@ -25,7 +25,7 @@ use tokio::sync::{Mutex, RwLock, broadcast, mpsc}; use tokio::time::{self, Duration, Instant}; // Macro for building JSON payloads when broadcasting client messages. -use serde_json::{json, Value as JsonValue}; +use serde_json::{Value as JsonValue, json}; // Clone-on-write string type: avoids extra allocations when we don't modify the string. use std::borrow::Cow; @@ -33,14 +33,14 @@ use std::borrow::Cow; // Reference-counted byte buffer: enables zero-copy sharing of message data. use bytes::Bytes; - /// === MODULES === - // Private modules handling specific server features. -mod backpressure; // queue monitoring and control commands -mod ratelimiter; // per-client rate limiting -mod quarantine; // temporary client bans - +#[rustfmt::skip] +mod backpressure; // queue monitoring and control commands +#[rustfmt::skip] +mod ratelimiter; // per-client rate limiting +#[rustfmt::skip] +mod quarantine; // temporary client bans // Import the parsed ServerConfig struct that holds all user settings. use crate::config::ServerConfig; @@ -57,7 +57,6 @@ use crate::server::ratelimiter::RateLimiter; // Quarantine list for tracking and expiring banned clients over time. use crate::server::quarantine::QuarantineList; - /// === TYPES === // Represents one connected client. Cloning this struct is cheap (Arc + channel). @@ -79,7 +78,7 @@ pub struct Client { // - RwLock lets many readers (e.g. broadcasts) but only one writer (add/remove) at a time. pub type ClientList = Arc>>; - +#[allow(clippy::doc_markdown)] /// === RUN_SERVER === /// This function launches the main server loop: @@ -92,7 +91,7 @@ pub async fn run_server( logger: Logger, ) -> Result<(), Box> { // Build the "host:port" string so the OS knows where to listen - let addr = format!("{}:{}", config.host, config.port); + let addr = config.addr(); // Open a TCP listener on that address; the `?` returns early on error let listener = TcpListener::bind(&addr).await?; @@ -114,11 +113,13 @@ pub async fn run_server( mpsc::channel::<(SocketAddr, usize)>(config.connection_buffer_size); // Channel used by the backpressure monitor to send commands (Throttle, Disconnect, etc.) + #[rustfmt::skip] let (command_tx, command_rx) = mpsc::channel::(config.command_buffer_size); // Track which clients are quarantined (banned temporarily) // Wrapped in Arc+Mutex for safe, exclusive access when marking or cleaning entries + #[rustfmt::skip] let quarantine_list = Arc::new(Mutex::new( QuarantineList::new(Duration::from_secs(config.quarantine_cooldown_secs)) )); @@ -190,6 +191,8 @@ fn spawn_shutdown_listener( /// === CONNECTIONS === /// Listens for new clients, spawns per-client tasks, and handles shutdown/backpressure +#[allow(clippy::too_many_arguments)] +#[rustfmt::skip] async fn accept_connections( listener: TcpListener, // TCP socket we bound in run_server clients: ClientList, // Shared list of connected clients @@ -274,10 +277,11 @@ async fn handle_backpressure_command( BackpressureCommand::Throttle(addr) | BackpressureCommand::FlowControl(addr) => { // Decide which control command to send + #[rustfmt::skip] let control_cmd = match cmd { BackpressureCommand::Throttle(_) => ClientCommand::Throttle, BackpressureCommand::FlowControl(_) => ClientCommand::FlowControl, - _ => unreachable!(), + BackpressureCommand::Disconnect(_) => unreachable!(), }; // Under a short read lock, grab a clone of the sender if the client still exists @@ -297,10 +301,12 @@ async fn handle_backpressure_command( control_cmd, addr, e )); } + #[rustfmt::skip] let emoji = match control_cmd { ClientCommand::Throttle => "⏳", ClientCommand::FlowControl => "⏸️", }; + #[rustfmt::skip] logger.info(&format!("{} {:?} requested for {}", emoji, control_cmd, addr)); } else { logger.warn(&format!( @@ -313,6 +319,8 @@ async fn handle_backpressure_command( } /// Validates a new TCP connection and then spawns its session task +#[allow(clippy::too_many_arguments)] +#[rustfmt::skip] async fn handle_new_connection( stream: TcpStream, // The raw TCP connection addr: SocketAddr, // Remote client address (IP:port) @@ -404,7 +412,6 @@ async fn handle_new_connection( let remaining = connection_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - 1; logger.info(&format!("🔌 {} disconnected. Active connections: {}", addr, remaining)); }); - } /// Manages a single client's session: @@ -412,6 +419,7 @@ async fn handle_new_connection( /// - Reports backpressure without blocking /// - Enforces inactivity timeouts and graceful shutdown /// - On exit, removes the client from the shared list and logs the remaining count +#[allow(clippy::too_many_arguments)] async fn handle_connection( reader_half: tokio::net::tcp::OwnedReadHalf, client: Client, @@ -446,6 +454,7 @@ async fn handle_connection( .await { // Log any unexpected error during the session + #[rustfmt::skip] logger.warn(&format!("⚠️ Error in client {} session: {}", client.addr, e)); } @@ -465,7 +474,6 @@ async fn handle_connection( Ok(()) } - /// === MESSAGES === /// Manages a single client session until disconnect or shutdown: @@ -473,6 +481,7 @@ async fn handle_connection( /// - Responds to throttle and flow-control commands /// - Enforces inactivity timeouts /// - Sends a shutdown notice on server exit +#[allow(clippy::too_many_arguments)] async fn handle_client_messages( // Line-based reader for this client's incoming data reader: &mut Lines>, @@ -502,6 +511,7 @@ async fn handle_client_messages( let mut last_active = Instant::now(); // Timer that fires every `timeout_check_interval_secs` to enforce inactivity + #[rustfmt::skip] let mut timeout_interval = time::interval(Duration::from_secs( config.timeout_check_interval_secs, )); @@ -566,6 +576,7 @@ async fn handle_client_messages( // 4) Check for inactivity timeout _ = timeout_interval.tick() => { + #[allow(clippy::collapsible_if)] if let Some(timeout) = timeout { if last_active.elapsed() > timeout { // Send a warning to the client then break @@ -664,6 +675,8 @@ async fn process_client_line( let envelope_text = envelope.to_string(); // 3) Parse the client's content *once* + #[allow(clippy::needless_borrow)] + #[rustfmt::skip] let json_content: JsonValue = serde_json::from_str(&content).unwrap_or(JsonValue::String(content.to_string())); // 4) Move it into pruning or keep it as-is @@ -705,6 +718,7 @@ fn remove_last_newline(s: &str) -> &str { } /// Sends `msg` to every client except the sender, reporting queue size for backpressure: +#[rustfmt::skip] async fn broadcast_message( clients: &ClientList, // Shared list of all connected clients sender: &Client, // The client who sent the original message diff --git a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs index 54ffbd1..e398937 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::time::{Duration, Instant}; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::Mutex; use crate::logger::Logger; @@ -62,4 +62,4 @@ impl QuarantineList { } }) } -} \ No newline at end of file +} diff --git a/summoner/server/__init__.py b/summoner/server/__init__.py index 1c8c79d..6123d76 100644 --- a/summoner/server/__init__.py +++ b/summoner/server/__init__.py @@ -1 +1,4 @@ -from .server import SummonerServer \ No newline at end of file +""" +TODO: doc server +""" +from .server import SummonerServer diff --git a/summoner/server/server.py b/summoner/server/server.py index cddf6c5..e5ddc45 100644 --- a/summoner/server/server.py +++ b/summoner/server/server.py @@ -1,9 +1,14 @@ +""" +TODO: doc server +""" +# pylint:disable=wrong-import-position, logging-fstring-interpolation import asyncio import signal import os import sys import json -from typing import Optional, Any +from typing import Optional +from typing import Any import platform import importlib @@ -14,7 +19,7 @@ # Imports from summoner.utils import ( - remove_last_newline, + remove_last_newline, ensure_trailing_newline, load_config, ) @@ -43,10 +48,13 @@ class ClientDisconnected(Exception): """Raised when the client disconnects cleanly (e.g., closes the socket).""" - pass + class SummonerServer: - + """ + TODO: doc server + """ + __slots__: tuple[str, ...] = ( "name", "logger", @@ -70,14 +78,17 @@ def __init__(self, name: Optional[str] = None): self.clients: set[asyncio.streams.StreamWriter] = set() self.clients_lock = asyncio.Lock() - self.active_tasks: dict[asyncio.Task, str] = {} + self.active_tasks: dict[Optional[asyncio.Task[Any]], str] = {} self.tasks_lock = asyncio.Lock() - + async def handle_client( self, - reader: asyncio.streams.StreamReader, + reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter ): + """ + TODO: doc handle_client + """ addr = format_addr(writer.get_extra_info("peername")) self.logger.info(f"{addr} connected.") @@ -103,24 +114,29 @@ async def handle_client( # Iterate over the snapshot to avoid concurrency issues without long-held locks for other_writer in clients_snapshot: if other_writer != writer: - payload = json.dumps({"remote_addr": addr, "content": remove_last_newline(message)}) + payload = json.dumps( + { + "remote_addr": addr, + "content": remove_last_newline(message) + } + ) other_writer.write(ensure_trailing_newline(payload).encode()) await other_writer.drain() - + except ClientDisconnected: self.logger.info(f"{addr} disconnected.") - + except (ConnectionResetError, asyncio.IncompleteReadError, asyncio.CancelledError) as e: self.logger.warning(f"{addr} error or abrupt disconnect: {e}") - + finally: - + try: writer.write("Warning: Server disconnected.".encode()) await writer.drain() - except Exception: + except Exception: #pylint:disable=broad-exception-caught pass - + async with self.clients_lock: self.clients.discard(writer) writer.close() @@ -131,28 +147,43 @@ async def handle_client( self.logger.info(f"{addr} connection closed.") + #pylint:disable=duplicate-code def shutdown(self): + """ + Cancel all tasks for this loop + """ self.logger.info("Server is shutting down...") for task in asyncio.all_tasks(self.loop): task.cancel() def set_termination_signals(self): + """ + Upon interrupt signal or system/process-based termination + do the shutdown method. + Requires not Windows to have any effect + """ # SIGINT = interupt signal for Ctrl+C | value = 2 # SIGTERM = system/process-based termination | value = 15 if platform.system() != "Windows": for sig in (signal.SIGINT, signal.SIGTERM): + #pylint:disable=unnecessary-lambda self.loop.add_signal_handler(sig, lambda: self.shutdown()) async def run_server(self, host: str = '127.0.0.1', port: int = 8888): + """ + TODO: doc run_server + """ server = await asyncio.start_server(self.handle_client, host=host, port=port) self.logger.info(f"Python server listening on {host}:{port}.") async with server: await server.serve_forever() async def wait_for_tasks_to_finish(self): - # Wait for all client handlers to finish + """ + Wait for all client handlers to finish + """ async with self.tasks_lock: - tasks = list(self.active_tasks) + tasks = list(z for z in self.active_tasks if z is not None) if tasks: await asyncio.gather(*tasks, return_exceptions=True) @@ -163,15 +194,21 @@ def run( config_path: Optional[str] = None, config_dict: Optional[dict[str, Any]] = None, ): - + """ + TODO: doc run + """ + if config_dict is None: # Load config parameters - server_config = load_config(config_path=config_path, debug=True) + server_config = load_config(config_path=config_path, + debug=True) elif isinstance(config_dict, dict): # Shallow copy to avoid external mutation - server_config = dict(config_dict) + server_config = dict(config_dict) else: - raise TypeError(f"SummonerServer.run: config_dict must be a dict or None, got {type(config_dict).__name__}") + #pylint:disable=line-too-long + raise TypeError( + f"SummonerServer.run: config_dict must be a dict or None, got {type(config_dict).__name__}") if platform.system() != "Windows": option = server_config.get("version", None) @@ -195,6 +232,7 @@ def _payload(h, p): elif isinstance(option, str) and option.startswith("rust"): available = ", ".join(["rust"] + list(RUST_MODULES.keys())) if RUST_LATEST else \ ", ".join(RUST_MODULES.keys()) + #pylint:disable=line-too-long raise RuntimeError( f"Rust backend '{option}' requested but not available. Installed options: {available or '(none)'}" ) @@ -202,13 +240,14 @@ def _payload(h, p): if mod is not None: try: requested = option if option is not None else "(unset)" + #pylint:disable=line-too-long print(f"[DEBUG] Config requested version '{requested}' -> resolved to '{mod.__name__}'") mod.start_tokio_server(self.name, _payload(host, port)) except KeyboardInterrupt: pass return - + try: logger_cfg = server_config.get("logger", {}) configure_logger(self.logger, logger_cfg) @@ -218,7 +257,7 @@ def _payload(h, p): except (asyncio.CancelledError, KeyboardInterrupt): pass - + finally: self.loop.run_until_complete(self.wait_for_tasks_to_finish()) diff --git a/summoner/settings.py b/summoner/settings.py index c2101b7..cd76817 100644 --- a/summoner/settings.py +++ b/summoner/settings.py @@ -1,3 +1,6 @@ +""" +setup information +""" # settings.py import os from dotenv import load_dotenv diff --git a/summoner/utils/__init__.py b/summoner/utils/__init__.py index 208a44b..835094b 100644 --- a/summoner/utils/__init__.py +++ b/summoner/utils/__init__.py @@ -1,3 +1,7 @@ +""" +Various utilities for strings, json, peer addresses, inspecting code +""" + from .string_handlers import ( remove_last_newline, ensure_trailing_newline, @@ -15,4 +19,4 @@ extract_annotation_identifiers, rebuild_expression_for, resolve_import_statement, - ) \ No newline at end of file + ) diff --git a/summoner/utils/addr_handlers.py b/summoner/utils/addr_handlers.py index ad85b96..57537a4 100644 --- a/summoner/utils/addr_handlers.py +++ b/summoner/utils/addr_handlers.py @@ -1,7 +1,37 @@ +""" +Convert a socket peer address (or similar) to a compact, deterministic string. +""" from collections.abc import Iterable, Mapping from typing import Any import ipaddress +def _format_addr_iterable(addr: Iterable, max_items: int) -> str: + try: + seq = list(addr) + except Exception: # pylint:disable=broad-exception-caught + return repr(addr) + + # Common socket cases + if len(seq) >= 2 and isinstance(seq[0], (str, bytes)) and isinstance(seq[1], int): + host = seq[0].decode() if isinstance(seq[0], (bytes, bytearray, memoryview)) else seq[0] + port = seq[1] + scopeid = seq[3] if len(seq) >= 4 and isinstance(seq[3], int) else None + try: + ip = ipaddress.ip_address(host) + if ip.version == 6: + host_fmt = host if scopeid in (None, 0) else f"{host}%{scopeid}" + return f"[{host_fmt}]:{port}" + return f"{host}:{port}" + except ValueError: + # Not an IP literal; fall back to host:port + return f"{host}:{port}" + + # Generic iterable formatting + shown = seq[:max_items] + body = ",".join(map(str, shown)) + suffix = ",..." if len(seq) > max_items else "" + return f"[{body}{suffix}]" + def format_addr(addr: Any, max_items: int = 10) -> str: """ Convert a socket peer address (or similar) to a compact, deterministic string. @@ -49,35 +79,10 @@ def format_addr(addr: Any, max_items: int = 10) -> str: # generic iterable (list/tuple/set/generator, etc.), but not strings/bytes if isinstance(addr, Iterable): - try: - seq = list(addr) - except Exception: - return repr(addr) - - # Common socket cases - if len(seq) >= 2 and isinstance(seq[0], (str, bytes)) and isinstance(seq[1], int): - host = seq[0].decode() if isinstance(seq[0], (bytes, bytearray, memoryview)) else seq[0] - port = seq[1] - scopeid = seq[3] if len(seq) >= 4 and isinstance(seq[3], int) else None - try: - ip = ipaddress.ip_address(host) - if ip.version == 6: - host_fmt = host if scopeid in (None, 0) else f"{host}%{scopeid}" - return f"[{host_fmt}]:{port}" - else: - return f"{host}:{port}" - except ValueError: - # Not an IP literal; fall back to host:port - return f"{host}:{port}" - - # Generic iterable formatting - shown = seq[:max_items] - body = ",".join(map(str, shown)) - suffix = ",..." if len(seq) > max_items else "" - return f"[{body}{suffix}]" + _format_addr_iterable(addr, max_items) # fallback try: return str(addr) - except Exception: + except Exception: # pylint:disable=broad-exception-caught return repr(addr) diff --git a/summoner/utils/client_hyperparameter_configs.py b/summoner/utils/client_hyperparameter_configs.py new file mode 100644 index 0000000..8124d63 --- /dev/null +++ b/summoner/utils/client_hyperparameter_configs.py @@ -0,0 +1,75 @@ +""" +In `client.py` config["hyper_parameters"]["sender"] +has several logic steps. Isolate that out here. +""" +#pylint:disable=line-too-long + +from typing import List, Optional, TypeAlias + +# pylint:disable=unused-import +from summoner.utils.client_reconnect_configs import \ + ReconnectConfig, RETRY_DELAY_SECONDS_TYPE, PORT_TYPE +from summoner.utils.client_sender_configs import SenderConfig + +#pylint:disable=invalid-name +TIMEOUT_TYPE: TypeAlias = Optional[float] + +#pylint:disable=too-few-public-methods +class HyperparameterConfig: + """ + Client config hyperparameter section + It is mainly broken up into + SenderConfig + ReconnectConfig + The receiver part is smaller so is directly in the remaining 2 fields + """ + sender_config: SenderConfig + reconnect_config: ReconnectConfig + max_bytes_per_line: int + read_timeout_seconds: TIMEOUT_TYPE + + def __init__(self, + sender_config: SenderConfig, + reconnect_config: ReconnectConfig, + max_bytes_per_line: int, + read_timeout_seconds: Optional[float] + ): + self.sender_config = sender_config + self.reconnect_config = reconnect_config + self.max_bytes_per_line = max_bytes_per_line + self.read_timeout_seconds = read_timeout_seconds + + def __post_init__(self): + if self.max_bytes_per_line <= 0:\ + raise ValueError("The provided max_bytes_per_line must be an integer ≥ 1") + if self.read_timeout_seconds is not None and self.read_timeout_seconds <= 0.0:\ + raise ValueError("The provided read_timeout_seconds must be an float ≥ 0.0 if provided") + + def merge_in(self, **kwargs) -> List[str]: + """ + The **kwargs are new values from configuration files + and this is setting those values as appropriate + """ + all_problems = [] + all_problems.extend(self.sender_config.merge_in(**kwargs["sender"])) + all_problems.extend(self.reconnect_config.merge_in(**kwargs["reconnection"])) + receiver_cfg = kwargs["receiver"] + if "max_bytes_per_line" in receiver_cfg: + if not isinstance((z := receiver_cfg["max_bytes_per_line"]), int): + all_problems.append(f"The provided max_bytes_per_line was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided max_bytes_per_line must be an integer ≥ 1") + else: + self.max_bytes_per_line = z + if "read_timeout_seconds" in receiver_cfg: + if not isinstance((z := receiver_cfg["read_timeout_seconds"]), float | int | None): + all_problems.append(f"The provided read_timeout_seconds was not an optional integer or a float. It was {type(z)}") + else: + if z is None: + self.read_timeout_seconds = None + elif z < 0.0: + all_problems.append(f"The provided read_timeout_seconds was negative. It was {z}") + else: + self.read_timeout_seconds = z + return all_problems diff --git a/summoner/utils/client_reconnect_configs.py b/summoner/utils/client_reconnect_configs.py new file mode 100644 index 0000000..c7e4650 --- /dev/null +++ b/summoner/utils/client_reconnect_configs.py @@ -0,0 +1,76 @@ +""" +In `client.py` config["hyper_parameters"]["reconnection"] +has several logic steps. Isolate that out here +""" +#pylint:disable=line-too-long + +from dataclasses import dataclass +from typing import List, Optional, TypeAlias + +#pylint:disable=invalid-name +RETRY_DELAY_SECONDS_TYPE : TypeAlias = float +#pylint:disable=invalid-name +PORT_TYPE : TypeAlias = Optional[int] + +@dataclass(slots=True) +class ReconnectConfig: + """ + In `client.py` config["hyper_parameters"]["reconnection"] + has several logic steps. Isolate that out here + """ + retry_delay_seconds: RETRY_DELAY_SECONDS_TYPE + primary_retry_limit: int + default_host: Optional[str] + default_port: PORT_TYPE + default_retry_limit: int + + def __post_init__(self): + assert self.retry_delay_seconds >= 0.0,\ + "The provided retry_delay_seconds must be a float ≥ 0.0" + assert self.primary_retry_limit >= 0,\ + "The provided primary_retry_limit must be an integer ≥ 0" + assert self.default_retry_limit >= 0,\ + "The provided default_retry_limit must be an integer ≥ 0" + + #pylint:disable=too-many-branches + def merge_in(self, **kwargs) -> List[str]: + """ + The **kwargs are new values from configuration files + and this is setting those values as appropriate + """ + all_problems = [] + if "retry_delay_seconds" in kwargs: + if not isinstance((z := kwargs["retry_delay_seconds"]), float | int): + all_problems.append(f"The provided retry_delay_seconds was not an integer or a float. It was {type(z)}") + else: + if z < 0.0: + all_problems.append(f"The provided retry_delay_seconds was negative. It was {z}") + else: + self.retry_delay_seconds = z + if "primary_retry_limit" in kwargs: + if not isinstance((z := kwargs["primary_retry_limit"]), int): + all_problems.append(f"The provided primary_retry_limit was not an integer. It was {type(z)}") + else: + if z < 0: + all_problems.append(f"The provided primary_retry_limit was negative. It was {z}") + else: + self.primary_retry_limit = int(z) + if "default_host" in kwargs: + if isinstance((z := kwargs["default_host"]), str | None): + self.default_host = z + else: + all_problems.append(f"The provided default_host was not an optional string. It was {type(z)}") + if "default_port" in kwargs: + if not isinstance((z := kwargs["default_port"]), int | None): + all_problems.append(f"The provided default_port was not an optional integer. It was {type(z)}") + else: + self.default_port = z + if "default_retry_limit" in kwargs: + if not isinstance((z := kwargs["default_retry_limit"]), int): + all_problems.append(f"The provided default_retry_limit was not an integer. It was {type(z)}") + else: + if z < 0: + all_problems.append(f"The provided default_retry_limit was negative. It was {z}") + else: + self.default_retry_limit = z + return all_problems diff --git a/summoner/utils/client_sender_configs.py b/summoner/utils/client_sender_configs.py new file mode 100644 index 0000000..74a7ba6 --- /dev/null +++ b/summoner/utils/client_sender_configs.py @@ -0,0 +1,92 @@ +""" +In `client.py` config["hyper_parameters"]["sender"] +has several logic steps. Isolate that out here. +""" +#pylint:disable=line-too-long + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass(slots=True) +class SenderConfig: + """ + The __post_init__ + enforces the strict requirements beyond types + of positivity of the integers. + Also the `merge_in` of dictionaries from json configurations + will not maintain this property and give a list of strings + for the logger to output as errors. + """ + max_concurrent_workers: int + batch_drain: bool + send_queue_maxsize: int + event_bridge_maxsize: Optional[int] + max_consecutive_worker_errors: int + + def __post_init__(self): + if self.max_concurrent_workers <= 0:\ + raise ValueError("The provided max_concurrent_workers must be an integer ≥ 1") + if self.send_queue_maxsize <= 0:\ + raise ValueError("The provided send_queue_maxsize must be an integer ≥ 1") + if self.max_consecutive_worker_errors <= 0:\ + raise ValueError("The provided max_consecutive_worker_errors must be an integer ≥ 1") + + #pylint:disable=too-many-branches + def merge_in(self, **kwargs) -> List[str]: + """ + The logic of _apply_config that is relevant to the sender_cfg. + In addition rather than give ValueError with only one thing that is wrong, + if there are multiple problems with the config, then the error should show them all. + """ + all_problems = [] + if "concurrency_limit" in kwargs: + if not isinstance((z := kwargs["concurrency_limit"]), int): + all_problems.append(f"The provided concurrency_limit was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided concurrency_limit must be an integer ≥ 1") + else: + self.max_concurrent_workers = z + if "batch_drain" in kwargs: + if not isinstance((z := kwargs["batch_drain"]), bool | None): + all_problems.append(f"The provided batch_drain was not an optional bool. It was {type(z)}") + else: + if z is None: + self.batch_drain = False + else: + self.batch_drain = z + if "queue_maxsize" in kwargs: + if not isinstance((z := kwargs["queue_maxsize"]), int): + all_problems.append(f"The provided queue_maxsize was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided queue_maxsize must be an integer ≥ 1") + else: + self.send_queue_maxsize = z + if "event_bridge_maxsize" in kwargs: + if not isinstance((z := kwargs["event_bridge_maxsize"]), int | None): + all_problems.append(f"The provided event_bridge_maxsize was not an optional int. It was {type(z)}") + else: + if z is None: + self.event_bridge_maxsize = None + else: + self.event_bridge_maxsize = z + if "max_worker_errors" in kwargs: + if not isinstance((z := kwargs["max_worker_errors"]), int): + all_problems.append(f"The provided max_worker_errors was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided max_worker_errors must be an integer ≥ 1") + else: + self.max_consecutive_worker_errors = z + return all_problems + + def throttle_warning(self) -> Optional[str]: + """ + The configuration is valid at all times by construction. + However, this is valid but merits a warning. + """ + if self.send_queue_maxsize < self.max_concurrent_workers: + return f"queue_maxsize < concurrency_limit; back-pressure will throttle producers at {self.send_queue_maxsize}" + return None diff --git a/summoner/utils/code_handlers.py b/summoner/utils/code_handlers.py index 183976f..d2ccf0b 100644 --- a/summoner/utils/code_handlers.py +++ b/summoner/utils/code_handlers.py @@ -1,4 +1,9 @@ -from typing import Optional, Set, Any +""" +Using inspection for Python code +""" + +from typing import Optional, Set +from typing import Any import inspect import ast import textwrap @@ -46,7 +51,7 @@ def get_callable_source(fn: Any, override: Optional[str] = None) -> str: try: return inspect.getsource(fn) - except Exception: + except Exception: # pylint:disable=broad-exception-caught src = getattr(fn, "__dna_source__", None) if isinstance(src, str) and src.strip(): return src @@ -65,6 +70,7 @@ def get_callable_source(fn: Any, override: Optional[str] = None) -> str: ) +# pylint:disable=too-many-branches def extract_annotation_identifiers(src: str) -> Set[str]: """ Extract simple identifier names used inside function annotations from source text. @@ -97,7 +103,7 @@ def extract_annotation_identifiers(src: str) -> Set[str]: out: Set[str] = set() try: tree = ast.parse(textwrap.dedent(src)) - except Exception: + except Exception:# pylint:disable=broad-exception-caught return out fn_node = None @@ -179,7 +185,7 @@ def rebuild_expression_for(value: object, node_type: Optional[type] = None) -> O return None - +# pylint:disable=too-many-return-statements, too-many-branches def resolve_import_statement(name: str, value: object, known_modules: Set[str]) -> Optional[str]: """ Try to produce a stable Python import statement that binds `name` to `value`. @@ -240,7 +246,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) known_modules.add(mod) if getattr(m, name, None) is value: return f"from {mod} import {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass obj_name = getattr(value, "__name__", None) @@ -251,7 +257,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) if name == obj_name: return f"from {mod} import {obj_name}" return f"from {mod} import {obj_name} as {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # Fallback: search modules we've already seen. @@ -260,7 +266,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) m = import_module(km) if getattr(m, name, None) is value: return f"from {km} import {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught continue - return None \ No newline at end of file + return None diff --git a/summoner/utils/json_handlers.py b/summoner/utils/json_handlers.py index f69b412..57478ce 100644 --- a/summoner/utils/json_handlers.py +++ b/summoner/utils/json_handlers.py @@ -1,10 +1,16 @@ +""" +General json manipulation utilities +""" + import json from pathlib import Path -from typing import Any, Optional +from typing import Optional +from typing import Any def fully_recover_json(data): """ Recursively recover original nested structure from JSON strings. + A mix of stringified JSON and python lists, dicts and primitives Args: data (any): Data structure possibly containing nested JSON-encoded strings. @@ -42,9 +48,9 @@ def load_config(config_path: Optional[str], debug: bool = False) -> dict[str, An """ if config_path is None: if debug: - print(f"[DEBUG] Config file is `None`") + print("[DEBUG] Config file is `None`") return {} - + path = Path(config_path) if not path.is_file(): @@ -64,7 +70,7 @@ def load_config(config_path: Optional[str], debug: bool = False) -> dict[str, An except OSError as e: if debug: print(f"[DEBUG] OS error reading {config_path}: {e}") - except Exception as e: + except Exception as e: # pylint:disable=broad-exception-caught if debug: print(f"[DEBUG] Unexpected error loading {config_path}: {e}") @@ -97,5 +103,5 @@ def is_jsonable(value: Any) -> bool: try: json.dumps(value) return True - except Exception: + except Exception: # pylint:disable=broad-exception-caught return False diff --git a/summoner/utils/string_handlers.py b/summoner/utils/string_handlers.py index 31ed86d..9414cc5 100644 --- a/summoner/utils/string_handlers.py +++ b/summoner/utils/string_handlers.py @@ -1,7 +1,15 @@ +""" +Newline manipulation utilities +""" + def remove_last_newline(s: str): + """ + if ending with a newline, remove it + """ return s[:-1] if s.endswith('\n') else s def ensure_trailing_newline(s: str): + """ + Make sure ends with a newline + """ return s if s.endswith('\n') else s + '\n' - - diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..169266e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,118 @@ +""" +Tests for manipulating client_config +""" +#pylint:disable=line-too-long + +from pathlib import Path +from summoner.utils.json_handlers import load_config +from summoner.utils.client_reconnect_configs import ReconnectConfig +from summoner.utils.client_sender_configs import SenderConfig +from summoner.utils.client_hyperparameter_configs import HyperparameterConfig + +def test_template_client_reconnection_config(): + """ + Handling of reconnection hyperparameters part + of the config of client config json files + """ + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + try: + just_reconnection = big_config["hyper_parameters"]["reconnection"] + except KeyError as e: + raise KeyError(f"Had this: {big_config}") from e + reconnect_config = ReconnectConfig( + retry_delay_seconds=85.38, + primary_retry_limit=4380, + default_host="junk", + default_port = 390359, + default_retry_limit=593090, + ) + problems = reconnect_config.merge_in(**just_reconnection) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert reconnect_config.retry_delay_seconds == just_reconnection["retry_delay_seconds"] + assert reconnect_config.primary_retry_limit == just_reconnection["primary_retry_limit"] + assert reconnect_config.default_host == just_reconnection["default_host"] + assert reconnect_config.default_port == just_reconnection["default_port"] + assert reconnect_config.default_retry_limit == just_reconnection["default_retry_limit"] + +def test_template_client_server_config(): + """ + Handling of sender hyperparameters part + of the config of client config json files + """ + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + try: + just_sender = big_config["hyper_parameters"]["sender"] + except KeyError as e: + raise KeyError(f"Had this: {big_config}") from e + sender_config = SenderConfig( + max_concurrent_workers = 1, + batch_drain = False, + send_queue_maxsize = 100, + event_bridge_maxsize = None, + max_consecutive_worker_errors = 5, + ) + problems = sender_config.merge_in(**just_sender) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert sender_config.max_concurrent_workers == just_sender["concurrency_limit"] + assert sender_config.batch_drain == just_sender["batch_drain"] + assert sender_config.send_queue_maxsize == just_sender["queue_maxsize"] + assert sender_config.event_bridge_maxsize == just_sender["event_bridge_maxsize"] + assert sender_config.max_consecutive_worker_errors == just_sender["max_worker_errors"] + +def test_combined(): + """ + Handling of hyperparameters part + of the config of client config json files + """ + # Load up with the default's + sender_config = SenderConfig( + max_concurrent_workers = 1, + batch_drain = False, + send_queue_maxsize = 100, + event_bridge_maxsize = None, + max_consecutive_worker_errors = 5, + ) + reconnect_config = ReconnectConfig( + retry_delay_seconds=85.38, + primary_retry_limit=4380, + default_host="junk", + default_port = 390359, + default_retry_limit=593090, + ) + + hyperparameter_config = HyperparameterConfig( + sender_config, + reconnect_config, + max_bytes_per_line = 1024*64, + read_timeout_seconds = None, + ) + + # The actual config from the file which will override the defaults + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + template_hyperparameter_config = big_config["hyper_parameters"] + + problems = hyperparameter_config.merge_in(**template_hyperparameter_config) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert hyperparameter_config.sender_config.max_concurrent_workers == template_hyperparameter_config["sender"]["concurrency_limit"] + assert hyperparameter_config.sender_config.batch_drain == template_hyperparameter_config["sender"]["batch_drain"] + assert hyperparameter_config.sender_config.send_queue_maxsize == template_hyperparameter_config["sender"]["queue_maxsize"] + assert hyperparameter_config.sender_config.event_bridge_maxsize == template_hyperparameter_config["sender"]["event_bridge_maxsize"] + assert hyperparameter_config.sender_config.max_consecutive_worker_errors == template_hyperparameter_config["sender"]["max_worker_errors"] + + assert hyperparameter_config.reconnect_config.retry_delay_seconds == template_hyperparameter_config["reconnection"]["retry_delay_seconds"] + assert hyperparameter_config.reconnect_config.primary_retry_limit == template_hyperparameter_config["reconnection"]["primary_retry_limit"] + assert hyperparameter_config.reconnect_config.default_host == template_hyperparameter_config["reconnection"]["default_host"] + assert hyperparameter_config.reconnect_config.default_port == template_hyperparameter_config["reconnection"]["default_port"] + assert hyperparameter_config.reconnect_config.default_retry_limit == template_hyperparameter_config["reconnection"]["default_retry_limit"] + + assert hyperparameter_config.max_bytes_per_line == template_hyperparameter_config["receiver"]["max_bytes_per_line"] + assert hyperparameter_config.read_timeout_seconds == template_hyperparameter_config["receiver"]["read_timeout_seconds"] diff --git a/tests/test_flow.py b/tests/test_flow.py index c8c0098..9ee7392 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -6,19 +6,22 @@ from summoner.protocol.flow import get_token_list, Flow from summoner.protocol.process import Node, ParsedRoute - @pytest.mark.parametrize("input_str, sep, expected", [ ("foo,bar(baz,qux),zap", ",", ["foo", "bar(baz,qux)", "zap"]), (" a , b ,c ", ",", ["a", "b", "c"]), ("one(two,three,four),five", ",", ["one(two,three,four)", "five"]), ]) def test_get_token_list(input_str, sep, expected): - # Only top-level separators should split + """ + Only top-level separators should split + """ assert get_token_list(input_str, sep) == expected def make_flow(): - # Helper to set up a Flow with two arrow styles + """ + Helper to set up a Flow with two arrow styles + """ flow = Flow() flow.activate() flow.add_arrow_style(stem="-", brackets=("[", "]"), separator=",", tip=">") @@ -27,24 +30,45 @@ def make_flow(): return flow -def test_parse_route_complete_labeled(): +def test_parse_route_complete_simple(): + """ + A simple ParsedRoute + """ flow = make_flow() - pr = flow.parse_route("A --> B") + pr : ParsedRoute = flow.parse_route("A --> B") # Expect source A, no label, target B assert pr.source == (Node("A"),) assert pr.label == () assert pr.target == (Node("B"),) assert pr.is_arrow - def test_parse_route_unlabeled_complete(): + """ + A simple ParsedRoute that is explicitly unlabeled + """ flow = make_flow() pr = flow.parse_route("X--[]-->Y") assert pr.source == (Node("X"),) assert pr.target == (Node("Y"),) + assert pr.label == () + assert pr.is_arrow +def test_parse_route_complete_labelled(): + """ + A simple ParsedRoute + """ + flow = make_flow() + pr = flow.parse_route("A --[e,f,g]--> B") + # Expect source A, (e,f,g) label, target B + assert pr.source == (Node("A"),) + assert pr.label == (Node("e"),Node("f"), Node("g")) + assert pr.target == (Node("B"),) + assert pr.is_arrow def test_parse_route_dangling_right(): + """ + A dangling right ParsedRoute + """ flow = make_flow() pr = flow.parse_route("-->B") # Dangling left: no source, no label @@ -54,6 +78,9 @@ def test_parse_route_dangling_right(): def test_parse_route_dangling_left(): + """ + A dangling left ParsedRoute + """ flow = make_flow() pr = flow.parse_route("A-->") assert pr.source == (Node("A"),) @@ -61,6 +88,9 @@ def test_parse_route_dangling_left(): def test_parse_route_standalone(): + """ + Standalone object ParsedRoute + """ flow = make_flow() pr = flow.parse_route("X, Y ,Z") # Comma-separated standalone objects @@ -69,16 +99,22 @@ def test_parse_route_standalone(): def test_parse_route_invalid_token(): + """ + An invalid target causes the creation of ParsedRoute + to fail + """ flow = make_flow() with pytest.raises(ValueError): flow.parse_route("A --> inv&alid") def test_parse_routes_list(): + """ + Make several ParsedRoutes + """ flow = make_flow() routes = ["A-->B", "C"] prs = flow.parse_routes(routes) assert isinstance(prs, list) assert prs[0] == flow.parse_route("A-->B") assert prs[1] == flow.parse_route("C") - \ No newline at end of file diff --git a/tests/test_process.py b/tests/test_process.py index 5478c14..4417056 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,6 +1,9 @@ """ Tests for process.py: Node, ArrowStyle, ParsedRoute, activated_nodes, StateTape, collect_activations """ +#pylint:disable=import-outside-toplevel, no-member, invalid-name + +from typing import Dict, List, Tuple, cast import pytest from summoner.protocol.process import ( @@ -17,6 +20,11 @@ def test_node_parsing_and_str_repr(): + """ + The node kinds can be plain, all, exclude from a specific set, + or be one of a specific set. + Invalid names raise errors. + """ # Plain token n = Node("foo") assert n.kind == "plain" and str(n) == "foo" @@ -25,10 +33,10 @@ def test_node_parsing_and_str_repr(): assert a.kind == "all" and str(a) == "/all" # /not(A,B) notn = Node("/not(A,B)") - assert notn.kind == "not" and set(notn.values) == {"A", "B"} + assert notn.kind == "not" and set(notn.values or ()) == {"A", "B"} # /oneof(X,Y) ony = Node("/oneof(X,Y)") - assert ony.kind == "oneof" and set(ony.values) == {"X", "Y"} + assert ony.kind == "oneof" and set(ony.values or ()) == {"X", "Y"} with pytest.raises(ValueError): Node("invalid(token") @@ -41,11 +49,17 @@ def test_node_parsing_and_str_repr(): (Node("/oneof(a,b)"), Node("b"), True), ]) def test_node_accepts_matrix(gate, state, expected): + """ + Gates accept or reject the given states + """ # Verify accepts logic across kinds assert gate.accepts(state) is expected def test_arrowstyle_valid_and_invalid(): + """ + Creation of ArrowStyle being valid or not + """ # Valid style style = ArrowStyle("-", ("[", "]"), ",", ">") assert style.stem == "-" @@ -64,6 +78,10 @@ def test_arrowstyle_valid_and_invalid(): def test_parsedroute_properties_and_repr(): + """ + The properties and regex for matching + of ParsedRoute + """ style = ArrowStyle("-", ("[", "]"), ",", ">") pr = ParsedRoute((Node("A"),), (), (Node("B"),), style) assert not pr.has_label and pr.is_arrow and not pr.is_object @@ -75,20 +93,29 @@ def test_parsedroute_properties_and_repr(): def test_activated_nodes_various_actions(): + """ + Check what activated nodes depending on + what selected trigger/event passed + """ style = ArrowStyle("-", ("[", "]"), ",", ">") pr = ParsedRoute((Node("A"),), (Node("L"),), (Node("B"),), style) # Build a dummy trigger/event to pass - from summoner.protocol.triggers import Move as MoveEvt, Test as TestEvt, Stay as StayEvt, load_triggers + from summoner.protocol.triggers import Move as MoveEvt, Test as TestEvt, \ + Stay as StayEvt Trigger = load_triggers(json_dict={"d": None}) - move_event = MoveEvt(Trigger.d) + move_event = MoveEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(move_event) == (Node("L"), Node("B")) - test_event = TestEvt(Trigger.d) + test_event = TestEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(test_event) == (Node("L"),) - stay_event = StayEvt(Trigger.d) + stay_event = StayEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(stay_event) == (Node("A"),) def test_statetape_and_revert_extend_refresh(): + """ + revert on StateTape gives correct results + and refreshing has the correct effect + """ # SINGLE tape1 = StateTape("X") assert tape1.revert() == [Node("X")] @@ -103,25 +130,31 @@ def test_statetape_and_revert_extend_refresh(): assert tape4.revert() == {"k": [Node("A"), Node("B")]} # Test extend and refresh tape4.extend({"k": ["C"]}) - assert tape4.revert()["k"][-1] == Node("C") + t4r = cast(Dict[str,List[Node]],tape4.revert()) + assert t4r["k"][-1] == Node("C") fresh = tape4.refresh() - assert fresh.revert()["k"] == [] + fresh_t4r = cast(Dict[str,List[Node]],fresh.revert()) + assert fresh_t4r["k"] == [] def test_collect_activations_simple_case(): + """ + What activations do we get for a simple case + """ # Setup flow and parsed route for "A --> B" flow = Flow().activate() flow.add_arrow_style("-", ("[", "]"), ",", ">") flow.compile_arrow_patterns() pr = flow.parse_route("A --> B") # Fake receiver function - async def fn(msg): + async def fn(_msg): return None receiver_index = {str(pr): Receiver(fn=fn, priority=(1,))} parsed_routes = {str(pr): pr} # Tape with nodes A and C under key 'k' tape = StateTape({"k": ["A", "C"]}) - activations = tape.collect_activations(receiver_index, parsed_routes) + activations: Dict[Tuple[int,...],List[TapeActivation]] = \ + tape.collect_activations(receiver_index, parsed_routes) # Only A should match assert (1,) in activations acts = activations[(1,)] @@ -132,35 +165,47 @@ async def fn(msg): assert act.route == pr assert act.fn is fn - def test_sender_responds_to_filters(): + """ + Sender that responds depending on some filters + """ # Setup a simple Trigger set with two signals Trigger = load_triggers(json_dict={"X": None, "Y": None}) - sigX, sigY = Trigger.X, Trigger.Y + sigX, sigY = Trigger.X, Trigger.Y # pyright: ignore[reportAttributeAccessIssue] from summoner.protocol.triggers import Move as MoveEvt, Stay as StayEvt, Test as TestEvt # Events - move_evt = MoveEvt(sigX) - stay_evt = StayEvt(sigY) - test_evt = TestEvt(sigX) + move_x_evt = MoveEvt(sigX) + stay_y_evt = StayEvt(sigY) + test_x_evt = TestEvt(sigX) + + async def do_nothing(): + """ + We are just checking whether or not it responds + so what to put here is just a dummy function + """ + return None # 1) No filters → always responds - sender1 = Sender(fn=lambda: None, multi=False, actions=None, triggers=None) - assert sender1.responds_to(move_evt) - assert sender1.responds_to(stay_evt) + sender1 = Sender(fn=do_nothing, multi=False, actions=None, triggers=None) + assert sender1.responds_to(move_x_evt) + assert sender1.responds_to(stay_y_evt) + assert sender1.responds_to(test_x_evt) # 2) on_actions only - sender_actions = Sender(fn=lambda: None, multi=False, actions={Action.MOVE}, triggers=None) - assert sender_actions.responds_to(move_evt) - assert not sender_actions.responds_to(stay_evt) + sender_actions = Sender(fn=do_nothing, multi=False, actions={Action.MOVE}, triggers=None) + assert sender_actions.responds_to(move_x_evt) + assert not sender_actions.responds_to(stay_y_evt) + assert not sender_actions.responds_to(test_x_evt) # 3) on_triggers only — any event carrying sigX is accepted - sender_triggers = Sender(fn=lambda: None, multi=False, actions=None, triggers={sigX}) - assert sender_triggers.responds_to(move_evt) - assert sender_triggers.responds_to(test_evt) # ← change to True - assert not sender_triggers.responds_to(stay_evt) # different signal Y - - # 4) both filters - sender_both = Sender(fn=lambda: None, multi=False, actions={Action.STAY}, triggers={sigY}) - assert sender_both.responds_to(stay_evt) + sender_triggers = Sender(fn=do_nothing, multi=False, actions=None, triggers={sigX}) + assert sender_triggers.responds_to(move_x_evt) + assert sender_triggers.responds_to(test_x_evt) + assert not sender_triggers.responds_to(stay_y_evt) # different signal Y so not responds + + # 4) both filters, so only the stay and that it is carrying sigY + sender_both = Sender(fn=do_nothing, multi=False, actions={Action.STAY}, triggers={sigY}) + assert sender_both.responds_to(stay_y_evt) assert not sender_both.responds_to(StayEvt(sigX)) - assert not sender_both.responds_to(move_evt) + assert not sender_both.responds_to(move_x_evt) + assert not sender_both.responds_to(test_x_evt) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 95c38cb..e05c0e3 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,5 +1,10 @@ """ -Tests for triggers.py: signal-tree parsing, Trigger class, Signal ordering, Event classes, and extract_signal. +Tests for triggers.py: +- signal-tree parsing +- Trigger class +- Signal ordering +- Event classes +- extract_signal """ import pytest @@ -17,6 +22,9 @@ def test_parse_signal_tree_simple_hierarchy(): + """ + parse_signal_tree on a simple example + """ # Define lines simulating a TRIGGERS file with two levels lines = [ "OK\n", @@ -29,7 +37,34 @@ def test_parse_signal_tree_simple_hierarchy(): assert tree["OK"] == {"acceptable": None, "all_good": None} +def test_parse_signal_tree_complex_hierarchy(): + """ + parse_signal_tree on a more complicated example + """ + lines = [ + "OK\n", + " acceptable\n", + " all_good\n", + " great\n", + " perfect\n", + " just_good\n", + "BAD\n", + " fixable\n", + " unfixable\n", + " catastrophic\n" + ] + tree = parse_signal_tree_lines(lines, tabsize=4) + # Root key exists and children simplify to None leaves + assert "OK" in tree + assert tree["OK"] == {"acceptable": None, "all_good": {"great": {"perfect": None}, "just_good": None}} + assert "BAD" in tree + assert tree["BAD"] == {"fixable": None, "unfixable": {"catastrophic": None}} + def test_parse_signal_tree_invalid_varname(): + """ + parse_signal_tree but when the input has an + invalid name + """ # Names must be valid Python identifiers lines = [ "123invalid\n" @@ -40,6 +75,9 @@ def test_parse_signal_tree_invalid_varname(): def test_parse_signal_tree_inconsistent_indent(): + """ + Indentation syntax error of the parse_signal_tree + """ # Indents must follow previous levels exactly or match existing levels lines = [ "OK\n", @@ -52,7 +90,9 @@ def test_parse_signal_tree_inconsistent_indent(): def test_parse_signal_tree_duplicate_name(): - # Duplicate names at same indent level are disallowed + """ + Duplicate names at same indent level are disallowed + """ lines = [ "OK\n", " acceptable\n", @@ -63,19 +103,25 @@ def test_parse_signal_tree_duplicate_name(): assert "duplicate signal name" in str(excinfo.value) +#pylint:disable=no-member def test_load_triggers_with_json_dict(): + """ + loading triggers from a json_dict + has all the information provided + in the constructed Trigger + """ # Provide a nested dict directly to load_triggers json_dict = {"root": {"child": None}} Trigger = load_triggers(json_dict=json_dict) # Ensure attributes and path mappings are correct assert hasattr(Trigger, "root") and hasattr(Trigger, "child") - assert Trigger.root.path == (0,) - assert Trigger.child.path == (0, 0) - assert Trigger.name_of(0, 0) == "child" + assert Trigger.root.path == (0,) # pyright: ignore[reportAttributeAccessIssue] + assert Trigger.child.path == (0, 0) # pyright: ignore[reportAttributeAccessIssue] + assert Trigger.name_of(0, 0) == "child" # pyright: ignore[reportAttributeAccessIssue] def test_load_triggers_reserved_keyword(): - # Reserved Python keyword names should raise an error + """Reserved Python keyword names should raise an error""" json_dict = {"class": None} with pytest.raises(ValueError) as excinfo: load_triggers(json_dict=json_dict) @@ -83,16 +129,20 @@ def test_load_triggers_reserved_keyword(): def test_signal_comparison_and_properties(): + """ + Comparison of signals based on ancestry relation + """ # Simple hierarchy: A -> B (implemented so that ancestor > descendant) json_dict = {"A": {"B": None}} Trigger = load_triggers(json_dict=json_dict) - sigA, sigB = Trigger.A, Trigger.B + #pylint:disable=invalid-name + sigA, sigB = Trigger.A, Trigger.B # pyright: ignore[reportAttributeAccessIssue] # A parent signal compares greater than its child assert sigA > sigB assert sigB < sigA # Equality and hashing based on path - assert sigA == Trigger.A - assert hash(sigA) == hash(Trigger.A) + assert sigA == Trigger.A # pyright: ignore[reportAttributeAccessIssue] + assert hash(sigA) == hash(Trigger.A) # pyright: ignore[reportAttributeAccessIssue] # repr, name, path assert repr(sigA) == "" assert sigA.name == "A" @@ -102,9 +152,13 @@ def test_signal_comparison_and_properties(): def test_event_and_action_classes_and_extract_signal(): + """ + extract_signal + """ # Instantiate via a single-signal trigger Trigger = load_triggers(json_dict={"X": None}) - sigX = Trigger.X + #pylint:disable=invalid-name + sigX : Signal = Trigger.X # pyright: ignore[reportAttributeAccessIssue] move_evt = Move(sigX) stay_evt = Stay(sigX) test_evt = Test(sigX) @@ -120,4 +174,4 @@ def test_event_and_action_classes_and_extract_signal(): assert extract_signal(sigX) is sigX assert extract_signal(None) is None with pytest.raises(TypeError): - extract_signal(123) \ No newline at end of file + extract_signal(123)