Skip to content

Commit 5f07cf5

Browse files
handle comments
1 parent eb2a479 commit 5f07cf5

20 files changed

Lines changed: 367 additions & 1092 deletions

py/generate_bidi.py

Lines changed: 40 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
# WebDriver BiDi module: {{}}
4343
from __future__ import annotations
4444
45-
from typing import Any
4645
"""
4746

4847

@@ -198,8 +197,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str:
198197
if param_name in self.required_params:
199198
body += f" if {snake_param} is None:\n"
200199
msg = f"{method_snake}() missing required argument:"
200+
error_message = f"{msg} {snake_param!r}"
201201
body += (
202-
f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n'
202+
f" raise TypeError({error_message!r})\n"
203203
)
204204
body += "\n"
205205

@@ -591,23 +591,32 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str:
591591
# Collect needed imports to avoid duplicates
592592
needs_command_builder = bool(self.commands)
593593
needs_dataclass = self.commands or self.types or self.events
594-
needs_threading = self.events
595594
needs_callable = self.events
596-
needs_session = self.events
595+
596+
stdlib_imports = []
597+
local_imports = []
597598

598599
# Add imports (field import will be added conditionally after code generation)
599-
if needs_command_builder:
600-
code += "from .common import command_builder\n"
601-
if needs_dataclass:
602-
code += "from dataclasses import dataclass\n"
603-
if needs_threading:
604-
code += "import threading\n"
605600
if needs_callable:
606-
code += "from collections.abc import Callable\n"
607-
if needs_session:
608-
code += "from selenium.webdriver.common.bidi.session import Session\n"
601+
stdlib_imports.append("from collections.abc import Callable")
602+
if needs_dataclass:
603+
stdlib_imports.append("from dataclasses import dataclass")
604+
stdlib_imports.append("from typing import Any")
605+
606+
if needs_command_builder:
607+
local_imports.append(
608+
"from selenium.webdriver.common.bidi.common import command_builder"
609+
)
610+
if self.events:
611+
local_imports.append(
612+
"from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager"
613+
)
614+
615+
code += "\n".join(stdlib_imports) + "\n"
616+
if local_imports:
617+
code += "\n" + "\n".join(local_imports) + "\n"
609618

610-
code += "\n\n"
619+
code += "\n"
611620

612621
# Add helper function definitions from enhancements
613622
# Collect all referenced helper functions (validate, transform)
@@ -784,165 +793,11 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str:
784793
"""
785794
code += "\n\n"
786795

