|
42 | 42 | # WebDriver BiDi module: {{}} |
43 | 43 | from __future__ import annotations |
44 | 44 |
|
45 | | -from typing import Any |
46 | 45 | """ |
47 | 46 |
|
48 | 47 |
|
@@ -198,8 +197,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: |
198 | 197 | if param_name in self.required_params: |
199 | 198 | body += f" if {snake_param} is None:\n" |
200 | 199 | msg = f"{method_snake}() missing required argument:" |
| 200 | + error_message = f"{msg} {snake_param!r}" |
201 | 201 | body += ( |
202 | | - f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' |
| 202 | + f" raise TypeError({error_message!r})\n" |
203 | 203 | ) |
204 | 204 | body += "\n" |
205 | 205 |
|
@@ -591,23 +591,32 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: |
591 | 591 | # Collect needed imports to avoid duplicates |
592 | 592 | needs_command_builder = bool(self.commands) |
593 | 593 | needs_dataclass = self.commands or self.types or self.events |
594 | | - needs_threading = self.events |
595 | 594 | needs_callable = self.events |
596 | | - needs_session = self.events |
| 595 | + |
| 596 | + stdlib_imports = [] |
| 597 | + local_imports = [] |
597 | 598 |
|
598 | 599 | # Add imports (field import will be added conditionally after code generation) |
599 | | - if needs_command_builder: |
600 | | - code += "from .common import command_builder\n" |
601 | | - if needs_dataclass: |
602 | | - code += "from dataclasses import dataclass\n" |
603 | | - if needs_threading: |
604 | | - code += "import threading\n" |
605 | 600 | if needs_callable: |
606 | | - code += "from collections.abc import Callable\n" |
607 | | - if needs_session: |
608 | | - code += "from selenium.webdriver.common.bidi.session import Session\n" |
| 601 | + stdlib_imports.append("from collections.abc import Callable") |
| 602 | + if needs_dataclass: |
| 603 | + stdlib_imports.append("from dataclasses import dataclass") |
| 604 | + stdlib_imports.append("from typing import Any") |
| 605 | + |
| 606 | + if needs_command_builder: |
| 607 | + local_imports.append( |
| 608 | + "from selenium.webdriver.common.bidi.common import command_builder" |
| 609 | + ) |
| 610 | + if self.events: |
| 611 | + local_imports.append( |
| 612 | + "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" |
| 613 | + ) |
| 614 | + |
| 615 | + code += "\n".join(stdlib_imports) + "\n" |
| 616 | + if local_imports: |
| 617 | + code += "\n" + "\n".join(local_imports) + "\n" |
609 | 618 |
|
610 | | - code += "\n\n" |
| 619 | + code += "\n" |
611 | 620 |
|
612 | 621 | # Add helper function definitions from enhancements |
613 | 622 | # Collect all referenced helper functions (validate, transform) |
@@ -784,165 +793,11 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: |
784 | 793 | """ |
785 | 794 | code += "\n\n" |
786 | 795 |
|
787 | | - # Generate EventConfig and _EventManager for modules with events |
788 | | - if self.events: |
789 | | - # Generate EventConfig dataclass |
790 | | - code += """@dataclass |
791 | | -class EventConfig: |
792 | | - \"\"\"Configuration for a BiDi event.\"\"\" |
793 | | - event_key: str |
794 | | - bidi_event: str |
795 | | - event_class: type |
796 | | -
|
797 | | -
|
798 | | -""" |
799 | | - |
800 | | - # Generate _EventManager class |
801 | | - code += """class _EventWrapper: |
802 | | - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" |
803 | | - def __init__(self, bidi_event: str, event_class: type): |
804 | | - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class |
805 | | - self._python_class = event_class # Keep reference to Python dataclass for deserialization |
806 | | -
|
807 | | - def from_json(self, params: dict) -> Any: |
808 | | - \"\"\"Deserialize event params into the wrapped Python dataclass. |
809 | | -
|
810 | | - Args: |
811 | | - params: Raw BiDi event params with camelCase keys. |
812 | | -
|
813 | | - Returns: |
814 | | - An instance of the dataclass, or the raw dict on failure. |
815 | | - \"\"\" |
816 | | - if self._python_class is None or self._python_class is dict: |
817 | | - return params |
818 | | - try: |
819 | | - # Delegate to a classmethod from_json if the class defines one |
820 | | - if hasattr(self._python_class, \"from_json\") and callable( |
821 | | - self._python_class.from_json |
822 | | - ): |
823 | | - return self._python_class.from_json(params) |
824 | | - import dataclasses as dc |
825 | | -
|
826 | | - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} |
827 | | - if dc.is_dataclass(self._python_class): |
828 | | - valid_fields = {f.name for f in dc.fields(self._python_class)} |
829 | | - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} |
830 | | - return self._python_class(**filtered) |
831 | | - return self._python_class(**snake_params) |
832 | | - except Exception: |
833 | | - return params |
834 | | -
|
835 | | - @staticmethod |
836 | | - def _camel_to_snake(name: str) -> str: |
837 | | - result = [name[0].lower()] |
838 | | - for char in name[1:]: |
839 | | - if char.isupper(): |
840 | | - result.extend([\"_\", char.lower()]) |
841 | | - else: |
842 | | - result.append(char) |
843 | | - return \"\".join(result) |
844 | | -
|
845 | | -
|
846 | | -class _EventManager: |
847 | | - \"\"\"Manages event subscriptions and callbacks.\"\"\" |
848 | | -
|
849 | | - def __init__(self, conn, event_configs: dict[str, EventConfig]): |
850 | | - self.conn = conn |
851 | | - self.event_configs = event_configs |
852 | | - self.subscriptions: dict = {} |
853 | | - self._event_wrappers = {} # Cache of _EventWrapper objects |
854 | | - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} |
855 | | - self._available_events = ", ".join(sorted(event_configs.keys())) |
856 | | - self._subscription_lock = threading.Lock() |
857 | | -
|
858 | | - # Create event wrappers for each event |
859 | | - for config in event_configs.values(): |
860 | | - wrapper = _EventWrapper(config.bidi_event, config.event_class) |
861 | | - self._event_wrappers[config.bidi_event] = wrapper |
862 | | -
|
863 | | - def validate_event(self, event: str) -> EventConfig: |
864 | | - event_config = self.event_configs.get(event) |
865 | | - if not event_config: |
866 | | - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") |
867 | | - return event_config |
868 | | -
|
869 | | - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: |
870 | | - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" |
871 | | - with self._subscription_lock: |
872 | | - if bidi_event not in self.subscriptions: |
873 | | - session = Session(self.conn) |
874 | | - result = session.subscribe([bidi_event], contexts=contexts) |
875 | | - sub_id = ( |
876 | | - result.get(\"subscription\") if isinstance(result, dict) else None |
877 | | - ) |
878 | | - self.subscriptions[bidi_event] = { |
879 | | - \"callbacks\": [], |
880 | | - \"subscription_id\": sub_id, |
881 | | - } |
882 | | -
|
883 | | - def unsubscribe_from_event(self, bidi_event: str) -> None: |
884 | | - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" |
885 | | - with self._subscription_lock: |
886 | | - entry = self.subscriptions.get(bidi_event) |
887 | | - if entry is not None and not entry[\"callbacks\"]: |
888 | | - session = Session(self.conn) |
889 | | - sub_id = entry.get(\"subscription_id\") |
890 | | - if sub_id: |
891 | | - session.unsubscribe(subscriptions=[sub_id]) |
892 | | - else: |
893 | | - session.unsubscribe(events=[bidi_event]) |
894 | | - del self.subscriptions[bidi_event] |
895 | | -
|
896 | | - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: |
897 | | - with self._subscription_lock: |
898 | | - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) |
899 | | -
|
900 | | - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: |
901 | | - with self._subscription_lock: |
902 | | - entry = self.subscriptions.get(bidi_event) |
903 | | - if entry and callback_id in entry[\"callbacks\"]: |
904 | | - entry[\"callbacks\"].remove(callback_id) |
905 | | -
|
906 | | - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: |
907 | | - event_config = self.validate_event(event) |
908 | | - # Use the event wrapper for add_callback |
909 | | - event_wrapper = self._event_wrappers.get(event_config.bidi_event) |
910 | | - callback_id = self.conn.add_callback(event_wrapper, callback) |
911 | | - self.subscribe_to_event(event_config.bidi_event, contexts) |
912 | | - self.add_callback_to_tracking(event_config.bidi_event, callback_id) |
913 | | - return callback_id |
914 | | -
|
915 | | - def remove_event_handler(self, event: str, callback_id: int) -> None: |
916 | | - event_config = self.validate_event(event) |
917 | | - event_wrapper = self._event_wrappers.get(event_config.bidi_event) |
918 | | - self.conn.remove_callback(event_wrapper, callback_id) |
919 | | - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) |
920 | | - self.unsubscribe_from_event(event_config.bidi_event) |
921 | | -
|
922 | | - def clear_event_handlers(self) -> None: |
923 | | - \"\"\"Clear all event handlers.\"\"\" |
924 | | - with self._subscription_lock: |
925 | | - if not self.subscriptions: |
926 | | - return |
927 | | - session = Session(self.conn) |
928 | | - for bidi_event, entry in list(self.subscriptions.items()): |
929 | | - event_wrapper = self._event_wrappers.get(bidi_event) |
930 | | - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry |
931 | | - if event_wrapper: |
932 | | - for callback_id in callbacks: |
933 | | - self.conn.remove_callback(event_wrapper, callback_id) |
934 | | - sub_id = ( |
935 | | - entry.get(\"subscription_id\") if isinstance(entry, dict) else None |
936 | | - ) |
937 | | - if sub_id: |
938 | | - session.unsubscribe(subscriptions=[sub_id]) |
939 | | - else: |
940 | | - session.unsubscribe(events=[bidi_event]) |
941 | | - self.subscriptions.clear() |
942 | | -
|
943 | | -
|
944 | | -""" |
945 | | - code += "\n\n" |
| 796 | + # EventConfig, _EventWrapper, and _EventManager are imported from |
| 797 | + # selenium.webdriver.common.bidi._event_manager (see import section above) |
| 798 | + # rather than being duplicated inline in every generated module. |
| 799 | + if False: # placeholder to preserve indentation structure |
| 800 | + pass |
946 | 801 |
|
947 | 802 | # Generate class |
948 | 803 | # Convert module name (camelCase or snake_case) to proper class name (PascalCase) |
@@ -1103,15 +958,15 @@ def clear_event_handlers(self) -> None: |
1103 | 958 | if re.search(dataclass_import_pattern, code): |
1104 | 959 | code = re.sub( |
1105 | 960 | dataclass_import_pattern, |
1106 | | - "from dataclasses import dataclass\nfrom dataclasses import field\n", |
| 961 | + "from dataclasses import dataclass, field\n", |
1107 | 962 | code, |
1108 | | - count=1 |
| 963 | + count=1, |
1109 | 964 | ) |
1110 | 965 | elif "from dataclasses import" not in code: |
1111 | 966 | # If there's no dataclasses import yet, add field import after typing |
1112 | 967 | code = code.replace( |
1113 | 968 | "from typing import Any\n", |
1114 | | - "from typing import Any\nfrom dataclasses import field\n" |
| 969 | + "from dataclasses import field\nfrom typing import Any\n", |
1115 | 970 | ) |
1116 | 971 |
|
1117 | 972 | return code |
@@ -1615,7 +1470,9 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non |
1615 | 1470 | for module_name in sorted(modules.keys()): |
1616 | 1471 | class_name = module_name_to_class_name(module_name) |
1617 | 1472 | filename = module_name_to_filename(module_name) |
1618 | | - code += f"from .{filename} import {class_name}\n" |
| 1473 | + code += ( |
| 1474 | + f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" |
| 1475 | + ) |
1619 | 1476 |
|
1620 | 1477 | code += "\n__all__ = [\n" |
1621 | 1478 | for module_name in sorted(modules.keys()): |
@@ -1660,20 +1517,23 @@ def generate_common_file(output_path: Path) -> None: |
1660 | 1517 | "\n" |
1661 | 1518 | "\n" |
1662 | 1519 | "def command_builder(\n" |
1663 | | - " method: str, params: dict[str, Any]\n" |
| 1520 | + " method: str, params: dict[str, Any] | None = None\n" |
1664 | 1521 | ") -> Generator[dict[str, Any], Any, Any]:\n" |
1665 | 1522 | ' """Build a BiDi command generator.\n' |
1666 | 1523 | "\n" |
1667 | 1524 | " Args:\n" |
1668 | 1525 | ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' |
1669 | | - " params: The parameters for the command\n" |
| 1526 | + " params: The parameters for the command. If omitted, an empty\n" |
| 1527 | + " dictionary is sent.\n" |
1670 | 1528 | "\n" |
1671 | 1529 | " Yields:\n" |
1672 | 1530 | " A dictionary representing the BiDi command\n" |
1673 | 1531 | "\n" |
1674 | 1532 | " Returns:\n" |
1675 | 1533 | " The result from the BiDi command execution\n" |
1676 | 1534 | ' """\n' |
| 1535 | + " if params is None:\n" |
| 1536 | + " params = {}\n" |
1677 | 1537 | ' result = yield {"method": method, "params": params}\n' |
1678 | 1538 | " return result\n" |
1679 | 1539 | ) |
@@ -1750,7 +1610,7 @@ def generate_permissions_file(output_path: Path) -> None: |
1750 | 1610 | "from enum import Enum\n" |
1751 | 1611 | "from typing import Any\n" |
1752 | 1612 | "\n" |
1753 | | - "from .common import command_builder\n" |
| 1613 | + "from selenium.webdriver.common.bidi.common import command_builder\n" |
1754 | 1614 | "\n" |
1755 | 1615 | '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' |
1756 | 1616 | "\n" |
|
0 commit comments