From 8d2e30b3a48ef0be83637852f4962b4feb20345a Mon Sep 17 00:00:00 2001 From: Cobord Date: Tue, 10 Feb 2026 15:17:12 -0500 Subject: [PATCH 01/11] which errors can occur where --- summoner/protocol/triggers.py | 41 ++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/summoner/protocol/triggers.py b/summoner/protocol/triggers.py index d2f55f5..b8f643c 100644 --- a/summoner/protocol/triggers.py +++ b/summoner/protocol/triggers.py @@ -130,7 +130,7 @@ def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any return root -def parse_signal_tree(filepath: str, tabsize: int = 8) -> dict[str, Any]: +def parse_signal_tree(filepath: Path | str, tabsize: int = 8) -> dict[str, Any]: """ Read a file and parse it into a nested dict tree. This is the second entry point, for file-based input. @@ -272,19 +272,30 @@ def load_triggers( `json_dict` must match the nested structure output by `parse_signal_tree_lines`. Raises FileNotFoundError with clear message if file is missing. + Raises ValueError with message as in `parse_signal_tree_lines` and `build_triggers` + if any of the lines in the text or the file are malformed. """ - try: - if text is not None: - lines = text.splitlines() - tree = parse_signal_tree_lines(lines) - elif json_dict is not None: - tree = dict(json_dict) # shallow copy to avoid mutation - else: - path = WORKING_DIR / triggers_file - tree = parse_signal_tree(path) - except FileNotFoundError as e: - raise FileNotFoundError( - f"Could not find triggers file at {path if 'path' in locals() else ''}" - ) from e + if text is not None: + lines = text.splitlines() + # This below can raise ValueError + tree = parse_signal_tree_lines(lines) + elif json_dict is not None: + tree = dict(json_dict) + else: + path = "" + try: + if triggers_file is not None: + # In this case triggers_file was either provided as str + # or left out completely and the default is being used. + path = WORKING_DIR / triggers_file + else: + # In this case triggers_file was explicitly provided as None + path = WORKING_DIR / "TRIGGERS" + raise FileNotFoundError("no triggers file and weren't using the default either") + # This below can raise ValueError or FileNotFoundError + tree =parse_signal_tree(path) + except FileNotFoundError as e: + raise FileNotFoundError( + f"Could not find triggers file at {path if 'path' in locals() else ''}" + ) from e return build_triggers(tree) - From 13db25b5f85784bda69219256c906537cbd28bbe Mon Sep 17 00:00:00 2001 From: Cobord Date: Tue, 10 Feb 2026 19:39:51 -0500 Subject: [PATCH 02/11] organizing the many hyperparameters. with the goal of being possible to reduce the number of fields in SummonerClient, especially in light of the slight mismatches between how things are called in config_dicts vs as fields of the client --- summoner/client/client.py | 12 ++- .../utils/client_hyperparameter_configs.py | 59 ++++++++++ summoner/utils/client_reconnect_configs.py | 62 +++++++++++ summoner/utils/client_sender_configs.py | 90 ++++++++++++++++ tests/test_config.py | 102 ++++++++++++++++++ 5 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 summoner/utils/client_hyperparameter_configs.py create mode 100644 summoner/utils/client_reconnect_configs.py create mode 100644 summoner/utils/client_sender_configs.py create mode 100644 tests/test_config.py diff --git a/summoner/client/client.py b/summoner/client/client.py index 06e9a9c..64de891 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -120,13 +120,23 @@ def __init__(self, name: Optional[str] = None): self._upload_states: Optional[Callable[[Any], Awaitable]] = None self._download_states: Optional[Callable[[Any], Awaitable]] = None + # Sender HyperParameters self.event_bridge_maxsize = None self.max_concurrent_workers = None # Limit the sending rate (will use 50 if None is given) self.send_queue_maxsize = None + self.batch_drain = None + # self.max_consecutive_worker_errors is unbound until _apply_config + + # Receiver HyperParameters self.max_bytes_per_line = None self.read_timeout_seconds = None # None is prefered + + # Reconnction HyperParameters self.retry_delay_seconds = None - self.batch_drain = None + # self.primary_retry_limit is unbound until _apply_config + # self.default_host is unbound until _apply_config + # self.default_port is unbound until _apply_config + # self.default_retry_limit is unbound until _apply_config # Pass Event information from the receiving end to the sending end self.event_bridge: Optional[asyncio.Queue[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]]] = None diff --git a/summoner/utils/client_hyperparameter_configs.py b/summoner/utils/client_hyperparameter_configs.py new file mode 100644 index 0000000..abc6f17 --- /dev/null +++ b/summoner/utils/client_hyperparameter_configs.py @@ -0,0 +1,59 @@ +""" +In `client.py` config["hyper_parameters"]["sender"] +has several logic steps. Isolate that out here. +""" + +from dataclasses import dataclass +from typing import List, Optional + +from summoner.utils.client_reconnect_configs import ReconnectConfig +from summoner.utils.client_sender_configs import SenderConfig + + +class HyperparameterConfig: + sender_config: SenderConfig + reconnect_config: ReconnectConfig + max_bytes_per_line: int + read_timeout_seconds: Optional[float] + + def __init__(self, + sender_config: SenderConfig, + reconnect_config: ReconnectConfig, + max_bytes_per_line: int, + read_timeout_seconds: Optional[float] + ): + self.sender_config = sender_config + self.reconnect_config = reconnect_config + self.max_bytes_per_line = max_bytes_per_line + self.read_timeout_seconds = read_timeout_seconds + + def __post_init__(self): + if self.max_bytes_per_line <= 0:\ + raise ValueError("The provided max_bytes_per_line must be an integer ≥ 1") + if self.read_timeout_seconds is not None and self.read_timeout_seconds <= 0.0:\ + raise ValueError("The provided read_timeout_seconds must be an float ≥ 0.0 if provided") + + def merge_in(self, **kwargs) -> List[str]: + all_problems = [] + all_problems.extend(self.sender_config.merge_in(**kwargs["sender"])) + all_problems.extend(self.reconnect_config.merge_in(**kwargs["reconnection"])) + receiver_cfg = kwargs["receiver"] + if "max_bytes_per_line" in receiver_cfg: + if not isinstance((z := receiver_cfg["max_bytes_per_line"]), int): + all_problems.append(f"The provided max_bytes_per_line was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided max_bytes_per_line must be an integer ≥ 1") + else: + self.max_bytes_per_line = z + if "read_timeout_seconds" in receiver_cfg: + if not isinstance((z := receiver_cfg["read_timeout_seconds"]), float | int | None): + all_problems.append(f"The provided read_timeout_seconds was not an optional integer or a float. It was {type(z)}") + else: + if z is None: + self.read_timeout_seconds = None + elif z < 0.0: + all_problems.append(f"The provided read_timeout_seconds was negative. It was {z}") + else: + self.read_timeout_seconds = z + return all_problems diff --git a/summoner/utils/client_reconnect_configs.py b/summoner/utils/client_reconnect_configs.py new file mode 100644 index 0000000..360a314 --- /dev/null +++ b/summoner/utils/client_reconnect_configs.py @@ -0,0 +1,62 @@ +""" +In `client.py` config["hyper_parameters"]["reconnection"] +has several logic steps. Isolate that out here +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass(slots=True) +class ReconnectConfig: + retry_delay_seconds: float + primary_retry_limit: int + default_host: Optional[str] + default_port: Optional[int] + default_retry_limit: int + + def __post_init__(self): + assert self.retry_delay_seconds >= 0.0,\ + "The provided retry_delay_seconds must be a float ≥ 0.0" + assert self.primary_retry_limit >= 0,\ + "The provided primary_retry_limit must be an integer ≥ 0" + assert self.default_retry_limit >= 0,\ + "The provided default_retry_limit must be an integer ≥ 0" + + def merge_in(self, **kwargs) -> List[str]: + all_problems = [] + if "retry_delay_seconds" in kwargs: + if not isinstance((z := kwargs["retry_delay_seconds"]), float | int): + all_problems.append(f"The provided retry_delay_seconds was not an integer or a float. It was {type(z)}") + else: + if z < 0.0: + all_problems.append(f"The provided retry_delay_seconds was negative. It was {z}") + else: + self.retry_delay_seconds = z + if "primary_retry_limit" in kwargs: + if not isinstance((z := kwargs["primary_retry_limit"]), int): + all_problems.append(f"The provided primary_retry_limit was not an integer. It was {type(z)}") + else: + if z < 0: + all_problems.append(f"The provided primary_retry_limit was negative. It was {z}") + else: + self.primary_retry_limit = int(z) + if "default_host" in kwargs: + if isinstance((z := kwargs["default_host"]), str | None): + self.default_host = z + else: + all_problems.append(f"The provided default_host was not an optional string. It was {type(z)}") + if "default_port" in kwargs: + if not isinstance((z := kwargs["default_port"]), int | None): + all_problems.append(f"The provided default_port was not an optional integer. It was {type(z)}") + else: + self.default_port = z + if "default_retry_limit" in kwargs: + if not isinstance((z := kwargs["default_retry_limit"]), int): + all_problems.append(f"The provided default_retry_limit was not an integer. It was {type(z)}") + else: + if z < 0: + all_problems.append(f"The provided default_retry_limit was negative. It was {z}") + else: + self.default_retry_limit = z + return all_problems diff --git a/summoner/utils/client_sender_configs.py b/summoner/utils/client_sender_configs.py new file mode 100644 index 0000000..8f038c4 --- /dev/null +++ b/summoner/utils/client_sender_configs.py @@ -0,0 +1,90 @@ +""" +In `client.py` config["hyper_parameters"]["sender"] +has several logic steps. Isolate that out here. +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass(slots=True) +class SenderConfig: + """ + The __post_init__ + enforces the strict requirements beyond types + of positivity of the integers. + Also the `merge_in` of dictionaries from json configurations + will not maintain this property and give a list of strings + for the logger to output as errors. + """ + max_concurrent_workers: int + batch_drain: bool + send_queue_maxsize: int + event_bridge_maxsize: Optional[int] + max_consecutive_worker_errors: int + + def __post_init__(self): + if self.max_concurrent_workers <= 0:\ + raise ValueError("The provided max_concurrent_workers must be an integer ≥ 1") + if self.send_queue_maxsize <= 0:\ + raise ValueError("The provided send_queue_maxsize must be an integer ≥ 1") + if self.max_consecutive_worker_errors <= 0:\ + raise ValueError("The provided max_consecutive_worker_errors must be an integer ≥ 1") + + def merge_in(self, **kwargs) -> List[str]: + """ + The logic of _apply_config that is relevant to the sender_cfg. + In addition rather than give ValueError with only one thing that is wrong, + if there are multiple problems with the config, then the error should show them all. + """ + all_problems = [] + if "concurrency_limit" in kwargs: + if not isinstance((z := kwargs["concurrency_limit"]), int): + all_problems.append(f"The provided concurrency_limit was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided concurrency_limit must be an integer ≥ 1") + else: + self.max_concurrent_workers = z + if "batch_drain" in kwargs: + if not isinstance((z := kwargs["batch_drain"]), bool | None): + all_problems.append(f"The provided batch_drain was not an optional bool. It was {type(z)}") + else: + if z is None: + self.batch_drain = False + else: + self.batch_drain = z + if "queue_maxsize" in kwargs: + if not isinstance((z := kwargs["queue_maxsize"]), int): + all_problems.append(f"The provided queue_maxsize was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided queue_maxsize must be an integer ≥ 1") + else: + self.send_queue_maxsize = z + if "event_bridge_maxsize" in kwargs: + if not isinstance((z := kwargs["event_bridge_maxsize"]), int | None): + all_problems.append(f"The provided event_bridge_maxsize was not an optional int. It was {type(z)}") + else: + if z is None: + self.event_bridge_maxsize = None + else: + self.event_bridge_maxsize = z + if "max_worker_errors" in kwargs: + if not isinstance((z := kwargs["max_worker_errors"]), int): + all_problems.append(f"The provided max_worker_errors was not an integer. It was {type(z)}") + else: + if z <= 0: + all_problems.append("The provided max_worker_errors must be an integer ≥ 1") + else: + self.max_consecutive_worker_errors = z + return all_problems + + def throttle_warning(self) -> Optional[str]: + """ + The configuration is valid at all times by construction. + However, this is valid but merits a warning. + """ + if self.send_queue_maxsize < self.max_concurrent_workers: + return f"queue_maxsize < concurrency_limit; back-pressure will throttle producers at {self.send_queue_maxsize}" + return None diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..d01f370 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,102 @@ +""" +Tests for manipulating client_config +""" + +from pathlib import Path +import sys +from summoner.utils.json_handlers import load_config +from summoner.utils.client_reconnect_configs import ReconnectConfig +from summoner.utils.client_sender_configs import SenderConfig +from summoner.utils.client_hyperparameter_configs import HyperparameterConfig + +def test_template_client_reconnection_config(): + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + try: + just_reconnection = big_config["hyper_parameters"]["reconnection"] + except KeyError as e: + raise KeyError(f"Had this: {big_config}") from e + reconnect_config = ReconnectConfig( + retry_delay_seconds=85.38, + primary_retry_limit=4380, + default_host="junk", + default_port = 390359, + default_retry_limit=593090, + ) + problems = reconnect_config.merge_in(**just_reconnection) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert reconnect_config.retry_delay_seconds == just_reconnection["retry_delay_seconds"] + assert reconnect_config.primary_retry_limit == just_reconnection["primary_retry_limit"] + assert reconnect_config.default_host == just_reconnection["default_host"] + assert reconnect_config.default_port == just_reconnection["default_port"] + assert reconnect_config.default_retry_limit == just_reconnection["default_retry_limit"] + +def test_template_client_server_config(): + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + try: + just_sender = big_config["hyper_parameters"]["sender"] + except KeyError as e: + raise KeyError(f"Had this: {big_config}") from e + sender_config = SenderConfig( + max_concurrent_workers = 1, + batch_drain = False, + send_queue_maxsize = 100, + event_bridge_maxsize = None, + max_consecutive_worker_errors = 5, + ) + problems = sender_config.merge_in(**just_sender) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert sender_config.max_concurrent_workers == just_sender["concurrency_limit"] + assert sender_config.batch_drain == just_sender["batch_drain"] + assert sender_config.send_queue_maxsize == just_sender["queue_maxsize"] + assert sender_config.event_bridge_maxsize == just_sender["event_bridge_maxsize"] + assert sender_config.max_consecutive_worker_errors == just_sender["max_worker_errors"] + +def test_combined(): + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + template_hyperparameter_config = big_config["hyper_parameters"] + sender_config = SenderConfig( + max_concurrent_workers = 1, + batch_drain = False, + send_queue_maxsize = 100, + event_bridge_maxsize = None, + max_consecutive_worker_errors = 5, + ) + reconnect_config = ReconnectConfig( + retry_delay_seconds=85.38, + primary_retry_limit=4380, + default_host="junk", + default_port = 390359, + default_retry_limit=593090, + ) + + hyperparameter_config = HyperparameterConfig( + sender_config, + reconnect_config, + max_bytes_per_line = 1024*64, + read_timeout_seconds = None, + ) + problems = hyperparameter_config.merge_in(**template_hyperparameter_config) + assert len(problems) == 0, \ + "The template was supposed to be a well formed client configuration. So it should have had no problems." + assert hyperparameter_config.sender_config.max_concurrent_workers == template_hyperparameter_config["sender"]["concurrency_limit"] + assert hyperparameter_config.sender_config.batch_drain == template_hyperparameter_config["sender"]["batch_drain"] + assert hyperparameter_config.sender_config.send_queue_maxsize == template_hyperparameter_config["sender"]["queue_maxsize"] + assert hyperparameter_config.sender_config.event_bridge_maxsize == template_hyperparameter_config["sender"]["event_bridge_maxsize"] + assert hyperparameter_config.sender_config.max_consecutive_worker_errors == template_hyperparameter_config["sender"]["max_worker_errors"] + + assert hyperparameter_config.reconnect_config.retry_delay_seconds == template_hyperparameter_config["reconnection"]["retry_delay_seconds"] + assert hyperparameter_config.reconnect_config.primary_retry_limit == template_hyperparameter_config["reconnection"]["primary_retry_limit"] + assert hyperparameter_config.reconnect_config.default_host == template_hyperparameter_config["reconnection"]["default_host"] + assert hyperparameter_config.reconnect_config.default_port == template_hyperparameter_config["reconnection"]["default_port"] + assert hyperparameter_config.reconnect_config.default_retry_limit == template_hyperparameter_config["reconnection"]["default_retry_limit"] + + assert hyperparameter_config.max_bytes_per_line == template_hyperparameter_config["receiver"]["max_bytes_per_line"] + assert hyperparameter_config.read_timeout_seconds == template_hyperparameter_config["receiver"]["read_timeout_seconds"] From 3e12c8456f5eedeba93bf8941a3026dce723796a Mon Sep 17 00:00:00 2001 From: Cobord Date: Wed, 18 Feb 2026 22:17:39 -0500 Subject: [PATCH 03/11] undifferentiated cleanups, mostly suppressing until can actually fix the underlying issues or document why they are not a problem --- summoner/client/client.py | 134 ++++++++++-------- summoner/logger.py | 26 +++- summoner/protocol/flow.py | 6 +- summoner/protocol/process.py | 21 ++- summoner/settings.py | 3 + summoner/utils/addr_handlers.py | 10 +- .../utils/client_hyperparameter_configs.py | 7 +- summoner/utils/client_reconnect_configs.py | 8 +- tests/test_config.py | 30 +++- tests/test_flow.py | 50 ++++++- tests/test_process.py | 105 ++++++++++---- tests/test_triggers.py | 53 +++++-- 12 files changed, 317 insertions(+), 136 deletions(-) diff --git a/summoner/client/client.py b/summoner/client/client.py index 64de891..c6a109a 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -2,12 +2,14 @@ import sys import json from typing import ( + Dict, Optional, Callable, Union, Awaitable, Any, Type, + cast, ) import asyncio import signal @@ -15,6 +17,8 @@ from collections import defaultdict import platform +from summoner.utils.client_hyperparameter_configs import TIMEOUT_TYPE + target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) @@ -121,18 +125,18 @@ def __init__(self, name: Optional[str] = None): self._download_states: Optional[Callable[[Any], Awaitable]] = None # Sender HyperParameters - self.event_bridge_maxsize = None - self.max_concurrent_workers = None # Limit the sending rate (will use 50 if None is given) - self.send_queue_maxsize = None + self.event_bridge_maxsize : Optional[int] = None + self.max_concurrent_workers : Optional[int] = None # Limit the sending rate (will use 50 if None is given) + self.send_queue_maxsize : Optional[int] = None self.batch_drain = None # self.max_consecutive_worker_errors is unbound until _apply_config # Receiver HyperParameters - self.max_bytes_per_line = None - self.read_timeout_seconds = None # None is prefered + self.max_bytes_per_line : Optional[int] = None + self.read_timeout_seconds : Optional[float] = None # None is prefered # Reconnction HyperParameters - self.retry_delay_seconds = None + self.retry_delay_seconds : Optional[float] = None # self.primary_retry_limit is unbound until _apply_config # self.default_host is unbound until _apply_config # self.default_port is unbound until _apply_config @@ -163,14 +167,13 @@ def __init__(self, name: Optional[str] = None): # ==== VERSION SPECIFIC ==== def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): + self.host = config.get("host") # default is None # pyright: ignore[reportAttributeAccessIssue] + self.port = config.get("port") # default is None # pyright: ignore[reportAttributeAccessIssue] - self.host = config.get("host") # default is None - self.port = config.get("port") # default is None - - logger_cfg = config.get("logger", {}) - configure_logger(self.logger, logger_cfg) + logger_cfg = cast(Dict[str,Any],config.get("logger", {})) + configure_logger(self.logger, logger_cfg) - hp_config = config.get("hyper_parameters", {}) + hp_config = cast(Dict[str,Any],config.get("hyper_parameters", {})) reconn_cfg = hp_config.get("reconnection", {}) self.retry_delay_seconds = reconn_cfg.get("retry_delay_seconds", self.DEFAULT_RETRY_DELAY) @@ -239,8 +242,8 @@ def decorator(fn: Callable[[], Awaitable]): _check_param_and_return( fn, decorator_name="@upload_states", - allow_param=(type(None), str, dict, Any), # the payload - allow_return=(type(None), str, Any, Node, list, dict, + allow_param=(type(None), str, dict, Any), # the payload # type: ignore + allow_return=(type(None), str, Any, Node, list, dict, # type: ignore list[str], dict[str, str], dict[str, list[str]], list[Node], dict[str, Node], dict[str, list[Node]], dict[str, Union[str, list[str]]], @@ -262,7 +265,7 @@ def decorator(fn: Callable[[], Awaitable]): "source": inspect.getsource(fn), } - self._upload_states = fn + self._upload_states = fn # type: ignore return fn @@ -283,7 +286,7 @@ def decorator(fn: Callable[[Any], Awaitable]): _check_param_and_return( fn, decorator_name="@download_states", - allow_param=(type(None), Node, Any, list, dict, + allow_param=(type(None), Node, Any, list, dict, # type: ignore list[Node], dict[str, Node], dict[str, list[Node]], @@ -292,7 +295,7 @@ def decorator(fn: Callable[[Any], Awaitable]): dict[Optional[str], list[Node]], dict[Optional[str], Union[Node, list[Node]]], ), - allow_return=(type(None), Any), + allow_return=(type(None), Any), # type: ignore logger=self.logger, ) @@ -325,11 +328,11 @@ def _schedule_registration(self, register_coro: Awaitable): """ if self.loop.is_running(): def _cb(): - task = self.loop.create_task(register_coro) + task = self.loop.create_task(register_coro) # type: ignore self._registration_tasks.append(task) self.loop.call_soon_threadsafe(_cb) else: - task = self.loop.create_task(register_coro) + task = self.loop.create_task(register_coro) # type: ignore self._registration_tasks.append(task) # ==== HOOK REGISTRATION ==== @@ -349,8 +352,8 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic _check_param_and_return( fn, decorator_name="@hook", - allow_param=(Any, str, dict), - allow_return=(type(None), str, dict, Any), + allow_param=(Any, str, dict), # type: ignore + allow_return=(type(None), str, dict, Any), # type: ignore logger=self.logger, ) @@ -376,9 +379,9 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic async def register(): async with self.hooks_lock: if direction == Direction.RECEIVE: - self.receiving_hooks[tuple_priority] = fn + self.receiving_hooks[tuple_priority] = fn # type: ignore elif direction == Direction.SEND: - self.sending_hooks[tuple_priority] = fn + self.sending_hooks[tuple_priority] = fn # type: ignore # ----[ Safe Registration ]---- # NOTE: register() is run ASAP and _registration_tasks is used to wait all registrations before run_client() @@ -410,8 +413,8 @@ def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): _check_param_and_return( fn, decorator_name="@receive", - allow_param=(Any, str, dict), - allow_return=(type(None), Event, Any), + allow_param=(Any, str, dict), # type: ignore + allow_return=(type(None), Event, Any), # type: ignore logger=self.logger, ) @@ -492,7 +495,7 @@ def decorator(fn: Callable[[], Awaitable]): fn, decorator_name="@send", allow_param=(), # no args allowed - allow_return=(type(None), Any, str, dict), + allow_return=(type(None), Any, str, dict), # type: ignore logger=self.logger, ) else: @@ -500,7 +503,7 @@ def decorator(fn: Callable[[], Awaitable]): fn, decorator_name="@send[multi=True]", allow_param=(), # no args allowed - allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), + allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), # type: ignore logger=self.logger, ) @@ -831,7 +834,7 @@ def dna(self, include_context: bool = False) -> str: continue # Names referenced by the function body. - names_to_scan = set(getattr(fn, "__code__", None).co_names if hasattr(fn, "__code__") else ()) + names_to_scan = set(getattr(fn, "__code__", None).co_names if hasattr(fn, "__code__") else ()) # type: ignore # Names referenced only via annotations. try: @@ -840,7 +843,7 @@ def dna(self, include_context: bool = False) -> str: if isinstance(v, type) or inspect.isfunction(v) or inspect.ismodule(v): nm = getattr(v, "__name__", None) if isinstance(nm, str) and nm: - names_to_scan.add(nm) + names_to_scan.add(nm) # type: ignore except Exception: pass @@ -997,7 +1000,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: if not receiver_index: data = await self._read_line_safe( reader, - limit=self.max_bytes_per_line, + limit=self.max_bytes_per_line, # type: ignore timeout=0.1, ) # if not data: @@ -1011,7 +1014,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: data = await self._read_line_safe( reader, - limit=self.max_bytes_per_line, + limit=self.max_bytes_per_line, # type: ignore timeout=self.read_timeout_seconds, ) # data = await reader.readline() @@ -1026,10 +1029,10 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: for priority, receiving_hook in sorted(receiving_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: - new_payload = await receiving_hook(payload) + new_payload = await receiving_hook(payload) # type: ignore if new_payload is None: - payload = None + payload = None # type: ignore break except Exception as e: @@ -1050,7 +1053,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: if self._flow.in_use: raw_states = (await self._upload_states(payload)) if self._upload_states is not None else None tape = StateTape(raw_states) - activation_index = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) + activation_index = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) # type: ignore batches = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} else: for _, receiver in receiver_index.items(): @@ -1079,27 +1082,35 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # ----[ After: Handle Returns ]---- if self._flow.in_use: - activations = activation_index[priority] + activations = activation_index[priority] # type: ignore - local_tape = tape.refresh() + local_tape = tape.refresh() # type: ignore to_extend: dict[str, list[Node]] = defaultdict(list) for act, event in zip(activations, events): - to_extend[act.key].extend(act.route.activated_nodes(event)) + if act.key is None: + # Repeated the code in both branches even though this branch + # even though in this branch the key is not what it was supposed to be + # It was stated to have string keys when initializing as the defaultdict above + to_extend[act.key].extend(act.route.activated_nodes(event)) # type: ignore + else: + to_extend[act.key].extend(act.route.activated_nodes(event)) local_tape.extend(to_extend) buffer_entries = [(act.key, act.route, event) for act, event in zip(activations, events)] - event_buffer[priority].extend(buffer_entries) + # event buffer is bound because this is within if self._flow_in_use + # some of the event's in buffer_entries could have been None + event_buffer[priority].extend(buffer_entries) # type: ignore if self._download_states is not None: await self._download_states(local_tape.revert()) # ----[ Final: Pass Data Over To Senders ]---- if self._flow.in_use: - - for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): + # event buffer is bound because this is within if self._flow_in_use + for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): # type: ignore for event_data in event_list: # this will block if the bridge is full, slowing down readers - await self.event_bridge.put((priority,) + event_data) + await self.event_bridge.put((priority,) + event_data) # type: ignore event_buffer = {} @@ -1124,6 +1135,8 @@ def _start_send_workers( stop_event: asyncio.Event ): if not self.send_workers_started: + if self.max_concurrent_workers is None: + raise ValueError("_apply_config will make sure that the maximum number of workers is set to an integer ≥ 1") for _ in range(self.max_concurrent_workers): worker_task = self.loop.create_task(self._send_worker(writer, stop_event)) self.worker_tasks.append(worker_task) @@ -1138,9 +1151,9 @@ async def _send_worker( while True: - item: Optional[tuple[str, Sender]] = await self.send_queue.get() + item: Optional[tuple[str, Sender]] = await self.send_queue.get() # type: ignore if item is None: - self.send_queue.task_done() + self.send_queue.task_done() # type: ignore break route, sender = item @@ -1169,7 +1182,7 @@ async def _send_worker( for priority, sending_hook in sorted(sending_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: - new_payload = await sending_hook(payload) + new_payload = await sending_hook(payload) # type: ignore if new_payload is None: payload = None @@ -1220,7 +1233,7 @@ async def _send_worker( break finally: - self.send_queue.task_done() + self.send_queue.task_done() # type: ignore async def _cleanup_workers(self): for w in self.worker_tasks: @@ -1266,7 +1279,7 @@ def _route_accepts( pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]] = [] try: while True: - pending.append(self.event_bridge.get_nowait()) + pending.append(self.event_bridge.get_nowait()) # type: ignore except asyncio.QueueEmpty: pass @@ -1289,12 +1302,12 @@ def _route_accepts( elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or (sender.triggers and isinstance(sender.triggers, set))): - sender_parsed_route = sender_parsed_routes.get(route) + sender_parsed_route = sender_parsed_routes.get(route) # type: ignore if sender_parsed_route is None: continue # Iterate pending in queue order; first match "wins" for this (route,key,fn_name) - for (priority, key, parsed_route, event) in pending: + for (priority, key, parsed_route, event) in pending: # type: ignore if _route_accepts(sender_parsed_route, parsed_route) and sender.responds_to(event): dedup_key = (route, key, sender.fn.__name__) # key scopes to the activation thread/peer if dedup_key not in emitted: @@ -1307,21 +1320,21 @@ def _route_accepts( await asyncio.sleep(0.1) # Time continue else: - queue_size = self.send_queue.qsize() + queue_size = self.send_queue.qsize() # type: ignore expected_queue_size = queue_size + len(senders) - if expected_queue_size > self.send_queue_maxsize * 0.8: # 80% full + if expected_queue_size > self.send_queue_maxsize * 0.8: # type: ignore # 80% full self.logger.warning(f"Queue is about to exceed 80% its capacity; Attempted load size: {expected_queue_size} out of {self.send_queue_maxsize}") # ----[ Enqueue Sender Batch | Senders Are Run in Background ]---- try: for sender in senders: - await self.send_queue.put(sender) # Will block if full (i.e., back-pressure) + await self.send_queue.put(sender) # type: ignore # Will block if full (i.e., back-pressure) except asyncio.CancelledError: self.logger.info("Sender enqueue loop cancelled mid-batch.") raise # ----[ Wait for Sender Batch to Finish]---- - await self.send_queue.join() + await self.send_queue.join() # type: ignore if self.batch_drain: async with self.writer_lock: @@ -1340,6 +1353,8 @@ def _route_accepts( finally: # Best-effort signal to workers; never block on shutdown if self.send_queue is not None: + if self.max_concurrent_workers is None: + raise ValueError("_apply_config will make sure that the maximum number of workers is set to an integer ≥ 1") # This may result in redundant cancellation if shutdown() is also called, # but guarantees all workers get signaled even in abrupt exits. for _ in range(self.max_concurrent_workers): @@ -1368,8 +1383,8 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Always clean up old worker tasks and queues before starting new session; guarantees a fresh worker batch and prevents zombie tasks. await self._cleanup_workers() - self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) - self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) + self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) # type: ignore + self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) # type: ignore # reset any previous travel/quit intent so each session starts fresh; # travel is only honored if set after this point, quit likewise @@ -1382,7 +1397,7 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Register this session's task so it can be cancelled during shutdown current_task = asyncio.current_task() async with self.tasks_lock: - self.active_tasks.add(current_task) + self.active_tasks.add(current_task) # type: ignore # Use lock when accessing dynamic routing information async with self.connection_lock: @@ -1444,8 +1459,8 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Deregister this session and its children from active tasks async with self.tasks_lock: - if task is not None: - self.active_tasks.discard(task) + if task is not None: # type: ignore + self.active_tasks.discard(task) # type: ignore # Check whether we should quit or loop back to travel to the next server (agent migration) async with self.connection_lock: @@ -1506,13 +1521,14 @@ async def _retry_loop(self, host, port, limit, stage = "Primary"): except (ConnectionRefusedError, ServerDisconnected, OSError) as e: attempts += 1 + sleep_time = self.retry_delay_seconds or self.DEFAULT_RETRY_DELAY self.logger.error( f"[{type(e).__name__}: {e}] " f"({stage}) retry {attempts} of " f"{limit if limit is not None else '∞'}; " - f"sleeping {self.retry_delay_seconds}s", + f"sleeping {sleep_time}s", ) - await asyncio.sleep(self.retry_delay_seconds) + await asyncio.sleep(sleep_time) # Check retry limit if (limit is not None and attempts >= limit): diff --git a/summoner/logger.py b/summoner/logger.py index 6cf594c..cb7e202 100644 --- a/summoner/logger.py +++ b/summoner/logger.py @@ -1,3 +1,8 @@ +""" +Handle logging specifics +to this context +like details of how things are consistently formatted +""" import sys import os import json @@ -7,12 +12,16 @@ from typing import Optional, Any from logging.handlers import RotatingFileHandler -from logging import Logger # This makes Logger importable from logger.py + +# This makes Logger importable from logger.py +#pylint:disable=unused-import +from logging import Logger # Log formatting style LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -LOG_FORMAT_CONSOLE = "\033[92m%(asctime)s\033[0m - \033[94m%(name)s\033[0m - %(levelname)s - %(message)s" +LOG_FORMAT_CONSOLE = \ + "\033[92m%(asctime)s\033[0m - \033[94m%(name)s\033[0m - %(levelname)s - %(message)s" class SafeStreamHandler(logging.StreamHandler): """ @@ -190,7 +199,8 @@ def configure_logger(logger: logging.Logger, logger_cfg: dict[str, Any]) -> None date_format = logger_cfg.get("date_format") console_handler = SafeStreamHandler(sys.stdout) - console_handler.setFormatter(TextFormatter(console_log_format, date_format, logger_cfg.get("log_keys"))) + console_handler.setFormatter( + TextFormatter(console_log_format, date_format, logger_cfg.get("log_keys"))) logger.addHandler(console_handler) # file @@ -208,9 +218,13 @@ def configure_logger(logger: logging.Logger, logger_cfg: dict[str, Any]) -> None log_format = logger_cfg.get("log_format", LOG_FORMAT) date_format = logger_cfg.get("date_format") - file_handler = RotatingFileHandler(path, maxBytes=max_file_size, backupCount=backup_count) + file_handler = RotatingFileHandler(path, + maxBytes=max_file_size, + backupCount=backup_count) if logger_cfg.get("enable_json_log", False): - file_handler.setFormatter(JsonFormatter(log_format, date_format, logger_cfg.get("log_keys"))) + file_handler.setFormatter( + JsonFormatter(log_format, date_format, logger_cfg.get("log_keys"))) else: - file_handler.setFormatter(TextFormatter(log_format, date_format, logger_cfg.get("log_keys"))) + file_handler.setFormatter( + TextFormatter(log_format, date_format, logger_cfg.get("log_keys"))) logger.addHandler(file_handler) diff --git a/summoner/protocol/flow.py b/summoner/protocol/flow.py index 59e789a..d9690be 100644 --- a/summoner/protocol/flow.py +++ b/summoner/protocol/flow.py @@ -1,10 +1,10 @@ from __future__ import annotations import re from collections.abc import Callable -from typing import Optional, Any +from typing import Iterable, Optional, Any from .triggers import load_triggers from .process import Node, ArrowStyle, ParsedRoute -from ._deprecation import deprecated +from ._deprecation import deprecated # type: ignore import warnings # variable names or commands used in flow transitions @@ -293,5 +293,5 @@ def parse_route(self, route: str) -> ParsedRoute: return self._parse_standalone(route) - def parse_routes(self, routes: list[str]) -> list[ParsedRoute]: + def parse_routes(self, routes: Iterable[str]) -> list[ParsedRoute]: return [self.parse_route(route=route) for route in routes] diff --git a/summoner/protocol/process.py b/summoner/protocol/process.py index c432e23..1687828 100644 --- a/summoner/protocol/process.py +++ b/summoner/protocol/process.py @@ -24,7 +24,7 @@ class Node: def __init__(self, expr: str) -> None: _expr: str = expr.strip() self.kind: str - self.values: Optional[tuple[str]] + self.values: Optional[tuple[str,...]] if _ALL_RE.fullmatch(_expr): self.kind = 'all' @@ -70,10 +70,16 @@ def __str__(self) -> str: if self.kind == 'all': return '/all' elif self.kind == 'plain': + if self.values is None: + return f"" return self.values[0] elif self.kind == 'not': + if self.values is None: + return f"" return f"/not({','.join(self.values)})" elif self.kind == 'oneof': + if self.values is None: + return f"" return f"/oneof({','.join(self.values)})" else: return f"" @@ -227,6 +233,7 @@ def __init__( self.style = style if self.is_arrow: + assert self.style is not None, "If an arrow it has a style" src = self.style.separator.join(str(n) for n in self.source) lab = self.style.separator.join(str(n) for n in self.label) tgt = self.style.separator.join(str(n) for n in self.target) @@ -441,9 +448,17 @@ def _assess_type(states: Any) -> Optional[TapeType]: def _add_prefix(self, key: str, with_prefix: bool = True) -> str: return f"{self.prefix}:{key}" if with_prefix else key + def _remove_str_prefix(self, key: str) -> str: + p = f"{self.prefix}:" + return key[len(p):] if key.startswith(p) else key + + # The type annotation here is incorrect, key=None gives None + # but making this Optional[str] in return + # is not the intended behavior and pollutes that possibility + # for the caller even when they are not passing key=None def _remove_prefix(self, key: Optional[str]) -> str: p = f"{self.prefix}:" - return key[len(p):] if isinstance(key, str) and key.startswith(p) else key + return key[len(p):] if isinstance(key, str) and key.startswith(p) else key # type: ignore def extend(self, states: Any): # Delegate to a local StateTape then merge @@ -468,7 +483,7 @@ def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: if self._type in (TapeType.INDEX_SINGLE, TapeType.INDEX_MANY): out_dict: dict[str, list[Node]] = {} for pk, seq in self.states.items(): - key = self._remove_prefix(pk) + key = self._remove_str_prefix(pk) out_dict.setdefault(key, []) out_dict[key].extend(seq) return out_dict diff --git a/summoner/settings.py b/summoner/settings.py index c2101b7..cd76817 100644 --- a/summoner/settings.py +++ b/summoner/settings.py @@ -1,3 +1,6 @@ +""" +setup information +""" # settings.py import os from dotenv import load_dotenv diff --git a/summoner/utils/addr_handlers.py b/summoner/utils/addr_handlers.py index ad85b96..d3f830b 100644 --- a/summoner/utils/addr_handlers.py +++ b/summoner/utils/addr_handlers.py @@ -1,3 +1,6 @@ +""" +Convert a socket peer address (or similar) to a compact, deterministic string. +""" from collections.abc import Iterable, Mapping from typing import Any import ipaddress @@ -51,7 +54,7 @@ def format_addr(addr: Any, max_items: int = 10) -> str: if isinstance(addr, Iterable): try: seq = list(addr) - except Exception: + except Exception: # pylint:disable=broad-exception-caught return repr(addr) # Common socket cases @@ -64,8 +67,7 @@ def format_addr(addr: Any, max_items: int = 10) -> str: if ip.version == 6: host_fmt = host if scopeid in (None, 0) else f"{host}%{scopeid}" return f"[{host_fmt}]:{port}" - else: - return f"{host}:{port}" + return f"{host}:{port}" except ValueError: # Not an IP literal; fall back to host:port return f"{host}:{port}" @@ -79,5 +81,5 @@ def format_addr(addr: Any, max_items: int = 10) -> str: # fallback try: return str(addr) - except Exception: + except Exception: # pylint:disable=broad-exception-caught return repr(addr) diff --git a/summoner/utils/client_hyperparameter_configs.py b/summoner/utils/client_hyperparameter_configs.py index abc6f17..e98ee90 100644 --- a/summoner/utils/client_hyperparameter_configs.py +++ b/summoner/utils/client_hyperparameter_configs.py @@ -4,17 +4,18 @@ """ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, TypeAlias -from summoner.utils.client_reconnect_configs import ReconnectConfig +from summoner.utils.client_reconnect_configs import ReconnectConfig, RETRY_DELAY_SECONDS_TYPE, PORT_TYPE from summoner.utils.client_sender_configs import SenderConfig +TIMEOUT_TYPE: TypeAlias = Optional[float] class HyperparameterConfig: sender_config: SenderConfig reconnect_config: ReconnectConfig max_bytes_per_line: int - read_timeout_seconds: Optional[float] + read_timeout_seconds: TIMEOUT_TYPE def __init__(self, sender_config: SenderConfig, diff --git a/summoner/utils/client_reconnect_configs.py b/summoner/utils/client_reconnect_configs.py index 360a314..231d9ae 100644 --- a/summoner/utils/client_reconnect_configs.py +++ b/summoner/utils/client_reconnect_configs.py @@ -4,15 +4,17 @@ """ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, TypeAlias +RETRY_DELAY_SECONDS_TYPE : TypeAlias = float +PORT_TYPE : TypeAlias = Optional[int] @dataclass(slots=True) class ReconnectConfig: - retry_delay_seconds: float + retry_delay_seconds: RETRY_DELAY_SECONDS_TYPE primary_retry_limit: int default_host: Optional[str] - default_port: Optional[int] + default_port: PORT_TYPE default_retry_limit: int def __post_init__(self): diff --git a/tests/test_config.py b/tests/test_config.py index d01f370..169266e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,15 +1,19 @@ """ Tests for manipulating client_config """ +#pylint:disable=line-too-long from pathlib import Path -import sys from summoner.utils.json_handlers import load_config from summoner.utils.client_reconnect_configs import ReconnectConfig from summoner.utils.client_sender_configs import SenderConfig from summoner.utils.client_hyperparameter_configs import HyperparameterConfig def test_template_client_reconnection_config(): + """ + Handling of reconnection hyperparameters part + of the config of client config json files + """ tempate_client_config = Path() / "templates" / "client_config.json" assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" big_config = load_config(str(tempate_client_config)) @@ -34,6 +38,10 @@ def test_template_client_reconnection_config(): assert reconnect_config.default_retry_limit == just_reconnection["default_retry_limit"] def test_template_client_server_config(): + """ + Handling of sender hyperparameters part + of the config of client config json files + """ tempate_client_config = Path() / "templates" / "client_config.json" assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" big_config = load_config(str(tempate_client_config)) @@ -58,10 +66,11 @@ def test_template_client_server_config(): assert sender_config.max_consecutive_worker_errors == just_sender["max_worker_errors"] def test_combined(): - tempate_client_config = Path() / "templates" / "client_config.json" - assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" - big_config = load_config(str(tempate_client_config)) - template_hyperparameter_config = big_config["hyper_parameters"] + """ + Handling of hyperparameters part + of the config of client config json files + """ + # Load up with the default's sender_config = SenderConfig( max_concurrent_workers = 1, batch_drain = False, @@ -83,6 +92,13 @@ def test_combined(): max_bytes_per_line = 1024*64, read_timeout_seconds = None, ) + + # The actual config from the file which will override the defaults + tempate_client_config = Path() / "templates" / "client_config.json" + assert tempate_client_config.is_file(), f"{tempate_client_config} is not a file" + big_config = load_config(str(tempate_client_config)) + template_hyperparameter_config = big_config["hyper_parameters"] + problems = hyperparameter_config.merge_in(**template_hyperparameter_config) assert len(problems) == 0, \ "The template was supposed to be a well formed client configuration. So it should have had no problems." @@ -91,12 +107,12 @@ def test_combined(): assert hyperparameter_config.sender_config.send_queue_maxsize == template_hyperparameter_config["sender"]["queue_maxsize"] assert hyperparameter_config.sender_config.event_bridge_maxsize == template_hyperparameter_config["sender"]["event_bridge_maxsize"] assert hyperparameter_config.sender_config.max_consecutive_worker_errors == template_hyperparameter_config["sender"]["max_worker_errors"] - + assert hyperparameter_config.reconnect_config.retry_delay_seconds == template_hyperparameter_config["reconnection"]["retry_delay_seconds"] assert hyperparameter_config.reconnect_config.primary_retry_limit == template_hyperparameter_config["reconnection"]["primary_retry_limit"] assert hyperparameter_config.reconnect_config.default_host == template_hyperparameter_config["reconnection"]["default_host"] assert hyperparameter_config.reconnect_config.default_port == template_hyperparameter_config["reconnection"]["default_port"] assert hyperparameter_config.reconnect_config.default_retry_limit == template_hyperparameter_config["reconnection"]["default_retry_limit"] - + assert hyperparameter_config.max_bytes_per_line == template_hyperparameter_config["receiver"]["max_bytes_per_line"] assert hyperparameter_config.read_timeout_seconds == template_hyperparameter_config["receiver"]["read_timeout_seconds"] diff --git a/tests/test_flow.py b/tests/test_flow.py index c8c0098..9ee7392 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -6,19 +6,22 @@ from summoner.protocol.flow import get_token_list, Flow from summoner.protocol.process import Node, ParsedRoute - @pytest.mark.parametrize("input_str, sep, expected", [ ("foo,bar(baz,qux),zap", ",", ["foo", "bar(baz,qux)", "zap"]), (" a , b ,c ", ",", ["a", "b", "c"]), ("one(two,three,four),five", ",", ["one(two,three,four)", "five"]), ]) def test_get_token_list(input_str, sep, expected): - # Only top-level separators should split + """ + Only top-level separators should split + """ assert get_token_list(input_str, sep) == expected def make_flow(): - # Helper to set up a Flow with two arrow styles + """ + Helper to set up a Flow with two arrow styles + """ flow = Flow() flow.activate() flow.add_arrow_style(stem="-", brackets=("[", "]"), separator=",", tip=">") @@ -27,24 +30,45 @@ def make_flow(): return flow -def test_parse_route_complete_labeled(): +def test_parse_route_complete_simple(): + """ + A simple ParsedRoute + """ flow = make_flow() - pr = flow.parse_route("A --> B") + pr : ParsedRoute = flow.parse_route("A --> B") # Expect source A, no label, target B assert pr.source == (Node("A"),) assert pr.label == () assert pr.target == (Node("B"),) assert pr.is_arrow - def test_parse_route_unlabeled_complete(): + """ + A simple ParsedRoute that is explicitly unlabeled + """ flow = make_flow() pr = flow.parse_route("X--[]-->Y") assert pr.source == (Node("X"),) assert pr.target == (Node("Y"),) + assert pr.label == () + assert pr.is_arrow +def test_parse_route_complete_labelled(): + """ + A simple ParsedRoute + """ + flow = make_flow() + pr = flow.parse_route("A --[e,f,g]--> B") + # Expect source A, (e,f,g) label, target B + assert pr.source == (Node("A"),) + assert pr.label == (Node("e"),Node("f"), Node("g")) + assert pr.target == (Node("B"),) + assert pr.is_arrow def test_parse_route_dangling_right(): + """ + A dangling right ParsedRoute + """ flow = make_flow() pr = flow.parse_route("-->B") # Dangling left: no source, no label @@ -54,6 +78,9 @@ def test_parse_route_dangling_right(): def test_parse_route_dangling_left(): + """ + A dangling left ParsedRoute + """ flow = make_flow() pr = flow.parse_route("A-->") assert pr.source == (Node("A"),) @@ -61,6 +88,9 @@ def test_parse_route_dangling_left(): def test_parse_route_standalone(): + """ + Standalone object ParsedRoute + """ flow = make_flow() pr = flow.parse_route("X, Y ,Z") # Comma-separated standalone objects @@ -69,16 +99,22 @@ def test_parse_route_standalone(): def test_parse_route_invalid_token(): + """ + An invalid target causes the creation of ParsedRoute + to fail + """ flow = make_flow() with pytest.raises(ValueError): flow.parse_route("A --> inv&alid") def test_parse_routes_list(): + """ + Make several ParsedRoutes + """ flow = make_flow() routes = ["A-->B", "C"] prs = flow.parse_routes(routes) assert isinstance(prs, list) assert prs[0] == flow.parse_route("A-->B") assert prs[1] == flow.parse_route("C") - \ No newline at end of file diff --git a/tests/test_process.py b/tests/test_process.py index 5478c14..4417056 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,6 +1,9 @@ """ Tests for process.py: Node, ArrowStyle, ParsedRoute, activated_nodes, StateTape, collect_activations """ +#pylint:disable=import-outside-toplevel, no-member, invalid-name + +from typing import Dict, List, Tuple, cast import pytest from summoner.protocol.process import ( @@ -17,6 +20,11 @@ def test_node_parsing_and_str_repr(): + """ + The node kinds can be plain, all, exclude from a specific set, + or be one of a specific set. + Invalid names raise errors. + """ # Plain token n = Node("foo") assert n.kind == "plain" and str(n) == "foo" @@ -25,10 +33,10 @@ def test_node_parsing_and_str_repr(): assert a.kind == "all" and str(a) == "/all" # /not(A,B) notn = Node("/not(A,B)") - assert notn.kind == "not" and set(notn.values) == {"A", "B"} + assert notn.kind == "not" and set(notn.values or ()) == {"A", "B"} # /oneof(X,Y) ony = Node("/oneof(X,Y)") - assert ony.kind == "oneof" and set(ony.values) == {"X", "Y"} + assert ony.kind == "oneof" and set(ony.values or ()) == {"X", "Y"} with pytest.raises(ValueError): Node("invalid(token") @@ -41,11 +49,17 @@ def test_node_parsing_and_str_repr(): (Node("/oneof(a,b)"), Node("b"), True), ]) def test_node_accepts_matrix(gate, state, expected): + """ + Gates accept or reject the given states + """ # Verify accepts logic across kinds assert gate.accepts(state) is expected def test_arrowstyle_valid_and_invalid(): + """ + Creation of ArrowStyle being valid or not + """ # Valid style style = ArrowStyle("-", ("[", "]"), ",", ">") assert style.stem == "-" @@ -64,6 +78,10 @@ def test_arrowstyle_valid_and_invalid(): def test_parsedroute_properties_and_repr(): + """ + The properties and regex for matching + of ParsedRoute + """ style = ArrowStyle("-", ("[", "]"), ",", ">") pr = ParsedRoute((Node("A"),), (), (Node("B"),), style) assert not pr.has_label and pr.is_arrow and not pr.is_object @@ -75,20 +93,29 @@ def test_parsedroute_properties_and_repr(): def test_activated_nodes_various_actions(): + """ + Check what activated nodes depending on + what selected trigger/event passed + """ style = ArrowStyle("-", ("[", "]"), ",", ">") pr = ParsedRoute((Node("A"),), (Node("L"),), (Node("B"),), style) # Build a dummy trigger/event to pass - from summoner.protocol.triggers import Move as MoveEvt, Test as TestEvt, Stay as StayEvt, load_triggers + from summoner.protocol.triggers import Move as MoveEvt, Test as TestEvt, \ + Stay as StayEvt Trigger = load_triggers(json_dict={"d": None}) - move_event = MoveEvt(Trigger.d) + move_event = MoveEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(move_event) == (Node("L"), Node("B")) - test_event = TestEvt(Trigger.d) + test_event = TestEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(test_event) == (Node("L"),) - stay_event = StayEvt(Trigger.d) + stay_event = StayEvt(Trigger.d) # pyright: ignore[reportAttributeAccessIssue] assert pr.activated_nodes(stay_event) == (Node("A"),) def test_statetape_and_revert_extend_refresh(): + """ + revert on StateTape gives correct results + and refreshing has the correct effect + """ # SINGLE tape1 = StateTape("X") assert tape1.revert() == [Node("X")] @@ -103,25 +130,31 @@ def test_statetape_and_revert_extend_refresh(): assert tape4.revert() == {"k": [Node("A"), Node("B")]} # Test extend and refresh tape4.extend({"k": ["C"]}) - assert tape4.revert()["k"][-1] == Node("C") + t4r = cast(Dict[str,List[Node]],tape4.revert()) + assert t4r["k"][-1] == Node("C") fresh = tape4.refresh() - assert fresh.revert()["k"] == [] + fresh_t4r = cast(Dict[str,List[Node]],fresh.revert()) + assert fresh_t4r["k"] == [] def test_collect_activations_simple_case(): + """ + What activations do we get for a simple case + """ # Setup flow and parsed route for "A --> B" flow = Flow().activate() flow.add_arrow_style("-", ("[", "]"), ",", ">") flow.compile_arrow_patterns() pr = flow.parse_route("A --> B") # Fake receiver function - async def fn(msg): + async def fn(_msg): return None receiver_index = {str(pr): Receiver(fn=fn, priority=(1,))} parsed_routes = {str(pr): pr} # Tape with nodes A and C under key 'k' tape = StateTape({"k": ["A", "C"]}) - activations = tape.collect_activations(receiver_index, parsed_routes) + activations: Dict[Tuple[int,...],List[TapeActivation]] = \ + tape.collect_activations(receiver_index, parsed_routes) # Only A should match assert (1,) in activations acts = activations[(1,)] @@ -132,35 +165,47 @@ async def fn(msg): assert act.route == pr assert act.fn is fn - def test_sender_responds_to_filters(): + """ + Sender that responds depending on some filters + """ # Setup a simple Trigger set with two signals Trigger = load_triggers(json_dict={"X": None, "Y": None}) - sigX, sigY = Trigger.X, Trigger.Y + sigX, sigY = Trigger.X, Trigger.Y # pyright: ignore[reportAttributeAccessIssue] from summoner.protocol.triggers import Move as MoveEvt, Stay as StayEvt, Test as TestEvt # Events - move_evt = MoveEvt(sigX) - stay_evt = StayEvt(sigY) - test_evt = TestEvt(sigX) + move_x_evt = MoveEvt(sigX) + stay_y_evt = StayEvt(sigY) + test_x_evt = TestEvt(sigX) + + async def do_nothing(): + """ + We are just checking whether or not it responds + so what to put here is just a dummy function + """ + return None # 1) No filters → always responds - sender1 = Sender(fn=lambda: None, multi=False, actions=None, triggers=None) - assert sender1.responds_to(move_evt) - assert sender1.responds_to(stay_evt) + sender1 = Sender(fn=do_nothing, multi=False, actions=None, triggers=None) + assert sender1.responds_to(move_x_evt) + assert sender1.responds_to(stay_y_evt) + assert sender1.responds_to(test_x_evt) # 2) on_actions only - sender_actions = Sender(fn=lambda: None, multi=False, actions={Action.MOVE}, triggers=None) - assert sender_actions.responds_to(move_evt) - assert not sender_actions.responds_to(stay_evt) + sender_actions = Sender(fn=do_nothing, multi=False, actions={Action.MOVE}, triggers=None) + assert sender_actions.responds_to(move_x_evt) + assert not sender_actions.responds_to(stay_y_evt) + assert not sender_actions.responds_to(test_x_evt) # 3) on_triggers only — any event carrying sigX is accepted - sender_triggers = Sender(fn=lambda: None, multi=False, actions=None, triggers={sigX}) - assert sender_triggers.responds_to(move_evt) - assert sender_triggers.responds_to(test_evt) # ← change to True - assert not sender_triggers.responds_to(stay_evt) # different signal Y - - # 4) both filters - sender_both = Sender(fn=lambda: None, multi=False, actions={Action.STAY}, triggers={sigY}) - assert sender_both.responds_to(stay_evt) + sender_triggers = Sender(fn=do_nothing, multi=False, actions=None, triggers={sigX}) + assert sender_triggers.responds_to(move_x_evt) + assert sender_triggers.responds_to(test_x_evt) + assert not sender_triggers.responds_to(stay_y_evt) # different signal Y so not responds + + # 4) both filters, so only the stay and that it is carrying sigY + sender_both = Sender(fn=do_nothing, multi=False, actions={Action.STAY}, triggers={sigY}) + assert sender_both.responds_to(stay_y_evt) assert not sender_both.responds_to(StayEvt(sigX)) - assert not sender_both.responds_to(move_evt) + assert not sender_both.responds_to(move_x_evt) + assert not sender_both.responds_to(test_x_evt) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 95c38cb..964dca1 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,5 +1,10 @@ """ -Tests for triggers.py: signal-tree parsing, Trigger class, Signal ordering, Event classes, and extract_signal. +Tests for triggers.py: +- signal-tree parsing +- Trigger class +- Signal ordering +- Event classes +- extract_signal """ import pytest @@ -17,6 +22,9 @@ def test_parse_signal_tree_simple_hierarchy(): + """ + parse_signal_tree on a simple example + """ # Define lines simulating a TRIGGERS file with two levels lines = [ "OK\n", @@ -30,6 +38,10 @@ def test_parse_signal_tree_simple_hierarchy(): def test_parse_signal_tree_invalid_varname(): + """ + parse_signal_tree but when the input has an + invalid name + """ # Names must be valid Python identifiers lines = [ "123invalid\n" @@ -40,6 +52,9 @@ def test_parse_signal_tree_invalid_varname(): def test_parse_signal_tree_inconsistent_indent(): + """ + Indentation syntax error of the parse_signal_tree + """ # Indents must follow previous levels exactly or match existing levels lines = [ "OK\n", @@ -52,7 +67,9 @@ def test_parse_signal_tree_inconsistent_indent(): def test_parse_signal_tree_duplicate_name(): - # Duplicate names at same indent level are disallowed + """ + Duplicate names at same indent level are disallowed + """ lines = [ "OK\n", " acceptable\n", @@ -63,19 +80,25 @@ def test_parse_signal_tree_duplicate_name(): assert "duplicate signal name" in str(excinfo.value) +#pylint:disable=no-member def test_load_triggers_with_json_dict(): + """ + loading triggers from a json_dict + has all the information provided + in the constructed Trigger + """ # Provide a nested dict directly to load_triggers json_dict = {"root": {"child": None}} Trigger = load_triggers(json_dict=json_dict) # Ensure attributes and path mappings are correct assert hasattr(Trigger, "root") and hasattr(Trigger, "child") - assert Trigger.root.path == (0,) - assert Trigger.child.path == (0, 0) - assert Trigger.name_of(0, 0) == "child" + assert Trigger.root.path == (0,) # pyright: ignore[reportAttributeAccessIssue] + assert Trigger.child.path == (0, 0) # pyright: ignore[reportAttributeAccessIssue] + assert Trigger.name_of(0, 0) == "child" # pyright: ignore[reportAttributeAccessIssue] def test_load_triggers_reserved_keyword(): - # Reserved Python keyword names should raise an error + """Reserved Python keyword names should raise an error""" json_dict = {"class": None} with pytest.raises(ValueError) as excinfo: load_triggers(json_dict=json_dict) @@ -83,16 +106,20 @@ def test_load_triggers_reserved_keyword(): def test_signal_comparison_and_properties(): + """ + Comparison of signals based on ancestry relation + """ # Simple hierarchy: A -> B (implemented so that ancestor > descendant) json_dict = {"A": {"B": None}} Trigger = load_triggers(json_dict=json_dict) - sigA, sigB = Trigger.A, Trigger.B + #pylint:disable=invalid-name + sigA, sigB = Trigger.A, Trigger.B # pyright: ignore[reportAttributeAccessIssue] # A parent signal compares greater than its child assert sigA > sigB assert sigB < sigA # Equality and hashing based on path - assert sigA == Trigger.A - assert hash(sigA) == hash(Trigger.A) + assert sigA == Trigger.A # pyright: ignore[reportAttributeAccessIssue] + assert hash(sigA) == hash(Trigger.A) # pyright: ignore[reportAttributeAccessIssue] # repr, name, path assert repr(sigA) == "" assert sigA.name == "A" @@ -102,9 +129,13 @@ def test_signal_comparison_and_properties(): def test_event_and_action_classes_and_extract_signal(): + """ + extract_signal + """ # Instantiate via a single-signal trigger Trigger = load_triggers(json_dict={"X": None}) - sigX = Trigger.X + #pylint:disable=invalid-name + sigX : Signal = Trigger.X # type: ignore move_evt = Move(sigX) stay_evt = Stay(sigX) test_evt = Test(sigX) @@ -120,4 +151,4 @@ def test_event_and_action_classes_and_extract_signal(): assert extract_signal(sigX) is sigX assert extract_signal(None) is None with pytest.raises(TypeError): - extract_signal(123) \ No newline at end of file + extract_signal(123) From 39d498722efa1b8a751846b95bb851f018a13481 Mon Sep 17 00:00:00 2001 From: Cobord Date: Wed, 18 Feb 2026 23:02:46 -0500 Subject: [PATCH 04/11] continued --- summoner/client/__init__.py | 5 +- summoner/client/client.py | 365 +++++++++++------- summoner/client/merger.py | 1 + summoner/utils/__init__.py | 6 +- summoner/utils/addr_handlers.py | 53 +-- .../utils/client_hyperparameter_configs.py | 23 +- summoner/utils/client_reconnect_configs.py | 14 +- summoner/utils/client_sender_configs.py | 4 +- summoner/utils/code_handlers.py | 19 +- summoner/utils/json_handlers.py | 13 +- summoner/utils/string_handlers.py | 12 +- 11 files changed, 322 insertions(+), 193 deletions(-) diff --git a/summoner/client/__init__.py b/summoner/client/__init__.py index 08587fb..55d4f47 100644 --- a/summoner/client/__init__.py +++ b/summoner/client/__init__.py @@ -1,2 +1,5 @@ +""" +TODO: doc client and ClientMerger summary +""" from .client import SummonerClient -from .merger import ClientMerger, ClientTranslation \ No newline at end of file +from .merger import ClientMerger, ClientTranslation diff --git a/summoner/client/client.py b/summoner/client/client.py index c6a109a..c731f29 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -1,13 +1,19 @@ +""" +SummonerClient +""" +#pylint:disable=line-too-long, too-many-lines, wrong-import-position +#pylint:disable=logging-fstring-interpolation, broad-exception-caught + import os import sys import json from typing import ( Dict, - Optional, - Callable, - Union, - Awaitable, - Any, + Optional, + Callable, + Union, + Awaitable, + Any, Type, cast, ) @@ -17,8 +23,6 @@ from collections import defaultdict import platform -from summoner.utils.client_hyperparameter_configs import TIMEOUT_TYPE - target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) @@ -32,40 +36,43 @@ rebuild_expression_for, ) from summoner.logger import ( - get_logger, - configure_logger, + get_logger, + configure_logger, Logger, ) from summoner.protocol.triggers import ( - Signal, - Event, + Signal, + Event, Action ) from summoner.protocol.process import ( - StateTape, - ParsedRoute, - Node, - Sender, - Receiver, + StateTape, + ParsedRoute, + Node, + Sender, + Receiver, Direction, ClientIntent, ) from summoner.protocol.flow import Flow from summoner.protocol.validation import ( - hook_priority_order, + hook_priority_order, _check_param_and_return, ) from summoner.protocol.payload import ( - wrap_with_types, + wrap_with_types, recover_with_types, RelayedMessage ) class ServerDisconnected(Exception): """Raised when the server closes the connection.""" - pass +#pylint:disable=too-many-instance-attributes class SummonerClient: + """ + TODO doc client + """ DEFAULT_MAX_BYTES_PER_LINE = 64 * 1024 # 64 KiB DEFAULT_READ_TIMEOUT_SECONDS = None # Wait for messages to arrive @@ -81,10 +88,10 @@ class SummonerClient: core_version = "1.1.1" def __init__(self, name: Optional[str] = None): - + # Give a name to the server self.name = name if isinstance(name, str) else "" - + # Create a bare logger (no handlers yet) self.logger: Logger = get_logger(self.name) @@ -134,7 +141,7 @@ def __init__(self, name: Optional[str] = None): # Receiver HyperParameters self.max_bytes_per_line : Optional[int] = None self.read_timeout_seconds : Optional[float] = None # None is prefered - + # Reconnction HyperParameters self.retry_delay_seconds : Optional[float] = None # self.primary_retry_limit is unbound until _apply_config @@ -167,11 +174,16 @@ def __init__(self, name: Optional[str] = None): # ==== VERSION SPECIFIC ==== def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): + """ + Given the config which says specific hyperparameters, set the corresponding + fields here. Making sure they meet the expected constraints. + If some are not provided in config, they may be given default values. + """ self.host = config.get("host") # default is None # pyright: ignore[reportAttributeAccessIssue] self.port = config.get("port") # default is None # pyright: ignore[reportAttributeAccessIssue] logger_cfg = cast(Dict[str,Any],config.get("logger", {})) - configure_logger(self.logger, logger_cfg) + configure_logger(self.logger, logger_cfg) hp_config = cast(Dict[str,Any],config.get("hyper_parameters", {})) @@ -185,42 +197,56 @@ def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): receiver_cfg = hp_config.get("receiver", {}) self.max_bytes_per_line = receiver_cfg.get("max_bytes_per_line", self.DEFAULT_MAX_BYTES_PER_LINE) self.read_timeout_seconds = receiver_cfg.get("read_timeout_seconds", self.DEFAULT_READ_TIMEOUT_SECONDS) - + sender_cfg = hp_config.get("sender", {}) self.max_concurrent_workers = sender_cfg.get("concurrency_limit", self.DEFAULT_CONCURRENCY_LIMIT) self.batch_drain = bool(sender_cfg.get("batch_drain", True)) self.send_queue_maxsize = sender_cfg.get("queue_maxsize", self.max_concurrent_workers) self.event_bridge_maxsize = sender_cfg.get("event_bridge_maxsize", self.DEFAULT_EVENT_BRIDGE_SIZE) self.max_consecutive_worker_errors = sender_cfg.get("max_worker_errors", self.DEFAULT_MAX_CONSECUTIVE_ERRORS) - + if (not isinstance(self.max_consecutive_worker_errors, int) or self.max_consecutive_worker_errors < 1): raise ValueError("sender.max_worker_errors must be an integer ≥ 1") if not isinstance(self.max_concurrent_workers, int) or self.max_concurrent_workers <= 0: raise ValueError("sender.concurrency_limit must be an integer ≥ 1") - + if not isinstance(self.send_queue_maxsize, int) or self.send_queue_maxsize <= 0: raise ValueError("sender.queue_maxsize must be an integer ≥ 1") - + if self.send_queue_maxsize < self.max_concurrent_workers: + #pylint:disable=logging-fstring-interpolation self.logger.warning(f"queue_maxsize < concurrency_limit; back-pressure will throttle producers at {self.send_queue_maxsize}") def initialize(self): + """ + Get the regex patterns for flow ready + """ self._flow.compile_arrow_patterns() def flow(self) -> Flow: + """ + TODO: doc flow + """ return self._flow - + async def travel_to(self, host, port): + """ + Change the host and port. + So there is now pending travel. + """ async with self.connection_lock: self.host = host self.port = port self._travel = True async def quit(self): + """ + Now a pending quit + """ async with self.connection_lock: self._quit = True - + async def _reset_client_intent(self): """Clear any pending quit or travel so we start fresh next time.""" async with self.connection_lock: @@ -233,12 +259,12 @@ def upload_states(self): Must be used before client.run(). """ def decorator(fn: Callable[[], Awaitable]): - + # ----[ Safety Checks ]---- - + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@upload_states handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@upload_states", @@ -252,10 +278,10 @@ def decorator(fn: Callable[[], Awaitable]): ), # the payload-dependent tape logger=self.logger, ) - + # if self.loop.is_running(): # raise RuntimeError("@upload_states() must be registered before client.run()") - + if self._upload_states is not None: self.logger.warning("@upload_states handler overwritten") @@ -277,21 +303,21 @@ def download_states(self): Must be used before client.run(). """ def decorator(fn: Callable[[Any], Awaitable]): - + # ----[ Safety Checks ]---- if not inspect.iscoroutinefunction(fn): raise TypeError(f"@download_states handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@download_states", allow_param=(type(None), Node, Any, list, dict, # type: ignore - list[Node], - dict[str, Node], - dict[str, list[Node]], - dict[str, Union[Node, list[Node]]], - dict[Optional[str], Node], + list[Node], + dict[str, Node], + dict[str, list[Node]], + dict[str, Union[Node, list[Node]]], + dict[Optional[str], Node], dict[Optional[str], list[Node]], dict[Optional[str], Union[Node, list[Node]]], ), @@ -301,7 +327,7 @@ def decorator(fn: Callable[[Any], Awaitable]): # if self.loop.is_running(): # raise RuntimeError("@download_states() must be registered before client.run()") - + if self._download_states is not None: self.logger.warning("@download_states handler overwritten") @@ -316,7 +342,7 @@ def decorator(fn: Callable[[Any], Awaitable]): return fn return decorator - + # ==== REGISTRATION HELPER ==== @@ -338,17 +364,23 @@ def _cb(): # ==== HOOK REGISTRATION ==== def hook( - self, - direction: Direction, + self, + direction: Direction, priority: Union[int, tuple[int, ...]] = () ): - def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]]): - - # ----[ Safety Checks ]---- - + """ + TODO: doc hook + """ + def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]]): + """ + TODO: doc decorator + """ + + # ----[ Safety Checks ]---- + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@hook handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@hook", @@ -356,10 +388,10 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic allow_return=(type(None), str, dict, Any), # type: ignore logger=self.logger, ) - + if not isinstance(direction, Direction): - raise TypeError(f"Direction for hook must be either Direction.SEND or Direction.RECEIVE") - + raise TypeError("Direction for hook must be either Direction.SEND or Direction.RECEIVE") + if isinstance(priority, int): tuple_priority = (priority,) elif isinstance(priority, tuple) and all(isinstance(p, int) for p in priority): @@ -375,8 +407,11 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic "source": inspect.getsource(fn), }) - # ----[ Registration Code ]---- async def register(): + """ + ----[ Registration Code ]---- + TODO doc register + """ async with self.hooks_lock: if direction == Direction.RECEIVE: self.receiving_hooks[tuple_priority] = fn # type: ignore @@ -394,22 +429,25 @@ async def register(): # ==== RECEIVER REGISTRATION ==== def receive( - self, - route: str, + self, + route: str, priority: Union[int, tuple[int, ...]] = () ): + """ + TODO: doc receive + """ route = route.strip() def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): - + # ----[ Safety Checks ]---- - + if not inspect.iscoroutinefunction(fn): raise TypeError(f"@receive handler '{fn.__name__}' must be async") - + sig = inspect.signature(fn) if len(sig.parameters) != 1: raise TypeError(f"@receive '{fn.__name__}' must accept exactly one argument (payload)") - + _check_param_and_return( fn, decorator_name="@receive", @@ -439,7 +477,7 @@ def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): # ----[ Registration Code ]---- async def register(): receiver = Receiver(fn=fn, priority=tuple_priority) - + parsed_route = None normalized_route = route @@ -458,7 +496,7 @@ async def register(): async with self.routes_lock: if route in self.receiver_index: self.logger.warning(f"Route '{route}' already exists. Overwriting.") - + if self._flow.in_use and parsed_route is not None: self.receiver_parsed_routes[normalized_route] = parsed_route self.receiver_index[normalized_route] = receiver @@ -469,27 +507,30 @@ async def register(): return fn return decorator - + # ==== SENDER REGISTRATION ==== def send( - self, - route: str, - multi: bool = False, + self, + route: str, + multi: bool = False, on_triggers: Optional[set[Signal]] = None, on_actions: Optional[set[Type]] = None, ): + """ + TODO: doc send + """ route = route.strip() def decorator(fn: Callable[[], Awaitable]): - + # ----[ Safety Checks ]---- if not inspect.iscoroutinefunction(fn): raise TypeError(f"@send sender '{fn.__name__}' must be async") - + sig = inspect.signature(fn) if len(sig.parameters) != 0: raise TypeError(f"@send '{fn.__name__}' must accept no arguments") - + if not multi: _check_param_and_return( fn, @@ -506,7 +547,7 @@ def decorator(fn: Callable[[], Awaitable]): allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), # type: ignore logger=self.logger, ) - + if not isinstance(route, str): raise TypeError(f"Argument `route` must be string. Provided: {route}") @@ -524,7 +565,7 @@ def decorator(fn: Callable[[], Awaitable]): not all(isinstance(act, type) and issubclass(act, Event) and act in {Action.MOVE, Action.STAY, Action.TEST} for act in on_actions) ): raise TypeError(f"Argument `on_actions` must be `None` or a set of Action event classes: {{Action.MOVE, Action.STAY, Action.TEST}}. Provided: {on_actions!r}") - + # ----[ DNA capture ]---- self._dna_senders.append({ "fn": fn, @@ -537,11 +578,11 @@ def decorator(fn: Callable[[], Awaitable]): # ----[ Registration Code ]---- async def register(): - + sender = Sender(fn=fn, multi=multi, actions=on_actions, triggers=on_triggers) actions_exist = isinstance(on_actions, set) and bool(on_actions) triggers_exist = isinstance(on_triggers, set) and bool(on_triggers) - + parsed_route = None normalized_route = route @@ -647,6 +688,7 @@ def _infer_client_binding_name(self) -> str: return "agent" + #pylint:disable=too-many-locals, too-many-branches, too-many-statements def dna(self, include_context: bool = False) -> str: """ Serialize this client's registered behavior into a JSON string ("DNA"). @@ -702,6 +744,7 @@ def dna(self, include_context: bool = False) -> str: - Flow construction and trigger enum creation are not captured unless they are referenced and executed within decorated handler sources. """ + #pylint:disable=import-outside-toplevel import builtins # Collect handler functions once so we can reuse them for context analysis. @@ -885,6 +928,7 @@ def dna(self, include_context: bool = False) -> str: recipes.setdefault(name, r) if "Path(" in r: path_needed = True + #pylint:disable=consider-using-f-string if "{}(".format(Node.__name__) in r: imports_out.add("from summoner.protocol.process import Node") continue @@ -958,14 +1002,18 @@ async def _read_line_safe( continue return data - + + # pylint:disable=too-many-locals, too-many-branches, too-many-statements async def message_receiver_loop( - self, - reader: asyncio.StreamReader, + self, + reader: asyncio.StreamReader, stop_event: asyncio.Event ): - + """ + TODO: doc message receiver + """ + # ----[ Wrapper: Interpret Protocol-Only Errors as None ]---- async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: try: @@ -976,21 +1024,21 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: except Exception as e: self.logger.exception(f"Receiver function {fn.__name__} raised an unexpected error: {e}") raise - + try: - + # ----[ Constantly Listen ]---- while not stop_event.is_set(): - + async with self.connection_lock: if self._quit or self._travel: stop_event.set() break - + # ----[ Prepare Receiver Batches ]---- async with self.routes_lock: receiver_index: dict[str, Receiver] = self.receiver_index.copy() - + if self._flow.in_use: async with self.routes_lock: receiver_parsed_routes: dict[str, ParsedRoute] = self.receiver_parsed_routes.copy() @@ -999,21 +1047,21 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # ----[ Empty: Skip and Prevent Client Overwhelming ]---- if not receiver_index: data = await self._read_line_safe( - reader, + reader, limit=self.max_bytes_per_line, # type: ignore timeout=0.1, ) # if not data: # raise ServerDisconnected("EOF while dropping messages") continue - + # ----[ Build and Run Receiver Batches ]---- try: - + # ----[ Build: Get Messages ]---- - + data = await self._read_line_safe( - reader, + reader, limit=self.max_bytes_per_line, # type: ignore timeout=self.read_timeout_seconds, ) @@ -1034,7 +1082,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: if new_payload is None: payload = None # type: ignore break - + except Exception as e: self.logger.error( f"Receiving hook {receiving_hook.__name__} (priority={priority}) " @@ -1043,11 +1091,11 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: ) new_payload = payload payload = new_payload - + # if *any* hook returned None, skip the rest of processing if payload is None: continue - + # ----[ Build: Organize Batches by Priority ]---- batches: dict[tuple[int, ...], list[Callable[[Any], Awaitable]]] = {} if self._flow.in_use: @@ -1065,14 +1113,14 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: _ = await reader.readline() # Space await asyncio.sleep(0.1) # Time continue - + # ----[ Exec: Prepare Passage Receiver → Sender ]---- if self._flow.in_use: event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Event]]] = defaultdict(list) # ----[ Exec: Run Batches in Order ]---- for priority, batch_fns in sorted(batches.items(), key=lambda kv: kv[0]): - + # ----[ Before: Run Batch ]---- # label = "default priority" if priority == () else f"priority {priority}" # self.logger.info(f"Running batch at {label}, {len(batch_fns)} receivers") @@ -1083,7 +1131,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # ----[ After: Handle Returns ]---- if self._flow.in_use: activations = activation_index[priority] # type: ignore - + local_tape = tape.refresh() # type: ignore to_extend: dict[str, list[Node]] = defaultdict(list) for act, event in zip(activations, events): @@ -1100,7 +1148,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: # event buffer is bound because this is within if self._flow_in_use # some of the event's in buffer_entries could have been None event_buffer[priority].extend(buffer_entries) # type: ignore - + if self._download_states is not None: await self._download_states(local_tape.revert()) @@ -1113,12 +1161,12 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: await self.event_bridge.put((priority,) + event_data) # type: ignore event_buffer = {} - + except ServerDisconnected as e: # Intentionally propagate this so reconnection logic can trigger self.logger.info(f"Graceful disconnect from server: {e}") raise - + except (ConnectionResetError, BrokenPipeError) as e: self.logger.warning(f"Socket-level failure (client likely to blame): {e}") break @@ -1131,7 +1179,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: def _start_send_workers( self, - writer: asyncio.StreamWriter, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): if not self.send_workers_started: @@ -1144,18 +1192,18 @@ def _start_send_workers( async def _send_worker( self, - writer: asyncio.StreamWriter, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): consecutive_errors = 0 while True: - + item: Optional[tuple[str, Sender]] = await self.send_queue.get() # type: ignore if item is None: self.send_queue.task_done() # type: ignore break - + route, sender = item try: result = await sender.fn() @@ -1172,10 +1220,10 @@ async def _send_worker( # ----[ Unpack: Handle Multi Sends ]---- payloads = result if sender.multi else [result] for payload in payloads: - + if payload is None: continue - + # ----[ Unpack: Validation ]---- async with self.hooks_lock: sending_hooks = self.sending_hooks.copy() @@ -1183,11 +1231,11 @@ async def _send_worker( for priority, sending_hook in sorted(sending_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: new_payload = await sending_hook(payload) # type: ignore - + if new_payload is None: payload = None break - + except Exception as e: self.logger.error( f"[route={route}] Sending hook {sending_hook.__name__} (priority={priority}) " @@ -1214,7 +1262,7 @@ async def _send_worker( # ----[ Unpack: Post Messages ]---- async with self.writer_lock: writer.write(message) - + # No concurrency on batch_drain (initialized in run()) if not self.batch_drain: async with self.writer_lock: @@ -1236,6 +1284,12 @@ async def _send_worker( self.send_queue.task_done() # type: ignore async def _cleanup_workers(self): + """ + Cancel all worker taks + Gather raising any exceptions caused in the worker_tasks + Clear them + Set back to having not started send workers + """ for w in self.worker_tasks: w.cancel() if self.worker_tasks: @@ -1243,17 +1297,24 @@ async def _cleanup_workers(self): self.worker_tasks.clear() self.send_workers_started = False + #pylint:disable=too-many-nested-blocks, no-else-continue, no-else-break, attribute-defined-outside-init async def message_sender_loop( - self, - writer: asyncio.StreamWriter, + self, + writer: asyncio.StreamWriter, stop_event: asyncio.Event ): + """ + TODO: doc message_sender + """ # ----[ Helper: Matches Routes Between Senders and Receivers to Trigger Send ]---- def _route_accepts( - sender_pr: ParsedRoute, + sender_pr: ParsedRoute, receiver_pr: ParsedRoute ) -> bool: + """ + TODO: doc route_accepts + """ source_ok = all(any(n.accepts(m) for m in receiver_pr.source) for n in sender_pr.source) label_ok = all(any(n.accepts(m) for m in receiver_pr.label) for n in sender_pr.label) target_ok = all(any(n.accepts(m) for m in receiver_pr.target) for n in sender_pr.target) @@ -1264,28 +1325,28 @@ def _route_accepts( # ----[ Keep Sending While Actively Listening (No Travel) ]---- while not stop_event.is_set(): - + # ----[ Prepare Sender Batch ]---- - + async with self.routes_lock: sender_index: dict[str, list[Sender]] = self.sender_index.copy() # ----[ Fast upload of pending event data ]---- if self._flow.in_use: - + async with self.routes_lock: sender_parsed_routes: dict[str, ParsedRoute] = self.sender_parsed_routes.copy() - + pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]] = [] try: while True: pending.append(self.event_bridge.get_nowait()) # type: ignore except asyncio.QueueEmpty: pass - + pending.sort(key=lambda it: hook_priority_order(it[0])) - # ----[ Build Sender Batch ]---- + # ----[ Build Sender Batch ]---- senders: list[tuple[str, Sender]] = [] # De-dup set: at most one sender per (route, key-from-recv, recv-handler-name) this cycle. @@ -1293,28 +1354,28 @@ def _route_accepts( for route, routed_senders in sender_index.items(): for sender in routed_senders: - + # Non-reactive (no actions/triggers): preserve current behavior if (not self._flow.in_use) or (sender.actions is None and sender.triggers is None): senders.append((route, sender)) - + # Reactive: require matching a pending activation (existential) - elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or + elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or (sender.triggers and isinstance(sender.triggers, set))): - + sender_parsed_route = sender_parsed_routes.get(route) # type: ignore if sender_parsed_route is None: continue - + # Iterate pending in queue order; first match "wins" for this (route,key,fn_name) - for (priority, key, parsed_route, event) in pending: # type: ignore + for (_priority, key, parsed_route, event) in pending: # type: ignore if _route_accepts(sender_parsed_route, parsed_route) and sender.responds_to(event): dedup_key = (route, key, sender.fn.__name__) # key scopes to the activation thread/peer if dedup_key not in emitted: senders.append((route, sender)) emitted.add(dedup_key) break # do not enqueue multiple times for this sender this cycle - + # ----[ Empty: Skip and Prevent Client Overwhelming | Almost full: warning ]---- if not senders: await asyncio.sleep(0.1) # Time @@ -1446,14 +1507,14 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): writer.close() await writer.wait_closed() self.logger.info("Disconnected from server.") - + finally: - + # Ensure both child tasks are cancelled & awaited even if we were cancelled mid-wait for task in (listen_task, sender_task): if task is not None and not task.done(): task.cancel() - + # Clean up worker used in the sender loop await self._cleanup_workers() @@ -1470,10 +1531,13 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # ==== CLIENT LIFE CYCLE ==== def shutdown(self): + """ + Cancel all the tasks for this event loop + """ self.logger.info("Client is shutting down...") for task in asyncio.all_tasks(self.loop): task.cancel() - + def set_termination_signals(self): """ Install SIGINT/SIGTERM handlers onto the loop: @@ -1482,6 +1546,7 @@ def set_termination_signals(self): """ if platform.system() != "Windows": for sig in (signal.SIGINT, signal.SIGTERM): + #pylint:disable=unnecessary-lambda self.loop.add_signal_handler(sig, lambda: self.shutdown()) async def _wait_for_registration(self): @@ -1501,7 +1566,7 @@ async def _wait_for_tasks_to_finish(self): # tasks = list(self.active_tasks) # if tasks: # await asyncio.gather(*tasks, return_exceptions=True) - + async with self.tasks_lock: tasks = list(self.active_tasks) if tasks: @@ -1518,7 +1583,7 @@ async def _retry_loop(self, host, port, limit, stage = "Primary"): attempts = 0 # clean disconnect (/quit or /travel) return True - + except (ConnectionRefusedError, ServerDisconnected, OSError) as e: attempts += 1 sleep_time = self.retry_delay_seconds or self.DEFAULT_RETRY_DELAY @@ -1542,68 +1607,74 @@ async def _get_client_intent(self) -> ClientIntent: if self._travel: return ClientIntent.TRAVEL return ClientIntent.ABORT - + async def _fallback(self): + """ + use the default host and port + """ async with self.connection_lock: self.host = self.default_host self.port = self.default_port async def run_client(self, host: str = '127.0.0.1', port: int = 8888): + """ + TODO: doc run_client + """ primary_stage = True while True: - + stage = "Primary" if primary_stage else "Default" limit = self.primary_retry_limit if primary_stage else self.default_retry_limit - + succeeded = await self._retry_loop(host, port, limit, stage) - + if succeeded: - + intent = await self._get_client_intent() - + if intent is ClientIntent.QUIT: - break - + break + elif intent is ClientIntent.TRAVEL: primary_stage = True continue - + else: break else: - + if primary_stage: primary_stage = False await self._fallback() self.logger.warning(f"Falling back to default server at {self.default_host}:{self.default_port}") continue - + else: self.logger.critical( f"Cannot connect to fallback {self.default_host}:{self.default_port} after " f"{self.default_retry_limit or '∞'} attempts; exiting" ) break - + def run( - self, - host: str = '127.0.0.1', - port: int = 8888, + self, + host: str = '127.0.0.1', + port: int = 8888, config_path: Optional[str] = None, config_dict: Optional[dict[str, Any]] = None, ): try: - + if config_dict is None: # Load config parameters client_config = load_config(config_path=config_path, debug=True) elif isinstance(config_dict, dict): # Shallow copy to avoid external mutation - client_config = dict(config_dict) + client_config = dict(config_dict) else: raise TypeError(f"SummonerClient.run: config_dict must be a dict or None, got {type(config_dict).__name__}") - + # client_config = load_config(config_path=config_path, debug=True) self._apply_config(client_config) @@ -1634,6 +1705,6 @@ def run( self.loop.run_until_complete(asyncio.gather(*self.worker_tasks, return_exceptions=True)) except (asyncio.CancelledError, KeyboardInterrupt): pass - + self.loop.close() - self.logger.info("Client exited cleanly.") \ No newline at end of file + self.logger.info("Client exited cleanly.") diff --git a/summoner/client/merger.py b/summoner/client/merger.py index 3a3bad4..92e2140 100644 --- a/summoner/client/merger.py +++ b/summoner/client/merger.py @@ -47,6 +47,7 @@ This is intended for trusted DNA (typically produced by your own agents). Do not run untrusted DNA. """ +#pylint:disable=line-too-long from importlib import import_module from typing import Optional, Any diff --git a/summoner/utils/__init__.py b/summoner/utils/__init__.py index 208a44b..835094b 100644 --- a/summoner/utils/__init__.py +++ b/summoner/utils/__init__.py @@ -1,3 +1,7 @@ +""" +Various utilities for strings, json, peer addresses, inspecting code +""" + from .string_handlers import ( remove_last_newline, ensure_trailing_newline, @@ -15,4 +19,4 @@ extract_annotation_identifiers, rebuild_expression_for, resolve_import_statement, - ) \ No newline at end of file + ) diff --git a/summoner/utils/addr_handlers.py b/summoner/utils/addr_handlers.py index d3f830b..57537a4 100644 --- a/summoner/utils/addr_handlers.py +++ b/summoner/utils/addr_handlers.py @@ -5,6 +5,33 @@ from typing import Any import ipaddress +def _format_addr_iterable(addr: Iterable, max_items: int) -> str: + try: + seq = list(addr) + except Exception: # pylint:disable=broad-exception-caught + return repr(addr) + + # Common socket cases + if len(seq) >= 2 and isinstance(seq[0], (str, bytes)) and isinstance(seq[1], int): + host = seq[0].decode() if isinstance(seq[0], (bytes, bytearray, memoryview)) else seq[0] + port = seq[1] + scopeid = seq[3] if len(seq) >= 4 and isinstance(seq[3], int) else None + try: + ip = ipaddress.ip_address(host) + if ip.version == 6: + host_fmt = host if scopeid in (None, 0) else f"{host}%{scopeid}" + return f"[{host_fmt}]:{port}" + return f"{host}:{port}" + except ValueError: + # Not an IP literal; fall back to host:port + return f"{host}:{port}" + + # Generic iterable formatting + shown = seq[:max_items] + body = ",".join(map(str, shown)) + suffix = ",..." if len(seq) > max_items else "" + return f"[{body}{suffix}]" + def format_addr(addr: Any, max_items: int = 10) -> str: """ Convert a socket peer address (or similar) to a compact, deterministic string. @@ -52,31 +79,7 @@ def format_addr(addr: Any, max_items: int = 10) -> str: # generic iterable (list/tuple/set/generator, etc.), but not strings/bytes if isinstance(addr, Iterable): - try: - seq = list(addr) - except Exception: # pylint:disable=broad-exception-caught - return repr(addr) - - # Common socket cases - if len(seq) >= 2 and isinstance(seq[0], (str, bytes)) and isinstance(seq[1], int): - host = seq[0].decode() if isinstance(seq[0], (bytes, bytearray, memoryview)) else seq[0] - port = seq[1] - scopeid = seq[3] if len(seq) >= 4 and isinstance(seq[3], int) else None - try: - ip = ipaddress.ip_address(host) - if ip.version == 6: - host_fmt = host if scopeid in (None, 0) else f"{host}%{scopeid}" - return f"[{host_fmt}]:{port}" - return f"{host}:{port}" - except ValueError: - # Not an IP literal; fall back to host:port - return f"{host}:{port}" - - # Generic iterable formatting - shown = seq[:max_items] - body = ",".join(map(str, shown)) - suffix = ",..." if len(seq) > max_items else "" - return f"[{body}{suffix}]" + _format_addr_iterable(addr, max_items) # fallback try: diff --git a/summoner/utils/client_hyperparameter_configs.py b/summoner/utils/client_hyperparameter_configs.py index e98ee90..8124d63 100644 --- a/summoner/utils/client_hyperparameter_configs.py +++ b/summoner/utils/client_hyperparameter_configs.py @@ -2,21 +2,32 @@ In `client.py` config["hyper_parameters"]["sender"] has several logic steps. Isolate that out here. """ +#pylint:disable=line-too-long -from dataclasses import dataclass from typing import List, Optional, TypeAlias -from summoner.utils.client_reconnect_configs import ReconnectConfig, RETRY_DELAY_SECONDS_TYPE, PORT_TYPE +# pylint:disable=unused-import +from summoner.utils.client_reconnect_configs import \ + ReconnectConfig, RETRY_DELAY_SECONDS_TYPE, PORT_TYPE from summoner.utils.client_sender_configs import SenderConfig +#pylint:disable=invalid-name TIMEOUT_TYPE: TypeAlias = Optional[float] +#pylint:disable=too-few-public-methods class HyperparameterConfig: + """ + Client config hyperparameter section + It is mainly broken up into + SenderConfig + ReconnectConfig + The receiver part is smaller so is directly in the remaining 2 fields + """ sender_config: SenderConfig reconnect_config: ReconnectConfig max_bytes_per_line: int read_timeout_seconds: TIMEOUT_TYPE - + def __init__(self, sender_config: SenderConfig, reconnect_config: ReconnectConfig, @@ -27,7 +38,7 @@ def __init__(self, self.reconnect_config = reconnect_config self.max_bytes_per_line = max_bytes_per_line self.read_timeout_seconds = read_timeout_seconds - + def __post_init__(self): if self.max_bytes_per_line <= 0:\ raise ValueError("The provided max_bytes_per_line must be an integer ≥ 1") @@ -35,6 +46,10 @@ def __post_init__(self): raise ValueError("The provided read_timeout_seconds must be an float ≥ 0.0 if provided") def merge_in(self, **kwargs) -> List[str]: + """ + The **kwargs are new values from configuration files + and this is setting those values as appropriate + """ all_problems = [] all_problems.extend(self.sender_config.merge_in(**kwargs["sender"])) all_problems.extend(self.reconnect_config.merge_in(**kwargs["reconnection"])) diff --git a/summoner/utils/client_reconnect_configs.py b/summoner/utils/client_reconnect_configs.py index 231d9ae..c7e4650 100644 --- a/summoner/utils/client_reconnect_configs.py +++ b/summoner/utils/client_reconnect_configs.py @@ -2,21 +2,28 @@ In `client.py` config["hyper_parameters"]["reconnection"] has several logic steps. Isolate that out here """ +#pylint:disable=line-too-long from dataclasses import dataclass from typing import List, Optional, TypeAlias +#pylint:disable=invalid-name RETRY_DELAY_SECONDS_TYPE : TypeAlias = float +#pylint:disable=invalid-name PORT_TYPE : TypeAlias = Optional[int] @dataclass(slots=True) class ReconnectConfig: + """ + In `client.py` config["hyper_parameters"]["reconnection"] + has several logic steps. Isolate that out here + """ retry_delay_seconds: RETRY_DELAY_SECONDS_TYPE primary_retry_limit: int default_host: Optional[str] default_port: PORT_TYPE default_retry_limit: int - + def __post_init__(self): assert self.retry_delay_seconds >= 0.0,\ "The provided retry_delay_seconds must be a float ≥ 0.0" @@ -25,7 +32,12 @@ def __post_init__(self): assert self.default_retry_limit >= 0,\ "The provided default_retry_limit must be an integer ≥ 0" + #pylint:disable=too-many-branches def merge_in(self, **kwargs) -> List[str]: + """ + The **kwargs are new values from configuration files + and this is setting those values as appropriate + """ all_problems = [] if "retry_delay_seconds" in kwargs: if not isinstance((z := kwargs["retry_delay_seconds"]), float | int): diff --git a/summoner/utils/client_sender_configs.py b/summoner/utils/client_sender_configs.py index 8f038c4..74a7ba6 100644 --- a/summoner/utils/client_sender_configs.py +++ b/summoner/utils/client_sender_configs.py @@ -2,6 +2,7 @@ In `client.py` config["hyper_parameters"]["sender"] has several logic steps. Isolate that out here. """ +#pylint:disable=line-too-long from dataclasses import dataclass from typing import List, Optional @@ -22,7 +23,7 @@ class SenderConfig: send_queue_maxsize: int event_bridge_maxsize: Optional[int] max_consecutive_worker_errors: int - + def __post_init__(self): if self.max_concurrent_workers <= 0:\ raise ValueError("The provided max_concurrent_workers must be an integer ≥ 1") @@ -31,6 +32,7 @@ def __post_init__(self): if self.max_consecutive_worker_errors <= 0:\ raise ValueError("The provided max_consecutive_worker_errors must be an integer ≥ 1") + #pylint:disable=too-many-branches def merge_in(self, **kwargs) -> List[str]: """ The logic of _apply_config that is relevant to the sender_cfg. diff --git a/summoner/utils/code_handlers.py b/summoner/utils/code_handlers.py index 183976f..b02c4ed 100644 --- a/summoner/utils/code_handlers.py +++ b/summoner/utils/code_handlers.py @@ -1,3 +1,7 @@ +""" +Using inspection for Python code +""" + from typing import Optional, Set, Any import inspect import ast @@ -46,7 +50,7 @@ def get_callable_source(fn: Any, override: Optional[str] = None) -> str: try: return inspect.getsource(fn) - except Exception: + except Exception: # pylint:disable=broad-exception-caught src = getattr(fn, "__dna_source__", None) if isinstance(src, str) and src.strip(): return src @@ -65,6 +69,7 @@ def get_callable_source(fn: Any, override: Optional[str] = None) -> str: ) +# pylint:disable=too-many-branches def extract_annotation_identifiers(src: str) -> Set[str]: """ Extract simple identifier names used inside function annotations from source text. @@ -97,7 +102,7 @@ def extract_annotation_identifiers(src: str) -> Set[str]: out: Set[str] = set() try: tree = ast.parse(textwrap.dedent(src)) - except Exception: + except Exception:# pylint:disable=broad-exception-caught return out fn_node = None @@ -179,7 +184,7 @@ def rebuild_expression_for(value: object, node_type: Optional[type] = None) -> O return None - +# pylint:disable=too-many-return-statements, too-many-branches def resolve_import_statement(name: str, value: object, known_modules: Set[str]) -> Optional[str]: """ Try to produce a stable Python import statement that binds `name` to `value`. @@ -240,7 +245,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) known_modules.add(mod) if getattr(m, name, None) is value: return f"from {mod} import {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass obj_name = getattr(value, "__name__", None) @@ -251,7 +256,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) if name == obj_name: return f"from {mod} import {obj_name}" return f"from {mod} import {obj_name} as {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # Fallback: search modules we've already seen. @@ -260,7 +265,7 @@ def resolve_import_statement(name: str, value: object, known_modules: Set[str]) m = import_module(km) if getattr(m, name, None) is value: return f"from {km} import {name}" - except Exception: + except Exception:# pylint:disable=broad-exception-caught continue - return None \ No newline at end of file + return None diff --git a/summoner/utils/json_handlers.py b/summoner/utils/json_handlers.py index f69b412..d2acc03 100644 --- a/summoner/utils/json_handlers.py +++ b/summoner/utils/json_handlers.py @@ -1,3 +1,7 @@ +""" +General json manipulation utilities +""" + import json from pathlib import Path from typing import Any, Optional @@ -5,6 +9,7 @@ def fully_recover_json(data): """ Recursively recover original nested structure from JSON strings. + A mix of stringified JSON and python lists, dicts and primitives Args: data (any): Data structure possibly containing nested JSON-encoded strings. @@ -42,9 +47,9 @@ def load_config(config_path: Optional[str], debug: bool = False) -> dict[str, An """ if config_path is None: if debug: - print(f"[DEBUG] Config file is `None`") + print("[DEBUG] Config file is `None`") return {} - + path = Path(config_path) if not path.is_file(): @@ -64,7 +69,7 @@ def load_config(config_path: Optional[str], debug: bool = False) -> dict[str, An except OSError as e: if debug: print(f"[DEBUG] OS error reading {config_path}: {e}") - except Exception as e: + except Exception as e: # pylint:disable=broad-exception-caught if debug: print(f"[DEBUG] Unexpected error loading {config_path}: {e}") @@ -97,5 +102,5 @@ def is_jsonable(value: Any) -> bool: try: json.dumps(value) return True - except Exception: + except Exception: # pylint:disable=broad-exception-caught return False diff --git a/summoner/utils/string_handlers.py b/summoner/utils/string_handlers.py index 31ed86d..9414cc5 100644 --- a/summoner/utils/string_handlers.py +++ b/summoner/utils/string_handlers.py @@ -1,7 +1,15 @@ +""" +Newline manipulation utilities +""" + def remove_last_newline(s: str): + """ + if ending with a newline, remove it + """ return s[:-1] if s.endswith('\n') else s def ensure_trailing_newline(s: str): + """ + Make sure ends with a newline + """ return s if s.endswith('\n') else s + '\n' - - From cf4c9dabe815fe1ed01649bcd81d66036316a613 Mon Sep 17 00:00:00 2001 From: Cobord Date: Wed, 18 Feb 2026 23:03:28 -0500 Subject: [PATCH 05/11] continued --- summoner/client/merger.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/summoner/client/merger.py b/summoner/client/merger.py index 92e2140..bd47cdb 100644 --- a/summoner/client/merger.py +++ b/summoner/client/merger.py @@ -593,7 +593,7 @@ def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.Fu # if your dna() uses a __dna_source__ fallback, keep it if hasattr(fn, "__dna_source__"): - new_fn.__dna_source__ = fn.__dna_source__ + new_fn.__dna_source__ = fn.__dna_source__ # type: ignore return new_fn @@ -811,7 +811,7 @@ def initiate_senders(self): dec = self.send( entry["route"], multi=entry.get("multi", False), - on_triggers=on_triggers, + on_triggers=on_triggers, # type: ignore on_actions=on_actions, ) self._apply_with_source_patch(dec, fn, entry["source"]) @@ -1019,7 +1019,7 @@ def _cleanup_template_clients_from_modules(self): for module_name in modules: try: - module = sys.modules.get(module_name) or import_module(module_name) + module = sys.modules.get(module_name) or import_module(module_name) # type: ignore except Exception: continue @@ -1150,7 +1150,7 @@ def initiate_senders(self): dec = self.send( entry["route"], multi=entry.get("multi", False), - on_triggers=on_triggers, + on_triggers=on_triggers, # type: ignore on_actions=on_actions, ) self._apply_with_source_patch(dec, fn, entry["source"]) From a2fc8c29aadc65d9e8c4be01c62efc44167ef3c1 Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 19 Feb 2026 18:23:53 -0500 Subject: [PATCH 06/11] in protocol section --- summoner/client/client.py | 36 ++++-- summoner/client/merger.py | 36 +++++- summoner/logger.py | 2 +- summoner/protocol/__init__.py | 19 +-- summoner/protocol/_deprecation.py | 5 +- summoner/protocol/flow.py | 170 ++++++++++++++++++++++++--- summoner/protocol/payload.py | 11 +- summoner/protocol/process.py | 187 +++++++++++++++++++++++------- summoner/protocol/triggers.py | 69 ++++++++--- summoner/protocol/validation.py | 31 +++-- 10 files changed, 452 insertions(+), 114 deletions(-) diff --git a/summoner/client/client.py b/summoner/client/client.py index c731f29..171047a 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -9,6 +9,7 @@ import json from typing import ( Dict, + Generator, Optional, Callable, Union, @@ -65,6 +66,8 @@ RelayedMessage ) +ANY_TO_AWAIT = Callable[[Any],Awaitable] + class ServerDisconnected(Exception): """Raised when the server closes the connection.""" @@ -128,8 +131,8 @@ def __init__(self, name: Optional[str] = None): self._flow = Flow() # Functions to read and write the flow's active states in memory - self._upload_states: Optional[Callable[[Any], Awaitable]] = None - self._download_states: Optional[Callable[[Any], Awaitable]] = None + self._upload_states: Optional[ANY_TO_AWAIT] = None + self._download_states: Optional[ANY_TO_AWAIT] = None # Sender HyperParameters self.event_bridge_maxsize : Optional[int] = None @@ -173,6 +176,7 @@ def __init__(self, name: Optional[str] = None): # ==== VERSION SPECIFIC ==== + #pylint:disable=attribute-defined-outside-init def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): """ Given the config which says specific hyperparameters, set the corresponding @@ -215,7 +219,6 @@ def _apply_config(self, config: dict[str,Union[str,dict[str,Union[str,dict]]]]): raise ValueError("sender.queue_maxsize must be an integer ≥ 1") if self.send_queue_maxsize < self.max_concurrent_workers: - #pylint:disable=logging-fstring-interpolation self.logger.warning(f"queue_maxsize < concurrency_limit; back-pressure will throttle producers at {self.send_queue_maxsize}") def initialize(self): @@ -302,7 +305,7 @@ def download_states(self): Decorator to supply a function that receives a StateTape. Must be used before client.run(). """ - def decorator(fn: Callable[[Any], Awaitable]): + def decorator(fn: ANY_TO_AWAIT): # ----[ Safety Checks ]---- @@ -1015,7 +1018,7 @@ async def message_receiver_loop( """ # ----[ Wrapper: Interpret Protocol-Only Errors as None ]---- - async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: + async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: try: return await fn(payload) except BlockingIOError: @@ -1097,7 +1100,7 @@ async def _safe_call(fn: Callable[[Any], Awaitable], payload: Any) -> Any: continue # ----[ Build: Organize Batches by Priority ]---- - batches: dict[tuple[int, ...], list[Callable[[Any], Awaitable]]] = {} + batches: dict[tuple[int, ...], list[ANY_TO_AWAIT]] = {} if self._flow.in_use: raw_states = (await self._upload_states(payload)) if self._upload_states is not None else None tape = StateTape(raw_states) @@ -1363,7 +1366,9 @@ def _route_accepts( elif self._flow.in_use and ((sender.actions and isinstance(sender.actions, set)) or (sender.triggers and isinstance(sender.triggers, set))): - sender_parsed_route = sender_parsed_routes.get(route) # type: ignore + # self._flow_in_use so sender_parsed_routes is bound + sender_parsed_routes = cast(Dict[str,ParsedRoute],sender_parsed_routes) # type: ignore + sender_parsed_route = sender_parsed_routes.get(route) if sender_parsed_route is None: continue @@ -1496,7 +1501,7 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): await task except ServerDisconnected as e: # Propagate server-side disconnection to the reconnection handler - raise ServerDisconnected(e) + raise ServerDisconnected(e) #pylint:disable=raise-missing-from except asyncio.CancelledError: # Normal during shutdown; ignore pass @@ -1664,6 +1669,9 @@ def run( config_path: Optional[str] = None, config_dict: Optional[dict[str, Any]] = None, ): + """ + TODO: doc run + """ try: if config_dict is None: @@ -1708,3 +1716,15 @@ def run( self.loop.close() self.logger.info("Client exited cleanly.") + + def _view_candidates(self) -> Generator[Optional[Callable[[Any], Awaitable]],None,None]: + if self._upload_states is not None: + yield self._upload_states + if self._download_states is not None: + yield self._download_states + for d in self._dna_receivers: + yield d.get("fn") # type: ignore + for d in self._dna_senders: + yield d.get("fn") # type: ignore + for d in self._dna_hooks: + yield d.get("fn") # type: ignore diff --git a/summoner/client/merger.py b/summoner/client/merger.py index bd47cdb..c1462e5 100644 --- a/summoner/client/merger.py +++ b/summoner/client/merger.py @@ -47,7 +47,8 @@ This is intended for trusted DNA (typically produced by your own agents). Do not run untrusted DNA. """ -#pylint:disable=line-too-long +#pylint:disable=line-too-long, too-many-lines, wrong-import-position +#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation from importlib import import_module from typing import Optional, Any @@ -60,7 +61,8 @@ import json import uuid -import os, sys +import os +import sys target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) @@ -185,6 +187,7 @@ class ClientMerger(SummonerClient): This class executes trusted code from DNA via exec()/eval(). Do not run untrusted DNA. """ + # pylint:disable=too-many-arguments, too-many-positional-arguments def __init__( self, named_clients: list[Any], # backward compatible: list[dict] or list[SummonerClient] or list[dna_list] @@ -228,6 +231,7 @@ def __init__( # Source normalization # ---------------------------- + #pylint:disable=too-many-branches def _normalize_source(self, entry: Any, idx: int) -> dict[str, Any]: """ Normalize a user-provided source specification into a canonical dict. @@ -351,7 +355,9 @@ def _infer_client_var_name(self, client: SummonerClient) -> Optional[str]: The inferred binding name, or None if not found. """ # Look for a module-global name whose value is exactly `client` + # TODO: can use _view_candidates instead of building a list candidates = [] + #pylint:disable=protected-access if client._upload_states is not None: candidates.append(client._upload_states) if client._download_states is not None: @@ -403,6 +409,7 @@ def _shutdown_imported_clients(self) -> None: client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access tasks = list(client._registration_tasks or []) loop = client.loop @@ -448,12 +455,14 @@ def _shutdown_imported_clients(self) -> None: # 4) clear list with suppress(Exception): + #pylint:disable=protected-access client._registration_tasks.clear() # ---------------------------- # Context application (DNA) # ---------------------------- + # pylint:disable=too-many-branches def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[str, Any]: """ Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. @@ -495,6 +504,7 @@ def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[st continue try: + # pylint:disable=exec-used exec(line, g) report["succeeded"].append(line) if self._verbose_context_imports: @@ -517,6 +527,7 @@ def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[st if not (isinstance(k, str) and isinstance(expr, str)): continue try: + # pylint:disable=eval-used g.setdefault(k, eval(expr, g, {})) except Exception as e: self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") @@ -649,6 +660,7 @@ def _make_from_source(self, entry: dict[str, Any], g: dict, sandbox_name: str) - if "__builtins__" not in g: g["__builtins__"] = __builtins__ + # pylint:disable=exec-used exec(compile(func_body, filename=f"<{sandbox_name}>", mode="exec"), g) fn = g.get(fn_name) @@ -668,6 +680,7 @@ def initiate_upload_states(self): if src["kind"] == "client": client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access fn = client._upload_states if fn is None: continue @@ -693,6 +706,7 @@ def initiate_download_states(self): if src["kind"] == "client": client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access fn = client._download_states if fn is None: continue @@ -718,6 +732,7 @@ def initiate_hooks(self): if src["kind"] == "client": client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access for dna in client._dna_hooks: fn_clone = self._clone_handler(dna["fn"], var_name) try: @@ -742,6 +757,7 @@ def initiate_receivers(self): if src["kind"] == "client": client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access for dna in client._dna_receivers: fn_clone = self._clone_handler(dna["fn"], var_name) try: @@ -779,6 +795,7 @@ def initiate_senders(self): if src["kind"] == "client": client: SummonerClient = src["client"] var_name: str = src["var_name"] + #pylint:disable=protected-access for dna in client._dna_senders: fn_clone = self._clone_handler(dna["fn"], var_name) try: @@ -833,7 +850,7 @@ def initiate_all(self): self.initiate_receivers() self.initiate_senders() - +# pylint:disable=too-many-instance-attributes class ClientTranslation(SummonerClient): """ Reconstruct a SummonerClient from its DNA list. @@ -862,6 +879,7 @@ class ClientTranslation(SummonerClient): This class attempts to find and clean up such template clients so they do not leave pending registration tasks or open loops behind. """ + # pylint:disable=too-many-arguments, too-many-positional-arguments def __init__( self, dna_list: list[dict[str, Any]], @@ -875,7 +893,7 @@ def __init__( if not isinstance(dna_list, list): raise TypeError("dna_list must be a list of DNA entries") - + self._rebind_globals = dict(rebind_globals or {}) self._allow_context_imports = allow_context_imports self._verbose_context_imports = verbose_context_imports @@ -915,6 +933,7 @@ def __init__( # Best-effort cleanup of template clients imported indirectly via DNA modules self._cleanup_template_clients_from_modules() + #pylint:disable=too-many-branches def _apply_context(self): """ Apply optional "__context__" entry into the translation sandbox. @@ -934,6 +953,7 @@ def _apply_context(self): if not self._allow_context_imports: continue try: + # pylint:disable=exec-used exec(line, g) if self._verbose_context_imports: self.logger.info(f"[translation ctx] import ok: {line}") @@ -954,11 +974,13 @@ def _apply_context(self): if not (isinstance(k, str) and isinstance(expr, str)): continue # eval in the sandbox global namespace + # pylint:disable=eval-used try: g.setdefault(k, eval(expr, g, {})) except Exception as e: self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") + #pylint:disable=unused-argument def _cleanup_one_template_client(self, client: SummonerClient, label: str): """ Best-effort cleanup for a template client discovered in an imported module. @@ -971,6 +993,7 @@ def _cleanup_one_template_client(self, client: SummonerClient, label: str): - if possible, await cancellations by driving the template's loop - close the loop when it is not running """ + # pylint:disable=protected-access regs = list(client._registration_tasks or []) loop = client.loop @@ -991,6 +1014,7 @@ def _cleanup_one_template_client(self, client: SummonerClient, label: str): # clear list try: + # pylint:disable=protected-access client._registration_tasks.clear() except Exception: pass @@ -1061,6 +1085,7 @@ def _make_from_source(self, entry: dict[str, Any]) -> types.FunctionType: else: raise RuntimeError(f"Could not find definition for '{fn_name}'") + #pylint:disable=exec-used exec(compile(func_body, filename=f"<{self._sandbox_module_name}>", mode="exec"), module_globals) fn = module_globals.get(fn_name) @@ -1209,7 +1234,7 @@ def shutdown(self): except RuntimeError: # If the loop isn't running, ignore. pass - + async def quit(self): """ Quit the translated client: @@ -1232,4 +1257,3 @@ def run(self, *args, **kwargs): self.logger.info("KeyboardInterrupt caught-cancelling registration tasks…") for task in list(self._registration_tasks or []): task.cancel() - return diff --git a/summoner/logger.py b/summoner/logger.py index cb7e202..bb3d217 100644 --- a/summoner/logger.py +++ b/summoner/logger.py @@ -150,7 +150,7 @@ def format(self, record: logging.LogRecord) -> str: else {k: record.msg.get(k) for k in self.log_keys if k in record.msg}) else: payload = record.getMessage() - base["message"] = payload + base["message"] = payload # type: ignore return json.dumps(base, default=str) def get_logger(name: str) -> logging.Logger: diff --git a/summoner/protocol/__init__.py b/summoner/protocol/__init__.py index edb0303..4189d96 100644 --- a/summoner/protocol/__init__.py +++ b/summoner/protocol/__init__.py @@ -1,18 +1,21 @@ +""" +TODO: summary of triggers, process and flow +""" from .triggers import ( - Signal, - Event, + Signal, + Event, Action, Move, Stay, Test, ) from .process import ( - StateTape, - ParsedRoute, - Node, - Sender, - Receiver, + StateTape, + ParsedRoute, + Node, + Sender, + Receiver, Direction, ) from .flow import Flow -# from .validation import hook_priority_order, _check_param_and_return \ No newline at end of file +# from .validation import hook_priority_order, _check_param_and_return diff --git a/summoner/protocol/_deprecation.py b/summoner/protocol/_deprecation.py index 0ee7e66..c0c76c0 100644 --- a/summoner/protocol/_deprecation.py +++ b/summoner/protocol/_deprecation.py @@ -1,6 +1,9 @@ +""" +Warnings about using deprecated methods +""" try: from warnings import deprecated # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python <= 3.12 -__all__ = ["deprecated"] \ No newline at end of file +__all__ = ["deprecated"] diff --git a/summoner/protocol/flow.py b/summoner/protocol/flow.py index d9690be..b9e6ab7 100644 --- a/summoner/protocol/flow.py +++ b/summoner/protocol/flow.py @@ -1,11 +1,16 @@ +""" +Handles many regexes for different ArrowStyle +""" from __future__ import annotations import re from collections.abc import Callable from typing import Iterable, Optional, Any +import warnings from .triggers import load_triggers from .process import Node, ArrowStyle, ParsedRoute from ._deprecation import deprecated # type: ignore -import warnings + +# pylint:disable=line-too-long # variable names or commands used in flow transitions _TOKEN_RE = re.compile(r""" @@ -34,7 +39,7 @@ def get_token_list(input_string: str, separator: str) -> list[str]: parenthesis_depth: int = 0 for character in input_string: - + if character == "(": parenthesis_depth += 1 elif character == ")": @@ -56,20 +61,34 @@ def get_token_list(input_string: str, separator: str) -> list[str]: TriggerType = type class Flow: - + """ + Handles many regexes for different ArrowStyle + """ + def __init__(self, triggers_file: Optional[str] = None) -> None: + """ + Handles many regexes for different ArrowStyle + """ self.triggers_file = triggers_file self.in_use: bool = False self.arrows: set[ArrowStyle] = set() - + self._regex_ready: bool = False self._regex_patterns: list[tuple[re.Pattern[str], ArrowStyle, Unpack]] = [] - + def activate(self) -> Flow: + """ + In the client there is logic + that depends on how this is toggled. + """ self.in_use = True return self def deactivate(self) -> Flow: + """ + In the client there is logic + that depends on how this is toggled. + """ self.in_use = False return self @@ -80,6 +99,12 @@ def add_arrow_style( separator: str, tip: str ) -> None: + """ + Add an arrow, that is another possibility + for parsing routes. + This means the regex patterns will need + to be recompiled as the set of ArrowStyle has changed. + """ style = ArrowStyle( stem=stem, brackets=brackets, @@ -91,14 +116,15 @@ def add_arrow_style( self._regex_patterns.clear() def triggers(self, json_dict: Optional[dict[str, Any]] = None) -> TriggerType: - if json_dict is None: + """ + Load and build the TRIG class from the nested_dict or TRIGGERS file + """ + if json_dict is None: if self.triggers_file is None: # use the file TRIGGERS return load_triggers() - else: - return load_triggers(triggers_file=self.triggers_file) - else: - return load_triggers(json_dict=json_dict) + return load_triggers(triggers_file=self.triggers_file) + return load_triggers(json_dict=json_dict) def _build_labeled_complete( self, @@ -107,6 +133,15 @@ def _build_labeled_complete( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a complete labelled arrow. + For example, + As --[Cs]--> Bs + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -118,6 +153,14 @@ def _build_labeled_complete( return re.compile(regex, re.VERBOSE) def _build_unlabeled_complete(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a complete unlabelled arrow. + For example, + As --> Bs + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* @@ -125,7 +168,7 @@ def _build_unlabeled_complete(self, base: str, tip: str) -> re.Pattern[str]: (?P.+?)\s*$ """ return re.compile(regex, re.VERBOSE) - + def _build_labeled_dangling_right( self, base: str, @@ -133,6 +176,15 @@ def _build_labeled_dangling_right( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the right labelled arrow. + For example, + As --[Cs]--> + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -143,13 +195,21 @@ def _build_labeled_dangling_right( return re.compile(regex, re.VERBOSE) def _build_unlabeled_dangling_right(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the right unlabelled arrow. + For example, + As --> + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* (?P.+?)\s*{arrow}\s*$ """ return re.compile(regex, re.VERBOSE) - + def _build_labeled_dangling_left( self, base: str, @@ -157,6 +217,15 @@ def _build_labeled_dangling_left( right_bracket: str, tip: str ) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the left labelled arrow. + For example, + --[Cs]--> Bs + - base says the shaft of the arrow is -- + - left_bracket,right_bracket say the labels are surrounded by [] + - tip says the tip of the arrow is > + """ left = rf"{base}{left_bracket}" right = rf"{right_bracket}{base}{tip}" regex = rf""" @@ -168,6 +237,14 @@ def _build_labeled_dangling_left( return re.compile(regex, re.VERBOSE) def _build_unlabeled_dangling_left(self, base: str, tip: str) -> re.Pattern[str]: + """ + Build the regex to match + a dangling on the left unlabelled arrow. + For example, + --> Bs + - base says the shaft of the arrow is -- + - tip says the tip of the arrow is > + """ arrow = rf"{base}{tip}" regex = rf""" ^\s* @@ -177,6 +254,10 @@ def _build_unlabeled_dangling_left(self, base: str, tip: str) -> re.Pattern[str] return re.compile(regex, re.VERBOSE) def _unpack_labeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was complete and labelled, the produced regex will have 3 groups + for source, label and target. A match with this will give strings for each of these. + """ return ( matched_pattern.group("source"), matched_pattern.group("label"), @@ -184,21 +265,48 @@ def _unpack_labeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, ) def _unpack_unlabeled_complete(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was complete and unlabelled, the produced regex will have 2 groups + for source and target. A match with this will give strings for each of these and the label is empty. + """ return matched_pattern.group("source"), "", matched_pattern.group("target") - + def _unpack_labeled_dangling_right(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the right and labelled, the produced regex will have 2 groups + for source and label. A match with this will give strings for each of these and the target is empty. + """ return matched_pattern.group("source"), matched_pattern.group("label"), "" - + def _unpack_unlabeled_dangling_right(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the right and unlabelled, the produced regex will have 1 group + for source. A match with this will give strings for each of these and the label and target are empty. + """ return matched_pattern.group("source"), "", "" - + def _unpack_labeled_dangling_left(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the left and labelled, the produced regex will have 2 groups + for label and target. A match with this will give strings for each of these and the source is empty. + """ return "", matched_pattern.group("label"), matched_pattern.group("target") def _unpack_unlabeled_dangling_left(self, matched_pattern: re.Match[str]) -> tuple[str, str, str]: + """ + If the arrow was dangling on the left and unlabelled, the produced regex will have 1 group + for target. A match with this will give strings for each of these and the source and label are empty. + """ return "", "", matched_pattern.group("target") def _prepare_regex(self) -> None: + """ + If the regex for all the ArrowStyle(s) have not been compiled, compile them all. + This takes care of all variants of + - labelled, unlabelled + - dangling left, dangling right, complete + on each of the styles. + """ if self._regex_ready: return @@ -230,11 +338,21 @@ def _prepare_regex(self) -> None: self._regex_ready = True def _validate_tokens(self, tokens: list[str], text: str) -> None: + """ + Each source, label and target must conform to the naming style + similar to a Python identifier but also allowing /all /oneof /not + which are also valid as sources and targets. + """ for token in tokens: if not _TOKEN_RE.match(token): raise ValueError(f"Invalid token {token!r} in route {text!r}") def _parse_standalone(self, text: str) -> ParsedRoute: + """ + Similar to parsing the dangling right unlabelled case + which is like `As ->` + but this is just the As part without the -> + """ # split on commas, strip whitespace, drop empties source_list = get_token_list(text, separator=',') # validate each token @@ -259,6 +377,9 @@ def compile_arrow_patterns(self) -> None: "When using SummonerClient, patterns are compiled automatically during registration." ) def ready(self) -> None: + """ + Same as compile_arrow_patterns but the deprecated version. + """ warnings.warn( "Flow.ready() is deprecated; use Flow.compile_arrow_patterns() instead. " "In SummonerClient, you generally don't need to call this.", @@ -269,10 +390,22 @@ def ready(self) -> None: self._prepare_regex() def parse_route(self, route: str) -> ParsedRoute: + """ + route is something like As --[Bs]-> Cs + or As --> etc among all the possibilities of + having labels or not, being dangling or not, + having an arrow or not. + Give the ParsedRoute which is storing the + sources, labels and targets with empty tuples for any + of them that were not present. + If present, the As, Bs, and Cs are split into tuples according + to the separator of the relevant ArrowStyle in order + to handle multiple sources, labels and/or targets + """ route = route.strip() if not self._regex_ready: self._prepare_regex() - + for pattern, style, unpack in self._regex_patterns: matched_pattern = pattern.match(route) if not matched_pattern: @@ -290,8 +423,11 @@ def parse_route(self, route: str) -> ParsedRoute: target=tuple(Node(tok) for tok in target_list), style=style, ) - + return self._parse_standalone(route) def parse_routes(self, routes: Iterable[str]) -> list[ParsedRoute]: + """ + Do parse_route for each of the routes + """ return [self.parse_route(route=route) for route in routes] diff --git a/summoner/protocol/payload.py b/summoner/protocol/payload.py index 09b7b0c..03c64c3 100644 --- a/summoner/protocol/payload.py +++ b/summoner/protocol/payload.py @@ -1,9 +1,13 @@ +""" +TODO: doc payload +""" +#pylint:disable=line-too-long import json from json import JSONDecodeError from typing import Any, Tuple, Dict, List, Union, TypedDict from summoner.utils import ( - fully_recover_json, + fully_recover_json, remove_last_newline, ensure_trailing_newline, ) @@ -29,9 +33,6 @@ def register_envelope_version( envelope_parsers[version] = parser envelope_casters[version] = caster - -from typing import Any, Tuple, Dict, List - STR_TYPE = "str" BOOL_TYPE = "bool" INT_TYPE = "int" @@ -39,6 +40,7 @@ def register_envelope_version( NULL_TYPE = "null" +#pylint:disable=too-many-return-statements def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: """ Walk `obj` and build a parallel `type_tree`. Every leaf in `obj` @@ -89,6 +91,7 @@ def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: return s, STR_TYPE +#pylint:disable=too-many-return-statements, too-many-branches def cast_v0_0_1(val: Any, expected: Any) -> Any: """ Coerce `val` according to `expected`, but never fail on unknown types. diff --git a/summoner/protocol/process.py b/summoner/protocol/process.py index 1687828..2ca02fb 100644 --- a/summoner/protocol/process.py +++ b/summoner/protocol/process.py @@ -1,7 +1,11 @@ +""" +TODO: doc process +""" + from __future__ import annotations import re from collections import defaultdict -from typing import Type, Any, Optional, Union, Callable, Awaitable +from typing import Dict, List, Literal, Tuple, Type, Any, Optional, Union, Callable, Awaitable from enum import Enum, auto from dataclasses import dataclass from .triggers import Signal, Event, Action, extract_signal @@ -18,38 +22,43 @@ # Wildcard sentinel for dispatch _WILDCARD = object() +KindType = Literal["all", "not", "oneof", "plain"] + class Node: + """ + TODO: doc node + """ __slots__ = ('expr', 'kind', 'values') def __init__(self, expr: str) -> None: _expr: str = expr.strip() - self.kind: str + self.kind: KindType self.values: Optional[tuple[str,...]] if _ALL_RE.fullmatch(_expr): self.kind = 'all' self.values = None - + elif (found_match := _NOT_RE.fullmatch(_expr)): self.kind = 'not' items = [item.strip() for item in found_match.group(1).split(',') if item.strip()] self.values = tuple(items) - + elif (found_match := _ONEOF_RE.fullmatch(_expr)): self.kind = 'oneof' items = [item.strip() for item in found_match.group(1).split(',') if item.strip()] self.values = tuple(items) - + elif _PLAIN_TOKEN_RE.fullmatch(_expr): self.kind = 'plain' - self.values = (_expr,) - + self.values = (_expr,) + else: raise ValueError(f"Invalid syntax for token: {_expr!r}") def __eq__(self, other: Any) -> bool: return ( - isinstance(other, Node) and + isinstance(other, Node) and self.kind == other.kind and self.values == other.values ) @@ -65,49 +74,56 @@ def __repr__(self) -> str: # f"\033[94mvalues\033[0m=\033[90m{self.values!r}\033[0m)" ) + #pylint:disable=too-many-return-statements def __str__(self) -> str: try: if self.kind == 'all': return '/all' - elif self.kind == 'plain': + if self.kind == 'plain': if self.values is None: - return f"" + return "" return self.values[0] - elif self.kind == 'not': + if self.kind == 'not': if self.values is None: - return f"" + return "" return f"/not({','.join(self.values)})" - elif self.kind == 'oneof': + if self.kind == 'oneof': if self.values is None: - return f"" + return "" return f"/oneof({','.join(self.values)})" - else: - return f"" - except Exception as e: + return f"" + except Exception as e: # pylint:disable=broad-exception-caught return f"" def accepts(self, state: Node) -> bool: - if not isinstance(state, Node): + """ + Handle the logic of whether this Node + accepts state or not. + In the plain case it is matching, + but there is also the logic of oneof, all, not kinds on either + operand. + """ + if not isinstance(state, Node): raise TypeError(f"Argument `state` must be Node; {state} provided") - - table = { + + table : Dict[Tuple[KindType | object, KindType | object], Callable[[Node,Node],bool]] = { ('all', 'all'): lambda g, s: True, ('all', _WILDCARD): lambda g, s: True, (_WILDCARD, 'all'): lambda g, s: True, - ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], - ('plain', 'not'): lambda g, s: g.values[0] not in s.values, - ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, - ('not', 'plain'): lambda g, s: s.values[0] not in g.values, + ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], # type: ignore + ('plain', 'not'): lambda g, s: g.values[0] not in s.values, # type: ignore + ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, # type: ignore + ('not', 'plain'): lambda g, s: s.values[0] not in g.values, # type: ignore ('not', _WILDCARD): lambda g, s: True, - ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, - ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), - ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), + ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, # type: ignore + ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), # type: ignore + ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), # type: ignore } for (gk, sk), fn in table.items(): if (gk == self.kind or gk is _WILDCARD) and (sk == state.kind or sk is _WILDCARD): return fn(self, state) - + raise RuntimeError("Unhandled combination in Node.is_compatible_with") # ======= ARROW STYLE ======= @@ -211,6 +227,7 @@ def _check_regex_safe(self) -> None: try: re.escape(part) except re.error as e: + #pylint:disable=raise-missing-from raise ValueError( f"Part {part!r} invalid for regex: {e}" ) @@ -218,6 +235,17 @@ def _check_regex_safe(self) -> None: # ======= PARSED ROUTE ======= class ParsedRoute: + """ + A parsed route holds + all the sources, labels and targets for an arrow that has been parsed + As --[Bs,Cs]--> /all has + Node for As in source + Node for Bs and Node for Cs in label + Node for /all in target + + This is also for when it is not an arrow as in standalone when there + are only sources + """ __slots__ = ('source', 'label', 'target', 'style', '_string') def __init__( @@ -269,18 +297,30 @@ def __str__(self) -> str: @property def has_label(self) -> bool: + """ + There are labels + """ return bool(self.label) @property def is_arrow(self) -> bool: + """ + This is actually an arrow, unlike the standalone only sources + """ return bool(self.target) or self.has_label @property def is_object(self) -> bool: + """ + It is standalone + """ return not self.is_arrow - + @property def is_initial(self) -> bool: + """ + It is dangling on the left + """ return self.is_arrow and not self.source def activated_nodes( @@ -294,9 +334,8 @@ def activated_nodes( if isinstance(event, Event) and event is not None and not self.is_arrow: if isinstance(event, Action.TEST): return () - else: - # standalone → only the source nodes - return self.source + # standalone → only the source nodes + return self.source # arrow route → pick based on the Action subtype if isinstance(event, Action.MOVE): @@ -313,6 +352,9 @@ def activated_nodes( @dataclass(frozen=True) class Sender: + """ + TODO: doc sender + """ __slots__ = ('fn', 'multi', 'actions', 'triggers') fn: Callable[[], Awaitable] multi: bool @@ -320,30 +362,42 @@ class Sender: triggers: Optional[set[Signal]] def responds_to(self, event: Any) -> bool: + """ + TODO: doc sender + """ action_check = True if self.actions is not None: if not any(isinstance(event, action) for action in self.actions): action_check = False - + trigger_check = True if self.triggers is not None: if not any(extract_signal(event) == trig for trig in self.triggers): trigger_check = False - + return action_check and trigger_check @dataclass(frozen=True) class Receiver: + """ + TODO: doc receiver + """ __slots__ = ('fn', 'priority') fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]] priority: tuple[int, ...] class Direction(Enum): + """ + Only two directions + """ SEND = auto() RECEIVE = auto() @dataclass(frozen=True) class TapeActivation: + """ + TODO: doc tape activation + """ __slots__ = ('key', 'state', 'route', 'fn') key: Optional[str] state: Optional[Node] @@ -353,12 +407,18 @@ class TapeActivation: # ======= STATE TAPE ======= class TapeType(Enum): + """ + TODO: doc tapetype + """ SINGLE = auto() MANY = auto() INDEX_SINGLE = auto() INDEX_MANY = auto() class StateTape: + """ + TODO: doc state tape + """ __slots__ = ('states', '_type') prefix: str = "tape" @@ -369,7 +429,7 @@ def __init__(self, states: Any = None, with_prefix: bool = True): # Default: empty index-many if tp is None: - self.states = {} + self.states : Dict[str,List[Node]] = {} self._type = TapeType.INDEX_MANY # Exactly SINGLE @@ -405,11 +465,20 @@ def __init__(self, states: Any = None, with_prefix: bool = True): raise RuntimeError(f"Unhandled TapeType {tp!r}") def set_type(self, value: TapeType) -> StateTape: + """ + Change the type + """ self._type = value return self @staticmethod def _assess_type(states: Any) -> Optional[TapeType]: + """ + If there is only one state, SINGLE + If there is a list or tuple of many states, MANY + If there is a dictionary sending each Optional[str] key to a single state, INDEX_SINGLE + If there is a dictionary sending each Optional[str] key to possibly many states, INDEX_MANY + """ # Scalar → SINGLE if isinstance(states, (str, Node)): return TapeType.SINGLE @@ -430,6 +499,7 @@ def _assess_type(states: Any) -> Optional[TapeType]: return TapeType.INDEX_SINGLE # all values are either scalar or sequence of scalars → INDEX_MANY + # but at least one of the values was actually a sequence if all( isinstance(k, (str, type(None))) and ( @@ -446,21 +516,36 @@ def _assess_type(states: Any) -> Optional[TapeType]: return None def _add_prefix(self, key: str, with_prefix: bool = True) -> str: + """ + TODO: doc prefix + """ return f"{self.prefix}:{key}" if with_prefix else key def _remove_str_prefix(self, key: str) -> str: + """ + TODO: doc prefix + """ p = f"{self.prefix}:" return key[len(p):] if key.startswith(p) else key - + # The type annotation here is incorrect, key=None gives None # but making this Optional[str] in return # is not the intended behavior and pollutes that possibility # for the caller even when they are not passing key=None def _remove_prefix(self, key: Optional[str]) -> str: + """ + TODO: doc prefix + """ p = f"{self.prefix}:" return key[len(p):] if isinstance(key, str) and key.startswith(p) else key # type: ignore def extend(self, states: Any): + """ + Merge in these new states with the current self.states + Creating a temporary StateTape handles the different ways + the new states can be presented as in __init__ rather than always being + a dictionary from strings to List[Node] + """ # Delegate to a local StateTape then merge local_tape = StateTape(states, with_prefix=False) for k, nodes in local_tape.states.items(): @@ -468,10 +553,21 @@ def extend(self, states: Any): self.states[k].extend(nodes) def refresh(self): - # Delegate to a fresh StateTape - return StateTape({key: [] for key in self.states.keys()}, with_prefix=False).set_type(self._type) + """ + Delegate to a fresh StateTape + This has the same keys, but no more Nodes in the accompanying values + """ + return StateTape( + {key: [] for key in self.states.keys()}, + with_prefix=False + ).set_type(self._type) def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: + """ + The states as it was provided as the input to __init__ + StateTape(revert(StateTape(states,remove_prefix=b)),remove_prefix=b) + goes back and forth + """ # SINGLE or MANY → flatten to a single list if self._type in (TapeType.SINGLE, TapeType.MANY): out_list: list[Node] = [] @@ -491,9 +587,11 @@ def revert(self) -> Union[list[Node], dict[str, list[Node]], None]: return None def _nodeify(self, x: Union[str, Node]) -> Node: - # wrap raw strings into Node + """ + wrap raw strings into Node + """ return x if isinstance(x, Node) else Node(x) - + def collect_activations( self, receiver_index: dict[str, Receiver], @@ -502,7 +600,8 @@ def collect_activations( """ For each receiver, and each (key, state) in self.states, if the parsed route matches that state (or has no source), - record a TapeActivation(priority, key, state, route, fn). + record a priority: TapeActivation(key, state, route, fn) + key-value pair """ activation_index: dict[tuple[int, ...], list[TapeActivation]] = defaultdict(list) @@ -531,10 +630,10 @@ def collect_activations( # ======= LIFE CYCLES ======= -from enum import Enum, auto - class ClientIntent(Enum): + """ + Life Cycles + """ QUIT = auto() # brutal, immediate exit TRAVEL = auto() # switch to a new host/port ABORT = auto() # abort due to error - diff --git a/summoner/protocol/triggers.py b/summoner/protocol/triggers.py index b8f643c..56c688c 100644 --- a/summoner/protocol/triggers.py +++ b/summoner/protocol/triggers.py @@ -117,7 +117,9 @@ def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any parent = nodes_by_depth[depth] if name in parent: - raise ValueError(f"Line {lineno}: duplicate signal name {name!r} at indent level {indent}") + raise ValueError( + f"Line {lineno}: duplicate signal name {name!r} at indent level {indent}" + ) parent[name] = {} # 3) Trim nodes_by_depth so we don't keep stale deeper dicts @@ -141,6 +143,11 @@ def parse_signal_tree(filepath: Path | str, tabsize: int = 8) -> dict[str, Any]: class Signal: + """ + Keep track of position in a tree + via the path followed. + Allowing comparison by the ancestor relationship. + """ __slots__ = ("_path", "_name") def __init__(self, path: tuple[int, ...], name: str): self._path = path @@ -148,10 +155,16 @@ def __init__(self, path: tuple[int, ...], name: str): @property def path(self) -> tuple[int, ...]: + """ + The path in the tree from the root + """ return self._path @property def name(self) -> str: + """ + The name of this Signal + """ return self._name def __gt__(self, other): @@ -159,7 +172,7 @@ def __gt__(self, other): return NotImplemented a, b = self._path, other._path return len(a) < len(b) and b[:len(a)] == a - + def __lt__(self, other): return other > self @@ -167,7 +180,7 @@ def __ge__(self, other): if not isinstance(other, Signal): return NotImplemented return self._path == other._path or self > other - + def __le__(self, other): return other >= self @@ -184,6 +197,9 @@ def __repr__(self): def build_triggers(tree: dict[str, Any]): + """ + TODO: doc Trigger + """ name_to_path: dict[str, tuple[int, ...]] = {} path_to_name: dict[tuple[int, ...], str] = {} @@ -215,13 +231,17 @@ def recurse(subtree: dict[str, Any], prefix: tuple[int, ...] =()): def name_of(*args): """Get signal name from tuple (or *args).""" return path_to_name.get(tuple(args)) - + attrs["name_of"] = staticmethod(name_of) return type("Trigger", (), attrs) +#pylint:disable=too-few-public-methods class Event: + """ + TODO: doc event + """ __slots__ = ("signal",) def __init__(self, signal: Signal) -> None: self.signal = signal @@ -229,36 +249,54 @@ def __repr__(self) -> str: return f"{type(self).__name__}({self.signal!r})" -class Move(Event): pass -class Stay(Event): pass +class Move(Event): + """ + TODO: doc move event + activated_nodes + """ +class Stay(Event): + """ + TODO: doc stay event + activated_nodes + """ class Test(Event): + """ + TODO: test event + activated_nodes + """ __test__ = False - pass + class Action: + """ + TODO: doc in activation_nodes + """ MOVE = Move STAY = Stay TEST = Test def extract_signal(trigger): + """ + Just the signal part. + Handling if it was wrapped up in an event or not + """ if isinstance(trigger, Event): return trigger.signal - elif isinstance(trigger, Signal): + if isinstance(trigger, Signal): return trigger - elif trigger is None: + if trigger is None: return None - else: - raise TypeError(f"Cannot extract signal from object of type {type(trigger).__name__!r}") - + raise TypeError(f"Cannot extract signal from object of type {type(trigger).__name__!r}") + # the file TRIGGERS needs to be next to the code for the client, hence sys.argv[0] WORKING_DIR = Path(sys.argv[0]).resolve().parent def load_triggers( triggers_file: Optional[str] = "TRIGGERS", - text: Optional[str] = None, + text: Optional[str] = None, json_dict: Optional[dict[str, Any]] = None ): """ @@ -291,10 +329,13 @@ def load_triggers( else: # In this case triggers_file was explicitly provided as None path = WORKING_DIR / "TRIGGERS" - raise FileNotFoundError("no triggers file and weren't using the default either") + raise FileNotFoundError( + "no triggers file and weren't using the default either" + ) # This below can raise ValueError or FileNotFoundError tree =parse_signal_tree(path) except FileNotFoundError as e: + #pylint:disable=line-too-long raise FileNotFoundError( f"Could not find triggers file at {path if 'path' in locals() else ''}" ) from e diff --git a/summoner/protocol/validation.py b/summoner/protocol/validation.py index f2659dc..078cb61 100644 --- a/summoner/protocol/validation.py +++ b/summoner/protocol/validation.py @@ -1,27 +1,33 @@ +""" +TODO: doc validation +""" import os import sys from typing import ( - Union, - get_type_hints, - get_origin, + Union, + get_type_hints, + get_origin, get_args, ) import inspect target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) +#pylint:disable=wrong-import-position from summoner.logger import Logger def hook_priority_order(priority: tuple) -> tuple: - + """ + TODO: doc hook_priority + """ + if priority != () and isinstance(priority[0], str): # SDK priority: sort first, e.g., ('aurora', 0) → (0, 'aurora', 0) return (0, priority) - else: - # User priority: sort after SDK, e.g., (0, 1) → (1, 0, 1) - return (1, priority) + # User priority: sort after SDK, e.g., (0, 1) → (1, 0, 1) + return (1, priority) def _normalize_annotation(raw): """ @@ -58,16 +64,19 @@ def _check_param_and_return(fn, # parameter check params = list(sig.parameters.values()) - expected_params = 1 if decorator_name in ("@hook", "@receive", "@upload_states", "@download_states") else 0 - + expected_params = 1 if decorator_name in \ + ("@hook", "@receive", "@upload_states", "@download_states") else 0 + if len(params) != expected_params: raise TypeError(f"{decorator_name} '{fn.__name__}' must have " f"{expected_params} parameter(s), not {len(params)}") if expected_params == 1: raw_param = params[0].annotation - param_hint = _normalize_annotation(raw_param) or hints.get(params[0].name, None) + param_hint = _normalize_annotation(raw_param) or \ + hints.get(params[0].name, None) if param_hint is None: + #pylint:disable=line-too-long logger.warning( f"{decorator_name} '{fn.__name__}' missing parameter annotation; skipping type check" ) @@ -87,4 +96,4 @@ def _check_param_and_return(fn, elif not _valid_type_hint(ret_hint, allow_return): raise TypeError( f"{decorator_name} '{fn.__name__}' must return one of {allow_return}, not {ret_hint!r}" - ) \ No newline at end of file + ) From 5d09d191a9e0573a72b8b64528b502f64f8e8241 Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 19 Feb 2026 23:56:04 -0500 Subject: [PATCH 07/11] misc, hook annotations wip --- summoner/client/__init__.py | 2 +- summoner/client/client.py | 20 +- summoner/client/just_merger.py | 850 +++++++++++++++++++++ summoner/client/merger.py | 1262 +------------------------------- summoner/client/translation.py | 479 ++++++++++++ summoner/logger.py | 6 +- summoner/protocol/payload.py | 25 +- summoner/server/__init__.py | 5 +- summoner/server/server.py | 80 +- 9 files changed, 1429 insertions(+), 1300 deletions(-) create mode 100644 summoner/client/just_merger.py create mode 100644 summoner/client/translation.py diff --git a/summoner/client/__init__.py b/summoner/client/__init__.py index 55d4f47..fc94b26 100644 --- a/summoner/client/__init__.py +++ b/summoner/client/__init__.py @@ -1,5 +1,5 @@ """ -TODO: doc client and ClientMerger summary +TODO: doc client, ClientMerger and ClientTranslation summary """ from .client import SummonerClient from .merger import ClientMerger, ClientTranslation diff --git a/summoner/client/client.py b/summoner/client/client.py index 171047a..bc97466 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -1,7 +1,7 @@ """ SummonerClient """ -#pylint:disable=line-too-long, too-many-lines, wrong-import-position +#pylint:disable=line-too-long, wrong-import-position #pylint:disable=logging-fstring-interpolation, broad-exception-caught import os @@ -66,7 +66,11 @@ RelayedMessage ) +#pylint:disable=invalid-name ANY_TO_AWAIT = Callable[[Any],Awaitable] +HOOK_TYPE = Callable[[Union[str, dict]], Union[str, dict]] +GEN_HOOK_TYPE = Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]] +ASYNC_GEN_HOOK_TYPE = Callable[[Optional[Union[str, dict]]], Awaitable[Optional[Union[str, dict]]]] class ServerDisconnected(Exception): """Raised when the server closes the connection.""" @@ -161,8 +165,8 @@ def __init__(self, name: Optional[str] = None): self.writer_lock = asyncio.Lock() # Store validation hooks to be used before sending and after receiving - self.sending_hooks: dict[tuple[int,...], Callable[[Union[str, dict]], Union[str, dict]]] = {} - self.receiving_hooks: dict[tuple[int,...], Callable[[Union[str, dict]], Union[str, dict]]] = {} + self.sending_hooks: dict[tuple[int,...], HOOK_TYPE | GEN_HOOK_TYPE | ASYNC_GEN_HOOK_TYPE] = {} + self.receiving_hooks: dict[tuple[int,...], HOOK_TYPE | GEN_HOOK_TYPE | ASYNC_GEN_HOOK_TYPE] = {} self.hooks_lock = asyncio.Lock() # ─── DNA capture for merging ───────────────────────────────────────── @@ -374,7 +378,7 @@ def hook( """ TODO: doc hook """ - def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]]): + def decorator(fn: GEN_HOOK_TYPE): """ TODO: doc decorator """ @@ -384,6 +388,7 @@ def decorator(fn: Callable[[Optional[Union[str, dict]]], Optional[Union[str, dic if not inspect.iscoroutinefunction(fn): raise TypeError(f"@hook handler '{fn.__name__}' must be async") + _check_param_and_return( fn, decorator_name="@hook", @@ -417,9 +422,9 @@ async def register(): """ async with self.hooks_lock: if direction == Direction.RECEIVE: - self.receiving_hooks[tuple_priority] = fn # type: ignore + self.receiving_hooks[tuple_priority] = fn elif direction == Direction.SEND: - self.sending_hooks[tuple_priority] = fn # type: ignore + self.sending_hooks[tuple_priority] = fn # ----[ Safe Registration ]---- # NOTE: register() is run ASAP and _registration_tasks is used to wait all registrations before run_client() @@ -1072,7 +1077,8 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: # if not data: # raise ServerDisconnected("Server closed the connection.") - payload: RelayedMessage = recover_with_types(data.decode()) + pre_payload: RelayedMessage = recover_with_types(data.decode()) + payload: str | dict | RelayedMessage = pre_payload # ----[ Build: Validation ]---- async with self.hooks_lock: diff --git a/summoner/client/just_merger.py b/summoner/client/just_merger.py new file mode 100644 index 0000000..690fc8b --- /dev/null +++ b/summoner/client/just_merger.py @@ -0,0 +1,850 @@ +""" +merger.py + +This module provides two related utilities built on top of SummonerClient: + +1) ClientMerger + Build a single composite SummonerClient by replaying handlers from multiple sources. + + A "source" can be: + - an imported SummonerClient instance (live Python object), or + - a DNA list (already loaded JSON list[dict]), or + - a DNA JSON file path. + + Imported-client sources: + - handlers keep their original module globals (module-backed execution), + - the original client binding (for example the name "agent") is rebound to the merged client, + - optional rebind_globals are injected into handler globals. + + DNA sources: + - handlers are reconstructed by compiling their recorded source text into an isolated + sandbox module (one sandbox per DNA source), + - the sandbox binds var_name (for example "agent") to the merged client instance, + so handler code that references `agent` executes against the composite client, + - optional context (imports, globals, recipes) is applied into the sandbox. + + Usage pattern: + - instantiate ClientMerger(...) + - configure flow / styles as usual on the merged client if desired + - call agent.initiate_all() to replay handlers onto the merged client + - call agent.run(...) + +2) ClientTranslation + Reconstruct a fresh SummonerClient from a DNA list. + + Translation compiles handler functions from their recorded source into a fresh sandbox module, + binds var_name (for example "agent") to the translated client, then registers the handlers + using the normal decorators. + +Security and trust model +------------------------ +Both classes execute code from DNA via exec() and eval(): + +- context imports (ctx["imports"]) +- recipes (ctx["recipes"]) +- handler bodies (entry["source"]) + +This is intended for trusted DNA (typically produced by your own agents). +Do not run untrusted DNA. +""" +#pylint:disable=line-too-long, wrong-import-position +#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation + +from typing import Optional, Any +from contextlib import suppress +from pathlib import Path +import inspect +import asyncio +import types +import re +import json +import uuid + +import os +import sys +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.client.client import SummonerClient +from summoner.protocol.triggers import Action, load_triggers +from summoner.protocol.process import Direction + + +def _resolve_trigger(TriggerCls, name: str): + """ + Resolve a trigger name into a trigger instance from TriggerCls. + + DNA stores triggers as strings. This helper supports the two common access patterns: + - Enum-style indexing: TriggerCls["ok"] + - Attribute access: TriggerCls.ok + + Parameters + ---------- + TriggerCls: + Trigger class or enum-like object returned by flow.triggers() or load_triggers(). + name: + Trigger name as stored in DNA. + + Returns + ------- + Any + The resolved trigger value. + + Raises + ------ + KeyError + If the trigger cannot be resolved. + """ + # Enum-style: TriggerCls["ok"] + try: + return TriggerCls[name] + except Exception: + pass + # Attribute-style: TriggerCls.ok + try: + return getattr(TriggerCls, name) + except Exception: + pass + raise KeyError(f"Unknown trigger '{name}' for {TriggerCls}") + + +def _resolve_action(ActionCls, name: str): + """ + Resolve an action name into the corresponding Action entry. + + DNA stores actions as strings. Depending on how a sender was serialized, the name can be: + - the enum attribute name ("MOVE") + - a mixed-case name ("Move") + - the underlying class name (Move.__name__ == "Move") + + Parameters + ---------- + ActionCls: + The Action container used by the protocol layer (typically summoner.protocol.triggers.Action). + name: + Action name as stored in DNA. + + Returns + ------- + Any + The resolved Action entry. + + Raises + ------ + KeyError + If the action cannot be resolved. + """ + # 1) Try direct attribute match: "MOVE" + if hasattr(ActionCls, name): + return getattr(ActionCls, name) + + # 2) Try uppercased: "Move" -> "MOVE" + up = name.upper() + if hasattr(ActionCls, up): + return getattr(ActionCls, up) + + # 3) Try matching the underlying class/function name: + # Action.MOVE == Move, where Move.__name__ == "Move" + for v in ActionCls.__dict__.values(): + if getattr(v, "__name__", None) == name: + return v + + raise KeyError(f"Unknown action '{name}' for {ActionCls}") + + +class ClientMerger(SummonerClient): + """ + Merge multiple sources into one client. + + Each input source can be: + - an imported SummonerClient instance (module-backed execution), + - a DNA list (list[dict]), + - a DNA JSON file path. + + Imported-client sources + ----------------------- + - Handlers keep their original module globals. + - The original client binding (var_name such as "agent") is rebound to the merged client. + - Optional rebind_globals are injected into handler globals. + - Note: rebinding mutates handler globals. This is intentional. + + DNA sources + ----------- + - Each DNA source gets its own sandbox module (isolated globals dict). + - var_name is bound to the merged client in that sandbox, so handler code referencing + `agent` executes against the composite client. + - Optional context imports/globals/recipes are executed in the sandbox. + + Execution model + --------------- + ClientMerger does not automatically register handlers during __init__. + You must call initiate_all() (or initiate_* individually) before run(). + + Safety + ------ + This class executes trusted code from DNA via exec()/eval(). Do not run untrusted DNA. + """ + + # pylint:disable=too-many-arguments, too-many-positional-arguments + def __init__( + self, + named_clients: list[Any], # backward compatible: list[dict] or list[SummonerClient] or list[dna_list] + name: Optional[str] = None, + rebind_globals: Optional[dict[str, Any]] = None, + allow_context_imports: bool = True, + verbose_context_imports: bool = False, + close_subclients: bool = True, + ): + super().__init__(name=name) + + # Globals injected into: + # - sandbox globals dicts (DNA sources), and + # - imported handler globals (imported-client sources). + # This is how "missing" symbols like Trigger or shared objects are supplied. + self._rebind_globals = dict(rebind_globals or {}) + + # Context controls: + # - allow_context_imports: execute import lines found in DNA context + # - verbose_context_imports: log successes as well as failures + self._allow_context_imports = allow_context_imports + self._verbose_context_imports = verbose_context_imports + + # If True, imported template clients are cleaned up after extraction: + # cancel/drain registration tasks and close their event loops when possible. + # This reduces warnings when importing agent scripts as templates. + self._close_subclients = close_subclients + + # Normalized sources used later by initiate_* replay methods. + self.sources: list[dict[str, Any]] = [] + self._import_reports: list[dict[str, Any]] = [] + + for idx, entry in enumerate(named_clients): + src = self._normalize_source(entry, idx) + self.sources.append(src) + + if self._close_subclients: + self._shutdown_imported_clients() + + # ---------------------------- + # Source normalization + # ---------------------------- + + #pylint:disable=too-many-branches + def _normalize_source(self, entry: Any, idx: int) -> dict[str, Any]: + """ + Normalize a user-provided source specification into a canonical dict. + + Accepted inputs + --------------- + - SummonerClient instance + - DNA list (list[dict]) + - dict containing one of: {"client"}, {"dna_list"}, {"dna_path"} + + Normalized output + ----------------- + Returns a dict with at least: + - kind: "client" or "dna" + - var_name: global name used by handler sources to refer to the client ("agent" by default) + + For kind="client": + - client: the imported SummonerClient instance + + For kind="dna": + - dna_entries: handler entries (context removed if present) + - context: optional __context__ entry + - sandbox_name: unique module name + - globals: sandbox globals dict where code is compiled and executed + - import_report: best-effort report for context imports + """ + # Allow passing a SummonerClient directly + if isinstance(entry, SummonerClient): + entry = {"client": entry} + + # Allow passing a dna_list directly + if isinstance(entry, list): + entry = {"dna_list": entry} + + if not isinstance(entry, dict): + raise TypeError(f"Entry #{idx} must be dict | SummonerClient | dna_list, got {type(entry).__name__}") + + if "client" in entry: + client = entry["client"] + if not isinstance(client, SummonerClient): + raise TypeError(f"Entry #{idx} 'client' must be SummonerClient, got {type(client).__name__}") + + # var_name controls which global name will be rebound to the merged client. + var_name = entry.get("var_name") + if var_name is None: + var_name = self._infer_client_var_name(client) or "agent" + if not isinstance(var_name, str): + raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") + + return { + "kind": "client", + "client": client, + "var_name": var_name, + } + + # DNA sources + dna_list = None + if "dna_list" in entry: + dna_list = entry["dna_list"] + elif "dna_path" in entry: + p = Path(entry["dna_path"]) + dna_list = json.loads(p.read_text(encoding="utf-8")) + else: + raise KeyError(f"Entry #{idx} must contain 'client' or 'dna_list' or 'dna_path'") + + if not isinstance(dna_list, list): + raise TypeError(f"Entry #{idx} DNA must be a list, got {type(dna_list).__name__}") + + # Optional context header is stored in DNA as the first entry. + ctx = None + entries = dna_list + if entries and isinstance(entries[0], dict) and entries[0].get("type") == "__context__": + ctx = entries[0] + entries = entries[1:] + + # Determine var_name binding: + # - explicit var_name in entry wins + # - else use ctx["var_name"] + # - else default to "agent" + var_name = entry.get("var_name") + if var_name is None: + ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None + var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" + if not isinstance(var_name, str): + raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") + + # Each DNA source gets its own sandbox module (isolated globals dict). + sandbox_module_name = f"summoner_merge_{uuid.uuid4().hex}" + sandbox_module = types.ModuleType(sandbox_module_name) + sys.modules[sandbox_module_name] = sandbox_module + g = sandbox_module.__dict__ + + # Bind the client name used inside handler source to the merged client. + # This makes `await agent.travel_to(...)` act on the composite client. + g[var_name] = self + + # Apply context (imports/globals/recipes) into that sandbox. + report = self._apply_context(ctx, g, label=f"source#{idx}") + self._import_reports.append(report) + + return { + "kind": "dna", + "dna_entries": entries, + "context": ctx, + "var_name": var_name, + "sandbox_name": sandbox_module_name, + "globals": g, + "import_report": report, + } + + def _infer_client_var_name(self, client: SummonerClient) -> Optional[str]: + """ + Infer the module-global variable name used by handlers to refer to `client`. + + This is used for imported-client sources so that we can rebind that name + (for example "agent") to the merged client in handler globals. + + Returns + ------- + Optional[str] + The inferred binding name, or None if not found. + """ + # Look for a module-global name whose value is exactly `client` + # TODO: can use _view_candidates instead of building a list + candidates = [] + #pylint:disable=protected-access + if client._upload_states is not None: + candidates.append(client._upload_states) + if client._download_states is not None: + candidates.append(client._download_states) + for d in client._dna_receivers: + candidates.append(d.get("fn")) + for d in client._dna_senders: + candidates.append(d.get("fn")) + for d in client._dna_hooks: + candidates.append(d.get("fn")) + + for fn in candidates: + if fn is None: + continue + g = getattr(fn, "__globals__", None) + if not isinstance(g, dict): + continue + for k, v in g.items(): + if v is client and isinstance(k, str): + return k + return None + + def _shutdown_imported_clients(self) -> None: + """ + Best-effort cleanup for imported template clients. + + Why this exists + -------------- + Many agent scripts create a SummonerClient at import-time. That client creates + an event loop and schedules registration tasks. + + When agent scripts are imported only as templates for merging, those template + clients should not be left alive, otherwise you often see: + - "coroutine was never awaited" + - "Task was destroyed but it is pending" + + Cleanup approach + ---------------- + For each imported client: + 1) cancel pending registration tasks + 2) if its loop is not running, drive the loop to await cancellations + 3) close the loop + 4) clear the template's registration list + """ + for src in self.sources: + if src.get("kind") != "client": + continue + + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + + #pylint:disable=protected-access + tasks = list(client._registration_tasks or []) + loop = client.loop + + # Nothing to do. + if not tasks and loop.is_closed(): + continue + + # 1) cancel tasks + for t in tasks: + with suppress(Exception): + t.cancel() + + # 2) drain tasks on that loop so they are actually awaited + if loop is not None and not loop.is_closed(): + if loop.is_running(): + # Can't safely run_until_complete; also shouldn't close a running loop. + self.logger.warning( + f"[{var_name}] Imported client loop is running; " + f"cannot drain/close registration tasks cleanly." + ) + else: + old_loop = None + try: + # Set context so asyncio.gather/futures bind to the right loop. + with suppress(Exception): + old_loop = asyncio.get_event_loop_policy().get_event_loop() + asyncio.set_event_loop(loop) + + # Await cancellation. This is what prevents warnings. + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + except Exception as e: + self.logger.warning(f"[{var_name}] Error draining registration tasks: {e}") + finally: + # Restore previous loop context (or clear it). + with suppress(Exception): + asyncio.set_event_loop(old_loop) + + # 3) close loop after drain + try: + loop.close() + except Exception as e: + self.logger.warning(f"[{var_name}] Error closing event loop: {e}") + + # 4) clear list + with suppress(Exception): + #pylint:disable=protected-access + client._registration_tasks.clear() + + # ---------------------------- + # Context application (DNA) + # ---------------------------- + + # pylint:disable=too-many-branches + def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[str, Any]: + """ + Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. + + This is best-effort: + - import failures are recorded and logged + - recipe failures are logged but do not abort merge + + Parameters + ---------- + ctx: + The optional "__context__" DNA header entry. + g: + The sandbox globals dict (module.__dict__) where code will be executed. + label: + Used for logging and for the returned report. + + Returns + ------- + dict[str, Any] + A structured report with keys: label, succeeded, failed, skipped. + + Security note + ------------- + This executes code from ctx (imports and recipes). Use only with trusted DNA. + """ + report = {"label": label, "succeeded": [], "failed": [], "skipped": []} + + if not isinstance(ctx, dict): + return report + + # imports (executed inside sandbox namespace) + for line in (ctx.get("imports") or []): + if not isinstance(line, str) or not line.strip(): + continue + + if not self._allow_context_imports: + report["skipped"].append(line) + continue + + try: + # pylint:disable=exec-used + exec(line, g) + report["succeeded"].append(line) + if self._verbose_context_imports: + self.logger.info(f"[merge ctx:{label}] import ok: {line}") + except Exception as e: + report["failed"].append((line, f"{type(e).__name__}: {e}")) + self.logger.warning(f"[merge ctx:{label}] import failed: {line!r} ({type(e).__name__}: {e})") + + # plain globals (already JSON-friendly) + globs = ctx.get("globals") or {} + if isinstance(globs, dict): + for k, v in globs.items(): + if isinstance(k, str): + g.setdefault(k, v) + + # recipes (evaluated inside the sandbox namespace) + recipes = ctx.get("recipes") or {} + if isinstance(recipes, dict): + for k, expr in recipes.items(): + if not (isinstance(k, str) and isinstance(expr, str)): + continue + try: + # pylint:disable=eval-used + g.setdefault(k, eval(expr, g, {})) + except Exception as e: + self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") + + return report + + # ---------------------------- + # Utility: source patch for getsource capture + # ---------------------------- + + def _apply_with_source_patch(self, decorator, fn, source: str): + """ + Apply a SummonerClient decorator while preserving DNA source text. + + Why patch inspect.getsource + --------------------------- + The base decorators record handler source using inspect.getsource(fn). + When functions are constructed from DNA via exec(), inspect.getsource(fn) + will typically fail because there is no real source file. + + This helper temporarily patches inspect.getsource so the decorator stores + the original DNA text. This patch is process-global, so it should be used + only in controlled, single-threaded contexts (typical CLI usage). + """ + orig = inspect.getsource + inspect.getsource = lambda o: source + try: + decorator(fn) + finally: + inspect.getsource = orig + + # ---------------------------- + # Imported-client handler cloning + # ---------------------------- + + def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.FunctionType: + """ + Clone a function object for imported-client sources while preserving module globals. + + Behavior + -------- + - Mutates fn.__globals__ in-place: + - rebind original_name (for example "agent") to the merged client + - inject rebind_globals into the same globals dict + - Creates a new function object using the original code object and closure. + + This preserves the module-backed environment of imported clients. The cost is that + imported handler globals are modified, which is intentional for merge semantics. + """ + g = fn.__globals__ + + # rebind the client variable name (agent/client/etc) + try: + g[original_name] = self + except Exception as e: + self.logger.warning(f"Could not bind '{original_name}' to merged client: {e}") + + # rebind any shared globals (viz, Trigger, etc.) + for k, v in self._rebind_globals.items(): + try: + g[k] = v + except Exception as e: + self.logger.warning(f"Could not bind global '{k}' in '{fn.__name__}': {e}") + + new_fn = types.FunctionType( + fn.__code__, + g, + name=fn.__name__, + argdefs=fn.__defaults__, + closure=fn.__closure__, + ) + new_fn.__annotations__ = getattr(fn, "__annotations__", {}) + new_fn.__doc__ = getattr(fn, "__doc__", None) + + # if your dna() uses a __dna_source__ fallback, keep it + if hasattr(fn, "__dna_source__"): + new_fn.__dna_source__ = fn.__dna_source__ # type: ignore + + return new_fn + + + # ---------------------------- + # DNA compilation (per-source sandbox) + # ---------------------------- + + def _make_from_source(self, entry: dict[str, Any], g: dict, sandbox_name: str) -> types.FunctionType: + """ + Build a function object from a DNA entry by executing its function body in a sandbox. + + Key rule: strip decorators + -------------------------- + DNA sources typically include decorator lines (for example "@agent.receive(...)"). + We skip all decorators and only exec the "def ..." block. Otherwise, compiling the + function would also register it immediately, which would duplicate handlers. + + Globals + ------- + - rebind_globals is injected into g before exec so required runtime symbols + (for example Trigger, viz, shared objects) are visible to the compiled function. + - var_name is already bound to the merged client in g during source normalization. + """ + fn_name = entry["fn_name"] + raw = entry["source"] + lines = raw.splitlines() + + # --------------------------------------------------------------------- + # 1) Find the def line for this function, skipping decorators. + # --------------------------------------------------------------------- + def_idx = None + for idx, line in enumerate(lines): + pat = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" + if re.match(pat, line): + def_idx = idx + break + if def_idx is None: + raise RuntimeError(f"Could not find def for '{fn_name}'") + + func_body = "\n".join(lines[def_idx:]) + + # --------------------------------------------------------------------- + # 2) Ensure rebinding happens in the same globals dict used by exec(). + # --------------------------------------------------------------------- + rebind = self._rebind_globals + if isinstance(rebind, dict) and rebind: + g.update(rebind) + + # --------------------------------------------------------------------- + # 3) Execute and retrieve the resulting function object. + # --------------------------------------------------------------------- + if "__builtins__" not in g: + g["__builtins__"] = __builtins__ + + # pylint:disable=exec-used + exec(compile(func_body, filename=f"<{sandbox_name}>", mode="exec"), g) + + fn = g.get(fn_name) + if not isinstance(fn, types.FunctionType): + raise RuntimeError(f"Failed to construct function '{fn_name}'") + + return fn + + + # ---------------------------- + # Public replay API + # ---------------------------- + + def initiate_upload_states(self): + """Replay @upload_states from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + fn = client._upload_states + if fn is None: + continue + fn_clone = self._clone_handler(fn, var_name) + try: + self.upload_states()(fn_clone) + except Exception as e: + self.logger.warning(f"[{var_name}] Failed to replay upload_states '{fn.__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "upload_states": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.upload_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_download_states(self): + """Replay @download_states from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + fn = client._download_states + if fn is None: + continue + fn_clone = self._clone_handler(fn, var_name) + try: + self.download_states()(fn_clone) + except Exception as e: + self.logger.warning(f"[{var_name}] Failed to replay download_states '{fn.__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "download_states": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.download_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_hooks(self): + """Replay @hook(Direction, priority=...) from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_hooks: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.hook(dna["direction"], priority=dna["priority"])(fn_clone) + except Exception as e: + self.logger.warning(f"[{var_name}] Failed to replay hook '{dna['fn'].__name__}': {e}") + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "hook": + continue + fn = self._make_from_source(entry, g, sandbox) + direction = Direction[entry["direction"]] if isinstance(entry.get("direction"), str) else entry["direction"] + dec = self.hook(direction, priority=tuple(entry.get("priority", ()))) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_receivers(self): + """Replay @receive(route, priority=...) from every source onto the merged client.""" + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_receivers: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.receive(dna["route"], priority=dna["priority"])(fn_clone) + except Exception as e: + self.logger.warning( + f"[{var_name}] Failed to replay receiver '{dna['fn'].__name__}' on route '{dna['route']}': {e}" + ) + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + for entry in src["dna_entries"]: + if entry.get("type") != "receive": + continue + fn = self._make_from_source(entry, g, sandbox) + dec = self.receive(entry["route"], priority=tuple(entry.get("priority", ()))) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_senders(self): + """ + Replay @send(route, multi, on_triggers, on_actions) from every source onto the merged client. + + Imported-client sources: + - carry actual trigger/action objects in _dna_senders. + + DNA sources: + - store trigger/action names as strings. + - triggers are resolved using TriggerCls: + - prefer Trigger class provided by sandbox context ("Trigger" in sandbox globals) + - else fall back to load_triggers() + - actions are resolved from Action by name via _resolve_action. + """ + for src in self.sources: + if src["kind"] == "client": + client: SummonerClient = src["client"] + var_name: str = src["var_name"] + #pylint:disable=protected-access + for dna in client._dna_senders: + fn_clone = self._clone_handler(dna["fn"], var_name) + try: + self.send( + dna["route"], + multi=dna["multi"], + on_triggers=dna["on_triggers"], + on_actions=dna["on_actions"], + )(fn_clone) + except Exception as e: + self.logger.warning( + f"[{var_name}] Failed to replay sender '{dna['fn'].__name__}' on route '{dna['route']}': {e}" + ) + + else: + g = src["globals"] + sandbox = src["sandbox_name"] + + # Triggers: prefer a Trigger class provided by sandbox context; otherwise load defaults. + TriggerCls = g.get("Trigger") + if TriggerCls is None: + TriggerCls = load_triggers() + + for entry in src["dna_entries"]: + if entry.get("type") != "send": + continue + fn = self._make_from_source(entry, g, sandbox) + on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None + on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None + dec = self.send( + entry["route"], + multi=entry.get("multi", False), + on_triggers=on_triggers, # type: ignore + on_actions=on_actions, + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_all(self): + """ + Replay all supported handler types in a standard order. + + This should be called before run(). The order matches SummonerClient.dna(): + 1) upload_states + 2) download_states + 3) hooks + 4) receivers + 5) senders + """ + self.initiate_upload_states() + self.initiate_download_states() + self.initiate_hooks() + self.initiate_receivers() + self.initiate_senders() diff --git a/summoner/client/merger.py b/summoner/client/merger.py index c1462e5..0c6afca 100644 --- a/summoner/client/merger.py +++ b/summoner/client/merger.py @@ -1,1259 +1,7 @@ """ -merger.py - -This module provides two related utilities built on top of SummonerClient: - -1) ClientMerger - Build a single composite SummonerClient by replaying handlers from multiple sources. - - A "source" can be: - - an imported SummonerClient instance (live Python object), or - - a DNA list (already loaded JSON list[dict]), or - - a DNA JSON file path. - - Imported-client sources: - - handlers keep their original module globals (module-backed execution), - - the original client binding (for example the name "agent") is rebound to the merged client, - - optional rebind_globals are injected into handler globals. - - DNA sources: - - handlers are reconstructed by compiling their recorded source text into an isolated - sandbox module (one sandbox per DNA source), - - the sandbox binds var_name (for example "agent") to the merged client instance, - so handler code that references `agent` executes against the composite client, - - optional context (imports, globals, recipes) is applied into the sandbox. - - Usage pattern: - - instantiate ClientMerger(...) - - configure flow / styles as usual on the merged client if desired - - call agent.initiate_all() to replay handlers onto the merged client - - call agent.run(...) - -2) ClientTranslation - Reconstruct a fresh SummonerClient from a DNA list. - - Translation compiles handler functions from their recorded source into a fresh sandbox module, - binds var_name (for example "agent") to the translated client, then registers the handlers - using the normal decorators. - -Security and trust model ------------------------- -Both classes execute code from DNA via exec() and eval(): - -- context imports (ctx["imports"]) -- recipes (ctx["recipes"]) -- handler bodies (entry["source"]) - -This is intended for trusted DNA (typically produced by your own agents). -Do not run untrusted DNA. +merger.py had both ClientMerger and ClientTranslation. +Make sure that still holds so that any from summoner.client.merger ... still hold """ -#pylint:disable=line-too-long, too-many-lines, wrong-import-position -#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation - -from importlib import import_module -from typing import Optional, Any -from contextlib import suppress -from pathlib import Path -import inspect -import asyncio -import types -import re -import json -import uuid - -import os -import sys -target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) -if target_path not in sys.path: - sys.path.insert(0, target_path) - -from summoner.client.client import SummonerClient -from summoner.protocol.triggers import Action, load_triggers -from summoner.protocol.process import Direction - - -def _resolve_trigger(TriggerCls, name: str): - """ - Resolve a trigger name into a trigger instance from TriggerCls. - - DNA stores triggers as strings. This helper supports the two common access patterns: - - Enum-style indexing: TriggerCls["ok"] - - Attribute access: TriggerCls.ok - - Parameters - ---------- - TriggerCls: - Trigger class or enum-like object returned by flow.triggers() or load_triggers(). - name: - Trigger name as stored in DNA. - - Returns - ------- - Any - The resolved trigger value. - - Raises - ------ - KeyError - If the trigger cannot be resolved. - """ - # Enum-style: TriggerCls["ok"] - try: - return TriggerCls[name] - except Exception: - pass - # Attribute-style: TriggerCls.ok - try: - return getattr(TriggerCls, name) - except Exception: - pass - raise KeyError(f"Unknown trigger '{name}' for {TriggerCls}") - - -def _resolve_action(ActionCls, name: str): - """ - Resolve an action name into the corresponding Action entry. - - DNA stores actions as strings. Depending on how a sender was serialized, the name can be: - - the enum attribute name ("MOVE") - - a mixed-case name ("Move") - - the underlying class name (Move.__name__ == "Move") - - Parameters - ---------- - ActionCls: - The Action container used by the protocol layer (typically summoner.protocol.triggers.Action). - name: - Action name as stored in DNA. - - Returns - ------- - Any - The resolved Action entry. - - Raises - ------ - KeyError - If the action cannot be resolved. - """ - # 1) Try direct attribute match: "MOVE" - if hasattr(ActionCls, name): - return getattr(ActionCls, name) - - # 2) Try uppercased: "Move" -> "MOVE" - up = name.upper() - if hasattr(ActionCls, up): - return getattr(ActionCls, up) - - # 3) Try matching the underlying class/function name: - # Action.MOVE == Move, where Move.__name__ == "Move" - for v in ActionCls.__dict__.values(): - if getattr(v, "__name__", None) == name: - return v - - raise KeyError(f"Unknown action '{name}' for {ActionCls}") - - -class ClientMerger(SummonerClient): - """ - Merge multiple sources into one client. - - Each input source can be: - - an imported SummonerClient instance (module-backed execution), - - a DNA list (list[dict]), - - a DNA JSON file path. - - Imported-client sources - ----------------------- - - Handlers keep their original module globals. - - The original client binding (var_name such as "agent") is rebound to the merged client. - - Optional rebind_globals are injected into handler globals. - - Note: rebinding mutates handler globals. This is intentional. - - DNA sources - ----------- - - Each DNA source gets its own sandbox module (isolated globals dict). - - var_name is bound to the merged client in that sandbox, so handler code referencing - `agent` executes against the composite client. - - Optional context imports/globals/recipes are executed in the sandbox. - - Execution model - --------------- - ClientMerger does not automatically register handlers during __init__. - You must call initiate_all() (or initiate_* individually) before run(). - - Safety - ------ - This class executes trusted code from DNA via exec()/eval(). Do not run untrusted DNA. - """ - - # pylint:disable=too-many-arguments, too-many-positional-arguments - def __init__( - self, - named_clients: list[Any], # backward compatible: list[dict] or list[SummonerClient] or list[dna_list] - name: Optional[str] = None, - rebind_globals: Optional[dict[str, Any]] = None, - allow_context_imports: bool = True, - verbose_context_imports: bool = False, - close_subclients: bool = True, - ): - super().__init__(name=name) - - # Globals injected into: - # - sandbox globals dicts (DNA sources), and - # - imported handler globals (imported-client sources). - # This is how "missing" symbols like Trigger or shared objects are supplied. - self._rebind_globals = dict(rebind_globals or {}) - - # Context controls: - # - allow_context_imports: execute import lines found in DNA context - # - verbose_context_imports: log successes as well as failures - self._allow_context_imports = allow_context_imports - self._verbose_context_imports = verbose_context_imports - - # If True, imported template clients are cleaned up after extraction: - # cancel/drain registration tasks and close their event loops when possible. - # This reduces warnings when importing agent scripts as templates. - self._close_subclients = close_subclients - - # Normalized sources used later by initiate_* replay methods. - self.sources: list[dict[str, Any]] = [] - self._import_reports: list[dict[str, Any]] = [] - - for idx, entry in enumerate(named_clients): - src = self._normalize_source(entry, idx) - self.sources.append(src) - - if self._close_subclients: - self._shutdown_imported_clients() - - # ---------------------------- - # Source normalization - # ---------------------------- - - #pylint:disable=too-many-branches - def _normalize_source(self, entry: Any, idx: int) -> dict[str, Any]: - """ - Normalize a user-provided source specification into a canonical dict. - - Accepted inputs - --------------- - - SummonerClient instance - - DNA list (list[dict]) - - dict containing one of: {"client"}, {"dna_list"}, {"dna_path"} - - Normalized output - ----------------- - Returns a dict with at least: - - kind: "client" or "dna" - - var_name: global name used by handler sources to refer to the client ("agent" by default) - - For kind="client": - - client: the imported SummonerClient instance - - For kind="dna": - - dna_entries: handler entries (context removed if present) - - context: optional __context__ entry - - sandbox_name: unique module name - - globals: sandbox globals dict where code is compiled and executed - - import_report: best-effort report for context imports - """ - # Allow passing a SummonerClient directly - if isinstance(entry, SummonerClient): - entry = {"client": entry} - - # Allow passing a dna_list directly - if isinstance(entry, list): - entry = {"dna_list": entry} - - if not isinstance(entry, dict): - raise TypeError(f"Entry #{idx} must be dict | SummonerClient | dna_list, got {type(entry).__name__}") - - if "client" in entry: - client = entry["client"] - if not isinstance(client, SummonerClient): - raise TypeError(f"Entry #{idx} 'client' must be SummonerClient, got {type(client).__name__}") - - # var_name controls which global name will be rebound to the merged client. - var_name = entry.get("var_name") - if var_name is None: - var_name = self._infer_client_var_name(client) or "agent" - if not isinstance(var_name, str): - raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") - - return { - "kind": "client", - "client": client, - "var_name": var_name, - } - - # DNA sources - dna_list = None - if "dna_list" in entry: - dna_list = entry["dna_list"] - elif "dna_path" in entry: - p = Path(entry["dna_path"]) - dna_list = json.loads(p.read_text(encoding="utf-8")) - else: - raise KeyError(f"Entry #{idx} must contain 'client' or 'dna_list' or 'dna_path'") - - if not isinstance(dna_list, list): - raise TypeError(f"Entry #{idx} DNA must be a list, got {type(dna_list).__name__}") - - # Optional context header is stored in DNA as the first entry. - ctx = None - entries = dna_list - if entries and isinstance(entries[0], dict) and entries[0].get("type") == "__context__": - ctx = entries[0] - entries = entries[1:] - - # Determine var_name binding: - # - explicit var_name in entry wins - # - else use ctx["var_name"] - # - else default to "agent" - var_name = entry.get("var_name") - if var_name is None: - ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None - var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" - if not isinstance(var_name, str): - raise TypeError(f"Entry #{idx} 'var_name' must be str, got {type(var_name).__name__}") - - # Each DNA source gets its own sandbox module (isolated globals dict). - sandbox_module_name = f"summoner_merge_{uuid.uuid4().hex}" - sandbox_module = types.ModuleType(sandbox_module_name) - sys.modules[sandbox_module_name] = sandbox_module - g = sandbox_module.__dict__ - - # Bind the client name used inside handler source to the merged client. - # This makes `await agent.travel_to(...)` act on the composite client. - g[var_name] = self - - # Apply context (imports/globals/recipes) into that sandbox. - report = self._apply_context(ctx, g, label=f"source#{idx}") - self._import_reports.append(report) - - return { - "kind": "dna", - "dna_entries": entries, - "context": ctx, - "var_name": var_name, - "sandbox_name": sandbox_module_name, - "globals": g, - "import_report": report, - } - - def _infer_client_var_name(self, client: SummonerClient) -> Optional[str]: - """ - Infer the module-global variable name used by handlers to refer to `client`. - - This is used for imported-client sources so that we can rebind that name - (for example "agent") to the merged client in handler globals. - - Returns - ------- - Optional[str] - The inferred binding name, or None if not found. - """ - # Look for a module-global name whose value is exactly `client` - # TODO: can use _view_candidates instead of building a list - candidates = [] - #pylint:disable=protected-access - if client._upload_states is not None: - candidates.append(client._upload_states) - if client._download_states is not None: - candidates.append(client._download_states) - for d in client._dna_receivers: - candidates.append(d.get("fn")) - for d in client._dna_senders: - candidates.append(d.get("fn")) - for d in client._dna_hooks: - candidates.append(d.get("fn")) - - for fn in candidates: - if fn is None: - continue - g = getattr(fn, "__globals__", None) - if not isinstance(g, dict): - continue - for k, v in g.items(): - if v is client and isinstance(k, str): - return k - return None - - def _shutdown_imported_clients(self) -> None: - """ - Best-effort cleanup for imported template clients. - - Why this exists - -------------- - Many agent scripts create a SummonerClient at import-time. That client creates - an event loop and schedules registration tasks. - - When agent scripts are imported only as templates for merging, those template - clients should not be left alive, otherwise you often see: - - "coroutine was never awaited" - - "Task was destroyed but it is pending" - - Cleanup approach - ---------------- - For each imported client: - 1) cancel pending registration tasks - 2) if its loop is not running, drive the loop to await cancellations - 3) close the loop - 4) clear the template's registration list - """ - for src in self.sources: - if src.get("kind") != "client": - continue - - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - - #pylint:disable=protected-access - tasks = list(client._registration_tasks or []) - loop = client.loop - - # Nothing to do. - if not tasks and loop.is_closed(): - continue - - # 1) cancel tasks - for t in tasks: - with suppress(Exception): - t.cancel() - - # 2) drain tasks on that loop so they are actually awaited - if loop is not None and not loop.is_closed(): - if loop.is_running(): - # Can't safely run_until_complete; also shouldn't close a running loop. - self.logger.warning( - f"[{var_name}] Imported client loop is running; " - f"cannot drain/close registration tasks cleanly." - ) - else: - old_loop = None - try: - # Set context so asyncio.gather/futures bind to the right loop. - with suppress(Exception): - old_loop = asyncio.get_event_loop_policy().get_event_loop() - asyncio.set_event_loop(loop) - - # Await cancellation. This is what prevents warnings. - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - except Exception as e: - self.logger.warning(f"[{var_name}] Error draining registration tasks: {e}") - finally: - # Restore previous loop context (or clear it). - with suppress(Exception): - asyncio.set_event_loop(old_loop) - - # 3) close loop after drain - try: - loop.close() - except Exception as e: - self.logger.warning(f"[{var_name}] Error closing event loop: {e}") - - # 4) clear list - with suppress(Exception): - #pylint:disable=protected-access - client._registration_tasks.clear() - - # ---------------------------- - # Context application (DNA) - # ---------------------------- - - # pylint:disable=too-many-branches - def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[str, Any]: - """ - Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. - - This is best-effort: - - import failures are recorded and logged - - recipe failures are logged but do not abort merge - - Parameters - ---------- - ctx: - The optional "__context__" DNA header entry. - g: - The sandbox globals dict (module.__dict__) where code will be executed. - label: - Used for logging and for the returned report. - - Returns - ------- - dict[str, Any] - A structured report with keys: label, succeeded, failed, skipped. - - Security note - ------------- - This executes code from ctx (imports and recipes). Use only with trusted DNA. - """ - report = {"label": label, "succeeded": [], "failed": [], "skipped": []} - - if not isinstance(ctx, dict): - return report - - # imports (executed inside sandbox namespace) - for line in (ctx.get("imports") or []): - if not isinstance(line, str) or not line.strip(): - continue - - if not self._allow_context_imports: - report["skipped"].append(line) - continue - - try: - # pylint:disable=exec-used - exec(line, g) - report["succeeded"].append(line) - if self._verbose_context_imports: - self.logger.info(f"[merge ctx:{label}] import ok: {line}") - except Exception as e: - report["failed"].append((line, f"{type(e).__name__}: {e}")) - self.logger.warning(f"[merge ctx:{label}] import failed: {line!r} ({type(e).__name__}: {e})") - - # plain globals (already JSON-friendly) - globs = ctx.get("globals") or {} - if isinstance(globs, dict): - for k, v in globs.items(): - if isinstance(k, str): - g.setdefault(k, v) - - # recipes (evaluated inside the sandbox namespace) - recipes = ctx.get("recipes") or {} - if isinstance(recipes, dict): - for k, expr in recipes.items(): - if not (isinstance(k, str) and isinstance(expr, str)): - continue - try: - # pylint:disable=eval-used - g.setdefault(k, eval(expr, g, {})) - except Exception as e: - self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") - - return report - - # ---------------------------- - # Utility: source patch for getsource capture - # ---------------------------- - - def _apply_with_source_patch(self, decorator, fn, source: str): - """ - Apply a SummonerClient decorator while preserving DNA source text. - - Why patch inspect.getsource - --------------------------- - The base decorators record handler source using inspect.getsource(fn). - When functions are constructed from DNA via exec(), inspect.getsource(fn) - will typically fail because there is no real source file. - - This helper temporarily patches inspect.getsource so the decorator stores - the original DNA text. This patch is process-global, so it should be used - only in controlled, single-threaded contexts (typical CLI usage). - """ - orig = inspect.getsource - inspect.getsource = lambda o: source - try: - decorator(fn) - finally: - inspect.getsource = orig - - # ---------------------------- - # Imported-client handler cloning - # ---------------------------- - - def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.FunctionType: - """ - Clone a function object for imported-client sources while preserving module globals. - - Behavior - -------- - - Mutates fn.__globals__ in-place: - - rebind original_name (for example "agent") to the merged client - - inject rebind_globals into the same globals dict - - Creates a new function object using the original code object and closure. - - This preserves the module-backed environment of imported clients. The cost is that - imported handler globals are modified, which is intentional for merge semantics. - """ - g = fn.__globals__ - - # rebind the client variable name (agent/client/etc) - try: - g[original_name] = self - except Exception as e: - self.logger.warning(f"Could not bind '{original_name}' to merged client: {e}") - - # rebind any shared globals (viz, Trigger, etc.) - for k, v in self._rebind_globals.items(): - try: - g[k] = v - except Exception as e: - self.logger.warning(f"Could not bind global '{k}' in '{fn.__name__}': {e}") - - new_fn = types.FunctionType( - fn.__code__, - g, - name=fn.__name__, - argdefs=fn.__defaults__, - closure=fn.__closure__, - ) - new_fn.__annotations__ = getattr(fn, "__annotations__", {}) - new_fn.__doc__ = getattr(fn, "__doc__", None) - - # if your dna() uses a __dna_source__ fallback, keep it - if hasattr(fn, "__dna_source__"): - new_fn.__dna_source__ = fn.__dna_source__ # type: ignore - - return new_fn - - - # ---------------------------- - # DNA compilation (per-source sandbox) - # ---------------------------- - - def _make_from_source(self, entry: dict[str, Any], g: dict, sandbox_name: str) -> types.FunctionType: - """ - Build a function object from a DNA entry by executing its function body in a sandbox. - - Key rule: strip decorators - -------------------------- - DNA sources typically include decorator lines (for example "@agent.receive(...)"). - We skip all decorators and only exec the "def ..." block. Otherwise, compiling the - function would also register it immediately, which would duplicate handlers. - - Globals - ------- - - rebind_globals is injected into g before exec so required runtime symbols - (for example Trigger, viz, shared objects) are visible to the compiled function. - - var_name is already bound to the merged client in g during source normalization. - """ - fn_name = entry["fn_name"] - raw = entry["source"] - lines = raw.splitlines() - - # --------------------------------------------------------------------- - # 1) Find the def line for this function, skipping decorators. - # --------------------------------------------------------------------- - def_idx = None - for idx, line in enumerate(lines): - pat = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" - if re.match(pat, line): - def_idx = idx - break - if def_idx is None: - raise RuntimeError(f"Could not find def for '{fn_name}'") - - func_body = "\n".join(lines[def_idx:]) - - # --------------------------------------------------------------------- - # 2) Ensure rebinding happens in the same globals dict used by exec(). - # --------------------------------------------------------------------- - rebind = self._rebind_globals - if isinstance(rebind, dict) and rebind: - g.update(rebind) - - # --------------------------------------------------------------------- - # 3) Execute and retrieve the resulting function object. - # --------------------------------------------------------------------- - if "__builtins__" not in g: - g["__builtins__"] = __builtins__ - - # pylint:disable=exec-used - exec(compile(func_body, filename=f"<{sandbox_name}>", mode="exec"), g) - - fn = g.get(fn_name) - if not isinstance(fn, types.FunctionType): - raise RuntimeError(f"Failed to construct function '{fn_name}'") - - return fn - - - # ---------------------------- - # Public replay API - # ---------------------------- - - def initiate_upload_states(self): - """Replay @upload_states from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - #pylint:disable=protected-access - fn = client._upload_states - if fn is None: - continue - fn_clone = self._clone_handler(fn, var_name) - try: - self.upload_states()(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay upload_states '{fn.__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "upload_states": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.upload_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_download_states(self): - """Replay @download_states from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - #pylint:disable=protected-access - fn = client._download_states - if fn is None: - continue - fn_clone = self._clone_handler(fn, var_name) - try: - self.download_states()(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay download_states '{fn.__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "download_states": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.download_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_hooks(self): - """Replay @hook(Direction, priority=...) from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - #pylint:disable=protected-access - for dna in client._dna_hooks: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.hook(dna["direction"], priority=dna["priority"])(fn_clone) - except Exception as e: - self.logger.warning(f"[{var_name}] Failed to replay hook '{dna['fn'].__name__}': {e}") - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "hook": - continue - fn = self._make_from_source(entry, g, sandbox) - direction = Direction[entry["direction"]] if isinstance(entry.get("direction"), str) else entry["direction"] - dec = self.hook(direction, priority=tuple(entry.get("priority", ()))) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_receivers(self): - """Replay @receive(route, priority=...) from every source onto the merged client.""" - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - #pylint:disable=protected-access - for dna in client._dna_receivers: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.receive(dna["route"], priority=dna["priority"])(fn_clone) - except Exception as e: - self.logger.warning( - f"[{var_name}] Failed to replay receiver '{dna['fn'].__name__}' on route '{dna['route']}': {e}" - ) - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - for entry in src["dna_entries"]: - if entry.get("type") != "receive": - continue - fn = self._make_from_source(entry, g, sandbox) - dec = self.receive(entry["route"], priority=tuple(entry.get("priority", ()))) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_senders(self): - """ - Replay @send(route, multi, on_triggers, on_actions) from every source onto the merged client. - - Imported-client sources: - - carry actual trigger/action objects in _dna_senders. - - DNA sources: - - store trigger/action names as strings. - - triggers are resolved using TriggerCls: - - prefer Trigger class provided by sandbox context ("Trigger" in sandbox globals) - - else fall back to load_triggers() - - actions are resolved from Action by name via _resolve_action. - """ - for src in self.sources: - if src["kind"] == "client": - client: SummonerClient = src["client"] - var_name: str = src["var_name"] - #pylint:disable=protected-access - for dna in client._dna_senders: - fn_clone = self._clone_handler(dna["fn"], var_name) - try: - self.send( - dna["route"], - multi=dna["multi"], - on_triggers=dna["on_triggers"], - on_actions=dna["on_actions"], - )(fn_clone) - except Exception as e: - self.logger.warning( - f"[{var_name}] Failed to replay sender '{dna['fn'].__name__}' on route '{dna['route']}': {e}" - ) - - else: - g = src["globals"] - sandbox = src["sandbox_name"] - - # Triggers: prefer a Trigger class provided by sandbox context; otherwise load defaults. - TriggerCls = g.get("Trigger") - if TriggerCls is None: - TriggerCls = load_triggers() - - for entry in src["dna_entries"]: - if entry.get("type") != "send": - continue - fn = self._make_from_source(entry, g, sandbox) - on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None - on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None - dec = self.send( - entry["route"], - multi=entry.get("multi", False), - on_triggers=on_triggers, # type: ignore - on_actions=on_actions, - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_all(self): - """ - Replay all supported handler types in a standard order. - - This should be called before run(). The order matches SummonerClient.dna(): - 1) upload_states - 2) download_states - 3) hooks - 4) receivers - 5) senders - """ - self.initiate_upload_states() - self.initiate_download_states() - self.initiate_hooks() - self.initiate_receivers() - self.initiate_senders() - -# pylint:disable=too-many-instance-attributes -class ClientTranslation(SummonerClient): - """ - Reconstruct a SummonerClient from its DNA list. - - Translation compiles handlers from their recorded source into a fresh sandbox module, - then registers them on this client via normal decorators. - - Execution environment - --------------------- - - Translated handlers do not run inside the original agent modules. - - They run inside the translation sandbox, with only explicitly imported or rebound - globals available (for example: Trigger, shared objects, viz, etc.). - - Context behavior - ---------------- - If the DNA begins with a "__context__" entry, translation may: - - exec() ctx["imports"] (if allow_context_imports=True) - - bind ctx["globals"] into sandbox - - eval() ctx["recipes"] into sandbox - - Template-client cleanup - ----------------------- - DNA entries carry a "module" field from the original capture. - Importing those modules can create template SummonerClient instances at import-time. - - This class attempts to find and clean up such template clients so they do not leave - pending registration tasks or open loops behind. - """ - # pylint:disable=too-many-arguments, too-many-positional-arguments - def __init__( - self, - dna_list: list[dict[str, Any]], - name: Optional[str] = None, - var_name: Optional[str] = None, - rebind_globals: Optional[dict[str, Any]] = None, - allow_context_imports: bool = True, - verbose_context_imports: bool = False - ): - super().__init__(name=name) - - if not isinstance(dna_list, list): - raise TypeError("dna_list must be a list of DNA entries") - - self._rebind_globals = dict(rebind_globals or {}) - self._allow_context_imports = allow_context_imports - self._verbose_context_imports = verbose_context_imports - - # Extract optional context entry - ctx = None - if dna_list and isinstance(dna_list[0], dict) and dna_list[0].get("type") == "__context__": - ctx = dna_list[0] - dna_list = dna_list[1:] - - # Decide binding name: - # - explicit var_name wins (admin override) - # - else use context var_name if present - # - else default to "agent" - if var_name is None: - ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None - var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" - - self._dna_list = dna_list - if not isinstance(var_name, str): - raise TypeError("var_name must be a string") - self._var_name = var_name - self._context = ctx - - # Create a sandbox module for translated code (independent from user imports) - self._sandbox_module_name = f"summoner_translation_{uuid.uuid4().hex}" - self._sandbox_module = types.ModuleType(self._sandbox_module_name) - sys.modules[self._sandbox_module_name] = self._sandbox_module - self._sandbox_globals = self._sandbox_module.__dict__ - - # Bind the client name used in handler source - self._sandbox_globals[self._var_name] = self - - # Apply context if present - self._apply_context() - - # Best-effort cleanup of template clients imported indirectly via DNA modules - self._cleanup_template_clients_from_modules() - - #pylint:disable=too-many-branches - def _apply_context(self): - """ - Apply optional "__context__" entry into the translation sandbox. - - This is best-effort and intended for trusted DNA. - It may execute code via exec()/eval() if imports/recipes are present. - """ - if not isinstance(self._context, dict): - return - - g = self._sandbox_globals - - # Imports - for line in self._context.get("imports", []) or []: - if not isinstance(line, str) or not line.strip(): - continue - if not self._allow_context_imports: - continue - try: - # pylint:disable=exec-used - exec(line, g) - if self._verbose_context_imports: - self.logger.info(f"[translation ctx] import ok: {line}") - except Exception as e: - self.logger.warning(f"[translation ctx] import failed: {line!r} ({type(e).__name__}: {e})") - - # Plain globals - globs = self._context.get("globals", {}) or {} - if isinstance(globs, dict): - for k, v in globs.items(): - if isinstance(k, str): - g.setdefault(k, v) - - # Recipes - rec = self._context.get("recipes", {}) or {} - if isinstance(rec, dict): - for k, expr in rec.items(): - if not (isinstance(k, str) and isinstance(expr, str)): - continue - # eval in the sandbox global namespace - # pylint:disable=eval-used - try: - g.setdefault(k, eval(expr, g, {})) - except Exception as e: - self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") - - #pylint:disable=unused-argument - def _cleanup_one_template_client(self, client: SummonerClient, label: str): - """ - Best-effort cleanup for a template client discovered in an imported module. - - If a DNA entry references a module, importing it may construct a SummonerClient - at module import-time. That client is not meant to run in translation mode. - - Cleanup steps: - - cancel pending registration tasks - - if possible, await cancellations by driving the template's loop - - close the loop when it is not running - """ - # pylint:disable=protected-access - regs = list(client._registration_tasks or []) - loop = client.loop - - # cancel registrations - for t in regs: - try: - t.cancel() - except Exception: - pass - - # drain cancellations if we can drive that loop - try: - if regs and loop is not None and (not loop.is_running()) and (not loop.is_closed()): - loop.run_until_complete(asyncio.gather(*regs, return_exceptions=True)) - except Exception: - # best-effort only - pass - - # clear list - try: - # pylint:disable=protected-access - client._registration_tasks.clear() - except Exception: - pass - - # close the loop (template clients are not meant to run) - try: - if loop is not None and (not loop.is_running()) and (not loop.is_closed()): - loop.close() - except Exception: - pass - - def _cleanup_template_clients_from_modules(self): - """ - For every module referenced in the DNA, attempt to find a template client: - - If the module defines a global variable named self._var_name (for example "agent") - and it points to a SummonerClient instance other than this translated client, - then treat it as a template and clean it up. - - This reduces warnings about pending registration tasks and open event loops. - """ - modules = {entry.get("module") for entry in self._dna_list if isinstance(entry, dict)} - modules.discard(None) - - seen_ids: set[int] = set() - - for module_name in modules: - try: - module = sys.modules.get(module_name) or import_module(module_name) # type: ignore - except Exception: - continue - - g = getattr(module, "__dict__", {}) - template = g.get(self._var_name) - - if isinstance(template, SummonerClient) and template is not self: - if id(template) in seen_ids: - continue - seen_ids.add(id(template)) - self._cleanup_one_template_client(template, label=f"{module_name}.{self._var_name}") - - def _make_from_source(self, entry: dict[str, Any]) -> types.FunctionType: - """ - Compile a handler function from its DNA 'source' into the translation sandbox globals. - - Decorator lines are stripped so compilation does not implicitly register handlers. - """ - fn_name = entry["fn_name"] - - module_globals = self._sandbox_globals - module_globals[self._var_name] = self - - # inject runtime globals (Trigger, viz, shared objects, etc.) - if self._rebind_globals: - module_globals.update(self._rebind_globals) - - if "__builtins__" not in module_globals: - module_globals["__builtins__"] = __builtins__ - - # strip off decorator lines so we only exec the `def` block - raw = entry["source"] - lines = raw.splitlines() - for idx, line in enumerate(lines): - pattern = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" - if re.match(pattern, line): - func_body = "\n".join(lines[idx:]) - break - else: - raise RuntimeError(f"Could not find definition for '{fn_name}'") - - #pylint:disable=exec-used - exec(compile(func_body, filename=f"<{self._sandbox_module_name}>", mode="exec"), module_globals) - - fn = module_globals.get(fn_name) - if not isinstance(fn, types.FunctionType): - raise RuntimeError(f"Failed to construct function '{fn_name}'") - return fn - - def _apply_with_source_patch(self, decorator, fn, source: str): - """ - Temporarily override inspect.getsource so SummonerClient decorators record DNA text. - - This is process-global and is intended for single-threaded translation runs. - """ - orig = inspect.getsource - inspect.getsource = lambda o: source - try: - decorator(fn) - finally: - inspect.getsource = orig - - def initiate_upload_states(self): - """Replay @upload_states from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "upload_states": - continue - fn = self._make_from_source(entry) - dec = self.upload_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_download_states(self): - """Replay @download_states from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "download_states": - continue - fn = self._make_from_source(entry) - dec = self.download_states() - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_hooks(self): - """Replay @hook from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "hook": - continue - fn = self._make_from_source(entry) - dec = self.hook( - Direction[entry["direction"]], - priority=tuple(entry.get("priority", ())) - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_receivers(self): - """Replay @receive from DNA onto this translated client.""" - for entry in self._dna_list: - if entry.get("type") != "receive": - continue - fn = self._make_from_source(entry) - dec = self.receive( - entry["route"], - priority=tuple(entry.get("priority", ())) - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_senders(self): - """ - Replay @send from DNA onto this translated client. - - Triggers and actions are stored by name in DNA, then resolved here: - - Trigger is resolved using a Trigger class found in sandbox globals, else load_triggers() - - Action is resolved from the Action container by name - """ - g = self._sandbox_globals - - # Ensure rebind globals are visible before resolving triggers/actions. - if self._rebind_globals: - g.update(self._rebind_globals) - - TriggerCls = g.get("Trigger") - if TriggerCls is None: - TriggerCls = load_triggers() - - for entry in self._dna_list: - if entry.get("type") != "send": - continue - fn = self._make_from_source(entry) - on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None - on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None - dec = self.send( - entry["route"], - multi=entry.get("multi", False), - on_triggers=on_triggers, # type: ignore - on_actions=on_actions, - ) - self._apply_with_source_patch(dec, fn, entry["source"]) - - def initiate_all(self): - """ - Replay all handler types from DNA in a standard order. - - This should be called before run(). - """ - self.initiate_upload_states() - self.initiate_download_states() - self.initiate_hooks() - self.initiate_receivers() - self.initiate_senders() - - async def _async_shutdown(self): - """ - Async shutdown path used by SIGINT/SIGTERM and quit(). - - Steps - ----- - 1) cancel everything (base shutdown) - 2) cancel and await pending decorator-registration coroutines - 3) await in-flight handler/worker tasks - 4) stop the loop so run_until_complete can return - """ - # 1) cancel everything - super().shutdown() - - # 2) wait for decorator register() tasks - regs = self._registration_tasks or [] - if regs: - for t in regs: - t.cancel() - await asyncio.gather(*regs, return_exceptions=True) - regs.clear() - - # 3) wait for in-flight handlers and workers - await self._wait_for_tasks_to_finish() - - # 4) stop the loop so run_client's run_until_complete can finish - self.loop.stop() - - def shutdown(self): - """ - Override the base-class SIGINT/SIGTERM handler. - - The base SummonerClient.shutdown() cancels tasks but does not await them. - In translation mode, a cleaner exit is preferred, so we schedule an async - shutdown coroutine instead. - """ - self.logger.info("Client is shutting down…") - try: - asyncio.create_task(self._async_shutdown()) - except RuntimeError: - # If the loop isn't running, ignore. - pass - - async def quit(self): - """ - Quit the translated client: - - set _quit so run_client exits - - then run the same cleanup as Ctrl+C - """ - await super().quit() - await self._async_shutdown() - - def run(self, *args, **kwargs): - """ - Wrap run() so that Ctrl+C cancels leftover registration tasks and exits cleanly. - - Note: this wrapper is intentionally conservative and does not change the base - reconnection and session logic. - """ - try: - super().run(*args, **kwargs) - except KeyboardInterrupt: - self.logger.info("KeyboardInterrupt caught-cancelling registration tasks…") - for task in list(self._registration_tasks or []): - task.cancel() +#pylint:disable=unused-import +from .just_merger import ClientMerger +from .translation import ClientTranslation diff --git a/summoner/client/translation.py b/summoner/client/translation.py new file mode 100644 index 0000000..9c68951 --- /dev/null +++ b/summoner/client/translation.py @@ -0,0 +1,479 @@ +""" +merger.py + +This module provides two related utilities built on top of SummonerClient: + +1) ClientMerger + Build a single composite SummonerClient by replaying handlers from multiple sources. + + A "source" can be: + - an imported SummonerClient instance (live Python object), or + - a DNA list (already loaded JSON list[dict]), or + - a DNA JSON file path. + + Imported-client sources: + - handlers keep their original module globals (module-backed execution), + - the original client binding (for example the name "agent") is rebound to the merged client, + - optional rebind_globals are injected into handler globals. + + DNA sources: + - handlers are reconstructed by compiling their recorded source text into an isolated + sandbox module (one sandbox per DNA source), + - the sandbox binds var_name (for example "agent") to the merged client instance, + so handler code that references `agent` executes against the composite client, + - optional context (imports, globals, recipes) is applied into the sandbox. + + Usage pattern: + - instantiate ClientMerger(...) + - configure flow / styles as usual on the merged client if desired + - call agent.initiate_all() to replay handlers onto the merged client + - call agent.run(...) + +2) ClientTranslation + Reconstruct a fresh SummonerClient from a DNA list. + + Translation compiles handler functions from their recorded source into a fresh sandbox module, + binds var_name (for example "agent") to the translated client, then registers the handlers + using the normal decorators. + +Security and trust model +------------------------ +Both classes execute code from DNA via exec() and eval(): + +- context imports (ctx["imports"]) +- recipes (ctx["recipes"]) +- handler bodies (entry["source"]) + +This is intended for trusted DNA (typically produced by your own agents). +Do not run untrusted DNA. +""" +#pylint:disable=line-too-long, wrong-import-position, duplicate-code +#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation + +from importlib import import_module +from typing import Optional, Any +import inspect +import asyncio +import types +import re +import uuid + +import os +import sys + +from summoner.client.just_merger import _resolve_action, _resolve_trigger +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.client.client import SummonerClient +from summoner.protocol.triggers import Action, load_triggers +from summoner.protocol.process import Direction + +# pylint:disable=too-many-instance-attributes +class ClientTranslation(SummonerClient): + """ + Reconstruct a SummonerClient from its DNA list. + + Translation compiles handlers from their recorded source into a fresh sandbox module, + then registers them on this client via normal decorators. + + Execution environment + --------------------- + - Translated handlers do not run inside the original agent modules. + - They run inside the translation sandbox, with only explicitly imported or rebound + globals available (for example: Trigger, shared objects, viz, etc.). + + Context behavior + ---------------- + If the DNA begins with a "__context__" entry, translation may: + - exec() ctx["imports"] (if allow_context_imports=True) + - bind ctx["globals"] into sandbox + - eval() ctx["recipes"] into sandbox + + Template-client cleanup + ----------------------- + DNA entries carry a "module" field from the original capture. + Importing those modules can create template SummonerClient instances at import-time. + + This class attempts to find and clean up such template clients so they do not leave + pending registration tasks or open loops behind. + """ + # pylint:disable=too-many-arguments, too-many-positional-arguments + def __init__( + self, + dna_list: list[dict[str, Any]], + name: Optional[str] = None, + var_name: Optional[str] = None, + rebind_globals: Optional[dict[str, Any]] = None, + allow_context_imports: bool = True, + verbose_context_imports: bool = False + ): + super().__init__(name=name) + + if not isinstance(dna_list, list): + raise TypeError("dna_list must be a list of DNA entries") + + self._rebind_globals = dict(rebind_globals or {}) + self._allow_context_imports = allow_context_imports + self._verbose_context_imports = verbose_context_imports + + # Extract optional context entry + ctx = None + if dna_list and isinstance(dna_list[0], dict) and dna_list[0].get("type") == "__context__": + ctx = dna_list[0] + dna_list = dna_list[1:] + + # Decide binding name: + # - explicit var_name wins (admin override) + # - else use context var_name if present + # - else default to "agent" + if var_name is None: + ctx_var = ctx.get("var_name") if isinstance(ctx, dict) else None + var_name = ctx_var if isinstance(ctx_var, str) and ctx_var else "agent" + + self._dna_list = dna_list + if not isinstance(var_name, str): + raise TypeError("var_name must be a string") + self._var_name = var_name + self._context = ctx + + # Create a sandbox module for translated code (independent from user imports) + self._sandbox_module_name = f"summoner_translation_{uuid.uuid4().hex}" + self._sandbox_module = types.ModuleType(self._sandbox_module_name) + sys.modules[self._sandbox_module_name] = self._sandbox_module + self._sandbox_globals = self._sandbox_module.__dict__ + + # Bind the client name used in handler source + self._sandbox_globals[self._var_name] = self + + # Apply context if present + self._apply_context() + + # Best-effort cleanup of template clients imported indirectly via DNA modules + self._cleanup_template_clients_from_modules() + + #pylint:disable=too-many-branches + def _apply_context(self): + """ + Apply optional "__context__" entry into the translation sandbox. + + This is best-effort and intended for trusted DNA. + It may execute code via exec()/eval() if imports/recipes are present. + """ + if not isinstance(self._context, dict): + return + + g = self._sandbox_globals + + # Imports + for line in self._context.get("imports", []) or []: + if not isinstance(line, str) or not line.strip(): + continue + if not self._allow_context_imports: + continue + try: + # pylint:disable=exec-used + exec(line, g) + if self._verbose_context_imports: + self.logger.info(f"[translation ctx] import ok: {line}") + except Exception as e: + self.logger.warning(f"[translation ctx] import failed: {line!r} ({type(e).__name__}: {e})") + + # Plain globals + globs = self._context.get("globals", {}) or {} + if isinstance(globs, dict): + for k, v in globs.items(): + if isinstance(k, str): + g.setdefault(k, v) + + # Recipes + rec = self._context.get("recipes", {}) or {} + if isinstance(rec, dict): + for k, expr in rec.items(): + if not (isinstance(k, str) and isinstance(expr, str)): + continue + # eval in the sandbox global namespace + # pylint:disable=eval-used + try: + g.setdefault(k, eval(expr, g, {})) + except Exception as e: + self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") + + #pylint:disable=unused-argument + def _cleanup_one_template_client(self, client: SummonerClient, label: str): + """ + Best-effort cleanup for a template client discovered in an imported module. + + If a DNA entry references a module, importing it may construct a SummonerClient + at module import-time. That client is not meant to run in translation mode. + + Cleanup steps: + - cancel pending registration tasks + - if possible, await cancellations by driving the template's loop + - close the loop when it is not running + """ + # pylint:disable=protected-access + regs = list(client._registration_tasks or []) + loop = client.loop + + # cancel registrations + for t in regs: + try: + t.cancel() + except Exception: + pass + + # drain cancellations if we can drive that loop + try: + if regs and loop is not None and (not loop.is_running()) and (not loop.is_closed()): + loop.run_until_complete(asyncio.gather(*regs, return_exceptions=True)) + except Exception: + # best-effort only + pass + + # clear list + try: + # pylint:disable=protected-access + client._registration_tasks.clear() + except Exception: + pass + + # close the loop (template clients are not meant to run) + try: + if loop is not None and (not loop.is_running()) and (not loop.is_closed()): + loop.close() + except Exception: + pass + + def _cleanup_template_clients_from_modules(self): + """ + For every module referenced in the DNA, attempt to find a template client: + + If the module defines a global variable named self._var_name (for example "agent") + and it points to a SummonerClient instance other than this translated client, + then treat it as a template and clean it up. + + This reduces warnings about pending registration tasks and open event loops. + """ + modules = {entry.get("module") for entry in self._dna_list if isinstance(entry, dict)} + modules.discard(None) + + seen_ids: set[int] = set() + + for module_name in modules: + try: + module = sys.modules.get(module_name) or import_module(module_name) # type: ignore + except Exception: + continue + + g = getattr(module, "__dict__", {}) + template = g.get(self._var_name) + + if isinstance(template, SummonerClient) and template is not self: + if id(template) in seen_ids: + continue + seen_ids.add(id(template)) + self._cleanup_one_template_client(template, label=f"{module_name}.{self._var_name}") + + def _make_from_source(self, entry: dict[str, Any]) -> types.FunctionType: + """ + Compile a handler function from its DNA 'source' into the translation sandbox globals. + + Decorator lines are stripped so compilation does not implicitly register handlers. + """ + fn_name = entry["fn_name"] + + module_globals = self._sandbox_globals + module_globals[self._var_name] = self + + # inject runtime globals (Trigger, viz, shared objects, etc.) + if self._rebind_globals: + module_globals.update(self._rebind_globals) + + if "__builtins__" not in module_globals: + module_globals["__builtins__"] = __builtins__ + + # strip off decorator lines so we only exec the `def` block + raw = entry["source"] + lines = raw.splitlines() + for idx, line in enumerate(lines): + pattern = rf"\s*(async\s+)?def\s+{re.escape(fn_name)}\b" + if re.match(pattern, line): + func_body = "\n".join(lines[idx:]) + break + else: + raise RuntimeError(f"Could not find definition for '{fn_name}'") + + #pylint:disable=exec-used + exec(compile(func_body, filename=f"<{self._sandbox_module_name}>", mode="exec"), module_globals) + + fn = module_globals.get(fn_name) + if not isinstance(fn, types.FunctionType): + raise RuntimeError(f"Failed to construct function '{fn_name}'") + return fn + + def _apply_with_source_patch(self, decorator, fn, source: str): + """ + Temporarily override inspect.getsource so SummonerClient decorators record DNA text. + + This is process-global and is intended for single-threaded translation runs. + """ + orig = inspect.getsource + inspect.getsource = lambda o: source + try: + decorator(fn) + finally: + inspect.getsource = orig + + def initiate_upload_states(self): + """Replay @upload_states from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "upload_states": + continue + fn = self._make_from_source(entry) + dec = self.upload_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_download_states(self): + """Replay @download_states from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "download_states": + continue + fn = self._make_from_source(entry) + dec = self.download_states() + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_hooks(self): + """Replay @hook from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "hook": + continue + fn = self._make_from_source(entry) + dec = self.hook( + Direction[entry["direction"]], + priority=tuple(entry.get("priority", ())) + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_receivers(self): + """Replay @receive from DNA onto this translated client.""" + for entry in self._dna_list: + if entry.get("type") != "receive": + continue + fn = self._make_from_source(entry) + dec = self.receive( + entry["route"], + priority=tuple(entry.get("priority", ())) + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_senders(self): + """ + Replay @send from DNA onto this translated client. + + Triggers and actions are stored by name in DNA, then resolved here: + - Trigger is resolved using a Trigger class found in sandbox globals, else load_triggers() + - Action is resolved from the Action container by name + """ + g = self._sandbox_globals + + # Ensure rebind globals are visible before resolving triggers/actions. + if self._rebind_globals: + g.update(self._rebind_globals) + + TriggerCls = g.get("Trigger") + if TriggerCls is None: + TriggerCls = load_triggers() + + for entry in self._dna_list: + if entry.get("type") != "send": + continue + fn = self._make_from_source(entry) + on_triggers = {_resolve_trigger(TriggerCls, t) for t in entry.get("on_triggers", [])} or None + on_actions = {_resolve_action(Action, a) for a in entry.get("on_actions", [])} or None + dec = self.send( + entry["route"], + multi=entry.get("multi", False), + on_triggers=on_triggers, # type: ignore + on_actions=on_actions, + ) + self._apply_with_source_patch(dec, fn, entry["source"]) + + def initiate_all(self): + """ + Replay all handler types from DNA in a standard order. + + This should be called before run(). + """ + self.initiate_upload_states() + self.initiate_download_states() + self.initiate_hooks() + self.initiate_receivers() + self.initiate_senders() + + async def _async_shutdown(self): + """ + Async shutdown path used by SIGINT/SIGTERM and quit(). + + Steps + ----- + 1) cancel everything (base shutdown) + 2) cancel and await pending decorator-registration coroutines + 3) await in-flight handler/worker tasks + 4) stop the loop so run_until_complete can return + """ + # 1) cancel everything + super().shutdown() + + # 2) wait for decorator register() tasks + regs = self._registration_tasks or [] + if regs: + for t in regs: + t.cancel() + await asyncio.gather(*regs, return_exceptions=True) + regs.clear() + + # 3) wait for in-flight handlers and workers + await self._wait_for_tasks_to_finish() + + # 4) stop the loop so run_client's run_until_complete can finish + self.loop.stop() + + def shutdown(self): + """ + Override the base-class SIGINT/SIGTERM handler. + + The base SummonerClient.shutdown() cancels tasks but does not await them. + In translation mode, a cleaner exit is preferred, so we schedule an async + shutdown coroutine instead. + """ + self.logger.info("Client is shutting down…") + try: + asyncio.create_task(self._async_shutdown()) + except RuntimeError: + # If the loop isn't running, ignore. + pass + + async def quit(self): + """ + Quit the translated client: + - set _quit so run_client exits + - then run the same cleanup as Ctrl+C + """ + await super().quit() + await self._async_shutdown() + + def run(self, *args, **kwargs): + """ + Wrap run() so that Ctrl+C cancels leftover registration tasks and exits cleanly. + + Note: this wrapper is intentionally conservative and does not change the base + reconnection and session logic. + """ + try: + super().run(*args, **kwargs) + except KeyboardInterrupt: + self.logger.info("KeyboardInterrupt caught-cancelling registration tasks…") + for task in list(self._registration_tasks or []): + task.cancel() diff --git a/summoner/logger.py b/summoner/logger.py index bb3d217..ccb5fae 100644 --- a/summoner/logger.py +++ b/summoner/logger.py @@ -10,7 +10,7 @@ import datetime import re -from typing import Optional, Any +from typing import Dict, Optional, Any from logging.handlers import RotatingFileHandler # This makes Logger importable from logger.py @@ -140,7 +140,7 @@ def __init__(self, fmt: str, datefmt: Optional[str], log_keys: Optional[list[str self.log_keys = log_keys def format(self, record: logging.LogRecord) -> str: - base = { + base : Dict[str, str | Dict[Any,Any]] = { "timestamp": self.formatTime(record, self._raw_datefmt), "name": record.name, "level": record.levelname, @@ -150,7 +150,7 @@ def format(self, record: logging.LogRecord) -> str: else {k: record.msg.get(k) for k in self.log_keys if k in record.msg}) else: payload = record.getMessage() - base["message"] = payload # type: ignore + base["message"] = payload return json.dumps(base, default=str) def get_logger(name: str) -> logging.Logger: diff --git a/summoner/protocol/payload.py b/summoner/protocol/payload.py index 03c64c3..c5c2aa6 100644 --- a/summoner/protocol/payload.py +++ b/summoner/protocol/payload.py @@ -2,9 +2,10 @@ TODO: doc payload """ #pylint:disable=line-too-long +from enum import Enum, auto import json from json import JSONDecodeError -from typing import Any, Tuple, Dict, List, Union, TypedDict +from typing import Any, Literal, Tuple, Dict, List, Union, TypedDict from summoner.utils import ( fully_recover_json, @@ -67,13 +68,13 @@ def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: # Lists → walk each element if isinstance(obj, list): - payloads: List[Any] = [] - types: List[Any] = [] + payloads_list: List[Any] = [] + types_list: List[Any] = [] for v in obj: p, t = parse_v0_0_1(v) - payloads.append(p) - types.append(t) - return payloads, types + payloads_list.append(p) + types_list.append(t) + return payloads_list, types_list # Dicts → walk each value if isinstance(obj, dict): @@ -212,7 +213,6 @@ class RelayedMessage(TypedDict): remote_addr: str content: Union[str, dict] - def recover_with_types(text: str) -> RelayedMessage: """ Recover and validate a typed payload from a server message. @@ -230,6 +230,11 @@ def recover_with_types(text: str) -> RelayedMessage: - Invalid JSON or non-JSON warning strings: strip trailing newline and return raw string. - Missing outer keys ("remote_addr"/"content"): strip newline and return raw string. - Missing envelope keys inside content: return the parsed `{"remote_addr":…, "content":…}` as-is. + + Strictly the return type annotation is not accurate. + It may end up in the fallback, but that pollutes that possibility + for the caller even when they are known to be passing in the case for text + which does not end up in the fallbacks. Args: text: Raw text received from the server (may include newline). @@ -246,13 +251,13 @@ def recover_with_types(text: str) -> RelayedMessage: obj = fully_recover_json(text) except (ValueError, JSONDecodeError): # If that fails (e.g. pure warning string), strip newline and relay the raw text - return remove_last_newline(text) + return remove_last_newline(text) # type: ignore # 2) Ensure we have the outer {"remote_addr":…, "content":…} wrapper if not (isinstance(obj, dict) and "remote_addr" in obj and "content" in obj): # Malformed protocol message; upstream code can catch this if needed # raise ValueError("Unsupported message format from server.") - return obj + return obj # type: ignore addr = obj["remote_addr"] content = obj["content"] @@ -265,7 +270,7 @@ def recover_with_types(text: str) -> RelayedMessage: and "_payload" in content and "_type" in content ): - return obj + return obj # type: ignore # 4) We have the versioned envelope—now look up the correct caster version = content["_version"] diff --git a/summoner/server/__init__.py b/summoner/server/__init__.py index 1c8c79d..6123d76 100644 --- a/summoner/server/__init__.py +++ b/summoner/server/__init__.py @@ -1 +1,4 @@ -from .server import SummonerServer \ No newline at end of file +""" +TODO: doc server +""" +from .server import SummonerServer diff --git a/summoner/server/server.py b/summoner/server/server.py index cddf6c5..11a72ec 100644 --- a/summoner/server/server.py +++ b/summoner/server/server.py @@ -1,3 +1,7 @@ +""" +TODO: doc server +""" +# pylint:disable=wrong-import-position, logging-fstring-interpolation import asyncio import signal import os @@ -14,7 +18,7 @@ # Imports from summoner.utils import ( - remove_last_newline, + remove_last_newline, ensure_trailing_newline, load_config, ) @@ -43,10 +47,13 @@ class ClientDisconnected(Exception): """Raised when the client disconnects cleanly (e.g., closes the socket).""" - pass + class SummonerServer: - + """ + TODO: doc server + """ + __slots__: tuple[str, ...] = ( "name", "logger", @@ -72,12 +79,15 @@ def __init__(self, name: Optional[str] = None): self.active_tasks: dict[asyncio.Task, str] = {} self.tasks_lock = asyncio.Lock() - + async def handle_client( self, - reader: asyncio.streams.StreamReader, + reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter ): + """ + TODO: doc handle_client + """ addr = format_addr(writer.get_extra_info("peername")) self.logger.info(f"{addr} connected.") @@ -86,7 +96,7 @@ async def handle_client( self.clients.add(writer) async with self.tasks_lock: - self.active_tasks[task] = addr + self.active_tasks[task] = addr # type: ignore try: while True: @@ -103,54 +113,74 @@ async def handle_client( # Iterate over the snapshot to avoid concurrency issues without long-held locks for other_writer in clients_snapshot: if other_writer != writer: - payload = json.dumps({"remote_addr": addr, "content": remove_last_newline(message)}) + payload = json.dumps( + { + "remote_addr": addr, + "content": remove_last_newline(message) + } + ) other_writer.write(ensure_trailing_newline(payload).encode()) await other_writer.drain() - + except ClientDisconnected: self.logger.info(f"{addr} disconnected.") - + except (ConnectionResetError, asyncio.IncompleteReadError, asyncio.CancelledError) as e: self.logger.warning(f"{addr} error or abrupt disconnect: {e}") - + finally: - + try: writer.write("Warning: Server disconnected.".encode()) await writer.drain() - except Exception: + except Exception: #pylint:disable=broad-exception-caught pass - + async with self.clients_lock: self.clients.discard(writer) writer.close() await writer.wait_closed() async with self.tasks_lock: - self.active_tasks.pop(task, None) + self.active_tasks.pop(task, None) # type: ignore self.logger.info(f"{addr} connection closed.") + #pylint:disable=duplicate-code def shutdown(self): + """ + Cancel all tasks for this loop + """ self.logger.info("Server is shutting down...") for task in asyncio.all_tasks(self.loop): task.cancel() def set_termination_signals(self): + """ + Upon interrupt signal or system/process-based termination + do the shutdown method. + Requires not Windows to have any effect + """ # SIGINT = interupt signal for Ctrl+C | value = 2 # SIGTERM = system/process-based termination | value = 15 if platform.system() != "Windows": for sig in (signal.SIGINT, signal.SIGTERM): + #pylint:disable=unnecessary-lambda self.loop.add_signal_handler(sig, lambda: self.shutdown()) async def run_server(self, host: str = '127.0.0.1', port: int = 8888): + """ + TODO: doc run_server + """ server = await asyncio.start_server(self.handle_client, host=host, port=port) self.logger.info(f"Python server listening on {host}:{port}.") async with server: await server.serve_forever() async def wait_for_tasks_to_finish(self): - # Wait for all client handlers to finish + """ + Wait for all client handlers to finish + """ async with self.tasks_lock: tasks = list(self.active_tasks) if tasks: @@ -163,15 +193,21 @@ def run( config_path: Optional[str] = None, config_dict: Optional[dict[str, Any]] = None, ): - + """ + TODO: doc run + """ + if config_dict is None: # Load config parameters - server_config = load_config(config_path=config_path, debug=True) + server_config = load_config(config_path=config_path, + debug=True) elif isinstance(config_dict, dict): # Shallow copy to avoid external mutation - server_config = dict(config_dict) + server_config = dict(config_dict) else: - raise TypeError(f"SummonerServer.run: config_dict must be a dict or None, got {type(config_dict).__name__}") + #pylint:disable=line-too-long + raise TypeError( + f"SummonerServer.run: config_dict must be a dict or None, got {type(config_dict).__name__}") if platform.system() != "Windows": option = server_config.get("version", None) @@ -195,6 +231,7 @@ def _payload(h, p): elif isinstance(option, str) and option.startswith("rust"): available = ", ".join(["rust"] + list(RUST_MODULES.keys())) if RUST_LATEST else \ ", ".join(RUST_MODULES.keys()) + #pylint:disable=line-too-long raise RuntimeError( f"Rust backend '{option}' requested but not available. Installed options: {available or '(none)'}" ) @@ -202,13 +239,14 @@ def _payload(h, p): if mod is not None: try: requested = option if option is not None else "(unset)" + #pylint:disable=line-too-long print(f"[DEBUG] Config requested version '{requested}' -> resolved to '{mod.__name__}'") mod.start_tokio_server(self.name, _payload(host, port)) except KeyboardInterrupt: pass return - + try: logger_cfg = server_config.get("logger", {}) configure_logger(self.logger, logger_cfg) @@ -218,7 +256,7 @@ def _payload(h, p): except (asyncio.CancelledError, KeyboardInterrupt): pass - + finally: self.loop.run_until_complete(self.wait_for_tasks_to_finish()) From ff2805a883e22b6730db940176a32b8aef7d0601 Mon Sep 17 00:00:00 2001 From: Cobord Date: Fri, 20 Feb 2026 20:44:29 -0500 Subject: [PATCH 08/11] types for the kinds of functions that are decorated, moving the functionality of _dna... into own file, so client.py can be simplified by continuing this process, removed all the type ignores for their specific reasons with explanations of why they do not matter --- simulations/simul_flow_2.py | 12 +- summoner/client/client.py | 275 +++++++++++++++++------------- summoner/client/client_types.py | 61 +++++++ summoner/client/dna.py | 167 ++++++++++++++++++ summoner/client/just_merger.py | 6 +- summoner/client/translation.py | 6 +- summoner/protocol/_deprecation.py | 2 +- summoner/protocol/flow.py | 2 +- summoner/protocol/payload.py | 9 +- summoner/protocol/process.py | 28 +-- summoner/server/server.py | 8 +- tests/test_triggers.py | 2 +- 12 files changed, 419 insertions(+), 159 deletions(-) create mode 100644 summoner/client/client_types.py create mode 100644 summoner/client/dna.py diff --git a/simulations/simul_flow_2.py b/simulations/simul_flow_2.py index 5da4bb3..78786cb 100644 --- a/simulations/simul_flow_2.py +++ b/simulations/simul_flow_2.py @@ -1,5 +1,5 @@ from summoner.protocol.flow import Flow -from typing import Type, Optional, List, Dict, Union, Callable, Any, Awaitable +from typing import Coroutine, Tuple, Optional, Callable, Any from summoner.protocol.triggers import ( Signal, Action, @@ -123,7 +123,7 @@ def upload_states(flow: Flow, routes: list[str]) -> dict[str, list[Node]]: tape = StateTape(raw_states) activation_index: dict[tuple[int, ...], list[TapeActivation]] = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) -batches: dict[tuple[int, ...], list[Callable[[Any],Awaitable]]] = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} +batches: dict[tuple[int, ...], list[Callable[[Any],Coroutine[Any,Any,Any]]]] = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} print("\n\nactivation_index") pprint.pprint(activation_index) @@ -145,7 +145,7 @@ async def _safe_call(fn, payload): payload = {"content": "hello"} async def receiver_test(): - event_buffer: dict[tuple[int, ...], list[tuple[str, ParsedRoute, Event]]] = defaultdict(list) + event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Optional[Event]]]] = defaultdict(list) for priority, batch_fns in sorted(batches.items(), key=lambda kv: kv[0]): label = "default priority" if priority == () else f"priority {priority}" @@ -158,7 +158,7 @@ async def receiver_test(): activations = activation_index[priority] local_tape = tape.refresh() - to_extend: dict[str, list[Node]] = defaultdict(list) + to_extend: dict[Optional[str], list[Node]] = defaultdict(list) for act, event in zip(activations, events): to_extend[act.key].extend(act.route.activated_nodes(event)) to_extend = dict(to_extend) @@ -323,7 +323,7 @@ async def _start_send_workers(num_workers, send_queue): # --- start workers and test runner --- async def sender_test(): - send_queue: asyncio.Queue[Optional[Sender]] = asyncio.Queue(maxsize=50) + send_queue: asyncio.Queue[Optional[Tuple[str,Sender]]] = asyncio.Queue(maxsize=50) num_workers = 50 await _start_send_workers(num_workers, send_queue) @@ -341,5 +341,3 @@ async def sender_test(): # run the simulation asyncio.run(sender_test()) - - diff --git a/summoner/client/client.py b/summoner/client/client.py index bc97466..b4b12aa 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -1,21 +1,24 @@ """ SummonerClient """ -#pylint:disable=line-too-long, wrong-import-position +#pylint:disable=line-too-long, wrong-import-position, too-many-lines #pylint:disable=logging-fstring-interpolation, broad-exception-caught import os import sys import json from typing import ( + Awaitable, Dict, Generator, + List, Optional, Callable, + Set, + Tuple, Union, - Awaitable, + Coroutine, Any, - Type, cast, ) import asyncio @@ -28,6 +31,22 @@ if target_path not in sys.path: sys.path.insert(0, target_path) +from summoner.client.client_types import ( + DOWNLOAD_TYPE, + HOOK_TYPE, + RECEIVE_DECORATED_TYPE, + SEND_DECORATED_TYPE, + SENDING_HOOKS_TYPE, + RECEIVING_HOOKS_TYPE, + UPLOAD_TYPE, +) +from summoner.client.dna import ( + DNA_DOWNLOAD, + DNA_UPLOAD, + DNAHook, + DNAReceiver, + DNASender +) from summoner.utils import ( load_config, is_jsonable, @@ -66,12 +85,6 @@ RelayedMessage ) -#pylint:disable=invalid-name -ANY_TO_AWAIT = Callable[[Any],Awaitable] -HOOK_TYPE = Callable[[Union[str, dict]], Union[str, dict]] -GEN_HOOK_TYPE = Callable[[Optional[Union[str, dict]]], Optional[Union[str, dict]]] -ASYNC_GEN_HOOK_TYPE = Callable[[Optional[Union[str, dict]]], Awaitable[Optional[Union[str, dict]]]] - class ServerDisconnected(Exception): """Raised when the server closes the connection.""" @@ -109,7 +122,7 @@ def __init__(self, name: Optional[str] = None): asyncio.set_event_loop(self.loop) # Protect concurrent access to the set of active tasks - self.active_tasks: set[asyncio.Task] = set() + self.active_tasks: set[asyncio.Task[Any]] = set() self.tasks_lock = asyncio.Lock() # Protect route registration and access for receive/send functions @@ -125,7 +138,7 @@ def __init__(self, name: Optional[str] = None): self.connection_lock = asyncio.Lock() # Safe registration of decorators (hooks, receivers, senders) - self._registration_tasks: list[asyncio.Task] = [] + self._registration_tasks: list[asyncio.Task[Any]] = [] # One-time indexing of parsed routes self.receiver_parsed_routes: dict[str, ParsedRoute] = {} @@ -135,14 +148,14 @@ def __init__(self, name: Optional[str] = None): self._flow = Flow() # Functions to read and write the flow's active states in memory - self._upload_states: Optional[ANY_TO_AWAIT] = None - self._download_states: Optional[ANY_TO_AWAIT] = None + self._upload_states: Optional[UPLOAD_TYPE] = None + self._download_states: Optional[DOWNLOAD_TYPE] = None # Sender HyperParameters self.event_bridge_maxsize : Optional[int] = None self.max_concurrent_workers : Optional[int] = None # Limit the sending rate (will use 50 if None is given) self.send_queue_maxsize : Optional[int] = None - self.batch_drain = None + self.batch_drain: Optional[bool] = None # self.max_consecutive_worker_errors is unbound until _apply_config # Receiver HyperParameters @@ -157,26 +170,26 @@ def __init__(self, name: Optional[str] = None): # self.default_retry_limit is unbound until _apply_config # Pass Event information from the receiving end to the sending end - self.event_bridge: Optional[asyncio.Queue[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]]] = None + self.event_bridge: Optional[asyncio.Queue[tuple[tuple[int, ...], Optional[str], ParsedRoute, Optional[Event]]]] = None - self.send_queue: Optional[asyncio.Queue] = None + self.send_queue: Optional[asyncio.Queue[Optional[Tuple[str,Sender]]]] = None self.send_workers_started = False # To avoid double-starting workers self.worker_tasks: list[asyncio.Task] = [] self.writer_lock = asyncio.Lock() # Store validation hooks to be used before sending and after receiving - self.sending_hooks: dict[tuple[int,...], HOOK_TYPE | GEN_HOOK_TYPE | ASYNC_GEN_HOOK_TYPE] = {} - self.receiving_hooks: dict[tuple[int,...], HOOK_TYPE | GEN_HOOK_TYPE | ASYNC_GEN_HOOK_TYPE] = {} + self.sending_hooks: dict[tuple[int,...], SENDING_HOOKS_TYPE] = {} + self.receiving_hooks: dict[tuple[int,...], RECEIVING_HOOKS_TYPE] = {} self.hooks_lock = asyncio.Lock() # ─── DNA capture for merging ───────────────────────────────────────── # lists of dicts, each entry records one decorated handler - self._dna_receivers: list[dict] = [] - self._dna_senders: list[dict] = [] - self._dna_hooks: list[dict] = [] + self._dna_receivers: list[DNAReceiver] = [] + self._dna_senders: list[DNASender] = [] + self._dna_hooks: list[DNAHook] = [] - self._dna_upload_states: Optional[dict] = None - self._dna_download_states: Optional[dict] = None + self._dna_upload_states: Optional[DNA_UPLOAD] = None + self._dna_download_states: Optional[DNA_DOWNLOAD] = None # ==== VERSION SPECIFIC ==== @@ -265,7 +278,7 @@ def upload_states(self): Decorator to supply a function that returns the current state snapshot. Must be used before client.run(). """ - def decorator(fn: Callable[[], Awaitable]): + def decorator(fn: UPLOAD_TYPE): # ----[ Safety Checks ]---- @@ -275,8 +288,8 @@ def decorator(fn: Callable[[], Awaitable]): _check_param_and_return( fn, decorator_name="@upload_states", - allow_param=(type(None), str, dict, Any), # the payload # type: ignore - allow_return=(type(None), str, Any, Node, list, dict, # type: ignore + allow_param=(type(None), str, dict, Any), # the payload # pyright: ignore[reportArgumentType] + allow_return=(type(None), str, Any, Node, list, dict, # pyright: ignore[reportArgumentType] list[str], dict[str, str], dict[str, list[str]], list[Node], dict[str, Node], dict[str, list[Node]], dict[str, Union[str, list[str]]], @@ -298,7 +311,7 @@ def decorator(fn: Callable[[], Awaitable]): "source": inspect.getsource(fn), } - self._upload_states = fn # type: ignore + self._upload_states = fn return fn @@ -309,7 +322,7 @@ def download_states(self): Decorator to supply a function that receives a StateTape. Must be used before client.run(). """ - def decorator(fn: ANY_TO_AWAIT): + def decorator(fn: DOWNLOAD_TYPE): # ----[ Safety Checks ]---- @@ -319,7 +332,7 @@ def decorator(fn: ANY_TO_AWAIT): _check_param_and_return( fn, decorator_name="@download_states", - allow_param=(type(None), Node, Any, list, dict, # type: ignore + allow_param=(type(None), Node, Any, list, dict, # pyright: ignore[reportArgumentType] list[Node], dict[str, Node], dict[str, list[Node]], @@ -328,7 +341,7 @@ def decorator(fn: ANY_TO_AWAIT): dict[Optional[str], list[Node]], dict[Optional[str], Union[Node, list[Node]]], ), - allow_return=(type(None), Any), # type: ignore + allow_return=(type(None), Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) @@ -353,7 +366,7 @@ def decorator(fn: ANY_TO_AWAIT): # ==== REGISTRATION HELPER ==== - def _schedule_registration(self, register_coro: Awaitable): + def _schedule_registration(self, register_coro: Coroutine[Any,Any,None]): """ Schedule `register_coro` onto self.loop. If the loop isn't running yet, create the task immediately. @@ -361,11 +374,11 @@ def _schedule_registration(self, register_coro: Awaitable): """ if self.loop.is_running(): def _cb(): - task = self.loop.create_task(register_coro) # type: ignore + task = self.loop.create_task(register_coro) self._registration_tasks.append(task) self.loop.call_soon_threadsafe(_cb) else: - task = self.loop.create_task(register_coro) # type: ignore + task = self.loop.create_task(register_coro) self._registration_tasks.append(task) # ==== HOOK REGISTRATION ==== @@ -378,7 +391,7 @@ def hook( """ TODO: doc hook """ - def decorator(fn: GEN_HOOK_TYPE): + def decorator(fn : HOOK_TYPE): """ TODO: doc decorator """ @@ -388,12 +401,12 @@ def decorator(fn: GEN_HOOK_TYPE): if not inspect.iscoroutinefunction(fn): raise TypeError(f"@hook handler '{fn.__name__}' must be async") - + _check_param_and_return( fn, decorator_name="@hook", - allow_param=(Any, str, dict), # type: ignore - allow_return=(type(None), str, dict, Any), # type: ignore + allow_param=(Any, str, dict), # pyright: ignore[reportArgumentType] + allow_return=(type(None), str, dict, Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) @@ -408,12 +421,12 @@ def decorator(fn: GEN_HOOK_TYPE): raise ValueError(f"Priority must be an integer or a tuple of integers (got type {type(priority).__name__}: {priority!r})") # ----[ DNA capture ]---- - self._dna_hooks.append({ - "fn": fn, - "direction": direction, - "priority": tuple_priority, - "source": inspect.getsource(fn), - }) + self._dna_hooks.append(DNAHook( + fn = fn, + direction = direction, + priority = tuple_priority, + source = inspect.getsource(fn), + )) async def register(): """ @@ -422,9 +435,9 @@ async def register(): """ async with self.hooks_lock: if direction == Direction.RECEIVE: - self.receiving_hooks[tuple_priority] = fn + self.receiving_hooks[tuple_priority] = cast(RECEIVING_HOOKS_TYPE, fn) elif direction == Direction.SEND: - self.sending_hooks[tuple_priority] = fn + self.sending_hooks[tuple_priority] = cast(SENDING_HOOKS_TYPE, fn) # ----[ Safe Registration ]---- # NOTE: register() is run ASAP and _registration_tasks is used to wait all registrations before run_client() @@ -445,7 +458,7 @@ def receive( TODO: doc receive """ route = route.strip() - def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): + def decorator(fn: RECEIVE_DECORATED_TYPE): # ----[ Safety Checks ]---- @@ -459,8 +472,8 @@ def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): _check_param_and_return( fn, decorator_name="@receive", - allow_param=(Any, str, dict), # type: ignore - allow_return=(type(None), Event, Any), # type: ignore + allow_param=(Any, str, dict), # pyright: ignore[reportArgumentType] + allow_return=(type(None), Event, Any), # pyright: ignore[reportArgumentType] logger=self.logger, ) @@ -475,12 +488,12 @@ def decorator(fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]]): raise ValueError(f"Priority must be an integer or a tuple of integers (got type {type(priority).__name__}: {priority!r})") # ----[ DNA capture ]---- - self._dna_receivers.append({ - "fn": fn, - "route": route, - "priority": tuple_priority, - "source": inspect.getsource(fn), # for text serialization - }) + self._dna_receivers.append(DNAReceiver( + fn = fn, + route = route, + priority = tuple_priority, + source = inspect.getsource(fn), # for text serialization + )) # ----[ Registration Code ]---- async def register(): @@ -522,14 +535,14 @@ def send( self, route: str, multi: bool = False, - on_triggers: Optional[set[Signal]] = None, - on_actions: Optional[set[Type]] = None, + on_triggers: Optional[set[Any]] = None, + on_actions: Optional[set[Any]] = None, ): """ TODO: doc send """ route = route.strip() - def decorator(fn: Callable[[], Awaitable]): + def decorator(fn: SEND_DECORATED_TYPE): # ----[ Safety Checks ]---- if not inspect.iscoroutinefunction(fn): @@ -544,7 +557,7 @@ def decorator(fn: Callable[[], Awaitable]): fn, decorator_name="@send", allow_param=(), # no args allowed - allow_return=(type(None), Any, str, dict), # type: ignore + allow_return=(type(None), Any, str, dict), # pyright: ignore[reportArgumentType] logger=self.logger, ) else: @@ -552,7 +565,7 @@ def decorator(fn: Callable[[], Awaitable]): fn, decorator_name="@send[multi=True]", allow_param=(), # no args allowed - allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), # type: ignore + allow_return=(Any, list, list[str], list[dict], list[Union[str, dict]]), # pyright: ignore[reportArgumentType] logger=self.logger, ) @@ -575,14 +588,14 @@ def decorator(fn: Callable[[], Awaitable]): raise TypeError(f"Argument `on_actions` must be `None` or a set of Action event classes: {{Action.MOVE, Action.STAY, Action.TEST}}. Provided: {on_actions!r}") # ----[ DNA capture ]---- - self._dna_senders.append({ - "fn": fn, - "route": route, - "multi": multi, - "on_triggers": on_triggers, - "on_actions": on_actions, - "source": inspect.getsource(fn), - }) + self._dna_senders.append(DNASender( + fn = fn, + route = route, + multi = multi, + on_triggers = on_triggers, + on_actions = on_actions, + source = inspect.getsource(fn), + )) # ----[ Registration Code ]---- async def register(): @@ -752,16 +765,19 @@ def dna(self, include_context: bool = False) -> str: - Flow construction and trigger enum creation are not captured unless they are referenced and executed within decorated handler sources. """ + # TODO: use *_entry_contribution methods to build entries and keep this method cleaner + #pylint:disable=import-outside-toplevel import builtins - - # Collect handler functions once so we can reuse them for context analysis. + + # TODO: Only used below, so could have not turned into a list right away + # Keep only as a generator being iterated over handler_fns = list(self._iter_registered_handler_functions()) # ---------------------------- # Handler DNA entries # ---------------------------- - entries: list[dict] = [] + entries: List[Dict[str, Any]] = [] # Upload state hook, if present if self._dna_upload_states is not None: @@ -829,7 +845,7 @@ def dna(self, include_context: bool = False) -> str: # Serialize triggers/actions by name so they can be re-resolved later. "on_triggers": [t.name for t in (dna["on_triggers"] or [])], "on_actions": [a.__name__ for a in (dna["on_actions"] or [])], - "source": get_callable_source(fn, dna.get("source")), + "source": get_callable_source(fn, dna["source"]), "module": fn.__module__, "fn_name": fn.__name__, }) @@ -878,6 +894,7 @@ def dna(self, include_context: bool = False) -> str: # Used by resolve_import_statement to avoid emitting repeated imports. known_modules: set[str] = set() + # Collect handler functions once so we can reuse them for context analysis. for fn in handler_fns: g = getattr(fn, "__globals__", None) @@ -885,7 +902,10 @@ def dna(self, include_context: bool = False) -> str: continue # Names referenced by the function body. - names_to_scan = set(getattr(fn, "__code__", None).co_names if hasattr(fn, "__code__") else ()) # type: ignore + if hasattr(fn, "__code__"): + names_to_scan = set(fn.__code__.co_names) + else: + names_to_scan: Set[str] = set() # Names referenced only via annotations. try: @@ -894,7 +914,7 @@ def dna(self, include_context: bool = False) -> str: if isinstance(v, type) or inspect.isfunction(v) or inspect.ismodule(v): nm = getattr(v, "__name__", None) if isinstance(nm, str) and nm: - names_to_scan.add(nm) # type: ignore + names_to_scan.add(nm) except Exception: pass @@ -958,7 +978,7 @@ def dna(self, include_context: bool = False) -> str: if path_needed: imports_out.add("from pathlib import Path") - context_entry = { + context_entry : Dict[str, Any] = { "type": "__context__", "var_name": inferred_var_name, "imports": sorted(imports_out), @@ -1023,7 +1043,7 @@ async def message_receiver_loop( """ # ----[ Wrapper: Interpret Protocol-Only Errors as None ]---- - async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: + async def _safe_call(fn: Callable[[Any],Awaitable[Any]], payload: Any) -> Optional[Any]: try: return await fn(payload) except BlockingIOError: @@ -1056,7 +1076,7 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: if not receiver_index: data = await self._read_line_safe( reader, - limit=self.max_bytes_per_line, # type: ignore + limit=self.max_bytes_per_line, # max_bytes_per_line is set to an integer by now # pyright: ignore[reportArgumentType] timeout=0.1, ) # if not data: @@ -1070,7 +1090,7 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: data = await self._read_line_safe( reader, - limit=self.max_bytes_per_line, # type: ignore + limit=self.max_bytes_per_line, # max_bytes_per_line is set to an integer by now # pyright: ignore[reportArgumentType] timeout=self.read_timeout_seconds, ) # data = await reader.readline() @@ -1078,7 +1098,7 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: # raise ServerDisconnected("Server closed the connection.") pre_payload: RelayedMessage = recover_with_types(data.decode()) - payload: str | dict | RelayedMessage = pre_payload + payload = cast(Union[RelayedMessage, None], pre_payload) # ----[ Build: Validation ]---- async with self.hooks_lock: @@ -1086,11 +1106,12 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: for priority, receiving_hook in sorted(receiving_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: - new_payload = await receiving_hook(payload) # type: ignore + new_payload = await receiving_hook(payload) if new_payload is None: - payload = None # type: ignore + payload = None break + payload = new_payload except Exception as e: self.logger.error( @@ -1098,19 +1119,18 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: f"failed on payload {payload!r}: {e}", exc_info=True ) - new_payload = payload - payload = new_payload # if *any* hook returned None, skip the rest of processing if payload is None: continue # ----[ Build: Organize Batches by Priority ]---- - batches: dict[tuple[int, ...], list[ANY_TO_AWAIT]] = {} + batches: dict[tuple[int, ...], list[Callable[[Any], Coroutine[Any,Any,Any]]]] = {} if self._flow.in_use: raw_states = (await self._upload_states(payload)) if self._upload_states is not None else None tape = StateTape(raw_states) - activation_index = tape.collect_activations(receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) # type: ignore + activation_index = tape.collect_activations( + receiver_index=receiver_index, parsed_routes=receiver_parsed_routes) # pyright: ignore[reportPossiblyUnboundVariable] batches = {priority: [activation.fn for activation in activations] for priority, activations in activation_index.items()} else: for _, receiver in receiver_index.items(): @@ -1125,7 +1145,7 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: # ----[ Exec: Prepare Passage Receiver → Sender ]---- if self._flow.in_use: - event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Event]]] = defaultdict(list) + event_buffer: dict[tuple[int, ...], list[tuple[Optional[str], ParsedRoute, Optional[Event]]]] = defaultdict(list) # ----[ Exec: Run Batches in Order ]---- for priority, batch_fns in sorted(batches.items(), key=lambda kv: kv[0]): @@ -1139,24 +1159,18 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: # ----[ After: Handle Returns ]---- if self._flow.in_use: - activations = activation_index[priority] # type: ignore + activations = activation_index[priority] # pyright: ignore[reportPossiblyUnboundVariable] - local_tape = tape.refresh() # type: ignore - to_extend: dict[str, list[Node]] = defaultdict(list) + local_tape = tape.refresh() # pyright: ignore[reportPossiblyUnboundVariable] + to_extend: dict[Optional[str], list[Node]] = defaultdict(list) for act, event in zip(activations, events): - if act.key is None: - # Repeated the code in both branches even though this branch - # even though in this branch the key is not what it was supposed to be - # It was stated to have string keys when initializing as the defaultdict above - to_extend[act.key].extend(act.route.activated_nodes(event)) # type: ignore - else: - to_extend[act.key].extend(act.route.activated_nodes(event)) + to_extend[act.key].extend(act.route.activated_nodes(event)) local_tape.extend(to_extend) buffer_entries = [(act.key, act.route, event) for act, event in zip(activations, events)] # event buffer is bound because this is within if self._flow_in_use # some of the event's in buffer_entries could have been None - event_buffer[priority].extend(buffer_entries) # type: ignore + event_buffer[priority].extend(buffer_entries) # pyright: ignore[reportPossiblyUnboundVariable] if self._download_states is not None: await self._download_states(local_tape.revert()) @@ -1164,10 +1178,11 @@ async def _safe_call(fn: ANY_TO_AWAIT, payload: Any) -> Any: # ----[ Final: Pass Data Over To Senders ]---- if self._flow.in_use: # event buffer is bound because this is within if self._flow_in_use - for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): # type: ignore + for priority, event_list in sorted(event_buffer.items(), key=lambda kv: kv[0]): # pyright: ignore[reportPossiblyUnboundVariable] for event_data in event_list: # this will block if the bridge is full, slowing down readers - await self.event_bridge.put((priority,) + event_data) # type: ignore + to_put = (priority,) + event_data + await self.event_bridge.put(to_put) # pyright: ignore[reportOptionalMemberAccess] event_buffer = {} @@ -1208,9 +1223,13 @@ async def _send_worker( while True: - item: Optional[tuple[str, Sender]] = await self.send_queue.get() # type: ignore + if self.send_queue is None: + # pylint:disable=line-too-long + raise ValueError("send_queue is not initialized; this should not happen because this is protected method and called within handle_session") + + item: Optional[tuple[str, Sender]] = await self.send_queue.get() if item is None: - self.send_queue.task_done() # type: ignore + self.send_queue.task_done() break route, sender = item @@ -1239,11 +1258,12 @@ async def _send_worker( for priority, sending_hook in sorted(sending_hooks.items(), key=lambda kv: hook_priority_order(kv[0])): try: - new_payload = await sending_hook(payload) # type: ignore + new_payload = await sending_hook(payload) if new_payload is None: payload = None break + payload = new_payload except Exception as e: self.logger.error( @@ -1251,8 +1271,6 @@ async def _send_worker( f"failed on payload {payload!r}: {e}", exc_info=True ) - new_payload = payload - payload = new_payload # if *any* hook returned None, skip the rest of processing if payload is None: @@ -1290,7 +1308,7 @@ async def _send_worker( break finally: - self.send_queue.task_done() # type: ignore + self.send_queue.task_done() async def _cleanup_workers(self): """ @@ -1346,10 +1364,11 @@ def _route_accepts( async with self.routes_lock: sender_parsed_routes: dict[str, ParsedRoute] = self.sender_parsed_routes.copy() - pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Event]] = [] + pending: list[tuple[tuple[int, ...], Optional[str], ParsedRoute, Optional[Event]]] = [] try: while True: - pending.append(self.event_bridge.get_nowait()) # type: ignore + # event_bridge exists, handle_session did so + pending.append(self.event_bridge.get_nowait()) # pyright: ignore[reportOptionalMemberAccess] except asyncio.QueueEmpty: pass @@ -1373,13 +1392,14 @@ def _route_accepts( (sender.triggers and isinstance(sender.triggers, set))): # self._flow_in_use so sender_parsed_routes is bound - sender_parsed_routes = cast(Dict[str,ParsedRoute],sender_parsed_routes) # type: ignore + sender_parsed_routes = cast(Dict[str,ParsedRoute],sender_parsed_routes) # pyright: ignore[reportPossiblyUnboundVariable] sender_parsed_route = sender_parsed_routes.get(route) if sender_parsed_route is None: continue # Iterate pending in queue order; first match "wins" for this (route,key,fn_name) - for (_priority, key, parsed_route, event) in pending: # type: ignore + # _flow_in_use so pending is bound + for (_priority, key, parsed_route, event) in pending: # pyright: ignore[reportPossiblyUnboundVariable] if _route_accepts(sender_parsed_route, parsed_route) and sender.responds_to(event): dedup_key = (route, key, sender.fn.__name__) # key scopes to the activation thread/peer if dedup_key not in emitted: @@ -1392,21 +1412,25 @@ def _route_accepts( await asyncio.sleep(0.1) # Time continue else: - queue_size = self.send_queue.qsize() # type: ignore + # send_queue exists, handle_session did so + queue_size = self.send_queue.qsize() # pyright: ignore[reportOptionalMemberAccess] expected_queue_size = queue_size + len(senders) - if expected_queue_size > self.send_queue_maxsize * 0.8: # type: ignore # 80% full + # self.send_queue_maxsize is set to an integer by now + if expected_queue_size > self.send_queue_maxsize * 0.8: # pyright: ignore[reportOptionalOperand] # 80% full self.logger.warning(f"Queue is about to exceed 80% its capacity; Attempted load size: {expected_queue_size} out of {self.send_queue_maxsize}") # ----[ Enqueue Sender Batch | Senders Are Run in Background ]---- try: for sender in senders: - await self.send_queue.put(sender) # type: ignore # Will block if full (i.e., back-pressure) + # self.send_queue exists, handle_session did so + await self.send_queue.put(sender) # pyright: ignore[reportOptionalMemberAccess] # Will block if full (i.e., back-pressure) except asyncio.CancelledError: self.logger.info("Sender enqueue loop cancelled mid-batch.") raise # ----[ Wait for Sender Batch to Finish]---- - await self.send_queue.join() # type: ignore + # self.send_queue exists, handle_session did so + await self.send_queue.join() # pyright: ignore[reportOptionalMemberAccess] if self.batch_drain: async with self.writer_lock: @@ -1448,15 +1472,16 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): Run listener and sender concurrently; whichever exits first (due to disconnect, /quit or /travel) triggers session termination. The remaining task is cancelled. """ - while True: + # Shared flag between the two tasks to signal coordinated session termination stop_event = asyncio.Event() # Always clean up old worker tasks and queues before starting new session; guarantees a fresh worker batch and prevents zombie tasks. await self._cleanup_workers() - self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) # type: ignore - self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) # type: ignore + # The maxsize's are set to integers by now by _apply_config, so these queues will not raise TypeError for maxsize=None + self.send_queue = asyncio.Queue(maxsize=self.send_queue_maxsize) # pyright: ignore[reportArgumentType] + self.event_bridge = asyncio.Queue(maxsize = self.event_bridge_maxsize) # pyright: ignore[reportArgumentType] # reset any previous travel/quit intent so each session starts fresh; # travel is only honored if set after this point, quit likewise @@ -1469,7 +1494,9 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Register this session's task so it can be cancelled during shutdown current_task = asyncio.current_task() async with self.tasks_lock: - self.active_tasks.add(current_task) # type: ignore + # current_task assumed not None + # but we are within the try block + self.active_tasks.add(current_task) # pyright: ignore[reportArgumentType] # Use lock when accessing dynamic routing information async with self.connection_lock: @@ -1531,8 +1558,12 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Deregister this session and its children from active tasks async with self.tasks_lock: - if task is not None: # type: ignore - self.active_tasks.discard(task) # type: ignore + try: + if task is not None: # pyright: ignore[reportPossiblyUnboundVariable] + self.active_tasks.discard(task) # pyright: ignore[reportPossiblyUnboundVariable] + except NameError: + pass + # Check whether we should quit or loop back to travel to the next server (agent migration) async with self.connection_lock: @@ -1723,14 +1754,14 @@ def run( self.loop.close() self.logger.info("Client exited cleanly.") - def _view_candidates(self) -> Generator[Optional[Callable[[Any], Awaitable]],None,None]: + def _view_candidates(self) -> Generator[Optional[Callable[..., Coroutine[Any,Any,Any] | Awaitable[Any]]],None,None]: if self._upload_states is not None: yield self._upload_states if self._download_states is not None: yield self._download_states for d in self._dna_receivers: - yield d.get("fn") # type: ignore + yield d.get("fn") for d in self._dna_senders: - yield d.get("fn") # type: ignore + yield d.get("fn") for d in self._dna_hooks: - yield d.get("fn") # type: ignore + yield d.get("fn") diff --git a/summoner/client/client_types.py b/summoner/client/client_types.py new file mode 100644 index 0000000..85278bc --- /dev/null +++ b/summoner/client/client_types.py @@ -0,0 +1,61 @@ +""" +Types used for client and client DNA +""" +#pylint:disable=wrong-import-position +from typing import ( +Dict, +List, +Optional, +Callable, +TypedDict, +Union, +Any, +Coroutine, +) +import os +import sys + +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) +from summoner.protocol.process import Node +from summoner.protocol.triggers import Event +from summoner.protocol.payload import RelayedMessage + +""" +See the _check_param_and_return_types function in client.py +inside the decorator internal functions for where these types are actually enforced. +This tells us what the expected signatures of client hooks, upload functions, and decorated send/receive handlers are. +That way we can put those expected signatures in one place here. Though +the actual enforcement of these types is in client.py, not here. +""" + +HOOK_ARGUMENT = Any | str | Dict[Any,Any] +HOOK_RETURN = None | str | Dict[Any,Any] | Any +HOOK_TYPE = Callable[[HOOK_ARGUMENT], Coroutine[Any,Any,HOOK_RETURN]] + +UPLOAD_RETURN = Union[None | str | Any | Node | List[Any] | Dict[Any,Any]\ + | List[str] | Dict[str, str] | Dict[str, List[str]]\ + | List[Node] | Dict[str, Node] | Dict[str, List[Node]]\ + | Dict[str, Union[str, list[str]]]\ + | Dict[str, Union[Node, list[Node]]]\ + | Dict[str, Union[str, list[str], Node, list[Node]]] +] +UPLOAD_ARGUMENT = None | Any | str | Dict[Any,Any] | TypedDict +UPLOAD_TYPE = Callable[[UPLOAD_ARGUMENT], Coroutine[Any,Any,UPLOAD_RETURN]] + +DOWNLOAD_TYPE = Callable[[Any], Coroutine[Any,Any,Optional[Any]]] + +SENDING_HOOKS_TYPE = Callable[ + [Optional[Union[str, dict]]], + Coroutine[Any,Any,Optional[Union[str, dict]]]] +RECEIVING_HOOKS_TYPE = Callable[ + [Optional[Union[str, dict, RelayedMessage]]], + Coroutine[Any,Any,Optional[Union[str, dict]]]] + +SEND_RETURN_SINGLE_TYPE = None | Any | str | Dict[Any,Any] +SEND_RETURN_MULTI_TYPE = Any | List[Any] | List[str] | List[Dict[Any,Any]] | List[Union[str, Dict[Any,Any]]] +SEND_DECORATED_TYPE = Callable[[], Coroutine[Any, Any, SEND_RETURN_SINGLE_TYPE]] |\ + Callable[[], Coroutine[Any, Any, SEND_RETURN_MULTI_TYPE]] + +RECEIVE_DECORATED_TYPE = Callable[[Any | str | Dict[Any,Any]], Coroutine[Any, Any, Any | Event | None]] diff --git a/summoner/client/dna.py b/summoner/client/dna.py new file mode 100644 index 0000000..eed24e6 --- /dev/null +++ b/summoner/client/dna.py @@ -0,0 +1,167 @@ +""" +A client's registered behavior can be serialized into a JSON string ("DNA"). +The types and functions that support this are defined here. + +Purpose +------- +DNA is intended to support: + 1) cloning: rehydrate a client from data by re-evaluating handler sources + 2) merging: combine multiple clients into a composite client (ClientMerger) + +Portability rule +---------------- +DNA is meant to be replayable across environments. Unstable runtime bindings +(for example '__main__' imports or live objects that cannot be rebuilt) should +end up in "missing", not embedded implicitly. +""" +#pylint:disable=wrong-import-position + +from typing import Any, Optional, Set, Type, TypedDict + +import os +import sys + +target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +if target_path not in sys.path: + sys.path.insert(0, target_path) + +from summoner.protocol.flow import Flow +from summoner.utils.code_handlers import get_callable_source +from summoner.protocol.triggers import Signal +from summoner.client.client_types import DOWNLOAD_TYPE, RECEIVE_DECORATED_TYPE, SEND_DECORATED_TYPE, HOOK_TYPE, UPLOAD_TYPE +from summoner.protocol.process import Direction + + +class DNAHook(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_hooks list + """ + fn: HOOK_TYPE + direction: Direction + priority: tuple[int, ...] + source: Optional[str] + +def hook_entry_contribution(hook_entry: DNAHook) -> dict[str, Any]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = hook_entry["fn"] + return { + "type": "hook", + "direction": hook_entry["direction"].name, + "priority": hook_entry["priority"], + "source": get_callable_source(fn, hook_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNAReceiver(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_receivers list + """ + fn: RECEIVE_DECORATED_TYPE + route: str + priority: tuple[int, ...] + source: Optional[str] + +def receiver_entry_contribution(receiver_entry: DNAReceiver, flow_in_use: Optional[Flow]) -> dict[str, Any]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = receiver_entry["fn"] + raw_route = receiver_entry["route"] + + try: + if flow_in_use is not None: + route_key = str(flow_in_use.parse_route(raw_route)) + else: + route_key = raw_route + except Exception: + route_key = raw_route + route_key = "".join(str(route_key).split()) + + return { + "type": "receive", + "route": raw_route, # original route string + "route_key": route_key, # stable route representative + "priority": receiver_entry["priority"], + "source": get_callable_source(fn, receiver_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNASender(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_senders list + """ + fn: SEND_DECORATED_TYPE + route: str + multi: bool + on_triggers: Optional[Set[Any] | Set[Signal]] + on_actions: Optional[set[Any] | Set[Type]] + source: Optional[str] + +def sender_entry_contribution(sender_entry: DNASender, flow_in_use: Optional[Flow]) -> dict[str, Any]: + """ + The contribution of this entry to the overall DNA dict. + """ + fn = sender_entry["fn"] + raw_route = sender_entry["route"] + + try: + if flow_in_use is not None: + route_key = str(flow_in_use.parse_route(raw_route)) + else: + route_key = raw_route + except Exception: + route_key = raw_route + route_key = "".join(str(route_key).split()) + + return { + "type": "send", + "route": raw_route, # original route string + "route_key": route_key, # stable route representative + "multi": sender_entry["multi"], + # Serialize triggers/actions by name so they can be re-resolved later. + "on_triggers": [t.name for t in (sender_entry["on_triggers"] or [])], + "on_actions": [a.__name__ for a in (sender_entry["on_actions"] or [])], + "source": get_callable_source(fn, sender_entry.get("source")), + "module": fn.__module__, + "fn_name": fn.__name__, + } + +class DNA_UPLOAD(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_upload_states entry + """ + fn: UPLOAD_TYPE + source: str + +def upload_entry_contribution(upload_entry: DNA_UPLOAD) -> dict[str, Any]: + """ + The contribution of this entry to the overall DNA dict. + """ + return { + "type": "upload_states", + "source": get_callable_source(upload_entry["fn"], upload_entry["source"]), + "module": upload_entry["fn"].__module__, + "fn_name": upload_entry["fn"].__name__, + } + +class DNA_DOWNLOAD(TypedDict): + """ + The sort of information kept in a SummonerClient's _dna_download_states entry + """ + fn: DOWNLOAD_TYPE + source: str + +def download_entry_contribution(download_entry: DNA_DOWNLOAD) -> dict[str, Any]: + """ + The contribution of this entry to the overall DNA dict. + """ + return { + "type": "download_states", + "source": get_callable_source(download_entry["fn"], download_entry["source"]), + "module": download_entry["fn"].__module__, + "fn_name": download_entry["fn"].__name__, + } diff --git a/summoner/client/just_merger.py b/summoner/client/just_merger.py index 690fc8b..f7f47e5 100644 --- a/summoner/client/just_merger.py +++ b/summoner/client/just_merger.py @@ -71,7 +71,7 @@ from summoner.protocol.process import Direction -def _resolve_trigger(TriggerCls, name: str): +def _resolve_trigger(TriggerCls, name: str) -> Any: """ Resolve a trigger name into a trigger instance from TriggerCls. @@ -603,7 +603,7 @@ def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.Fu # if your dna() uses a __dna_source__ fallback, keep it if hasattr(fn, "__dna_source__"): - new_fn.__dna_source__ = fn.__dna_source__ # type: ignore + new_fn.__dna_source__ = fn.__dna_source__ # pyright: ignore[reportFunctionMemberAccess] return new_fn @@ -827,7 +827,7 @@ def initiate_senders(self): dec = self.send( entry["route"], multi=entry.get("multi", False), - on_triggers=on_triggers, # type: ignore + on_triggers=on_triggers, on_actions=on_actions, ) self._apply_with_source_patch(dec, fn, entry["source"]) diff --git a/summoner/client/translation.py b/summoner/client/translation.py index 9c68951..3beebc2 100644 --- a/summoner/client/translation.py +++ b/summoner/client/translation.py @@ -61,11 +61,11 @@ import os import sys -from summoner.client.just_merger import _resolve_action, _resolve_trigger target_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) if target_path not in sys.path: sys.path.insert(0, target_path) +from summoner.client.just_merger import _resolve_action, _resolve_trigger from summoner.client.client import SummonerClient from summoner.protocol.triggers import Action, load_triggers from summoner.protocol.process import Direction @@ -263,7 +263,7 @@ def _cleanup_template_clients_from_modules(self): for module_name in modules: try: - module = sys.modules.get(module_name) or import_module(module_name) # type: ignore + module = sys.modules.get(module_name) or import_module(module_name) # pyright: ignore[reportArgumentType] except Exception: continue @@ -395,7 +395,7 @@ def initiate_senders(self): dec = self.send( entry["route"], multi=entry.get("multi", False), - on_triggers=on_triggers, # type: ignore + on_triggers=on_triggers, on_actions=on_actions, ) self._apply_with_source_patch(dec, fn, entry["source"]) diff --git a/summoner/protocol/_deprecation.py b/summoner/protocol/_deprecation.py index c0c76c0..c56f128 100644 --- a/summoner/protocol/_deprecation.py +++ b/summoner/protocol/_deprecation.py @@ -2,7 +2,7 @@ Warnings about using deprecated methods """ try: - from warnings import deprecated # Python 3.13+ + from warnings import deprecated # pyright: ignore[reportAttributeAccessIssue] # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python <= 3.12 diff --git a/summoner/protocol/flow.py b/summoner/protocol/flow.py index b9e6ab7..44ae5a2 100644 --- a/summoner/protocol/flow.py +++ b/summoner/protocol/flow.py @@ -8,7 +8,7 @@ import warnings from .triggers import load_triggers from .process import Node, ArrowStyle, ParsedRoute -from ._deprecation import deprecated # type: ignore +from ._deprecation import deprecated # pyright: ignore[reportAttributeAccessIssue] # pylint:disable=line-too-long diff --git a/summoner/protocol/payload.py b/summoner/protocol/payload.py index c5c2aa6..4536a8e 100644 --- a/summoner/protocol/payload.py +++ b/summoner/protocol/payload.py @@ -2,10 +2,9 @@ TODO: doc payload """ #pylint:disable=line-too-long -from enum import Enum, auto import json from json import JSONDecodeError -from typing import Any, Literal, Tuple, Dict, List, Union, TypedDict +from typing import Any, Tuple, Dict, List, Union, TypedDict from summoner.utils import ( fully_recover_json, @@ -251,13 +250,13 @@ def recover_with_types(text: str) -> RelayedMessage: obj = fully_recover_json(text) except (ValueError, JSONDecodeError): # If that fails (e.g. pure warning string), strip newline and relay the raw text - return remove_last_newline(text) # type: ignore + return remove_last_newline(text) # pyright: ignore[reportReturnType] # 2) Ensure we have the outer {"remote_addr":…, "content":…} wrapper if not (isinstance(obj, dict) and "remote_addr" in obj and "content" in obj): # Malformed protocol message; upstream code can catch this if needed # raise ValueError("Unsupported message format from server.") - return obj # type: ignore + return obj # pyright: ignore[reportReturnType] addr = obj["remote_addr"] content = obj["content"] @@ -270,7 +269,7 @@ def recover_with_types(text: str) -> RelayedMessage: and "_payload" in content and "_type" in content ): - return obj # type: ignore + return obj # pyright: ignore[reportReturnType] # 4) We have the versioned envelope—now look up the correct caster version = content["_version"] diff --git a/summoner/protocol/process.py b/summoner/protocol/process.py index 2ca02fb..993b143 100644 --- a/summoner/protocol/process.py +++ b/summoner/protocol/process.py @@ -5,7 +5,8 @@ from __future__ import annotations import re from collections import defaultdict -from typing import Dict, List, Literal, Tuple, Type, Any, Optional, Union, Callable, Awaitable +from typing import Coroutine, Dict, List, Literal, Tuple, Type, \ + Any, Optional, Union, Callable, Awaitable from enum import Enum, auto from dataclasses import dataclass from .triggers import Signal, Event, Action, extract_signal @@ -106,18 +107,19 @@ def accepts(self, state: Node) -> bool: if not isinstance(state, Node): raise TypeError(f"Argument `state` must be Node; {state} provided") + # pylint:disable=line-too-long table : Dict[Tuple[KindType | object, KindType | object], Callable[[Node,Node],bool]] = { ('all', 'all'): lambda g, s: True, ('all', _WILDCARD): lambda g, s: True, (_WILDCARD, 'all'): lambda g, s: True, - ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], # type: ignore - ('plain', 'not'): lambda g, s: g.values[0] not in s.values, # type: ignore - ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, # type: ignore - ('not', 'plain'): lambda g, s: s.values[0] not in g.values, # type: ignore + ('plain', 'plain'): lambda g, s: g.values[0] == s.values[0], # pyright: ignore[reportOptionalSubscript] + ('plain', 'not'): lambda g, s: g.values[0] not in s.values, # pyright: ignore[reportOperatorIssue,reportOptionalSubscript] + ('plain', 'oneof'): lambda g, s: g.values[0] in s.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] + ('not', 'plain'): lambda g, s: s.values[0] not in g.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] ('not', _WILDCARD): lambda g, s: True, - ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, # type: ignore - ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), # type: ignore - ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), # type: ignore + ('oneof', 'plain'): lambda g, s: s.values[0] in g.values, # pyright: ignore[reportOptionalSubscript,reportOperatorIssue] + ('oneof', 'not'): lambda g, s: bool(set(g.values) - set(s.values)), # pyright: ignore[reportArgumentType] + ('oneof', 'oneof'): lambda g, s: bool(set(g.values) & set(s.values)), # pyright: ignore[reportArgumentType] } for (gk, sk), fn in table.items(): @@ -356,7 +358,7 @@ class Sender: TODO: doc sender """ __slots__ = ('fn', 'multi', 'actions', 'triggers') - fn: Callable[[], Awaitable] + fn: Callable[[], Awaitable[Any]] multi: bool actions: Optional[set[Type]] triggers: Optional[set[Signal]] @@ -383,7 +385,7 @@ class Receiver: TODO: doc receiver """ __slots__ = ('fn', 'priority') - fn: Callable[[Union[str, dict]], Awaitable[Optional[Event]]] + fn: Callable[[Union[str, dict]], Coroutine[Any,Any,Optional[Event]]] priority: tuple[int, ...] class Direction(Enum): @@ -402,7 +404,7 @@ class TapeActivation: key: Optional[str] state: Optional[Node] route: ParsedRoute - fn: Callable[[Any], Awaitable] + fn: Callable[[Any], Coroutine[Any,Any,Any]] # ======= STATE TAPE ======= @@ -537,7 +539,9 @@ def _remove_prefix(self, key: Optional[str]) -> str: TODO: doc prefix """ p = f"{self.prefix}:" - return key[len(p):] if isinstance(key, str) and key.startswith(p) else key # type: ignore + if isinstance(key, str) and key.startswith(p): + return key[len(p):] + return key # pyright: ignore[reportReturnType] def extend(self, states: Any): """ diff --git a/summoner/server/server.py b/summoner/server/server.py index 11a72ec..fb9dca5 100644 --- a/summoner/server/server.py +++ b/summoner/server/server.py @@ -77,7 +77,7 @@ def __init__(self, name: Optional[str] = None): self.clients: set[asyncio.streams.StreamWriter] = set() self.clients_lock = asyncio.Lock() - self.active_tasks: dict[asyncio.Task, str] = {} + self.active_tasks: dict[Optional[asyncio.Task[Any]], str] = {} self.tasks_lock = asyncio.Lock() async def handle_client( @@ -96,7 +96,7 @@ async def handle_client( self.clients.add(writer) async with self.tasks_lock: - self.active_tasks[task] = addr # type: ignore + self.active_tasks[task] = addr try: while True: @@ -142,7 +142,7 @@ async def handle_client( await writer.wait_closed() async with self.tasks_lock: - self.active_tasks.pop(task, None) # type: ignore + self.active_tasks.pop(task, None) self.logger.info(f"{addr} connection closed.") @@ -182,7 +182,7 @@ async def wait_for_tasks_to_finish(self): Wait for all client handlers to finish """ async with self.tasks_lock: - tasks = list(self.active_tasks) + tasks = list(z for z in self.active_tasks if z is not None) if tasks: await asyncio.gather(*tasks, return_exceptions=True) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 964dca1..35346ee 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -135,7 +135,7 @@ def test_event_and_action_classes_and_extract_signal(): # Instantiate via a single-signal trigger Trigger = load_triggers(json_dict={"X": None}) #pylint:disable=invalid-name - sigX : Signal = Trigger.X # type: ignore + sigX : Signal = Trigger.X # pyright: ignore[reportAttributeAccessIssue] move_evt = Move(sigX) stay_evt = Stay(sigX) test_evt = Test(sigX) From 4bf96786a1d0c9e36c11d00775a8f67c59c6fe9a Mon Sep 17 00:00:00 2001 From: Cobord Date: Mon, 23 Feb 2026 04:52:38 -0500 Subject: [PATCH 09/11] guards more typeddict that are still compatible with using as regular dict elsewhere --- setup.py | 7 +- summoner/_version.py | 1 + summoner/client/client.py | 143 ++++++++++---------------- summoner/client/client_types.py | 16 ++- summoner/client/dna.py | 32 +++--- summoner/client/just_merger.py | 92 ++++++++++++----- summoner/client/translation.py | 19 ++-- summoner/logger.py | 3 +- summoner/protocol/flow.py | 3 +- summoner/protocol/payload.py | 57 ++++++----- summoner/protocol/process.py | 176 +++++++++++++++++++++++--------- summoner/protocol/triggers.py | 27 +++-- summoner/server/server.py | 3 +- summoner/utils/code_handlers.py | 3 +- summoner/utils/json_handlers.py | 3 +- tests/test_triggers.py | 23 +++++ 16 files changed, 375 insertions(+), 233 deletions(-) create mode 100644 summoner/_version.py diff --git a/setup.py b/setup.py index ed9511c..179c7f0 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,13 @@ from setuptools import setup, find_packages +# Read version without importing the package (safe for builds) +version_ns = {} +with open("summoner/_version.py", encoding="utf-8") as f: + exec(f.read(), version_ns) + setup( name="summoner", - version="0.1.0", + version=version_ns["__version__"], description="Summoner's core SDK", author="Remy Tuyeras", author_email="rtuyeras@summoner.org", diff --git a/summoner/_version.py b/summoner/_version.py new file mode 100644 index 0000000..545d07d --- /dev/null +++ b/summoner/_version.py @@ -0,0 +1 @@ +__version__ = "1.1.1" \ No newline at end of file diff --git a/summoner/client/client.py b/summoner/client/client.py index b4b12aa..7d27c3f 100644 --- a/summoner/client/client.py +++ b/summoner/client/client.py @@ -2,11 +2,12 @@ SummonerClient """ #pylint:disable=line-too-long, wrong-import-position, too-many-lines -#pylint:disable=logging-fstring-interpolation, broad-exception-caught +#pylint:disable=logging-fstring-interpolation import os import sys import json +from types import FrameType from typing import ( Awaitable, Dict, @@ -18,9 +19,9 @@ Tuple, Union, Coroutine, - Any, cast, ) +from typing import Any import asyncio import signal import inspect @@ -45,7 +46,12 @@ DNA_UPLOAD, DNAHook, DNAReceiver, - DNASender + DNASender, + hook_entry_contribution, + receiver_entry_contribution, + sender_entry_contribution, + upload_entry_contribution, + download_entry_contribution, ) from summoner.utils import ( load_config, @@ -85,6 +91,8 @@ RelayedMessage ) +from summoner._version import __version__ as core_version + class ServerDisconnected(Exception): """Raised when the server closes the connection.""" @@ -105,7 +113,7 @@ class SummonerClient: DEFAULT_EVENT_BRIDGE_SIZE = 1000 DEFAULT_MAX_CONSECUTIVE_ERRORS = 3 # Failed attempts to send before disconnecting - core_version = "1.1.1" + core_version = core_version def __init__(self, name: Optional[str] = None): @@ -506,7 +514,7 @@ async def register(): try: parsed_route = self._flow.parse_route(route) normalized_route = str(parsed_route) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"@receive: could not parse route {route!r} while flow is enabled; " f"registering raw route. Error: {type(e).__name__}: {e}" @@ -611,7 +619,7 @@ async def register(): try: parsed_route = self._flow.parse_route(route) normalized_route = str(parsed_route) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"@send: could not parse route {route!r} while flow is enabled; " f"registering raw route. Error: {type(e).__name__}: {e}" @@ -765,11 +773,9 @@ def dna(self, include_context: bool = False) -> str: - Flow construction and trigger enum creation are not captured unless they are referenced and executed within decorated handler sources. """ - # TODO: use *_entry_contribution methods to build entries and keep this method cleaner - #pylint:disable=import-outside-toplevel import builtins - + # TODO: Only used below, so could have not turned into a list right away # Keep only as a generator being iterated over handler_fns = list(self._iter_registered_handler_functions()) @@ -777,90 +783,31 @@ def dna(self, include_context: bool = False) -> str: # ---------------------------- # Handler DNA entries # ---------------------------- - entries: List[Dict[str, Any]] = [] + entries: List[ + Dict[str, str] | \ + Dict[str, str | tuple[int,...]] |\ + Dict[str, str | bool | List[str]] + ] = [] # Upload state hook, if present if self._dna_upload_states is not None: - fn = self._dna_upload_states["fn"] - entries.append({ - "type": "upload_states", - "source": get_callable_source(fn, self._dna_upload_states.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(upload_entry_contribution(self._dna_upload_states)) # Download state hook, if present if self._dna_download_states is not None: - fn = self._dna_download_states["fn"] - entries.append({ - "type": "download_states", - "source": get_callable_source(fn, self._dna_download_states.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(download_entry_contribution(self._dna_download_states)) # All receivers for dna in self._dna_receivers: - fn = dna["fn"] - raw_route = dna["route"] - - try: - if self._flow.in_use: - route_key = str(self._flow.parse_route(raw_route)) - else: - route_key = raw_route - except Exception: - route_key = raw_route - route_key = "".join(str(route_key).split()) - - entries.append({ - "type": "receive", - "route": raw_route, # original route string - "route_key": route_key, # stable route representative - "priority": dna["priority"], - "source": get_callable_source(fn, dna.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(receiver_entry_contribution(dna, self._flow if self._flow.in_use else None)) # All senders for dna in self._dna_senders: - fn = dna["fn"] - raw_route = dna["route"] - - try: - if self._flow.in_use: - route_key = str(self._flow.parse_route(raw_route)) - else: - route_key = raw_route - except Exception: - route_key = raw_route - route_key = "".join(str(route_key).split()) - - entries.append({ - "type": "send", - "route": raw_route, # original route string - "route_key": route_key, # stable route representative - "multi": dna["multi"], - # Serialize triggers/actions by name so they can be re-resolved later. - "on_triggers": [t.name for t in (dna["on_triggers"] or [])], - "on_actions": [a.__name__ for a in (dna["on_actions"] or [])], - "source": get_callable_source(fn, dna["source"]), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(sender_entry_contribution(dna, self._flow if self._flow.in_use else None)) # All hooks for dna in self._dna_hooks: - fn = dna["fn"] - entries.append({ - "type": "hook", - "direction": dna["direction"].name, - "priority": dna["priority"], - "source": get_callable_source(fn, dna.get("source")), - "module": fn.__module__, - "fn_name": fn.__name__, - }) + entries.append(hook_entry_contribution(dna)) # Fast path: return only handler entries. if not include_context: @@ -889,7 +836,7 @@ def dna(self, include_context: bool = False) -> str: # Identify the runtime type of asyncio.Lock() so we can detect it. try: lock_type = type(asyncio.Lock()) - except Exception: + except Exception:# pylint:disable=broad-exception-caught lock_type = None # Used by resolve_import_statement to avoid emitting repeated imports. @@ -915,14 +862,14 @@ def dna(self, include_context: bool = False) -> str: nm = getattr(v, "__name__", None) if isinstance(nm, str) and nm: names_to_scan.add(nm) - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # Fallback: parse identifiers from annotation syntax in the source. try: src = get_callable_source(fn) names_to_scan |= extract_annotation_identifiers(src) - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass for name in names_to_scan: @@ -978,7 +925,7 @@ def dna(self, include_context: bool = False) -> str: if path_needed: imports_out.add("from pathlib import Path") - context_entry : Dict[str, Any] = { + context_entry : Dict[str, str | list[str] | dict[str,object] | dict[str,str]] = { "type": "__context__", "var_name": inferred_var_name, "imports": sorted(imports_out), @@ -1113,7 +1060,7 @@ async def _safe_call(fn: Callable[[Any],Awaitable[Any]], payload: Any) -> Option break payload = new_payload - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.error( f"Receiving hook {receiving_hook.__name__} (priority={priority}) " f"failed on payload {payload!r}: {e}", @@ -1265,7 +1212,7 @@ async def _send_worker( break payload = new_payload - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.error( f"[route={route}] Sending hook {sending_hook.__name__} (priority={priority}) " f"failed on payload {payload!r}: {e}", @@ -1295,7 +1242,7 @@ async def _send_worker( async with self.writer_lock: await writer.drain() - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught consecutive_errors += 1 self.logger.error( f"Worker for {sender.fn.__name__} crashed ({consecutive_errors} in a row): {e}", @@ -1324,7 +1271,7 @@ async def _cleanup_workers(self): self.worker_tasks.clear() self.send_workers_started = False - #pylint:disable=too-many-nested-blocks, no-else-continue, no-else-break, attribute-defined-outside-init + #pylint:disable=too-many-nested-blocks, no-else-continue, no-else-break async def message_sender_loop( self, writer: asyncio.StreamWriter, @@ -1538,7 +1485,7 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): except asyncio.CancelledError: # Normal during shutdown; ignore pass - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.exception(f"Unexpected error during session task: {e}") # Cleanly close the connection @@ -1559,8 +1506,8 @@ async def handle_session(self, host: str = '127.0.0.1', port: int = 8888): # Deregister this session and its children from active tasks async with self.tasks_lock: try: - if task is not None: # pyright: ignore[reportPossiblyUnboundVariable] - self.active_tasks.discard(task) # pyright: ignore[reportPossiblyUnboundVariable] + if current_task is not None: + self.active_tasks.discard(current_task) except NameError: pass @@ -1588,8 +1535,22 @@ def set_termination_signals(self): """ if platform.system() != "Windows": for sig in (signal.SIGINT, signal.SIGTERM): - #pylint:disable=unnecessary-lambda - self.loop.add_signal_handler(sig, lambda: self.shutdown()) + self.loop.add_signal_handler(sig, self.shutdown) + else: + def _handler(_sig: int, _frame: Optional[FrameType]): + # thread-safe: schedule shutdown on the event loop + try: + self.loop.call_soon_threadsafe(self.shutdown) + except RuntimeError: + pass + signal.signal(signal.SIGINT, _handler) + # SIGTERM exists on Windows in Python, but behavior varies by launcher + if hasattr(signal, "SIGTERM"): + try: + signal.signal(signal.SIGTERM, _handler) + except Exception:# pylint:disable=broad-exception-caught + pass + async def _wait_for_registration(self): """ diff --git a/summoner/client/client_types.py b/summoner/client/client_types.py index 85278bc..850a76e 100644 --- a/summoner/client/client_types.py +++ b/summoner/client/client_types.py @@ -1,7 +1,7 @@ """ Types used for client and client DNA """ -#pylint:disable=wrong-import-position +#pylint:disable=wrong-import-position, invalid-name from typing import ( Dict, List, @@ -9,9 +9,9 @@ Callable, TypedDict, Union, -Any, Coroutine, ) +from typing import Any import os import sys @@ -22,10 +22,12 @@ from summoner.protocol.triggers import Event from summoner.protocol.payload import RelayedMessage +#pylint:disable=pointless-string-statement """ See the _check_param_and_return_types function in client.py inside the decorator internal functions for where these types are actually enforced. -This tells us what the expected signatures of client hooks, upload functions, and decorated send/receive handlers are. +This tells us what the expected signatures of client hooks, upload functions, +and decorated send/receive handlers are. That way we can put those expected signatures in one place here. Though the actual enforcement of these types is in client.py, not here. """ @@ -54,8 +56,12 @@ Coroutine[Any,Any,Optional[Union[str, dict]]]] SEND_RETURN_SINGLE_TYPE = None | Any | str | Dict[Any,Any] -SEND_RETURN_MULTI_TYPE = Any | List[Any] | List[str] | List[Dict[Any,Any]] | List[Union[str, Dict[Any,Any]]] +SEND_RETURN_MULTI_TYPE = Any | List[Any] | List[str] | \ + List[Dict[Any,Any]] | List[Union[str, Dict[Any,Any]]] SEND_DECORATED_TYPE = Callable[[], Coroutine[Any, Any, SEND_RETURN_SINGLE_TYPE]] |\ Callable[[], Coroutine[Any, Any, SEND_RETURN_MULTI_TYPE]] -RECEIVE_DECORATED_TYPE = Callable[[Any | str | Dict[Any,Any]], Coroutine[Any, Any, Any | Event | None]] +RECEIVE_DECORATED_TYPE = Callable[ + [Any | str | Dict[Any,Any]], + Coroutine[Any, Any, Any | Event | None] +] diff --git a/summoner/client/dna.py b/summoner/client/dna.py index eed24e6..91d2fd3 100644 --- a/summoner/client/dna.py +++ b/summoner/client/dna.py @@ -14,9 +14,10 @@ (for example '__main__' imports or live objects that cannot be rebuilt) should end up in "missing", not embedded implicitly. """ -#pylint:disable=wrong-import-position +#pylint:disable=wrong-import-position, invalid-name, duplicate-code -from typing import Any, Optional, Set, Type, TypedDict +from typing import List, Optional, Set, Type, TypedDict +from typing import Any import os import sys @@ -28,7 +29,8 @@ from summoner.protocol.flow import Flow from summoner.utils.code_handlers import get_callable_source from summoner.protocol.triggers import Signal -from summoner.client.client_types import DOWNLOAD_TYPE, RECEIVE_DECORATED_TYPE, SEND_DECORATED_TYPE, HOOK_TYPE, UPLOAD_TYPE +from summoner.client.client_types import DOWNLOAD_TYPE, RECEIVE_DECORATED_TYPE, \ + SEND_DECORATED_TYPE, HOOK_TYPE, UPLOAD_TYPE from summoner.protocol.process import Direction @@ -41,7 +43,7 @@ class DNAHook(TypedDict): priority: tuple[int, ...] source: Optional[str] -def hook_entry_contribution(hook_entry: DNAHook) -> dict[str, Any]: +def hook_entry_contribution(hook_entry: DNAHook) -> dict[str, str | tuple[int,...]]: """ The contribution of this entry to the overall DNA dict. """ @@ -64,19 +66,21 @@ class DNAReceiver(TypedDict): priority: tuple[int, ...] source: Optional[str] -def receiver_entry_contribution(receiver_entry: DNAReceiver, flow_in_use: Optional[Flow]) -> dict[str, Any]: +def receiver_entry_contribution( + receiver_entry: DNAReceiver, + flow_in_use: Optional[Flow]) -> dict[str, str | tuple[int,...]]: """ The contribution of this entry to the overall DNA dict. """ fn = receiver_entry["fn"] raw_route = receiver_entry["route"] - + try: if flow_in_use is not None: route_key = str(flow_in_use.parse_route(raw_route)) else: route_key = raw_route - except Exception: + except Exception: #pylint:disable=broad-exception-caught route_key = raw_route route_key = "".join(str(route_key).split()) @@ -98,10 +102,12 @@ class DNASender(TypedDict): route: str multi: bool on_triggers: Optional[Set[Any] | Set[Signal]] - on_actions: Optional[set[Any] | Set[Type]] + on_actions: Optional[Set[Any] | Set[Type]] source: Optional[str] -def sender_entry_contribution(sender_entry: DNASender, flow_in_use: Optional[Flow]) -> dict[str, Any]: +def sender_entry_contribution( + sender_entry: DNASender, + flow_in_use: Optional[Flow]) -> dict[str, str | bool | List[str]]: """ The contribution of this entry to the overall DNA dict. """ @@ -113,7 +119,7 @@ def sender_entry_contribution(sender_entry: DNASender, flow_in_use: Optional[Flo route_key = str(flow_in_use.parse_route(raw_route)) else: route_key = raw_route - except Exception: + except Exception: #pylint:disable=broad-exception-caught route_key = raw_route route_key = "".join(str(route_key).split()) @@ -136,8 +142,8 @@ class DNA_UPLOAD(TypedDict): """ fn: UPLOAD_TYPE source: str - -def upload_entry_contribution(upload_entry: DNA_UPLOAD) -> dict[str, Any]: + +def upload_entry_contribution(upload_entry: DNA_UPLOAD) -> dict[str, str]: """ The contribution of this entry to the overall DNA dict. """ @@ -155,7 +161,7 @@ class DNA_DOWNLOAD(TypedDict): fn: DOWNLOAD_TYPE source: str -def download_entry_contribution(download_entry: DNA_DOWNLOAD) -> dict[str, Any]: +def download_entry_contribution(download_entry: DNA_DOWNLOAD) -> dict[str, str]: """ The contribution of this entry to the overall DNA dict. """ diff --git a/summoner/client/just_merger.py b/summoner/client/just_merger.py index f7f47e5..2a048e2 100644 --- a/summoner/client/just_merger.py +++ b/summoner/client/just_merger.py @@ -48,9 +48,10 @@ Do not run untrusted DNA. """ #pylint:disable=line-too-long, wrong-import-position -#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation +#pylint:disable=invalid-name, logging-fstring-interpolation -from typing import Optional, Any +from typing import Dict, List, Literal, Optional, TypeGuard, TypedDict +from typing import Any from contextlib import suppress from pathlib import Path import inspect @@ -99,12 +100,12 @@ def _resolve_trigger(TriggerCls, name: str) -> Any: # Enum-style: TriggerCls["ok"] try: return TriggerCls[name] - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # Attribute-style: TriggerCls.ok try: return getattr(TriggerCls, name) - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass raise KeyError(f"Unknown trigger '{name}' for {TriggerCls}") @@ -152,6 +153,46 @@ def _resolve_action(ActionCls, name: str): raise KeyError(f"Unknown action '{name}' for {ActionCls}") +class StructuredReport(TypedDict): + """ + This can be used as a dict[str, Any] + but fixed with keys label, succeeded, failed, skipped + This is the information returned by _apply_context + which says what happened for each + line in the ctx["imports"] + as far as whether it succeeded, failed with an error, or was skipped due to settings. + If ctx was None or there was no key for imports, then this report will have empty lists. + """ + label: str + succeeded: list[str] + failed: list[tuple[str, str]] # list of (item, error) + skipped: list[str] + +class NormalizedClientSource(TypedDict): + """ + See _normalize_source for the meaning of these fields. + """ + kind: Literal["client"] + var_name: str + client: SummonerClient + +class NormalizedDNASource(TypedDict): + """ + See _normalize_source for the meaning of these fields. + """ + kind: Literal["dna"] + var_name: str + dna_entries: List[Dict[Any,Any]] + context: Optional[Dict[Any,Any]] + sandbox_name: str + globals: Dict[Any,Any] + import_report: StructuredReport + +def just_client_source(arbitrary_source: NormalizedDNASource | NormalizedClientSource) -> TypeGuard[NormalizedClientSource]: + """ + A type guard to help with the fact that self.sources is a list of two different dict types. + """ + return arbitrary_source["kind"] == "client" class ClientMerger(SummonerClient): """ @@ -216,8 +257,8 @@ def __init__( self._close_subclients = close_subclients # Normalized sources used later by initiate_* replay methods. - self.sources: list[dict[str, Any]] = [] - self._import_reports: list[dict[str, Any]] = [] + self.sources: list[NormalizedClientSource | NormalizedDNASource] = [] + self._import_reports: list[StructuredReport] = [] for idx, entry in enumerate(named_clients): src = self._normalize_source(entry, idx) @@ -231,7 +272,9 @@ def __init__( # ---------------------------- #pylint:disable=too-many-branches - def _normalize_source(self, entry: Any, idx: int) -> dict[str, Any]: + def _normalize_source(self, + entry: SummonerClient | List[Dict[Any,Any]] | dict[str, Any], + idx: int) -> NormalizedClientSource | NormalizedDNASource: """ Normalize a user-provided source specification into a canonical dict. @@ -402,7 +445,7 @@ def _shutdown_imported_clients(self) -> None: 4) clear the template's registration list """ for src in self.sources: - if src.get("kind") != "client": + if not just_client_source(src): continue client: SummonerClient = src["client"] @@ -433,13 +476,13 @@ def _shutdown_imported_clients(self) -> None: old_loop = None try: # Set context so asyncio.gather/futures bind to the right loop. - with suppress(Exception): + with suppress(Exception):# pylint:disable=broad-exception-caught old_loop = asyncio.get_event_loop_policy().get_event_loop() asyncio.set_event_loop(loop) # Await cancellation. This is what prevents warnings. loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[{var_name}] Error draining registration tasks: {e}") finally: # Restore previous loop context (or clear it). @@ -449,7 +492,7 @@ def _shutdown_imported_clients(self) -> None: # 3) close loop after drain try: loop.close() - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[{var_name}] Error closing event loop: {e}") # 4) clear list @@ -462,7 +505,7 @@ def _shutdown_imported_clients(self) -> None: # ---------------------------- # pylint:disable=too-many-branches - def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[str, Any]: + def _apply_context(self, ctx: Optional[dict[Any,Any]], g: dict[Any,Any], *, label: str) -> StructuredReport: """ Apply a DNA context entry (imports, globals, recipes) into a sandbox globals dict. @@ -481,14 +524,15 @@ def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[st Returns ------- - dict[str, Any] - A structured report with keys: label, succeeded, failed, skipped. + StructurcedReport + can be used as a dict[str, Any] and has keys: label, succeeded, failed, skipped. + so how it is used remains the same as a dict, but it is more explicit about the expected structure. Security note ------------- This executes code from ctx (imports and recipes). Use only with trusted DNA. """ - report = {"label": label, "succeeded": [], "failed": [], "skipped": []} + report : StructuredReport = {"label": label, "succeeded": [], "failed": [], "skipped": []} if not isinstance(ctx, dict): return report @@ -508,7 +552,7 @@ def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[st report["succeeded"].append(line) if self._verbose_context_imports: self.logger.info(f"[merge ctx:{label}] import ok: {line}") - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught report["failed"].append((line, f"{type(e).__name__}: {e}")) self.logger.warning(f"[merge ctx:{label}] import failed: {line!r} ({type(e).__name__}: {e})") @@ -528,7 +572,7 @@ def _apply_context(self, ctx: Optional[dict], g: dict, *, label: str) -> dict[st try: # pylint:disable=eval-used g.setdefault(k, eval(expr, g, {})) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[merge ctx:{label}] recipe failed {k}={expr!r} ({type(e).__name__}: {e})") return report @@ -581,14 +625,14 @@ def _clone_handler(self, fn: types.FunctionType, original_name: str) -> types.Fu # rebind the client variable name (agent/client/etc) try: g[original_name] = self - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"Could not bind '{original_name}' to merged client: {e}") # rebind any shared globals (viz, Trigger, etc.) for k, v in self._rebind_globals.items(): try: g[k] = v - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"Could not bind global '{k}' in '{fn.__name__}': {e}") new_fn = types.FunctionType( @@ -686,7 +730,7 @@ def initiate_upload_states(self): fn_clone = self._clone_handler(fn, var_name) try: self.upload_states()(fn_clone) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[{var_name}] Failed to replay upload_states '{fn.__name__}': {e}") else: @@ -712,7 +756,7 @@ def initiate_download_states(self): fn_clone = self._clone_handler(fn, var_name) try: self.download_states()(fn_clone) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[{var_name}] Failed to replay download_states '{fn.__name__}': {e}") else: @@ -736,7 +780,7 @@ def initiate_hooks(self): fn_clone = self._clone_handler(dna["fn"], var_name) try: self.hook(dna["direction"], priority=dna["priority"])(fn_clone) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[{var_name}] Failed to replay hook '{dna['fn'].__name__}': {e}") else: @@ -761,7 +805,7 @@ def initiate_receivers(self): fn_clone = self._clone_handler(dna["fn"], var_name) try: self.receive(dna["route"], priority=dna["priority"])(fn_clone) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"[{var_name}] Failed to replay receiver '{dna['fn'].__name__}' on route '{dna['route']}': {e}" ) @@ -804,7 +848,7 @@ def initiate_senders(self): on_triggers=dna["on_triggers"], on_actions=dna["on_actions"], )(fn_clone) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning( f"[{var_name}] Failed to replay sender '{dna['fn'].__name__}' on route '{dna['route']}': {e}" ) diff --git a/summoner/client/translation.py b/summoner/client/translation.py index 3beebc2..49f3534 100644 --- a/summoner/client/translation.py +++ b/summoner/client/translation.py @@ -48,10 +48,11 @@ Do not run untrusted DNA. """ #pylint:disable=line-too-long, wrong-import-position, duplicate-code -#pylint:disable=invalid-name, broad-exception-caught,logging-fstring-interpolation +#pylint:disable=invalid-name, logging-fstring-interpolation from importlib import import_module -from typing import Optional, Any +from typing import Optional +from typing import Any import inspect import asyncio import types @@ -177,7 +178,7 @@ def _apply_context(self): exec(line, g) if self._verbose_context_imports: self.logger.info(f"[translation ctx] import ok: {line}") - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"[translation ctx] import failed: {line!r} ({type(e).__name__}: {e})") # Plain globals @@ -197,7 +198,7 @@ def _apply_context(self): # pylint:disable=eval-used try: g.setdefault(k, eval(expr, g, {})) - except Exception as e: + except Exception as e:# pylint:disable=broad-exception-caught self.logger.warning(f"Could not eval recipe for {k}: {expr!r} ({e})") #pylint:disable=unused-argument @@ -221,14 +222,14 @@ def _cleanup_one_template_client(self, client: SummonerClient, label: str): for t in regs: try: t.cancel() - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # drain cancellations if we can drive that loop try: if regs and loop is not None and (not loop.is_running()) and (not loop.is_closed()): loop.run_until_complete(asyncio.gather(*regs, return_exceptions=True)) - except Exception: + except Exception:# pylint:disable=broad-exception-caught # best-effort only pass @@ -236,14 +237,14 @@ def _cleanup_one_template_client(self, client: SummonerClient, label: str): try: # pylint:disable=protected-access client._registration_tasks.clear() - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass # close the loop (template clients are not meant to run) try: if loop is not None and (not loop.is_running()) and (not loop.is_closed()): loop.close() - except Exception: + except Exception:# pylint:disable=broad-exception-caught pass def _cleanup_template_clients_from_modules(self): @@ -264,7 +265,7 @@ def _cleanup_template_clients_from_modules(self): for module_name in modules: try: module = sys.modules.get(module_name) or import_module(module_name) # pyright: ignore[reportArgumentType] - except Exception: + except Exception:# pylint:disable=broad-exception-caught continue g = getattr(module, "__dict__", {}) diff --git a/summoner/logger.py b/summoner/logger.py index ccb5fae..8da29c6 100644 --- a/summoner/logger.py +++ b/summoner/logger.py @@ -10,7 +10,8 @@ import datetime import re -from typing import Dict, Optional, Any +from typing import Dict, Optional +from typing import Any from logging.handlers import RotatingFileHandler # This makes Logger importable from logger.py diff --git a/summoner/protocol/flow.py b/summoner/protocol/flow.py index 44ae5a2..30fe5e9 100644 --- a/summoner/protocol/flow.py +++ b/summoner/protocol/flow.py @@ -4,7 +4,8 @@ from __future__ import annotations import re from collections.abc import Callable -from typing import Iterable, Optional, Any +from typing import Iterable, Optional +from typing import Any import warnings from .triggers import load_triggers from .process import Node, ArrowStyle, ParsedRoute diff --git a/summoner/protocol/payload.py b/summoner/protocol/payload.py index 4536a8e..f831634 100644 --- a/summoner/protocol/payload.py +++ b/summoner/protocol/payload.py @@ -4,7 +4,8 @@ #pylint:disable=line-too-long import json from json import JSONDecodeError -from typing import Any, Tuple, Dict, List, Union, TypedDict +from typing import Tuple, Dict, List, Union, TypedDict +from typing import Any from summoner.utils import ( fully_recover_json, @@ -12,8 +13,10 @@ ensure_trailing_newline, ) -# Current envelope version -WRAPPER_VERSION = "0.0.1" +from summoner._version import __version__ as core_version + +# Default envelope version +DEFAULT_VERSION = "0.0.1" # Registries for versioned parsers and casters envelope_parsers: Dict[str, Any] = {} @@ -41,7 +44,7 @@ def register_envelope_version( #pylint:disable=too-many-return-statements -def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: +def parse_v0_0_1(obj: Any) -> Tuple[Any,Any]: """ Walk `obj` and build a parallel `type_tree`. Every leaf in `obj` becomes a simple type name; every list/dict is recursed. @@ -90,24 +93,11 @@ def parse_v0_0_1(obj: Any) -> Tuple[Any, Any]: s = str(obj) return s, STR_TYPE - -#pylint:disable=too-many-return-statements, too-many-branches -def cast_v0_0_1(val: Any, expected: Any) -> Any: +def _cast_primitive(val: Any, expected: Any) -> str | bool | int | float | None | Any: """ Coerce `val` according to `expected`, but never fail on unknown types. Instead, pass the value through unchanged if we don't understand. - - Supported expected descriptors: - • "string", "boolean", "integer", "number", "null" - • a list of descriptors → cast element-wise up to len(expected) - • a dict of descriptors → cast known keys, keep extras unchanged - • None or unknown → identity (return val) """ - # 1) Identity for missing/unknown descriptors - if expected is None: - return val - - # 2) Primitives if expected == STR_TYPE: return str(val) if expected == BOOL_TYPE: @@ -126,6 +116,27 @@ def cast_v0_0_1(val: Any, expected: Any) -> Any: if expected == NULL_TYPE: # always map to Python None return None + # Unknown expected type → return the raw value + raise ValueError(f"Unreachable code in _cast_primitive: expected={expected!r}") + +def cast_v0_0_1(val: Any, expected: Any) -> Any: + """ + Coerce `val` according to `expected`, but never fail on unknown types. + Instead, pass the value through unchanged if we don't understand. + + Supported expected descriptors: + • "string", "boolean", "integer", "number", "null" + • a list of descriptors → cast element-wise up to len(expected) + • a dict of descriptors → cast known keys, keep extras unchanged + • None or unknown → identity (return val) + """ + # 1) Identity for missing/unknown descriptors + if expected is None: + return val + + # 2) Primitives + if expected in {STR_TYPE, BOOL_TYPE, INT_TYPE, NUMB_TYPE, NULL_TYPE}: + return _cast_primitive(val, expected) # 3) Lists: cast up to the expected length, keep extras untouched if isinstance(expected, list) and isinstance(val, list): @@ -159,12 +170,12 @@ def cast_v0_0_1(val: Any, expected: Any) -> Any: register_envelope_version("1.0.0", parse_v0_0_1, cast_v0_0_1) register_envelope_version("1.0.1", parse_v0_0_1, cast_v0_0_1) register_envelope_version("1.1.0", parse_v0_0_1, cast_v0_0_1) -register_envelope_version("1.1.1", parse_v0_0_1, cast_v0_0_1) - +# register_envelope_version("1.1.1", parse_v0_0_1, cast_v0_0_1) +register_envelope_version(core_version, parse_v0_0_1, cast_v0_0_1) def wrap_with_types( payload: Any, - version: str = WRAPPER_VERSION + version: str = DEFAULT_VERSION ) -> str: """ Wrap `payload` in a self-describing JSON envelope with type metadata. @@ -176,7 +187,7 @@ def wrap_with_types( Args: payload: Any JSON-serializable object (dict, list, primitive). - version: Version string for the envelope format (defaults to WRAPPER_VERSION). + version: Version string for the envelope format (defaults to DEFAULT_VERSION). Returns: A JSON string representing the typed envelope. @@ -273,7 +284,7 @@ def recover_with_types(text: str) -> RelayedMessage: # 4) We have the versioned envelope—now look up the correct caster version = content["_version"] - caster = envelope_casters.get(version, envelope_casters[WRAPPER_VERSION]) + caster = envelope_casters.get(version, envelope_casters[DEFAULT_VERSION]) if caster is None: # Unknown version: hard error so we don't silently mis-interpret data raise ValueError(f"Unsupported wrapper version: {version}") diff --git a/summoner/protocol/process.py b/summoner/protocol/process.py index 993b143..bf66de1 100644 --- a/summoner/protocol/process.py +++ b/summoner/protocol/process.py @@ -5,8 +5,9 @@ from __future__ import annotations import re from collections import defaultdict -from typing import Coroutine, Dict, List, Literal, Tuple, Type, \ - Any, Optional, Union, Callable, Awaitable +from typing import Coroutine, Dict, List, Literal, Mapping, Tuple, Type, \ + Optional, TypeGuard, Union, Callable, Awaitable +from typing import Any from enum import Enum, auto from dataclasses import dataclass from .triggers import Signal, Event, Action, extract_signal @@ -408,6 +409,22 @@ class TapeActivation: # ======= STATE TAPE ======= +TupleStrNode = Tuple[str | Node, ...] | Tuple[Node, ...] | Tuple[str, ...] +ListStrNode = List[str | Node] | List[Node] | List[str] +DictSingleStrNode = Mapping[Optional[str], + str | Node + ] | Mapping[str, + str | Node + ] +DictManyStrNode = Mapping[ + Optional[str], + str | Node | ListStrNode | TupleStrNode] | \ + Mapping[str, + str | Node | ListStrNode | TupleStrNode] + +StatesType = DictSingleStrNode | DictManyStrNode | str | Node | \ + ListStrNode | TupleStrNode | None + class TapeType(Enum): """ TODO: doc tapetype @@ -417,6 +434,87 @@ class TapeType(Enum): INDEX_SINGLE = auto() INDEX_MANY = auto() + @staticmethod + def single_type_guard(states: Any) -> \ + TypeGuard[str | Node]: + """ + This input as states would get + interpreted as SINGLE type, so it must be a single str or Node + """ + return isinstance(states, (str, Node)) + + @staticmethod + def many_type_guard(states: Any) -> \ + TypeGuard[TupleStrNode | ListStrNode]: + """ + This input as states would get + interpreted as MANY type, so it must be a list or tuple of str or Node + """ + return isinstance(states, (list, tuple)) and \ + all(isinstance(s, (str, Node)) for s in states) + + @staticmethod + def index_single_guard(states: Any) -> \ + TypeGuard[DictSingleStrNode]: + """ + This input as states would get + interpreted as INDEX_SINGLE type, so it must be a dict from Optional[str] to str or Node + """ + return isinstance(states, dict) and \ + all( + isinstance(k, (str, type(None))) and \ + isinstance(v, (str, Node)) for k, v in states.items() + ) + + @staticmethod + def index_many_guard(states: Any) -> \ + TypeGuard[DictManyStrNode]: + """ + This input as states would get + interpreted as INDEX_MANY type, + so it must be a dict from Optional[str] to list or tuple of str or Node + """ + return isinstance(states, dict) and \ + all( + isinstance(k, (str, type(None))) + and ( + isinstance(v, (str, Node)) + or ( + isinstance(v, (list, tuple)) + and all(isinstance(x, (str, Node)) for x in v) + ) + ) + for k, v in states.items()) + + @staticmethod + def _assess_type(states: StatesType) -> Optional[TapeType]: + """ + If there is only one state, SINGLE + If there is a list or tuple of many states, MANY + If there is a dictionary sending each Optional[str] key to a single state, INDEX_SINGLE + If there is a dictionary sending each Optional[str] key to possibly many states, INDEX_MANY + """ + # Scalar → SINGLE + if TapeType.single_type_guard(states): + return TapeType.SINGLE + + # Sequence of scalars → MANY + if TapeType.many_type_guard(states): + return TapeType.MANY + + # Mapping → either INDEX_SINGLE or INDEX_MANY + if isinstance(states, dict): + # all values are scalar → INDEX_SINGLE + if TapeType.index_single_guard(states): + return TapeType.INDEX_SINGLE + + # all values are either scalar or sequence of scalars → INDEX_MANY + # but at least one of the values was actually a sequence + if TapeType.index_many_guard(states): + return TapeType.INDEX_MANY + + return None + class StateTape: """ TODO: doc state tape @@ -425,9 +523,9 @@ class StateTape: prefix: str = "tape" - def __init__(self, states: Any = None, with_prefix: bool = True): + def __init__(self, states: StatesType = None, with_prefix: bool = True): # Figure out what kind of input we have - tp = StateTape._assess_type(states) + tp = TapeType._assess_type(states) # Default: empty index-many if tp is None: @@ -436,18 +534,27 @@ def __init__(self, states: Any = None, with_prefix: bool = True): # Exactly SINGLE elif tp is TapeType.SINGLE: + assert TapeType.single_type_guard(states) node = self._nodeify(states) # wrap str→Node if needed + if not with_prefix: + #pylint:disable=line-too-long + raise ValueError("StateTape constructor with with_prefix=False only gets called internally and using a dict input.") self.states = {self.prefix: [node]} self._type = tp # Exactly MANY elif tp is TapeType.MANY: + assert TapeType.many_type_guard(states) nodes = [self._nodeify(s) for s in states] + if not with_prefix: + #pylint:disable=line-too-long + raise ValueError("StateTape constructor with with_prefix=False only gets called internally and using a dict input.") self.states = {self.prefix: nodes} self._type = tp # Exactly INDEX_SINGLE elif tp is TapeType.INDEX_SINGLE: + assert TapeType.index_single_guard(states) self.states = { self._add_prefix(k, with_prefix): [self._nodeify(v)] for k, v in states.items() @@ -456,8 +563,15 @@ def __init__(self, states: Any = None, with_prefix: bool = True): # Exactly INDEX_MANY elif tp is TapeType.INDEX_MANY: + assert TapeType.index_many_guard(states) + def v_to_node_list(v: Union[str, Node, ListStrNode, TupleStrNode]) -> List[Node]: + if isinstance(v, (str, Node)): + return [self._nodeify(v)] + if isinstance(v, (list, tuple)): + return [self._nodeify(s) for s in v] + raise TypeError(f"Invalid value type in INDEX_MANY: {v!r}") self.states = { - self._add_prefix(k, with_prefix): [self._nodeify(s) for s in v] + self._add_prefix(k, with_prefix): v_to_node_list(v) for k, v in states.items() } self._type = tp @@ -473,54 +587,14 @@ def set_type(self, value: TapeType) -> StateTape: self._type = value return self - @staticmethod - def _assess_type(states: Any) -> Optional[TapeType]: - """ - If there is only one state, SINGLE - If there is a list or tuple of many states, MANY - If there is a dictionary sending each Optional[str] key to a single state, INDEX_SINGLE - If there is a dictionary sending each Optional[str] key to possibly many states, INDEX_MANY - """ - # Scalar → SINGLE - if isinstance(states, (str, Node)): - return TapeType.SINGLE - - # Sequence of scalars → MANY - if isinstance(states, (list, tuple)): - if all(isinstance(s, (str, Node)) for s in states): - return TapeType.MANY - - # Mapping → either INDEX_SINGLE or INDEX_MANY - if isinstance(states, dict): - # all values are scalar → INDEX_SINGLE - if all( - isinstance(k, (str, type(None))) - and isinstance(v, (str, Node)) - for k, v in states.items() - ): - return TapeType.INDEX_SINGLE - - # all values are either scalar or sequence of scalars → INDEX_MANY - # but at least one of the values was actually a sequence - if all( - isinstance(k, (str, type(None))) - and ( - isinstance(v, (str, Node)) - or ( - isinstance(v, (list, tuple)) - and all(isinstance(x, (str, Node)) for x in v) - ) - ) - for k, v in states.items() - ): - return TapeType.INDEX_MANY - - return None - - def _add_prefix(self, key: str, with_prefix: bool = True) -> str: + def _add_prefix(self, key: str | None, with_prefix: bool = True) -> str: """ TODO: doc prefix """ + if key is None and with_prefix: + return f"{self.prefix}:{key}" + if key is None: + return key # pyright: ignore[reportReturnType] return f"{self.prefix}:{key}" if with_prefix else key def _remove_str_prefix(self, key: str) -> str: @@ -543,7 +617,7 @@ def _remove_prefix(self, key: Optional[str]) -> str: return key[len(p):] return key # pyright: ignore[reportReturnType] - def extend(self, states: Any): + def extend(self, states: StatesType): """ Merge in these new states with the current self.states Creating a temporary StateTape handles the different ways diff --git a/summoner/protocol/triggers.py b/summoner/protocol/triggers.py index 56c688c..1449d17 100644 --- a/summoner/protocol/triggers.py +++ b/summoner/protocol/triggers.py @@ -29,13 +29,15 @@ import re import sys -from typing import Any, Optional +from typing import Optional +from typing import Any from pathlib import Path import keyword _VARNAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +type TreeType = dict[str, Optional["TreeType"]] def is_valid_varname(name: str) -> bool: """Return True if name matches Python variable naming rules.""" @@ -83,25 +85,27 @@ def update_hierarchy(indent: int, indent_levels: list[int]) -> int: return depth -def simplify_leaves(tree: dict[str, Any]) -> None: +def simplify_leaves(tree: TreeType) -> None: """ Recursively convert empty dicts to None to mark leaves. (4) No return, operates in-place on 'tree'. """ for key, subtree in list(tree.items()): + if subtree is None: + continue if subtree == {}: tree[key] = None elif isinstance(subtree, dict): simplify_leaves(subtree) -def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any]: +def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> TreeType: """ Parse a list of lines (strings) into a nested dict tree. This is one entry point, taking raw lines (great for testing). """ - root: dict[str, Any] = {} - nodes_by_depth: dict[int, dict[str, Any]] = {0: root} + root: TreeType = {} + nodes_by_depth: dict[int, TreeType] = {0: root} indent_levels: list[int] = [0] for lineno, raw in enumerate(lines, 1): @@ -126,13 +130,13 @@ def parse_signal_tree_lines(lines: list[str], tabsize: int = 8) -> dict[str, Any for d in list(nodes_by_depth): if d > depth: del nodes_by_depth[d] - nodes_by_depth[depth + 1] = parent[name] + nodes_by_depth[depth + 1] = parent[name] # pyright: ignore[reportArgumentType] simplify_leaves(root) return root -def parse_signal_tree(filepath: Path | str, tabsize: int = 8) -> dict[str, Any]: +def parse_signal_tree(filepath: Path | str, tabsize: int = 8) -> TreeType: """ Read a file and parse it into a nested dict tree. This is the second entry point, for file-based input. @@ -196,20 +200,21 @@ def __repr__(self): return f"" -def build_triggers(tree: dict[str, Any]): +def build_triggers(tree: TreeType): """ TODO: doc Trigger """ name_to_path: dict[str, tuple[int, ...]] = {} path_to_name: dict[tuple[int, ...], str] = {} - def recurse(subtree: dict[str, Any], prefix: tuple[int, ...] =()): + def recurse(subtree: TreeType, prefix: tuple[int, ...] =()): for idx, (name, child) in enumerate(subtree.items()): path = prefix + (idx,) name_to_path[name] = path path_to_name[path] = name - if isinstance(child, dict) and child: - recurse(child, path) + if child is not None: + if isinstance(child, dict) and child: + recurse(child, path) recurse(tree) diff --git a/summoner/server/server.py b/summoner/server/server.py index fb9dca5..e5ddc45 100644 --- a/summoner/server/server.py +++ b/summoner/server/server.py @@ -7,7 +7,8 @@ import os import sys import json -from typing import Optional, Any +from typing import Optional +from typing import Any import platform import importlib diff --git a/summoner/utils/code_handlers.py b/summoner/utils/code_handlers.py index b02c4ed..d2ccf0b 100644 --- a/summoner/utils/code_handlers.py +++ b/summoner/utils/code_handlers.py @@ -2,7 +2,8 @@ Using inspection for Python code """ -from typing import Optional, Set, Any +from typing import Optional, Set +from typing import Any import inspect import ast import textwrap diff --git a/summoner/utils/json_handlers.py b/summoner/utils/json_handlers.py index d2acc03..57478ce 100644 --- a/summoner/utils/json_handlers.py +++ b/summoner/utils/json_handlers.py @@ -4,7 +4,8 @@ import json from pathlib import Path -from typing import Any, Optional +from typing import Optional +from typing import Any def fully_recover_json(data): """ diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 35346ee..e05c0e3 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -37,6 +37,29 @@ def test_parse_signal_tree_simple_hierarchy(): assert tree["OK"] == {"acceptable": None, "all_good": None} +def test_parse_signal_tree_complex_hierarchy(): + """ + parse_signal_tree on a more complicated example + """ + lines = [ + "OK\n", + " acceptable\n", + " all_good\n", + " great\n", + " perfect\n", + " just_good\n", + "BAD\n", + " fixable\n", + " unfixable\n", + " catastrophic\n" + ] + tree = parse_signal_tree_lines(lines, tabsize=4) + # Root key exists and children simplify to None leaves + assert "OK" in tree + assert tree["OK"] == {"acceptable": None, "all_good": {"great": {"perfect": None}, "just_good": None}} + assert "BAD" in tree + assert tree["BAD"] == {"fixable": None, "unfixable": {"catastrophic": None}} + def test_parse_signal_tree_invalid_varname(): """ parse_signal_tree but when the input has an From e0b4776fe03f264bc6602b49e72e4132939a1c01 Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 5 Mar 2026 11:38:56 -0500 Subject: [PATCH 10/11] simple methods on backpressure policies, make play with cargo fmt consistently --- .../rust/rust_server_v1_1_0/src/config/mod.rs | 35 +++++++++++++++--- summoner/rust/rust_server_v1_1_0/src/lib.rs | 16 +++++++-- .../rust/rust_server_v1_1_0/src/logger/mod.rs | 23 +++++++----- .../src/server/backpressure.rs | 22 +++++------- .../rust/rust_server_v1_1_0/src/server/mod.rs | 36 +++++++++++++------ .../src/server/quarantine.rs | 2 +- 6 files changed, 95 insertions(+), 39 deletions(-) diff --git a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs index 5edc954..15ebdc1 100644 --- a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs @@ -31,6 +31,23 @@ pub struct BackpressurePolicy { pub disconnect_threshold: usize, } +impl BackpressurePolicy { + #[must_use = "The decision of whether to throttle"] + pub fn do_throttle(&self, queue_size: usize) -> bool { + self.enable_throttle && queue_size > self.throttle_threshold + } + + #[must_use = "The decision of whether to pause reading for chatty clients"] + pub fn do_flow_control(&self, queue_size: usize) -> bool { + self.enable_flow_control && queue_size > self.flow_control_threshold + } + + #[must_use = "The decision of whether to do a forced disconnect"] + pub fn do_disconnect(&self, queue_size: usize) -> bool { + self.enable_disconnect && queue_size > self.disconnect_threshold + } +} + ///////////////////////// // LoggerConfig // ///////////////////////// @@ -58,20 +75,19 @@ pub struct LoggerConfig { pub date_format: String, // Optional list of keys to include when logging dictionary messages in JSON pub log_keys: Option>, - // Maximum number of bytes per log file before rotation occurs // pub max_file_size: usize, // Number of rotated log files to keep before older ones are deleted // pub backup_count: usize, } - ////////////////////// // ServerConfig // ////////////////////// /// All the settings our server needs, driven by the Python-side dict #[derive(Debug, Clone)] +#[rustfmt::skip] pub struct ServerConfig { /// IP or hostname to listen on (e.g. `"127.0.0.1"`) pub host: String, @@ -81,7 +97,7 @@ pub struct ServerConfig { /// Logger configuration (text/JSON, file rotation, key filtering, etc.) pub logger: LoggerConfig, - + /// How many incoming connections we'll buffer before backpressure pub connection_buffer_size: usize, @@ -125,6 +141,14 @@ pub struct ServerConfig { pub worker_threads: usize, } +impl ServerConfig { + #[must_use = "The decision of whether to throttle"] + #[inline] + pub fn addr(&self) -> String { + format!("{}:{}", self.host, self.port) + } +} + ///////////////////////////////////////////// // Converting from a Python dict into Rust // ///////////////////////////////////////////// @@ -133,8 +157,10 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { // We return a PyErr if something goes wrong parsing type Error = PyErr; + #[allow(clippy::too_many_lines)] + #[rustfmt::skip] fn try_from(dict: &Bound<'py, PyDict>) -> Result { - + // Helper: look up `key` in the dict, or fall back to `default`. // If the Python value has the wrong type, we warn but still use `default`. fn extract_or<'py, T: FromPyObject<'py>>( @@ -240,6 +266,7 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { // let max_file_size = extract_or(&logger_dict, "max_file_size", 1_000_000); // let backup_count = extract_or(&logger_dict, "backup_count", 3); + #[allow(clippy::inconsistent_struct_constructor)] let logger = LoggerConfig { log_level, enable_console_log, diff --git a/summoner/rust/rust_server_v1_1_0/src/lib.rs b/summoner/rust/rust_server_v1_1_0/src/lib.rs index a4e174f..6a43244 100644 --- a/summoner/rust/rust_server_v1_1_0/src/lib.rs +++ b/summoner/rust/rust_server_v1_1_0/src/lib.rs @@ -1,8 +1,13 @@ +#![allow(clippy::uninlined_format_args, clippy::manual_string_new)] // Import everything we need from PyO3 so we can expose Rust functions to Python. // - `prelude::*` gives us the common traits and types. // - `PyModule` and `PyDict` let us work with Python modules and dicts. // - `Bound` helps us hold a reference into Python data safely. -use pyo3::{prelude::*, types::{PyModule, PyDict}, Bound}; +use pyo3::{ + Bound, + prelude::*, + types::{PyDict, PyModule}, +}; // Public module for parsing and validating server configuration provided by Python. pub mod config; @@ -14,9 +19,9 @@ pub mod logger; mod server; // Pull in the specific items we need so we don't have to write full paths later. +use config::ServerConfig; use logger::init_logger; use server::run_server; -use config::ServerConfig; /// Expose this Rust function to Python as `start_tokio_server(name, config)`. /// Responsibilities: @@ -33,6 +38,11 @@ use config::ServerConfig; /// Returns: /// - `Ok(())` if everything starts fine. /// - A Python `RuntimeError` if something goes wrong. +/// +/// # Errors +/// +/// A Python `RuntimeError` if something goes wrong +#[allow(clippy::needless_pass_by_value, clippy::used_underscore_binding)] #[pyfunction] pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) -> PyResult<()> { // Convert the Python dict into our `ServerConfig` Rust struct. @@ -41,6 +51,7 @@ pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) // Build a multi-threaded Tokio runtime based on the `worker_threads` value. // This runtime drives all our async I/O and timers. + #[rustfmt::skip] let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(server_config.worker_threads) // how many threads to use .thread_name("rust-server-worker") // helpful for debugging @@ -72,6 +83,7 @@ pub fn start_tokio_server(_py: Python<'_>, name: String, config: Bound) }); // If the server returned an error, convert it into a Python RuntimeError. + #[rustfmt::skip] server_result.map_err(|e| { PyErr::new::( format!("Server execution failed: {}", e) diff --git a/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs b/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs index 9667b91..9e74a41 100644 --- a/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/logger/mod.rs @@ -21,8 +21,10 @@ use crate::config::LoggerConfig; /// Given the *parsed* content (a JSON value) and an optional /// whitelist of keys, prune its `_payload` and `_type` sub-objects -/// if `keys` is Some, then return the resulting JsonValue. +/// if `keys` is Some, then return the resulting `JsonValue`. /// This consumes `content`, so no cloning is needed by the caller. +#[allow(clippy::must_use_candidate)] +#[rustfmt::skip] pub fn prune_content_value( mut content: JsonValue, keys: &Option>, @@ -32,7 +34,10 @@ pub fn prune_content_value( Some(m) if keys.is_some() => m, _ => return content, }; - let keys = keys.as_ref().unwrap(); + #[allow(clippy::missing_panics_doc)] + let keys = keys + .as_ref() + .expect("If keys was None, then it would have gone through early return"); // 2) Build a fresh map with version + filtered sub-objects let mut out = serde_json::Map::new(); @@ -46,14 +51,13 @@ pub fn prune_content_value( .filter(|(k, _)| keys.contains(k)) .map(|(k, v)| (k.clone(), v.clone())) .collect(); - out.insert(field.to_string(), JsonValue::Object(filtered)); + out.insert((*field).to_string(), JsonValue::Object(filtered)); } } JsonValue::Object(out) } - /// A simple Logger struct that wraps logging functions. /// Clonable to allow use across multiple threads/tasks. #[derive(Clone)] @@ -87,8 +91,8 @@ static LOGGER: OnceLock = OnceLock::new(); /// Initialize the global logger exactly once, according to the provided settings. /// After this call, all calls to `log::debug!(), info!(), warn!(), error!()` (and your /// `Logger` methods) will go through the configured fern dispatcher. +#[rustfmt::skip] pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { - LOGGER.get_or_init(|| { // ──────────────────────────────────────────────────────────────── // 1) Parse the configured level string into a log::LevelFilter @@ -128,7 +132,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { nm, record.level(), message - )) + )); }; base = base.chain( @@ -157,6 +161,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { let enable_json = cfg.enable_json_log; // Compute the logfile path + #[rustfmt::skip] let filepath = if cfg.log_file_path.is_empty() { format!("{}.log", nm.replace('.', "_")) } else { @@ -172,6 +177,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { if enable_json { // 1) Raw payload text let raw = message.to_string(); + #[rustfmt::skip] let message_json: JsonValue = serde_json::from_str(&raw).unwrap_or(JsonValue::String(raw.clone())); // 2) Build a real JSON envelope with "message" as an object @@ -183,7 +189,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { }); // 4) Emit the envelope unescaped - out.finish(format_args!("{}", envelope)) + out.finish(format_args!("{}", envelope)); } else { // Plain-text path out.finish(format_args!( @@ -192,7 +198,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { nm, record.level(), message - )) + )); } }; @@ -216,6 +222,7 @@ pub fn init_logger(name: &str, cfg: &LoggerConfig) -> Logger { // 5) Apply the composed dispatcher as the global logger // Any subsequent log:: calls will use this configuration. // ──────────────────────────────────────────────────────────────── + #[allow(clippy::missing_panics_doc)] base.apply().unwrap(); // Return our zero-sized Logger handle diff --git a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs index e2a8edb..06c8c13 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use std::net::SocketAddr; use tokio::sync::mpsc; -use crate::logger::Logger; use crate::config::BackpressurePolicy; +use crate::logger::Logger; /// Commands that the backpressure monitor can issue to the main server loop. #[derive(Debug)] @@ -19,8 +19,9 @@ pub enum ClientCommand { FlowControl, } - /// Spawn a task to monitor backpressure from clients +#[allow(clippy::needless_pass_by_value)] +#[rustfmt::skip] pub fn spawn_backpressure_monitor( mut rx: mpsc::Receiver<(SocketAddr, usize)>, command_tx: mpsc::Sender, @@ -28,42 +29,37 @@ pub fn spawn_backpressure_monitor( policy: BackpressurePolicy, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut client_queues: HashMap = HashMap::new(); while let Some((addr, queue_size)) = rx.recv().await { - client_queues.insert(addr, queue_size); - + let _old_queue_size = client_queues.insert(addr, queue_size); + // Log if queue size is getting large - if policy.enable_throttle && queue_size >= policy.throttle_threshold { + if policy.do_throttle(queue_size) { logger.warn(&format!("⚠️ Throttling client {}: {} messages queued", addr, queue_size)); if let Err(e) = command_tx.send(BackpressureCommand::Throttle(addr)).await { logger.error(&format!("Failed to send throttle command: {}", e)); } - } // Log if queue size is getting large - if policy.enable_flow_control && queue_size >= policy.flow_control_threshold { + if policy.do_flow_control(queue_size) { logger.warn(&format!("⏸️ Applying flow control to client {}: {} messages queued", addr, queue_size)); if let Err(e) = command_tx.send(BackpressureCommand::FlowControl(addr)).await { logger.error(&format!("Failed to send flow control command: {}", e)); } - } // Log if queue size is getting large - if policy.enable_disconnect && queue_size >= policy.disconnect_threshold { + if policy.do_disconnect(queue_size) { logger.warn(&format!("🚨 Disconnecting client {} due to extreme backpressure: {} messages queued", addr, queue_size)); - + if let Err(e) = command_tx.send(BackpressureCommand::Disconnect(addr)).await { logger.error(&format!("Failed to send disconnect command: {}", e)); } - } - } }) } \ No newline at end of file diff --git a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs index daecff5..c348e90 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs @@ -1,5 +1,5 @@ +#![allow(clippy::empty_line_after_doc_comments)] /// === IMPORTS === - // Standard library type for holding an IP address and port together. use std::net::SocketAddr; @@ -25,7 +25,7 @@ use tokio::sync::{Mutex, RwLock, broadcast, mpsc}; use tokio::time::{self, Duration, Instant}; // Macro for building JSON payloads when broadcasting client messages. -use serde_json::{json, Value as JsonValue}; +use serde_json::{Value as JsonValue, json}; // Clone-on-write string type: avoids extra allocations when we don't modify the string. use std::borrow::Cow; @@ -33,15 +33,15 @@ use std::borrow::Cow; // Reference-counted byte buffer: enables zero-copy sharing of message data. use bytes::Bytes; - /// === MODULES === - // Private modules handling specific server features. +#[rustfmt::skip] mod backpressure; // queue monitoring and control commands +#[rustfmt::skip] mod ratelimiter; // per-client rate limiting +#[rustfmt::skip] mod quarantine; // temporary client bans - // Import the parsed ServerConfig struct that holds all user settings. use crate::config::ServerConfig; @@ -57,7 +57,6 @@ use crate::server::ratelimiter::RateLimiter; // Quarantine list for tracking and expiring banned clients over time. use crate::server::quarantine::QuarantineList; - /// === TYPES === // Represents one connected client. Cloning this struct is cheap (Arc + channel). @@ -79,7 +78,7 @@ pub struct Client { // - RwLock lets many readers (e.g. broadcasts) but only one writer (add/remove) at a time. pub type ClientList = Arc>>; - +#[allow(clippy::doc_markdown)] /// === RUN_SERVER === /// This function launches the main server loop: @@ -92,7 +91,7 @@ pub async fn run_server( logger: Logger, ) -> Result<(), Box> { // Build the "host:port" string so the OS knows where to listen - let addr = format!("{}:{}", config.host, config.port); + let addr = config.addr(); // Open a TCP listener on that address; the `?` returns early on error let listener = TcpListener::bind(&addr).await?; @@ -114,11 +113,13 @@ pub async fn run_server( mpsc::channel::<(SocketAddr, usize)>(config.connection_buffer_size); // Channel used by the backpressure monitor to send commands (Throttle, Disconnect, etc.) + #[rustfmt::skip] let (command_tx, command_rx) = mpsc::channel::(config.command_buffer_size); // Track which clients are quarantined (banned temporarily) // Wrapped in Arc+Mutex for safe, exclusive access when marking or cleaning entries + #[rustfmt::skip] let quarantine_list = Arc::new(Mutex::new( QuarantineList::new(Duration::from_secs(config.quarantine_cooldown_secs)) )); @@ -190,6 +191,8 @@ fn spawn_shutdown_listener( /// === CONNECTIONS === /// Listens for new clients, spawns per-client tasks, and handles shutdown/backpressure +#[allow(clippy::too_many_arguments)] +#[rustfmt::skip] async fn accept_connections( listener: TcpListener, // TCP socket we bound in run_server clients: ClientList, // Shared list of connected clients @@ -274,10 +277,11 @@ async fn handle_backpressure_command( BackpressureCommand::Throttle(addr) | BackpressureCommand::FlowControl(addr) => { // Decide which control command to send + #[rustfmt::skip] let control_cmd = match cmd { BackpressureCommand::Throttle(_) => ClientCommand::Throttle, BackpressureCommand::FlowControl(_) => ClientCommand::FlowControl, - _ => unreachable!(), + BackpressureCommand::Disconnect(_) => unreachable!(), }; // Under a short read lock, grab a clone of the sender if the client still exists @@ -297,10 +301,12 @@ async fn handle_backpressure_command( control_cmd, addr, e )); } + #[rustfmt::skip] let emoji = match control_cmd { ClientCommand::Throttle => "⏳", ClientCommand::FlowControl => "⏸️", }; + #[rustfmt::skip] logger.info(&format!("{} {:?} requested for {}", emoji, control_cmd, addr)); } else { logger.warn(&format!( @@ -313,6 +319,8 @@ async fn handle_backpressure_command( } /// Validates a new TCP connection and then spawns its session task +#[allow(clippy::too_many_arguments)] +#[rustfmt::skip] async fn handle_new_connection( stream: TcpStream, // The raw TCP connection addr: SocketAddr, // Remote client address (IP:port) @@ -404,7 +412,6 @@ async fn handle_new_connection( let remaining = connection_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - 1; logger.info(&format!("🔌 {} disconnected. Active connections: {}", addr, remaining)); }); - } /// Manages a single client's session: @@ -412,6 +419,7 @@ async fn handle_new_connection( /// - Reports backpressure without blocking /// - Enforces inactivity timeouts and graceful shutdown /// - On exit, removes the client from the shared list and logs the remaining count +#[allow(clippy::too_many_arguments)] async fn handle_connection( reader_half: tokio::net::tcp::OwnedReadHalf, client: Client, @@ -446,6 +454,7 @@ async fn handle_connection( .await { // Log any unexpected error during the session + #[rustfmt::skip] logger.warn(&format!("⚠️ Error in client {} session: {}", client.addr, e)); } @@ -465,7 +474,6 @@ async fn handle_connection( Ok(()) } - /// === MESSAGES === /// Manages a single client session until disconnect or shutdown: @@ -473,6 +481,7 @@ async fn handle_connection( /// - Responds to throttle and flow-control commands /// - Enforces inactivity timeouts /// - Sends a shutdown notice on server exit +#[allow(clippy::too_many_arguments)] async fn handle_client_messages( // Line-based reader for this client's incoming data reader: &mut Lines>, @@ -502,6 +511,7 @@ async fn handle_client_messages( let mut last_active = Instant::now(); // Timer that fires every `timeout_check_interval_secs` to enforce inactivity + #[rustfmt::skip] let mut timeout_interval = time::interval(Duration::from_secs( config.timeout_check_interval_secs, )); @@ -566,6 +576,7 @@ async fn handle_client_messages( // 4) Check for inactivity timeout _ = timeout_interval.tick() => { + #[allow(clippy::collapsible_if)] if let Some(timeout) = timeout { if last_active.elapsed() > timeout { // Send a warning to the client then break @@ -664,6 +675,8 @@ async fn process_client_line( let envelope_text = envelope.to_string(); // 3) Parse the client's content *once* + #[allow(clippy::needless_borrow)] + #[rustfmt::skip] let json_content: JsonValue = serde_json::from_str(&content).unwrap_or(JsonValue::String(content.to_string())); // 4) Move it into pruning or keep it as-is @@ -705,6 +718,7 @@ fn remove_last_newline(s: &str) -> &str { } /// Sends `msg` to every client except the sender, reporting queue size for backpressure: +#[rustfmt::skip] async fn broadcast_message( clients: &ClientList, // Shared list of all connected clients sender: &Client, // The client who sent the original message diff --git a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs index 54ffbd1..30daddc 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::time::{Duration, Instant}; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::Mutex; use crate::logger::Logger; From f81cd5db3e6cfd782a70b8a45bda543fb1951c5e Mon Sep 17 00:00:00 2001 From: Cobord Date: Thu, 5 Mar 2026 11:47:41 -0500 Subject: [PATCH 11/11] cargo fmt --- summoner/rust/rust_server_v1_1_0/src/config/mod.rs | 4 ++-- summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs | 4 ++-- summoner/rust/rust_server_v1_1_0/src/server/mod.rs | 6 +++--- summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs index 15ebdc1..b1ec93e 100644 --- a/summoner/rust/rust_server_v1_1_0/src/config/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/config/mod.rs @@ -100,7 +100,7 @@ pub struct ServerConfig { /// How many incoming connections we'll buffer before backpressure pub connection_buffer_size: usize, - + /// Seconds before we drop an idle client pub client_timeout: Option, @@ -186,7 +186,7 @@ impl<'py> TryFrom<&Bound<'py, PyDict>> for ServerConfig { let host = extract_or(dict, "host", "127.0.0.1".to_string()); let port = extract_or(dict, "port", 8888); let connection_buffer_size = extract_or(dict, "connection_buffer_size", 128); - + let rate_limit = extract_or(dict, "rate_limit_msgs_per_minute", 300); let command_buffer_size = extract_or(dict, "command_buffer_size", 32); let quarantine_cooldown_secs = extract_or(dict, "quarantine_cooldown_secs", 300); diff --git a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs index 06c8c13..b8a5cf5 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/backpressure.rs @@ -46,7 +46,7 @@ pub fn spawn_backpressure_monitor( // Log if queue size is getting large if policy.do_flow_control(queue_size) { logger.warn(&format!("⏸️ Applying flow control to client {}: {} messages queued", addr, queue_size)); - + if let Err(e) = command_tx.send(BackpressureCommand::FlowControl(addr)).await { logger.error(&format!("Failed to send flow control command: {}", e)); } @@ -62,4 +62,4 @@ pub fn spawn_backpressure_monitor( } } }) -} \ No newline at end of file +} diff --git a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs index c348e90..fd3ba37 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/mod.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/mod.rs @@ -36,11 +36,11 @@ use bytes::Bytes; /// === MODULES === // Private modules handling specific server features. #[rustfmt::skip] -mod backpressure; // queue monitoring and control commands +mod backpressure; // queue monitoring and control commands #[rustfmt::skip] -mod ratelimiter; // per-client rate limiting +mod ratelimiter; // per-client rate limiting #[rustfmt::skip] -mod quarantine; // temporary client bans +mod quarantine; // temporary client bans // Import the parsed ServerConfig struct that holds all user settings. use crate::config::ServerConfig; diff --git a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs index 30daddc..e398937 100644 --- a/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs +++ b/summoner/rust/rust_server_v1_1_0/src/server/quarantine.rs @@ -62,4 +62,4 @@ impl QuarantineList { } }) } -} \ No newline at end of file +}