787-
# Generate EventConfig and _EventManager for modules with events
788-
if self.events:
789-
# Generate EventConfig dataclass
790-
code += """@dataclass
791-
class EventConfig:
792-
\"\"\"Configuration for a BiDi event.\"\"\"
793-
event_key: str
794-
bidi_event: str
795-
event_class: type
796-
797-
798-
"""
799-
800-
# Generate _EventManager class
801-
code += """class _EventWrapper:
802-
\"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\"
803-
def __init__(self, bidi_event: str, event_class: type):
804-
self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class
805-
self._python_class = event_class # Keep reference to Python dataclass for deserialization
806-
807-
def from_json(self, params: dict) -> Any:
808-
\"\"\"Deserialize event params into the wrapped Python dataclass.
809-
810-
Args:
811-
params: Raw BiDi event params with camelCase keys.
812-
813-
Returns:
814-
An instance of the dataclass, or the raw dict on failure.
815-
\"\"\"
816-
if self._python_class is None or self._python_class is dict:
817-
return params
818-
try:
819-
# Delegate to a classmethod from_json if the class defines one
820-
if hasattr(self._python_class, \"from_json\") and callable(
821-
self._python_class.from_json
822-
):
823-
return self._python_class.from_json(params)
824-
import dataclasses as dc
825-
826-
snake_params = {self._camel_to_snake(k): v for k, v in params.items()}
827-
if dc.is_dataclass(self._python_class):
828-
valid_fields = {f.name for f in dc.fields(self._python_class)}
829-
filtered = {k: v for k, v in snake_params.items() if k in valid_fields}
830-
return self._python_class(**filtered)
831-
return self._python_class(**snake_params)
832-
except Exception:
833-
return params
834-
835-
@staticmethod
836-
def _camel_to_snake(name: str) -> str:
837-
result = [name[0].lower()]
838-
for char in name[1:]:
839-
if char.isupper():
840-
result.extend([\"_\", char.lower()])
841-
else:
842-
result.append(char)
843-
return \"\".join(result)
844-
845-
846-
class _EventManager:
847-
\"\"\"Manages event subscriptions and callbacks.\"\"\"
848-
849-
def __init__(self, conn, event_configs: dict[str, EventConfig]):
850-
self.conn = conn
851-
self.event_configs = event_configs
852-
self.subscriptions: dict = {}
853-
self._event_wrappers = {} # Cache of _EventWrapper objects
854-
self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()}
855-
self._available_events = ", ".join(sorted(event_configs.keys()))
856-
self._subscription_lock = threading.Lock()
857-
858-
# Create event wrappers for each event
859-
for config in event_configs.values():
860-
wrapper = _EventWrapper(config.bidi_event, config.event_class)
861-
self._event_wrappers[config.bidi_event] = wrapper
862-
863-
def validate_event(self, event: str) -> EventConfig:
864-
event_config = self.event_configs.get(event)
865-
if not event_config:
866-
raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}")
867-
return event_config
868-
869-
def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None:
870-
\"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\"
871-
with self._subscription_lock:
872-
if bidi_event not in self.subscriptions:
873-
session = Session(self.conn)
874-
result = session.subscribe([bidi_event], contexts=contexts)
875-
sub_id = (
876-
result.get(\"subscription\") if isinstance(result, dict) else None
877-
)
878-
self.subscriptions[bidi_event] = {
879-
\"callbacks\": [],
880-
\"subscription_id\": sub_id,
881-
}
882-
883-
def unsubscribe_from_event(self, bidi_event: str) -> None:
884-
\"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\"
885-
with self._subscription_lock:
886-
entry = self.subscriptions.get(bidi_event)
887-
if entry is not None and not entry[\"callbacks\"]:
888-
session = Session(self.conn)
889-
sub_id = entry.get(\"subscription_id\")
890-
if sub_id:
891-
session.unsubscribe(subscriptions=[sub_id])
892-
else:
893-
session.unsubscribe(events=[bidi_event])
894-
del self.subscriptions[bidi_event]
895-
896-
def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None:
897-
with self._subscription_lock:
898-
self.subscriptions[bidi_event][\"callbacks\"].append(callback_id)
899-
900-
def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None:
901-
with self._subscription_lock:
902-
entry = self.subscriptions.get(bidi_event)
903-
if entry and callback_id in entry[\"callbacks\"]:
904-
entry[\"callbacks\"].remove(callback_id)
905-
906-
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
907-
event_config = self.validate_event(event)
908-
# Use the event wrapper for add_callback
909-
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
910-
callback_id = self.conn.add_callback(event_wrapper, callback)
911-
self.subscribe_to_event(event_config.bidi_event, contexts)
912-
self.add_callback_to_tracking(event_config.bidi_event, callback_id)
913-
return callback_id
914-
915-
def remove_event_handler(self, event: str, callback_id: int) -> None:
916-
event_config = self.validate_event(event)
917-
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
918-
self.conn.remove_callback(event_wrapper, callback_id)
919-
self.remove_callback_from_tracking(event_config.bidi_event, callback_id)
920-
self.unsubscribe_from_event(event_config.bidi_event)
921-
922-
def clear_event_handlers(self) -> None:
923-
\"\"\"Clear all event handlers.\"\"\"
924-
with self._subscription_lock:
925-
if not self.subscriptions:
926-
return
927-
session = Session(self.conn)
928-
for bidi_event, entry in list(self.subscriptions.items()):
929-
event_wrapper = self._event_wrappers.get(bidi_event)
930-
callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry
931-
if event_wrapper:
932-
for callback_id in callbacks:
933-
self.conn.remove_callback(event_wrapper, callback_id)
934-
sub_id = (
935-
entry.get(\"subscription_id\") if isinstance(entry, dict) else None
936-
)
937-
if sub_id:
938-
session.unsubscribe(subscriptions=[sub_id])
939-
else:
940-
session.unsubscribe(events=[bidi_event])
941-
self.subscriptions.clear()
942-
943-
944-
"""
945-
code += "\n\n"
796+
# EventConfig, _EventWrapper, and _EventManager are imported from
797+
# selenium.webdriver.common.bidi._event_manager (see import section above)
798+
# rather than being duplicated inline in every generated module.
799+
if False: # placeholder to preserve indentation structure
800+
pass
946801

947802
# Generate class
948803
# Convert module name (camelCase or snake_case) to proper class name (PascalCase)
@@ -1103,15 +958,15 @@ def clear_event_handlers(self) -> None:
1103958
if re.search(dataclass_import_pattern, code):
1104959
code = re.sub(
1105960
dataclass_import_pattern,
1106-
"from dataclasses import dataclass\nfrom dataclasses import field\n",
961+
"from dataclasses import dataclass, field\n",
1107962
code,
1108-
count=1
963+
count=1,
1109964
)
1110965
elif "from dataclasses import" not in code:
1111966
# If there's no dataclasses import yet, add field import after typing
1112967
code = code.replace(
1113968
"from typing import Any\n",
1114-
"from typing import Any\nfrom dataclasses import field\n"
969+
"from dataclasses import field\nfrom typing import Any\n",
1115970
)
1116971

1117972
return code
@@ -1615,7 +1470,9 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non
16151470
for module_name in sorted(modules.keys()):
16161471
class_name = module_name_to_class_name(module_name)
16171472
filename = module_name_to_filename(module_name)
1618-
code += f"from .{filename} import {class_name}\n"
1473+
code += (
1474+
f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n"
1475+
)
16191476

16201477
code += "\n__all__ = [\n"
16211478
for module_name in sorted(modules.keys()):
@@ -1660,20 +1517,23 @@ def generate_common_file(output_path: Path) -> None:
16601517
"\n"
16611518
"\n"
16621519
"def command_builder(\n"
1663-
" method: str, params: dict[str, Any]\n"
1520+
" method: str, params: dict[str, Any] | None = None\n"
16641521
") -> Generator[dict[str, Any], Any, Any]:\n"
16651522
' """Build a BiDi command generator.\n'
16661523
"\n"
16671524
" Args:\n"
16681525
' method: The BiDi method name (e.g., "session.status", "browser.close")\n'
1669-
" params: The parameters for the command\n"
1526+
" params: The parameters for the command. If omitted, an empty\n"
1527+
" dictionary is sent.\n"
16701528
"\n"
16711529
" Yields:\n"
16721530
" A dictionary representing the BiDi command\n"
16731531
"\n"
16741532
" Returns:\n"
16751533
" The result from the BiDi command execution\n"
16761534
' """\n'
1535+
" if params is None:\n"
1536+
" params = {}\n"
16771537
' result = yield {"method": method, "params": params}\n'
16781538
" return result\n"
16791539
)
@@ -1750,7 +1610,7 @@ def generate_permissions_file(output_path: Path) -> None:
17501610
"from enum import Enum\n"
17511611
"from typing import Any\n"
17521612
"\n"
1753-
"from .common import command_builder\n"
1613+
"from selenium.webdriver.common.bidi.common import command_builder\n"
17541614
"\n"
17551615
'_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n'
17561616
"\n"

0 commit comments

Comments
 (0)