From 5ee10fee0050572b052a21d5a626830770b62950 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Feb 2026 10:51:21 +0000 Subject: [PATCH 01/42] [py] BiDi Python code generation from CDDL --- common/bidi/spec/all.cddl | 71 +- common/bidi/spec/local.cddl | 29 +- common/bidi/spec/remote.cddl | 65 +- py/BUILD.bazel | 1 + py/generate_bidi.py | 604 +++++++++--------- py/private/BUILD.bazel | 1 - py/private/bidi_enhancements_manifest.py | 428 ++----------- py/private/cdp.py | 42 +- py/private/generate_bidi.bzl | 2 + py/requirements_lock.txt | 5 +- py/selenium/common/exceptions.py | 12 +- py/selenium/webdriver/common/bidi/__init__.py | 42 +- py/selenium/webdriver/common/bidi/browser.py | 173 ++--- .../webdriver/common/bidi/browsing_context.py | 464 +++++++------- py/selenium/webdriver/common/bidi/cdp.py | 8 +- py/selenium/webdriver/common/bidi/common.py | 14 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 305 +++++---- py/selenium/webdriver/common/bidi/input.py | 233 +++++-- py/selenium/webdriver/common/bidi/log.py | 94 +-- py/selenium/webdriver/common/bidi/network.py | 450 +++++++------ .../webdriver/common/bidi/permissions.py | 15 +- py/selenium/webdriver/common/bidi/py.typed | 0 py/selenium/webdriver/common/bidi/script.py | 366 ++++++----- py/selenium/webdriver/common/bidi/session.py | 70 +- py/selenium/webdriver/common/bidi/storage.py | 127 ++-- .../webdriver/common/bidi/webextension.py | 80 +-- py/selenium/webdriver/common/proxy.py | 34 +- py/selenium/webdriver/remote/webdriver.py | 153 +++-- 29 files changed, 1838 insertions(+), 2050 deletions(-) mode change 100644 => 100755 py/selenium/webdriver/common/bidi/console.py mode change 100644 => 100755 py/selenium/webdriver/common/bidi/py.typed diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl index e10b42723b0f5..85c4536a2cd10 100644 --- a/common/bidi/spec/all.cddl +++ b/common/bidi/spec/all.cddl @@ -420,7 +420,6 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // - browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -436,7 +435,6 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / - browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -520,7 +518,6 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, - ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -608,8 +605,7 @@ browsingContext.CreateParameters = { } browsingContext.CreateResult = { - context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext + context: browsingContext.BrowsingContext } browsingContext.GetTree = ( @@ -719,19 +715,6 @@ browsingContext.ReloadParameters = { browsingContext.ReloadResult = browsingContext.NavigateResult -browsingContext.SetBypassCSP = ( - method: "browsingContext.setBypassCSP", - params: browsingContext.SetBypassCSPParameters -) - -browsingContext.SetBypassCSPParameters = { - bypass: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - -browsingContext.SetBypassCSPResult = EmptyResult - browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -791,8 +774,7 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text, - ? userContext: browser.UserContext + url: text } browsingContext.DomContentLoaded = ( @@ -862,7 +844,6 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, - ? userContext: browser.UserContext, ? userText: text } @@ -876,7 +857,6 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, - ? userContext: browser.UserContext, ? defaultValue: text } @@ -891,7 +871,8 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride ) @@ -904,7 +885,8 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult ) emulation.SetForcedColorsModeThemeOverride = ( @@ -967,10 +949,10 @@ emulation.SetLocaleOverrideResult = EmptyResult emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.SetNetworkConditionsParameters + params: emulation.setNetworkConditionsParameters ) -emulation.SetNetworkConditionsParameters = { +emulation.setNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -1036,6 +1018,19 @@ emulation.SetUserAgentOverrideParameters = { emulation.SetUserAgentOverrideResult = EmptyResult +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverrideResult = EmptyResult + emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -1150,7 +1145,6 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, - ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -1385,10 +1379,10 @@ network.ContinueWithAuthResult = EmptyResult network.DisownData = ( method: "network.disownData", - params: network.DisownDataParameters + params: network.disownDataParameters ) -network.DisownDataParameters = { +network.disownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1716,7 +1710,6 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext, ? sandbox: text } @@ -1976,8 +1969,7 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext + ? context: browsingContext.BrowsingContext } script.RealmTarget = { @@ -2389,15 +2381,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint, - ? height: js-uint, - ? pressure: (0.0..1.0), - ? tangentialPressure: (-1.0..1.0), - ? twist: (0..359), + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966), + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586), + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -2435,7 +2427,6 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl index 1bb2ce612e2c2..d43af0ae11b03 100644 --- a/common/bidi/spec/local.cddl +++ b/common/bidi/spec/local.cddl @@ -251,7 +251,6 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / - browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -335,7 +334,6 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, - ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -353,8 +351,7 @@ browsingContext.CaptureScreenshotResult = { browsingContext.CloseResult = EmptyResult browsingContext.CreateResult = { - context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext + context: browsingContext.BrowsingContext } browsingContext.GetTreeResult = { @@ -378,8 +375,6 @@ browsingContext.PrintResult = { browsingContext.ReloadResult = browsingContext.NavigateResult -browsingContext.SetBypassCSPResult = EmptyResult - browsingContext.SetViewportResult = EmptyResult browsingContext.TraverseHistoryResult = EmptyResult @@ -412,8 +407,7 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text, - ? userContext: browser.UserContext + url: text } browsingContext.DomContentLoaded = ( @@ -483,7 +477,6 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, - ? userContext: browser.UserContext, ? userText: text } @@ -497,7 +490,6 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, - ? userContext: browser.UserContext, ? defaultValue: text } @@ -510,7 +502,8 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult ) emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult @@ -527,6 +520,8 @@ emulation.SetScreenOrientationOverrideResult = EmptyResult emulation.SetUserAgentOverrideResult = EmptyResult +emulation.SetViewportMetaOverrideResult = EmptyResult + emulation.SetScriptingEnabledResult = EmptyResult emulation.SetScrollbarTypeOverrideResult = EmptyResult @@ -573,7 +568,6 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, - ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -932,7 +926,6 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext, ? sandbox: text } @@ -1192,8 +1185,7 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext + ? context: browsingContext.BrowsingContext } script.AddPreloadScriptResult = { @@ -1303,12 +1295,6 @@ log.EntryAdded = ( params: log.Entry, ) -InputResult = ( - input.PerformActionsResult / - input.ReleaseActionsResult / - input.SetFilesResult -) - InputEvent = ( input.FileDialogOpened @@ -1327,7 +1313,6 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, - ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl index 7490df1b44bc7..a98859a021e12 100644 --- a/common/bidi/spec/remote.cddl +++ b/common/bidi/spec/remote.cddl @@ -273,7 +273,6 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // - browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -481,17 +480,6 @@ browsingContext.ReloadParameters = { ? wait: browsingContext.ReadinessState, } -browsingContext.SetBypassCSP = ( - method: "browsingContext.setBypassCSP", - params: browsingContext.SetBypassCSPParameters -) - -browsingContext.SetBypassCSPParameters = { - bypass: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -530,7 +518,8 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride ) @@ -588,10 +577,10 @@ emulation.SetLocaleOverrideParameters = { emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.SetNetworkConditionsParameters + params: emulation.setNetworkConditionsParameters ) -emulation.SetNetworkConditionsParameters = { +emulation.setNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -649,6 +638,17 @@ emulation.SetUserAgentOverrideParameters = { ? userContexts: [+browser.UserContext], } +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -876,10 +876,10 @@ network.ContinueWithAuthNoCredentials = ( network.DisownData = ( method: "network.disownData", - params: network.DisownDataParameters + params: network.disownDataParameters ) -network.DisownDataParameters = { +network.disownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1500,6 +1500,12 @@ InputCommand = ( input.SetFiles ) +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + input.ElementOrigin = { type: "element", element: script.SharedReference @@ -1619,15 +1625,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint, - ? height: js-uint, - ? pressure: (0.0..1.0), - ? tangentialPressure: (-1.0..1.0), - ? twist: (0..359), + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966), + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586), + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -1652,6 +1658,17 @@ input.SetFilesParameters = { files: [*text] } +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + WebExtensionCommand = ( webExtension.Install // webExtension.Uninstall diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 186324560aade..292cde4981d74 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -810,6 +810,7 @@ BROWSER_TESTS = { ] ] + test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 5b301d3ec7e40..1770cf436bef1 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,21 +1,4 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +#!/usr/bin/env python3 """ Generate Python WebDriver BiDi command modules from CDDL specification. @@ -35,11 +18,12 @@ import logging import re import sys +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import indent as tw_indent -from typing import Any +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple __version__ = "1.0.0" @@ -59,6 +43,8 @@ # WebDriver BiDi module: {{}} from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder """ @@ -67,7 +53,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -85,7 +71,9 @@ def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: return {} try: - spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file) + spec = importlib.util.spec_from_file_location( + "bidi_enhancements", manifest_file + ) if spec is None or spec.loader is None: logger.warning(f"Could not load manifest: {manifest_path}") return {} @@ -136,10 +124,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"list[{inner_type}]" + return f"List[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "dict[str, Any]" + return "Dict[str, Any]" # Default to Any for unknown types return "Any" @@ -151,12 +139,11 @@ class CddlCommand: module: str name: str - params: dict[str, str] = field(default_factory=dict) - required_params: set[str] = field(default_factory=set) - result: str | None = None + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None description: str = "" - def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python method code for this command. Args: @@ -183,38 +170,15 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: param_strs.append(f"{snake_param}: {python_type} | None = None") if param_strs: - # Check if full signature would exceed line length limit (120 chars) - single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" - if len(single_line_signature) > 120: - # Format parameters on multiple lines - body = f" def {method_name}(\n" - body += " self,\n" - for i, param_str in enumerate(param_strs): - if i < len(param_strs) - 1: - body += f" {param_str},\n" - else: - body += f" {param_str},\n" - body += " ):\n" - else: - param_list = "self, " + ", ".join(param_strs) - body = f" def {method_name}({param_list}):\n" + param_list = "self, " + ", ".join(param_strs) else: - body = f" def {method_name}(self):\n" - body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' + param_list = "self" - # Add automatic validation for required parameters - # (This is used unless there's no required_params, in which case all params are optional) - if self.required_params: - method_snake = self._camel_to_snake(self.name) - for param_name, snake_param in param_names: - if param_name in self.required_params: - body += f" if {snake_param} is None:\n" - msg = f"{method_snake}() missing required argument:" - error_message = f"{msg} {snake_param!r}" - body += f" raise TypeError({error_message!r})\n" - body += "\n" + # Build method body + body = f" def {method_name}({param_list}):\n" + body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' - # Add validation if specified in enhancements (for additional business logic validation) + # Add validation if specified if "validate" in enhancements: validate_func = enhancements["validate"] # Build parameter list for validation function @@ -231,7 +195,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: transform_func = transform_spec.get("func") result_param = transform_spec.get("result_param", "params") input_params = [ - transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k) + transform_spec.get(k) + for k in ["allowed", "destination_folder"] + if transform_spec.get(k) ] if transform_func and result_param: @@ -254,7 +220,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: snake_param = self._camel_to_snake(param_name) if preprocess_type == "check_serialize_method": body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" - body += f" {snake_param} = {snake_param}.to_bidi_dict()\n" + body += ( + f" {snake_param} = {snake_param}.to_bidi_dict()\n" + ) body += "\n" # Build params dict @@ -269,6 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -295,45 +264,45 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f' item.get("{extract_property}")\n' - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += " return extracted\n" - body += " return result\n" + body += f" return extracted\n" + body += f" return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -382,10 +351,10 @@ class CddlTypeDefinition: module: str name: str - fields: dict[str, str] = field(default_factory=dict) + fields: Dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python dataclass code for this type. Args: @@ -397,7 +366,7 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = "@dataclass\n" + code = f"@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -416,16 +385,9 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str if literal_match: literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' - # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) - elif python_type.startswith("list["): - # Remove the trailing ' | None' from list types since default_factory=list ensures non-None - type_annotation = python_type.replace(" | None", "") - code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" - # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) - elif python_type.startswith("dict["): - # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None - type_annotation = python_type.replace(" | None", "") - code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" + # Check if this field is a list type + elif "List[" in python_type: + code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -491,7 +453,7 @@ class CddlEnum: module: str name: str - values: list[str] = field(default_factory=list) + values: List[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -545,7 +507,11 @@ def to_python_dataclass(self) -> str: # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") # The params_type comes from the CDDL and includes module prefix - type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type + type_name = ( + self.params_type.split(".")[-1] + if "." in self.params_type + else self.params_type + ) # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly # (NavigationInfo will be created as an alias to it) @@ -564,10 +530,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: list[CddlCommand] = field(default_factory=list) - types: list[CddlTypeDefinition] = field(default_factory=list) - enums: list[CddlEnum] = field(default_factory=list) - events: list[CddlEvent] = field(default_factory=list) + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -582,7 +548,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python code for this module. Args: @@ -591,33 +557,21 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: enhancements = enhancements or {} code = MODULE_HEADER.format(self.name) - # Collect needed imports to avoid duplicates - needs_command_builder = bool(self.commands) - needs_dataclass = self.commands or self.types or self.events - needs_callable = self.events + # Add imports if needed + if self.types: + code += "from dataclasses import field\n" + if self.commands or self.types: + code += "from typing import Generator\n" + code += "from dataclasses import dataclass\n" - stdlib_imports = [] - local_imports = [] - - # Add imports (field import will be added conditionally after code generation) - if needs_callable: - stdlib_imports.append("from collections.abc import Callable") - if needs_dataclass: - stdlib_imports.append("from dataclasses import dataclass") - stdlib_imports.append("from typing import Any") - - if needs_command_builder: - local_imports.append("from selenium.webdriver.common.bidi.common import command_builder") + # Add imports for event handling if needed if self.events: - local_imports.append( - "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" - ) + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" - code += "\n".join(stdlib_imports) + "\n" - if local_imports: - code += "\n" + "\n".join(local_imports) + "\n" - - code += "\n" + code += "\n\n" # Add helper function definitions from enhancements # Collect all referenced helper functions (validate, transform) @@ -627,7 +581,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: method_enhancements = enhancements.get(method_name_snake, {}) if "validate" in method_enhancements: helper_funcs_to_add.add(("validate", method_enhancements["validate"])) - if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict): + if "transform" in method_enhancements and isinstance( + method_enhancements["transform"], dict + ): transform_spec = method_enhancements["transform"] if "func" in transform_spec: helper_funcs_to_add.add(("transform", transform_spec["func"])) @@ -635,7 +591,10 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate helper functions if needed if helper_funcs_to_add: for func_type, func_name in sorted(helper_funcs_to_add): - if func_type == "validate" and func_name == "validate_download_behavior": + if ( + func_type == "validate" + and func_name == "validate_download_behavior" + ): code += """def validate_download_behavior( allowed: bool | None, destination_folder: str | None, @@ -658,7 +617,10 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - elif func_type == "transform" and func_name == "transform_download_params": + elif ( + func_type == "transform" + and func_name == "transform_download_params" + ): code += """def transform_download_params( allowed: bool | None, destination_folder: str | None, @@ -688,20 +650,8 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first (excluding those in exclude_types) - exclude_types = set(enhancements.get("exclude_types", [])) - - # Also exclude any types that have extra_dataclasses overrides - # Extract class names from extra_dataclasses strings - for extra_cls in enhancements.get("extra_dataclasses", []): - # Match "class ClassName:" pattern - match = re.search(r"class\s+(\w+)\s*:", extra_cls) - if match: - exclude_types.add(match.group(1)) - + # Generate enums first for enum_def in self.enums: - if enum_def.name in exclude_types: - continue code += enum_def.to_python_class() code += "\n\n" @@ -710,6 +660,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -721,11 +672,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += extra_cls code += "\n\n" - # Emit extra type aliases from enhancement manifest (e.g., union types for events) - for extra_alias in enhancements.get("extra_type_aliases", []): - code += extra_alias - code += "\n\n" - # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet # They will be generated after the class definition instead @@ -743,7 +689,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): - code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' + code += ( + f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' + ) code += "}\n\n" # Add custom method function definitions before the class (for browsingContext) @@ -784,11 +732,165 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # EventConfig, _EventWrapper, and _EventManager are imported from - # selenium.webdriver.common.bidi._event_manager (see import section above) - # rather than being duplicated inline in every generated module. - if False: # placeholder to preserve indentation structure - pass + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" # Generate class # Convert module name (camelCase or snake_case) to proper class name (PascalCase) @@ -798,7 +900,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add EVENT_CONFIGS dict if there are events if self.events: - code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -820,16 +924,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) - - # Automatically exclude methods that are defined in extra_methods - # to prevent generating duplicates - if "extra_methods" in enhancements: - for extra_method in enhancements["extra_methods"]: - # Extract method name from "def method_name(" - match = re.search(r"def\s+(\w+)\s*\(", extra_method) - if match: - exclude_methods = list(exclude_methods) + [match.group(1)] - if self.commands: for command in self.commands: # Get method-specific enhancements @@ -840,7 +934,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: method_enhancements = enhancements.get(method_name_snake, {}) code += command.to_python_method(method_enhancements) code += "\n" - elif not self.events and not enhancements.get("extra_methods", []): + else: code += " pass\n" # Emit extra methods from enhancement manifest @@ -882,53 +976,23 @@ def clear_event_handlers(self) -> None: # This ensures all types are available when we create the aliases if self.events: code += "\n# Event Info Type Aliases\n" - # Check for explicit event_type_aliases in the enhancement manifest - event_type_aliases = enhancements.get("event_type_aliases", {}) for event_def in self.events: - # Convert method name to user-friendly event name - method_parts = event_def.method.split(".") - if len(method_parts) == 2: - event_name = self._convert_method_to_event_name(method_parts[1]) - # Check if there's an explicit alias defined in the enhancement manifest - if event_name in event_type_aliases: - # Use the alias directly - type_name = event_type_aliases[event_name] - code += f"# Event: {event_def.method}\n" - code += f"{event_def.name} = {type_name}\n" - else: - # Fall back to the original behavior - code += event_def.to_python_dataclass() + code += event_def.to_python_dataclass() code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += "_globals = globals()\n" + code += f"_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - # Try to get event class from globals, default to dict if not found - getter = f'_globals.get("{event_def.name}", dict)' - condition = f'_globals.get("{event_def.name}")' - event_class = f"{getter} if {condition} else dict" - - # Build the entry line and check if it exceeds 120 chars - single_line = ( - f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),' - ) - - if len(single_line) > 120: - # Break into multiple lines - code += f' "{event_name}": EventConfig(\n' - code += f' "{event_name}",\n' - code += f' "{event_def.method}",\n' - code += f" {event_class},\n" - code += " ),\n" - else: - code += single_line + "\n" + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] @@ -937,26 +1001,6 @@ def clear_event_handlers(self) -> None: code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' code += "}\n" - # Check if field() is actually used in the generated code - # If so, add the field import after the dataclass import - if "field(" in code: - # Find where to insert the field import - # It should go after "from dataclasses import dataclass" line - dataclass_import_pattern = r"from dataclasses import dataclass\n" - if re.search(dataclass_import_pattern, code): - code = re.sub( - dataclass_import_pattern, - "from dataclasses import dataclass, field\n", - code, - count=1, - ) - elif "from dataclasses import" not in code: - # If there's no dataclasses import yet, add field import after typing - code = code.replace( - "from typing import Any\n", - "from dataclasses import field\nfrom typing import Any\n", - ) - return code @@ -967,9 +1011,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: dict[str, CddlModule] = {} - self.definitions: dict[str, str] = {} - self.event_names: set[str] = set() # Names of definitions that are events + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -977,12 +1021,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, encoding="utf-8") as f: + with open(self.cddl_path, "r", encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> dict[str, CddlModule]: + def parse(self) -> Dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1046,6 +1090,9 @@ def _extract_event_names(self) -> None: ... ) """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1093,7 +1140,9 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].enums.append(enum_def) - logger.debug(f"Found enum: {def_name} with {len(values)} values") + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) else: # Extract fields from type definition fields = self._extract_type_fields(def_content) @@ -1106,7 +1155,9 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].types.append(type_def) - logger.debug(f"Found type: {def_name} with {len(fields)} fields") + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) def _is_enum_definition(self, definition: str) -> bool: """Check if a definition is an enum (string union with /). @@ -1124,7 +1175,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> list[str]: + def _extract_enum_values(self, enum_definition: str) -> List[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1174,7 +1225,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1219,7 +1270,9 @@ def _extract_events(self) -> None: Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) """ # Find definitions that are in the event_names set - event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) for def_name, def_content in self.definitions.items(): # Skip if not identified as an event @@ -1253,12 +1306,16 @@ def _extract_events(self) -> None: ) self.modules[module_name].events.append(event) - logger.debug(f"Found event: {def_name} (method={method}, params={params_type})") + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) def _extract_commands(self) -> None: """Extract command definitions from parsed definitions.""" # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) for def_name, def_content in self.definitions.items(): # Skip definitions that are events (they share the same pattern) @@ -1278,50 +1335,41 @@ def _extract_commands(self) -> None: if module_name not in self.modules: self.modules[module_name] = CddlModule(name=module_name) - # Extract parameters and required parameters - params, required_params = self._extract_parameters_and_required(params_type) + # Extract parameters + params = self._extract_parameters(params_type) # Create command cmd = CddlCommand( module=module_name, name=command_name, params=params, - required_params=required_params, description=f"Execute {method}", ) self.modules[module_name].commands.append(cmd) - logger.debug(f"Found command: {method} with params {params_type}") + logger.debug( + f"Found command: {method} with params {params_type}" + ) - def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]: + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), merging all fields from each alternative as optional parameters. """ - params, _ = self._extract_parameters_and_required(params_type, _seen) - return params - - def _extract_parameters_and_required( - self, params_type: str, _seen: set[str] | None = None - ) -> tuple[dict[str, str], set[str]]: - """Extract parameters and track which are required from a parameter type definition. - - Returns: - Tuple of (params dict, required_params set) - """ params = {} - required = set() if _seen is None: _seen = set() if params_type in _seen: - return params, required + return params _seen.add(params_type) if params_type not in self.definitions: logger.debug(f"Parameter type not found: {params_type}") - return params, required + return params definition = self.definitions[params_type] @@ -1335,13 +1383,10 @@ def _extract_parameters_and_required( alternatives = [a.strip() for a in stripped.split("/") if a.strip()] all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) if all_named: - # For union types, collect parameters from all alternatives - # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params, _ = self._extract_parameters_and_required(alt_type, _seen) + alt_params = self._extract_parameters(alt_type, _seen) params.update(alt_params) - # Note: We intentionally DON'T add to required, since these are union alternatives - return params, required + return params # Remove the outer curly braces and split by comma # Then parse each line for key: type patterns @@ -1358,9 +1403,6 @@ def _extract_parameters_and_required( continue # Match pattern: [?] name: type - # Check if parameter has optional marker (?) - is_optional = line.startswith("?") - # Using a simple pattern that handles optional prefix match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) if not match: @@ -1375,14 +1417,11 @@ def _extract_parameters_and_required( # Skip lines that are part of nested definitions if "{" not in normalized_type and "(" not in normalized_type: params[param_name] = normalized_type - if not is_optional: - required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} " - f"(required={not is_optional}) from {params_type}" + f"Extracted param {param_name}: {normalized_type} from {params_type}" ) - return params, required + return params def module_name_to_class_name(module_name: str) -> str: @@ -1427,7 +1466,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1440,9 +1479,9 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) filename = module_name_to_filename(module_name) - code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" + code += f"from .{filename} import {class_name}\n" - code += "\n__all__ = [\n" + code += f"\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1478,21 +1517,17 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from __future__ import annotations\n" - "\n" - "from collections.abc import Generator\n" - "from typing import Any\n" + "from typing import Any, Dict, Generator\n" "\n" "\n" "def command_builder(\n" - " method: str, params: dict[str, Any] | None = None\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command. If omitted, an empty\n" - " dictionary is sent.\n" + " params: The parameters for the command\n" "\n" " Yields:\n" " A dictionary representing the BiDi command\n" @@ -1500,8 +1535,6 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' - " if params is None:\n" - " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1576,9 +1609,9 @@ def generate_permissions_file(output_path: Path) -> None: "from __future__ import annotations\n" "\n" "from enum import Enum\n" - "from typing import Any\n" + "from typing import Any, Optional, Union\n" "\n" - "from selenium.webdriver.common.bidi.common import command_builder\n" + "from .common import command_builder\n" "\n" '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' "\n" @@ -1619,10 +1652,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: PermissionDescriptor | str,\n" - " state: PermissionState | str,\n" - " origin: str | None = None,\n" - " user_context: str | None = None,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" @@ -1670,7 +1703,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: str | None = None, + enhancements_manifest: Optional[str] = None, ) -> None: """Main entry point. @@ -1745,7 +1778,9 @@ def main( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification") + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) parser.add_argument( "cddl_file", help="Path to CDDL specification file", @@ -1755,8 +1790,7 @@ def main( help="Output directory for generated Python modules", ) parser.add_argument( - "spec_version", - nargs="?", + "--version", default="1.0", help="BiDi spec version (default: 1.0)", ) @@ -1781,7 +1815,7 @@ def main( main( args.cddl_file, args.output_dir, - args.spec_version, + args.version, args.enhancements_manifest, ) sys.exit(0) diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index d2ea587fd8101..88acc9d2aba11 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,7 +1,6 @@ load("@rules_python//python:defs.bzl", "py_binary") exports_files([ - "_event_manager.py", "bidi_enhancements_manifest.py", "cdp.py", ]) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 8cec1f9da245f..ae7229f6ddebd 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1,21 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - """ Enhancement manifest for BiDi code generation. @@ -99,42 +81,11 @@ "result_param": "download_behavior", }, }, - # Replace the auto-generated ClientWindowNamedState so we can add the - # convenience NORMAL constant. In the BiDi spec "normal" is the state - # represented by ClientWindowRectState, but exposing it here keeps the - # Python API consistent with the old ClientWindowState enum. - "exclude_types": ["ClientWindowNamedState", "SetClientWindowStateParameters"], - "extra_dataclasses": [ - '''class ClientWindowNamedState: - """Named states for a browser client window.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - NORMAL = "normal"''', - '''@dataclass -class SetClientWindowStateParameters: - """SetClientWindowStateParameters. - - The ``state`` field is required and must be either a named-state string - (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a - :class:`ClientWindowRectState` instance. ``client_window`` is the ID of - the window to affect. - """ - - client_window: Any | None = None - state: Any | None = None''', - ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). "extra_methods": [ - ''' def set_download_behavior( - self, - allowed: bool | None = None, - destination_folder: str | None = None, - user_contexts: list[Any] | None = None, - ): + ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): """Set the download behavior for the browser. Args: @@ -160,49 +111,11 @@ class SetClientWindowStateParameters: if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) - return self._conn.execute(cmd)''', - ''' def set_client_window_state( - self, - client_window: Any | None = None, - state: Any | None = None, - ): - """Set the client window state. - - Args: - client_window: The client window ID to apply the state to. - state: The window state to set. Can be one of: - - A string: "fullscreen", "maximized", "minimized", "normal" - - A ClientWindowRectState object with width, height, x, y - - A dict representing the state - - Raises: - ValueError: If client_window is not provided or state is invalid. - """ - if client_window is None: - raise ValueError("client_window is required") - if state is None: - raise ValueError("state is required") - - # Serialize ClientWindowRectState if needed - state_param = state - if hasattr(state, '__dataclass_fields__'): - # It's a dataclass, convert to dict - state_param = { - k: v for k, v in state.__dict__.items() - if v is not None - } - - params = { - "clientWindow": client_window, - "state": state_param, - } - cmd = command_builder("browser.setClientWindowState", params) return self._conn.execute(cmd)''', ], }, "browsingContext": { # Method enhancements - "exclude_methods": ["set_viewport"], "create": { "extract_field": "context", }, @@ -242,33 +155,6 @@ class SetClientWindowStateParameters: "devicePixelRatio": "float", }, }, - "extra_methods": [ - ''' def set_viewport( - self, - context: str | None = None, - viewport: Any = ..., - user_contexts: Any | None = None, - device_pixel_ratio: Any = ..., - ): - """Execute browsingContext.setViewport. - - Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, - while omitted arguments are not sent. - """ - params = {} - if context is not None: - params["context"] = context - if user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not ...: - params["viewport"] = viewport - if device_pixel_ratio is not ...: - params["devicePixelRatio"] = device_pixel_ratio - - cmd = command_builder("browsingContext.setViewport", params) - result = self._conn.execute(cmd) - return result''', - ], # Non-CDDL download event dataclasses (Chromium-specific) "extra_dataclasses": [ '''@dataclass @@ -295,10 +181,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -310,7 +196,19 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp)''', ], - # Download events are now in the CDDL spec, so no extra_events needed + # Non-CDDL download events (Chromium-specific, not in the BiDi spec) + "extra_events": [ + { + "event_key": "download_will_begin", + "bidi_event": "browsingContext.downloadWillBegin", + "event_class": "DownloadWillBeginParams", + }, + { + "event_key": "download_end", + "bidi_event": "browsingContext.downloadEnd", + "event_class": "DownloadEndParams", + }, + ], }, "log": { # Make LogLevel an alias for Level so existing code using LogLevel works @@ -332,7 +230,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -356,7 +254,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -367,36 +265,15 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: stacktrace=params.get("stackTrace"), )''', ], - # Define Entry union type for log.entryAdded event deserialization - "extra_type_aliases": [ - "Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry", - ], - "event_type_aliases": { - "entry_added": "Entry", - }, }, "emulation": { - "exclude_types": ["setNetworkConditionsParameters"], - "extra_dataclasses": [ - '''@dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters.""" - - network_conditions: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -# Backward-compatible alias for existing imports -setNetworkConditionsParameters = SetNetworkConditionsParameters''', - ], "extra_methods": [ ''' def set_geolocation_override( self, coordinates=None, error=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -410,7 +287,7 @@ class SetNetworkConditionsParameters: contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {} + params = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -448,8 +325,8 @@ class SetNetworkConditionsParameters: ''' def set_timezone_override( self, timezone=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -462,7 +339,7 @@ class SetNetworkConditionsParameters: contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"timezone": timezone} + params = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -472,8 +349,8 @@ class SetNetworkConditionsParameters: ''' def set_scripting_enabled( self, enabled=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -486,7 +363,7 @@ class SetNetworkConditionsParameters: contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"enabled": enabled} + params = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -496,8 +373,8 @@ class SetNetworkConditionsParameters: ''' def set_user_agent_override( self, user_agent=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -509,7 +386,7 @@ class SetNetworkConditionsParameters: contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"userAgent": user_agent} + params = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -519,8 +396,8 @@ class SetNetworkConditionsParameters: ''' def set_screen_orientation_override( self, screen_orientation=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -545,7 +422,7 @@ class SetNetworkConditionsParameters: "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params: dict[str, Any] = {"screenOrientation": so_value} + params = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -556,8 +433,8 @@ class SetNetworkConditionsParameters: self, network_conditions=None, offline: bool | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setNetworkConditions. @@ -578,44 +455,12 @@ class SetNetworkConditionsParameters: nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params: dict[str, Any] = {"networkConditions": nc_value} + params = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) - return self._conn.execute(cmd)''', - ''' def set_screen_settings_override( - self, - width: int | None = None, - height: int | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenSettingsOverride. - - Sets or clears the screen settings override for specified browsing or user - contexts. - - Args: - width: The screen width in pixels, or ``None`` to clear the override. - height: The screen height in pixels, or ``None`` to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - screen_area = None - if width is not None or height is not None: - screen_area = {} - if width is not None: - screen_area["width"] = width - if height is not None: - screen_area["height"] = height - params: dict[str, Any] = {"screenArea": screen_area} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setScreenSettingsOverride", params) return self._conn.execute(cmd)''', ], }, @@ -689,14 +534,7 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw''', - ''' def _add_preload_script( - self, - function_declaration, - arguments=None, - contexts=None, - user_contexts=None, - sandbox=None, - ): + ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): """Add a preload script with validation. Args: @@ -748,15 +586,7 @@ def _serialize_arg(value): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id)''', - ''' def _evaluate( - self, - expression, - target, - await_promise, - result_ownership=None, - serialization_options=None, - user_activation=None, - ): + ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): """Evaluate a script expression and return a structured result. Args: @@ -791,17 +621,7 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None)''', - ''' def _call_function( - self, - function_declaration, - await_promise, - target, - arguments=None, - result_ownership=None, - this=None, - user_activation=None, - serialization_options=None, - ): + ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): """Call a function and return a structured result. Args: @@ -995,25 +815,10 @@ def from_json(self2, p): ], }, "network": { - "exclude_types": ["disownDataParameters"], - # Initialize intercepts tracking list and per-handler intercept map - "extra_init_code": [ - "self.intercepts: list[Any] = []", - "self._handler_intercepts: dict[str, Any] = {}", - ], + # Initialize intercepts tracking list in __init__ + "extra_init_code": ["self.intercepts = []"], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ - '''@dataclass -class DisownDataParameters: - """DisownDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - request: Any | None = None - - -# Backward-compatible alias for existing imports -disownDataParameters = DisownDataParameters''', '''class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -1024,7 +829,7 @@ class DisownDataParameters: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: Any | None, value: Any | None) -> None: + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value @@ -1105,8 +910,7 @@ def continue_request(self, **kwargs): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) - intercept_id = intercept_result.get("intercept") if intercept_result else None + self._add_intercept(phases=[phase], url_patterns=url_patterns) def _request_callback(params): raw = ( @@ -1117,21 +921,15 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - callback_id = self.add_event_handler(event, _request_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id''', + return self.add_event_handler(event, _request_callback)''', ''' def remove_request_handler(self, event, callback_id): - """Remove a network request handler and its associated network intercept. + """Remove a network request handler. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ - self.remove_event_handler(event, callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id)''', + self.remove_event_handler(event, callback_id)''', ''' def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() @@ -1149,10 +947,6 @@ def _request_callback(params): """ from selenium.webdriver.common.bidi.common import command_builder as _cb - # Set up network intercept for authRequired phase - intercept_result = self._add_intercept(phases=["authRequired"]) - intercept_id = intercept_result.get("intercept") if intercept_result else None - def _auth_callback(params): raw = ( params @@ -1180,20 +974,10 @@ def _auth_callback(params): ) ) - callback_id = self.add_event_handler("auth_required", _auth_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id''', + return self.add_event_handler("auth_required", _auth_callback)''', ''' def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID and its associated network intercept. - - Args: - callback_id: The handler ID returned by add_auth_handler. - """ - self.remove_event_handler("auth_required", callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id)''', + """Remove an auth handler by callback ID.""" + self.remove_event_handler("auth_required", callback_id)''', ], }, "storage": { @@ -1219,16 +1003,12 @@ def _auth_callback(params): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: Any | None, value: Any | None) -> None: + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value def to_bidi_dict(self) -> dict: - return {"type": self.type, "value": self.value} - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return {"type": self.type, "value": self.value}''', '''class SameSite: """SameSite cookie attribute values.""" @@ -1252,11 +1032,11 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) + value = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( @@ -1306,11 +1086,7 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return result''', # Custom PartialCookie with camelCase serialization '''@dataclass class PartialCookie: @@ -1344,11 +1120,7 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return result''', # BrowsingContextPartitionDescriptor: first positional arg is *context* # (the auto-generated dataclass had `type` first, breaking positional # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) @@ -1365,11 +1137,7 @@ def __init__(self, context: Any = None, type: str = "context") -> None: self.type = type def to_bidi_dict(self) -> dict: - return {"type": "context", "context": self.context} - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return {"type": "context", "context": self.context}''', # StorageKeyPartitionDescriptor with camelCase serialization '''@dataclass class StorageKeyPartitionDescriptor: @@ -1386,11 +1154,7 @@ def to_bidi_dict(self) -> dict: result["userContext"] = self.user_context if self.source_origin is not None: result["sourceOrigin"] = self.source_origin - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return result''', ], # Override the generated Storage class methods (Python's last-definition- # wins semantics means these extra_methods shadow the generated ones). @@ -1438,17 +1202,6 @@ def to_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return SetCookieResult(partition_key=pk) return result''', ''' def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" @@ -1463,17 +1216,6 @@ def to_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return DeleteCookiesResult(partition_key=pk) return result''', ], }, @@ -1507,23 +1249,14 @@ def to_bidi_dict(self) -> dict: result["file"] = self.file if self.prompt is not None: result["prompt"] = self.prompt - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict()''', + return result''', ], }, "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], "extra_methods": [ - ''' def install( - self, - path: str | None = None, - archive_path: str | None = None, - base64_value: str | None = None, - ): + ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -1541,11 +1274,7 @@ def to_dict(self) -> dict: Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None - ] + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" @@ -1555,41 +1284,22 @@ def to_dict(self) -> dict: elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: - assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - try: - return self._conn.execute(cmd) - except Exception as e: - if "Method not available" in str(e): - raise RuntimeError( - "webExtension.install failed with 'Method not available'. " - "This likely means that web extension support is disabled. " - "Enable unsafe extension debugging and/or set options.enable_webextensions " - "in your WebDriver configuration." - ) from e - raise''', - ''' def uninstall(self, extension: str | dict): + return self._conn.execute(cmd)''', + ''' def uninstall(self, extension: Any | None = None): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). - - Raises: - ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension_id: Any = extension.get("extension") - else: - extension_id = extension - - if extension_id is None: - raise ValueError("extension parameter is required") - - params = {"extension": extension_id} + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], @@ -1607,7 +1317,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -1717,7 +1427,9 @@ def transform_download_params( "type": "allowed", # Convert pathlib.Path (or any path-like) to str so the BiDi # protocol always receives a plain JSON string. - "destinationFolder": (str(destination_folder) if destination_folder is not None else None), + "destinationFolder": ( + str(destination_folder) if destination_folder is not None else None + ), } elif allowed is False: return {"type": "denied"} @@ -1790,7 +1502,6 @@ def _add_event_handler( - 'history_updated' Args: - self: The module instance this handler is bound to. event_name: The name of the event to subscribe to callback: Callback function to invoke when event occurs contexts: Optional list of context IDs to limit event subscription @@ -1827,7 +1538,6 @@ def _remove_event_handler( """Remove an event handler by its callback ID. Args: - self: The module instance this handler is bound to. callback_id: The callback ID returned from add_event_handler """ if not hasattr(self, "_event_handlers"): diff --git a/py/private/cdp.py b/py/private/cdp.py index d94f0dac2e32b..b097762fe50cd 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -1,20 +1,26 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# The MIT License(MIT) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Copyright(c) 2018 Hyperion Gray # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp import contextvars import importlib @@ -54,11 +60,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) - numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) - if not numeric_versions: - raise - latest = max(numeric_versions, key=int) + latest = max(int(x[1:]) for x in versions) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index 8b4cc4e3e648f..c11b6efe4735f 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -53,6 +53,7 @@ def _generate_bidi_impl(ctx): args = [ cddl_file.path, output_base, + "--version", spec_version, ] @@ -72,6 +73,7 @@ def _generate_bidi_impl(ctx): return [DefaultInfo(files = depset(outputs))] + generate_bidi = rule( implementation = _generate_bidi_impl, attrs = { diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index 68f8d858bb6f4..c58f4b1c76fe6 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -461,7 +461,6 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt - # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -1038,9 +1037,7 @@ rich==14.3.3 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via - # -r py/requirements.txt - # keyring + # via -r py/requirements.txt sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index 92526c3a701be..7ec809eb20b18 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -122,7 +122,9 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + ) super().__init__(with_support, screen, stacktrace) @@ -189,7 +191,9 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + ) super().__init__(with_support, screen, stacktrace) @@ -275,7 +279,9 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + ) super().__init__(with_support, screen, stacktrace) diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index b37319da3651b..ab96f2d81e292 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,43 +1,7 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. from __future__ import annotations -from selenium.webdriver.common.bidi.browser import Browser -from selenium.webdriver.common.bidi.browsing_context import BrowsingContext -from selenium.webdriver.common.bidi.emulation import Emulation -from selenium.webdriver.common.bidi.input import Input -from selenium.webdriver.common.bidi.log import Log -from selenium.webdriver.common.bidi.network import Network -from selenium.webdriver.common.bidi.script import Script -from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.storage import Storage -from selenium.webdriver.common.bidi.webextension import WebExtension - -__all__ = [ - "Browser", - "BrowsingContext", - "Emulation", - "Input", - "Log", - "Network", - "Script", - "Session", - "Storage", - "WebExtension", -] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 440f13ed00072..ed6a4d8f33bc5 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,27 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass def transform_download_params( @@ -72,6 +61,14 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") +class ClientWindowNamedState: + """ClientWindowNamedState.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + + @dataclass class ClientWindowInfo: """ClientWindowInfo.""" @@ -113,6 +110,7 @@ def get_y(self): return self.y + @dataclass class UserContextInfo: """UserContextInfo.""" @@ -133,14 +131,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any] = field(default_factory=list) + client_windows: list[Any | None] | None = None @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass @@ -150,6 +148,13 @@ class RemoveUserContextParameters: user_context: Any | None = None +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters.""" + + client_window: Any | None = None + + @dataclass class ClientWindowRectState: """ClientWindowRectState.""" @@ -166,7 +171,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass @@ -184,29 +189,6 @@ class DownloadBehaviorDenied: type: str = field(default="denied", init=False) -class ClientWindowNamedState: - """Named states for a browser client window.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - NORMAL = "normal" - - -@dataclass -class SetClientWindowStateParameters: - """SetClientWindowStateParameters. - - The ``state`` field is required and must be either a named-state string - (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a - :class:`ClientWindowRectState` instance. ``client_window`` is the ID of - the window to affect. - """ - - client_window: Any | None = None - state: Any | None = None - - class Browser: """WebDriver BiDi browser module.""" @@ -215,23 +197,19 @@ def __init__(self, conn) -> None: def close(self): """Execute browser.close.""" - params = {} + params = { + } params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.close", params) result = self._conn.execute(cmd) return result - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Any | None = None, - unhandled_prompt_behavior: Any | None = None, - ): + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): """Execute browser.createUserContext.""" - if proxy and hasattr(proxy, "to_bidi_dict"): + if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() - if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, "to_bidi_dict"): + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() params = { @@ -249,7 +227,8 @@ def create_user_context( def get_client_windows(self): """Execute browser.getClientWindows.""" - params = {} + params = { + } params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getClientWindows", params) result = self._conn.execute(cmd) @@ -263,7 +242,7 @@ def get_client_windows(self): state=item.get("state"), width=item.get("width"), x=item.get("x"), - y=item.get("y"), + y=item.get("y") ) for item in items if isinstance(item, dict) @@ -272,20 +251,22 @@ def get_client_windows(self): def get_user_contexts(self): """Execute browser.getUserContexts.""" - params = {} + params = { + } params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getUserContexts", params) result = self._conn.execute(cmd) if result and "userContexts" in result: items = result.get("userContexts", []) - return [item.get("userContext") for item in items if isinstance(item, dict)] + return [ + item.get("userContext") + for item in items + if isinstance(item, dict) + ] return [] def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" - if user_context is None: - raise TypeError("remove_user_context() missing required argument: 'user_context'") - params = { "userContext": user_context, } @@ -294,12 +275,33 @@ def remove_user_context(self, user_context: Any | None = None): result = self._conn.execute(cmd) return result - def set_download_behavior( - self, - allowed: bool | None = None, - destination_folder: str | None = None, - user_contexts: list[Any] | None = None, - ): + def set_client_window_state(self, client_window: Any | None = None): + """Execute browser.setClientWindowState.""" + params = { + "clientWindow": client_window, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setClientWindowState", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): """Set the download behavior for the browser. Args: @@ -326,38 +328,3 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) - - def set_client_window_state( - self, - client_window: Any | None = None, - state: Any | None = None, - ): - """Set the client window state. - - Args: - client_window: The client window ID to apply the state to. - state: The window state to set. Can be one of: - - A string: "fullscreen", "maximized", "minimized", "normal" - - A ClientWindowRectState object with width, height, x, y - - A dict representing the state - - Raises: - ValueError: If client_window is not provided or state is invalid. - """ - if client_window is None: - raise ValueError("client_window is required") - if state is None: - raise ValueError("state is required") - - # Serialize ClientWindowRectState if needed - state_param = state - if hasattr(state, "__dataclass_fields__"): - # It's a dataclass, convert to dict - state_param = {k: v for k, v in state.__dict__.items() if v is not None} - - params = { - "clientWindow": client_window, - "state": state_param, - } - cmd = command_builder("browser.setClientWindowState", params) - return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index b5e14f19c6864..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,29 +1,20 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session class ReadinessState: @@ -121,7 +112,6 @@ class BaseNavigationInfo: navigation: Any | None = None timestamp: Any | None = None url: str | None = None - user_context: Any | None = None @dataclass @@ -197,7 +187,6 @@ class CreateResult: """CreateResult.""" context: Any | None = None - user_context: Any | None = None @dataclass @@ -231,14 +220,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any] = field(default_factory=list) + start_nodes: list[Any | None] | None = None @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any] = field(default_factory=list) + nodes: list[Any | None] | None = None @dataclass @@ -304,15 +293,6 @@ class ReloadParameters: wait: Any | None = None -@dataclass -class SetBypassCSPParameters: - """SetBypassCSPParameters.""" - - bypass: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - @dataclass class SetViewportParameters: """SetViewportParameters.""" @@ -320,7 +300,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass @@ -346,7 +326,20 @@ class HistoryUpdatedParameters: context: Any | None = None timestamp: Any | None = None url: str | None = None - user_context: Any | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) @dataclass @@ -356,7 +349,6 @@ class UserPromptClosedParameters: context: Any | None = None accepted: bool | None = None type: Any | None = None - user_context: Any | None = None user_text: str | None = None @@ -368,7 +360,6 @@ class UserPromptOpenedParameters: handler: Any | None = None message: str | None = None type: Any | None = None - user_context: Any | None = None default_value: str | None = None @@ -378,14 +369,12 @@ class DownloadWillBeginParams: suggested_filename: str | None = None - @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None - @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -397,15 +386,14 @@ class DownloadParams: url: str | None = None filepath: str | None = None - @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -417,7 +405,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) - # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -434,9 +421,10 @@ def from_json(cls, params: dict) -> DownloadEndParams: "navigation_failed": "browsingContext.navigationFailed", "user_prompt_closed": "browsingContext.userPromptClosed", "user_prompt_opened": "browsingContext.userPromptOpened", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", } - def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -469,20 +457,171 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None + + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + class BrowsingContext: """WebDriver BiDi browsingContext module.""" - EVENT_CONFIGS: dict[str, EventConfig] = {} - + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" - if context is None: - raise TypeError("activate() missing required argument: 'context'") - params = { "context": context, } @@ -491,17 +630,8 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot( - self, - context: str | None = None, - format: Any | None = None, - clip: Any | None = None, - origin: str | None = None, - ): + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" - if context is None: - raise TypeError("capture_screenshot() missing required argument: 'context'") - params = { "context": context, "format": format, @@ -518,9 +648,6 @@ def capture_screenshot( def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" - if context is None: - raise TypeError("close() missing required argument: 'context'") - params = { "context": context, "promptUnload": prompt_unload, @@ -530,17 +657,8 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create( - self, - type: Any | None = None, - reference_context: Any | None = None, - background: bool | None = None, - user_context: Any | None = None, - ): + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" - if type is None: - raise TypeError("create() missing required argument: 'type'") - params = { "type": type, "referenceContext": reference_context, @@ -574,7 +692,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent"), + parent=item.get("parent") ) for item in items if isinstance(item, dict) @@ -583,9 +701,6 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" - if context is None: - raise TypeError("handle_user_prompt() missing required argument: 'context'") - params = { "context": context, "accept": accept, @@ -596,20 +711,8 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes( - self, - context: str | None = None, - locator: Any | None = None, - serialization_options: Any | None = None, - start_nodes: Any | None = None, - max_node_count: int | None = None, - ): + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" - if context is None: - raise TypeError("locate_nodes() missing required argument: 'context'") - if locator is None: - raise TypeError("locate_nodes() missing required argument: 'locator'") - params = { "context": context, "locator": locator, @@ -627,11 +730,6 @@ def locate_nodes( def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" - if context is None: - raise TypeError("navigate() missing required argument: 'context'") - if url is None: - raise TypeError("navigate() missing required argument: 'url'") - params = { "context": context, "url": url, @@ -642,19 +740,8 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print( - self, - context: Any | None = None, - background: bool | None = None, - margin: Any | None = None, - page: Any | None = None, - scale: Any | None = None, - shrink_to_fit: bool | None = None, - ): + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" - if context is None: - raise TypeError("print() missing required argument: 'context'") - params = { "context": context, "background": background, @@ -673,9 +760,6 @@ def print( def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" - if context is None: - raise TypeError("reload() missing required argument: 'context'") - params = { "context": context, "ignoreCache": ignore_cache, @@ -686,33 +770,21 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_bypass_csp( - self, - bypass: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute browsingContext.setBypassCSP.""" - if bypass is None: - raise TypeError("set_bypass_csp() missing required argument: 'bypass'") - + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + """Execute browsingContext.setViewport.""" params = { - "bypass": bypass, - "contexts": contexts, + "context": context, + "viewport": viewport, "userContexts": user_contexts, + "devicePixelRatio": device_pixel_ratio, } params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.setBypassCSP", params) + cmd = command_builder("browsingContext.setViewport", params) result = self._conn.execute(cmd) return result def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" - if context is None: - raise TypeError("traverse_history() missing required argument: 'context'") - if delta is None: - raise TypeError("traverse_history() missing required argument: 'delta'") - params = { "context": context, "delta": delta, @@ -722,31 +794,6 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None) result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any = ..., - user_contexts: Any | None = None, - device_pixel_ratio: Any = ..., - ): - """Execute browsingContext.setViewport. - - Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, - while omitted arguments are not sent. - """ - params = {} - if context is not None: - params["context"] = context - if user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not ...: - params["viewport"] = viewport - if device_pixel_ratio is not ...: - params["devicePixelRatio"] = device_pixel_ratio - - cmd = command_builder("browsingContext.setViewport", params) - result = self._conn.execute(cmd) - return result def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -774,118 +821,67 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": EventConfig( - "context_created", - "browsingContext.contextCreated", - _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, - ), - "context_destroyed": EventConfig( - "context_destroyed", - "browsingContext.contextDestroyed", - _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, - ), - "navigation_started": EventConfig( - "navigation_started", - "browsingContext.navigationStarted", - _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, - ), - "fragment_navigated": EventConfig( - "fragment_navigated", - "browsingContext.fragmentNavigated", - _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, - ), - "history_updated": EventConfig( - "history_updated", - "browsingContext.historyUpdated", - _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, - ), - "dom_content_loaded": EventConfig( - "dom_content_loaded", - "browsingContext.domContentLoaded", - _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, - ), - "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), - "download_will_begin": EventConfig( - "download_will_begin", - "browsingContext.downloadWillBegin", - _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, - ), - "download_end": EventConfig( - "download_end", - "browsingContext.downloadEnd", - _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, - ), - "navigation_aborted": EventConfig( - "navigation_aborted", - "browsingContext.navigationAborted", - _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, - ), - "navigation_committed": EventConfig( - "navigation_committed", - "browsingContext.navigationCommitted", - _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, - ), - "navigation_failed": EventConfig( - "navigation_failed", - "browsingContext.navigationFailed", - _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, - ), - "user_prompt_closed": EventConfig( - "user_prompt_closed", - "browsingContext.userPromptClosed", - _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, - ), - "user_prompt_opened": EventConfig( - "user_prompt_opened", - "browsingContext.userPromptOpened", - _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, - ), + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index bac00765f43ca..38dcf8d803ea3 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -59,12 +59,8 @@ def import_devtools(ver): # Attempt to parse and load the 'most recent' devtools module. This is likely # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") - versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) - numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) - if not numeric_versions: - raise - latest = max(numeric_versions, key=int) + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest") + latest = max(int(x[1:]) for x in versions) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index ff67b56622c35..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,19 +17,17 @@ """Common utilities for BiDi command construction.""" -from __future__ import annotations +from typing import Any, Dict, Generator -from collections.abc import Generator -from typing import Any - -def command_builder(method: str, params: dict[str, Any] | None = None) -> Generator[dict[str, Any], Any, Any]: +def command_builder( + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: method: The BiDi method name (e.g., "session.status", "browser.close") - params: The parameters for the command. If omitted, an empty - dictionary is sent. + params: The parameters for the command Yields: A dictionary representing the BiDi command @@ -37,7 +35,5 @@ def command_builder(method: str, params: dict[str, Any] | None = None) -> Genera Returns: The result from the BiDi command execution """ - if params is None: - params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100644 new mode 100755 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a3e6b4b6c4ddb..4cd6ae2e3c712 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,27 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class ForcedColorsModeTheme: @@ -52,16 +41,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -89,8 +78,17 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" + + network_conditions: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -113,8 +111,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -130,8 +128,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -139,8 +137,17 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class SetViewportMetaOverrideParameters: + """SetViewportMetaOverrideParameters.""" + + viewport_meta: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -148,8 +155,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -157,8 +164,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -166,29 +173,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters.""" - - network_conditions: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -# Backward-compatible alias for existing imports -setNetworkConditionsParameters = SetNetworkConditionsParameters + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None class Emulation: @@ -197,16 +191,8 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override( - self, - theme: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" - if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") - params = { "theme": theme, "contexts": contexts, @@ -217,16 +203,19 @@ def set_forced_colors_mode_theme_override( result = self._conn.execute(cmd) return result - def set_locale_override( - self, - locale: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setLocaleOverride.""" - if locale is None: - raise TypeError("set_locale_override() missing required argument: 'locale'") + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setLocaleOverride.""" params = { "locale": locale, "contexts": contexts, @@ -237,16 +226,80 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_scrollbar_type_override( - self, - scrollbar_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScrollbarTypeOverride.""" - if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result + + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenSettingsOverride.""" + params = { + "screenArea": screen_area, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenSettingsOverride", params) + result = self._conn.execute(cmd) + return result + + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result + + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setViewportMetaOverride.""" + params = { + "viewportMeta": viewport_meta, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setViewportMetaOverride", params) + result = self._conn.execute(cmd) + return result + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result + + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, "contexts": contexts, @@ -257,7 +310,19 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, @@ -272,8 +337,8 @@ def set_geolocation_override( self, coordinates=None, error=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -287,7 +352,7 @@ def set_geolocation_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {} + params = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -312,7 +377,9 @@ def set_geolocation_override( if isinstance(error, dict): params["error"] = error else: - params["error"] = {"type": error.type if error.type is not None else "positionUnavailable"} + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -320,12 +387,11 @@ def set_geolocation_override( cmd = command_builder("emulation.setGeolocationOverride", params) result = self._conn.execute(cmd) return result - def set_timezone_override( self, timezone=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -338,19 +404,18 @@ def set_timezone_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"timezone": timezone} + params = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setTimezoneOverride", params) return self._conn.execute(cmd) - def set_scripting_enabled( self, enabled=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -363,19 +428,18 @@ def set_scripting_enabled( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"enabled": enabled} + params = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setScriptingEnabled", params) return self._conn.execute(cmd) - def set_user_agent_override( self, user_agent=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -387,19 +451,18 @@ def set_user_agent_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params: dict[str, Any] = {"userAgent": user_agent} + params = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setUserAgentOverride", params) return self._conn.execute(cmd) - def set_screen_orientation_override( self, screen_orientation=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -424,20 +487,19 @@ def set_screen_orientation_override( "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params: dict[str, Any] = {"screenOrientation": so_value} + params = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setScreenOrientationOverride", params) return self._conn.execute(cmd) - def set_network_conditions( self, network_conditions=None, offline: bool | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, ): """Execute emulation.setNetworkConditions. @@ -458,43 +520,10 @@ def set_network_conditions( nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params: dict[str, Any] = {"networkConditions": nc_value} + params = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) return self._conn.execute(cmd) - - def set_screen_settings_override( - self, - width: int | None = None, - height: int | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenSettingsOverride. - - Sets or clears the screen settings override for specified browsing or user - contexts. - - Args: - width: The screen width in pixels, or ``None`` to clear the override. - height: The screen height in pixels, or ``None`` to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - screen_area = None - if width is not None or height is not None: - screen_area = {} - if width is not None: - screen_area["width"] = width - if height is not None: - screen_area["height"] = height - params: dict[str, Any] = {"screenArea": screen_area} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setScreenSettingsOverride", params) - return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 6c06fc4e7deaa..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,29 +1,20 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: input from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session class PointerType: @@ -54,7 +45,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any] = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass @@ -63,7 +54,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any] = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass @@ -72,7 +63,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any] = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass @@ -82,7 +73,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any] = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass @@ -98,7 +89,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any] = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass @@ -172,7 +163,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any] = field(default_factory=list) + files: list[Any | None] | None = None @dataclass @@ -184,7 +175,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -192,7 +183,6 @@ def from_json(cls, params: dict) -> FileDialogInfo: multiple=params.get("multiple"), ) - @dataclass class PointerMoveAction: """PointerMoveAction.""" @@ -204,7 +194,6 @@ class PointerMoveAction: origin: Any | None = None properties: Any | None = None - @dataclass class PointerDownAction: """PointerDownAction.""" @@ -213,29 +202,174 @@ class PointerDownAction: button: Any | None = None properties: Any | None = None - # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "file_dialog_opened": "input.fileDialogOpened", } +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + class Input: """WebDriver BiDi input module.""" - EVENT_CONFIGS: dict[str, EventConfig] = {} - + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" - if context is None: - raise TypeError("perform_actions() missing required argument: 'context'") - if actions is None: - raise TypeError("perform_actions() missing required argument: 'actions'") - params = { "context": context, "actions": actions, @@ -247,9 +381,6 @@ def perform_actions(self, context: Any | None = None, actions: list[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" - if context is None: - raise TypeError("release_actions() missing required argument: 'context'") - params = { "context": context, } @@ -258,15 +389,8 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" - if context is None: - raise TypeError("set_files() missing required argument: 'context'") - if element is None: - raise TypeError("set_files() missing required argument: 'element'") - if files is None: - raise TypeError("set_files() missing required argument: 'files'") - params = { "context": context, "element": element, @@ -322,18 +446,13 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: input.fileDialogOpened -FileDialogOpened = globals().get("FileDialogInfo", dict) # Fallback to dict if type not defined +FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": EventConfig( - "file_dialog_opened", - "input.fileDialogOpened", - _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, - ), + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 597936402f99c..faf6c85ae2b6c 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,28 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: log from __future__ import annotations -from collections.abc import Callable +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator from dataclasses import dataclass -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager class Level: @@ -36,7 +24,6 @@ class Level: LogLevel = Level - @dataclass class BaseLogEntry: """BaseLogEntry.""" @@ -69,7 +56,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -82,7 +69,6 @@ def from_json(cls, params: dict) -> ConsoleLogEntry: stack_trace=params.get("stackTrace"), ) - @dataclass class JavascriptLogEntry: """JavascriptLogEntry - a JavaScript error log entry from the browser.""" @@ -95,7 +81,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -106,62 +92,18 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: stacktrace=params.get("stackTrace"), ) - -Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "entry_added": "log.entryAdded", -} - - class Log: """WebDriver BiDi log module.""" - EVENT_CONFIGS: dict[str, EventConfig] = {} - def __init__(self, conn) -> None: self._conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: log.entryAdded -EntryAdded = Entry + def entry_added(self): + """Execute log.entryAdded.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("log.entryAdded", params) + result = self._conn.execute(cmd) + return result -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -Log.EVENT_CONFIGS = { - "entry_added": EventConfig( - "entry_added", - "log.entryAdded", - _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, - ), -} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6c24e399b0e54..4f44e309bffbb 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,29 +1,20 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session class SameSite: @@ -84,8 +75,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - user_context: Any | None = None - intercepts: list[Any] = field(default_factory=list) + intercepts: list[Any | None] | None = None @dataclass @@ -181,13 +171,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any] = field(default_factory=list) + headers: list[Any | None] | None = None mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any] = field(default_factory=list) + auth_challenges: list[Any | None] | None = None @dataclass @@ -229,11 +219,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any] = field(default_factory=list) + data_types: list[Any | None] | None = None max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -247,9 +237,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - url_patterns: list[Any] = field(default_factory=list) + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None @dataclass @@ -264,9 +254,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any] = field(default_factory=list) + cookies: list[Any | None] | None = None credentials: Any | None = None - headers: list[Any] = field(default_factory=list) + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @@ -286,6 +276,15 @@ class ContinueWithAuthCredentials: credentials: Any | None = None +@dataclass +class disownDataParameters: + """disownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + @dataclass class FailRequestParameters: """FailRequestParameters.""" @@ -316,8 +315,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any] = field(default_factory=list) - headers: list[Any] = field(default_factory=list) + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @@ -341,16 +340,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any] = field(default_factory=list) + contexts: list[Any | None] | None = None @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass @@ -360,19 +359,6 @@ class ResponseStartedParameters: response: Any | None = None -@dataclass -class DisownDataParameters: - """DisownDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - request: Any | None = None - - -# Backward-compatible alias for existing imports -disownDataParameters = DisownDataParameters - - class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -383,14 +369,13 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: Any | None, value: Any | None) -> None: + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} - class Request: """Wraps a BiDi network request event params and provides request action methods.""" @@ -409,39 +394,176 @@ def continue_request(self, **kwargs): params.update(kwargs) self._conn.execute(_cb("network.continueRequest", params)) - # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", "before_request": "network.beforeRequestSent", } +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + class Network: """WebDriver BiDi network module.""" - EVENT_CONFIGS: dict[str, EventConfig] = {} - + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - self.intercepts: list[Any] = [] - self._handler_intercepts: dict[str, Any] = {} - - def add_data_collector( - self, - data_types: list[Any] | None = None, - max_encoded_data_size: Any | None = None, - collector_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute network.addDataCollector.""" - if data_types is None: - raise TypeError("add_data_collector() missing required argument: 'data_types'") - if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + self.intercepts = [] + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute network.addDataCollector.""" params = { "dataTypes": data_types, "maxEncodedDataSize": max_encoded_data_size, @@ -454,16 +576,8 @@ def add_data_collector( result = self._conn.execute(cmd) return result - def add_intercept( - self, - phases: list[Any] | None = None, - contexts: list[Any] | None = None, - url_patterns: list[Any] | None = None, - ): + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" - if phases is None: - raise TypeError("add_intercept() missing required argument: 'phases'") - params = { "phases": phases, "contexts": contexts, @@ -474,19 +588,8 @@ def add_intercept( result = self._conn.execute(cmd) return result - def continue_request( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - method: Any | None = None, - url: Any | None = None, - ): + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" - if request is None: - raise TypeError("continue_request() missing required argument: 'request'") - params = { "request": request, "body": body, @@ -500,19 +603,8 @@ def continue_request( result = self._conn.execute(cmd) return result - def continue_response( - self, - request: Any | None = None, - cookies: list[Any] | None = None, - credentials: Any | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" - if request is None: - raise TypeError("continue_response() missing required argument: 'request'") - params = { "request": request, "cookies": cookies, @@ -528,9 +620,6 @@ def continue_response( def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" - if request is None: - raise TypeError("continue_with_auth() missing required argument: 'request'") - params = { "request": request, } @@ -541,13 +630,6 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" - if data_type is None: - raise TypeError("disown_data() missing required argument: 'data_type'") - if collector is None: - raise TypeError("disown_data() missing required argument: 'collector'") - if request is None: - raise TypeError("disown_data() missing required argument: 'request'") - params = { "dataType": data_type, "collector": collector, @@ -560,9 +642,6 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" - if request is None: - raise TypeError("fail_request() missing required argument: 'request'") - params = { "request": request, } @@ -571,19 +650,8 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data( - self, - data_type: Any | None = None, - collector: Any | None = None, - disown: bool | None = None, - request: Any | None = None, - ): + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" - if data_type is None: - raise TypeError("get_data() missing required argument: 'data_type'") - if request is None: - raise TypeError("get_data() missing required argument: 'request'") - params = { "dataType": data_type, "collector": collector, @@ -595,19 +663,8 @@ def get_data( result = self._conn.execute(cmd) return result - def provide_response( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" - if request is None: - raise TypeError("provide_response() missing required argument: 'request'") - params = { "request": request, "body": body, @@ -623,9 +680,6 @@ def provide_response( def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" - if collector is None: - raise TypeError("remove_data_collector() missing required argument: 'collector'") - params = { "collector": collector, } @@ -636,9 +690,6 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" - if intercept is None: - raise TypeError("remove_intercept() missing required argument: 'intercept'") - params = { "intercept": intercept, } @@ -647,11 +698,8 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" - if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") - params = { "cacheBehavior": cache_behavior, "contexts": contexts, @@ -661,16 +709,8 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[A result = self._conn.execute(cmd) return result - def set_extra_headers( - self, - headers: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" - if headers is None: - raise TypeError("set_extra_headers() missing required argument: 'headers'") - params = { "headers": headers, "contexts": contexts, @@ -683,11 +723,6 @@ def set_extra_headers( def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" - if method is None: - raise TypeError("before_request_sent() missing required argument: 'method'") - if params is None: - raise TypeError("before_request_sent() missing required argument: 'params'") - params = { "initiator": initiator, "method": method, @@ -700,13 +735,6 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" - if error_text is None: - raise TypeError("fetch_error() missing required argument: 'error_text'") - if method is None: - raise TypeError("fetch_error() missing required argument: 'method'") - if params is None: - raise TypeError("fetch_error() missing required argument: 'params'") - params = { "errorText": error_text, "method": method, @@ -719,13 +747,6 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" - if response is None: - raise TypeError("response_completed() missing required argument: 'response'") - if method is None: - raise TypeError("response_completed() missing required argument: 'method'") - if params is None: - raise TypeError("response_completed() missing required argument: 'params'") - params = { "response": response, "method": method, @@ -738,9 +759,6 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" - if response is None: - raise TypeError("response_started() missing required argument: 'response'") - params = { "response": response, } @@ -772,7 +790,6 @@ def _add_intercept(self, phases=None, url_patterns=None): if intercept_id and intercept_id not in self.intercepts: self.intercepts.append(intercept_id) return result - def _remove_intercept(self, intercept_id): """Remove a low-level network intercept.""" from selenium.webdriver.common.bidi.common import command_builder as _cb @@ -780,7 +797,6 @@ def _remove_intercept(self, intercept_id): self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) if intercept_id in self.intercepts: self.intercepts.remove(intercept_id) - def add_request_handler(self, event, callback, url_patterns=None): """Add a handler for network requests at the specified phase. @@ -799,37 +815,31 @@ def add_request_handler(self, event, callback, url_patterns=None): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) - intercept_id = intercept_result.get("intercept") if intercept_result else None + self._add_intercept(phases=[phase], url_patterns=url_patterns) def _request_callback(params): - raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) request = Request(self._conn, raw) callback(request) - callback_id = self.add_event_handler(event, _request_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id - + return self.add_event_handler(event, _request_callback) def remove_request_handler(self, event, callback_id): - """Remove a network request handler and its associated network intercept. + """Remove a network request handler. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ self.remove_event_handler(event, callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id) - def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() for intercept_id in list(self.intercepts): self._remove_intercept(intercept_id) - def add_auth_handler(self, username, password): """Add an auth handler that automatically provides credentials. @@ -842,13 +852,17 @@ def add_auth_handler(self, username, password): """ from selenium.webdriver.common.bidi.common import command_builder as _cb - # Set up network intercept for authRequired phase - intercept_result = self._add_intercept(phases=["authRequired"]) - intercept_id = intercept_result.get("intercept") if intercept_result else None - def _auth_callback(params): - raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) - request_id = raw.get("request", {}).get("request") if isinstance(raw, dict) else None + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) if request_id: self._conn.execute( _cb( @@ -865,21 +879,10 @@ def _auth_callback(params): ) ) - callback_id = self.add_event_handler("auth_required", _auth_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id - + return self.add_event_handler("auth_required", _auth_callback) def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID and its associated network intercept. - - Args: - callback_id: The handler ID returned by add_auth_handler. - """ + """Remove an auth handler by callback ID.""" self.remove_event_handler("auth_required", callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -907,19 +910,14 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: network.authRequired -AuthRequired = globals().get("AuthRequiredParameters", dict) # Fallback to dict if type not defined +AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": EventConfig( - "auth_required", - "network.authRequired", - _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, - ), + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 98e25a1d2f856..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,9 +20,9 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Optional, Union -from selenium.webdriver.common.bidi.common import command_builder +from .common import command_builder _VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: PermissionDescriptor | str, - state: PermissionState | str, - origin: str | None = None, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: """Set a permission for a given origin. @@ -82,7 +82,8 @@ def set_permission( state_value = state.value if isinstance(state, PermissionState) else state if state_value not in _VALID_PERMISSION_STATES: raise ValueError( - f"Invalid permission state: {state_value!r}. Must be one of {sorted(_VALID_PERMISSION_STATES)}" + f"Invalid permission state: {state_value!r}. " + f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" ) if isinstance(descriptor, str): diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed old mode 100644 new mode 100755 diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index ee6eb4f4a437a..e13c11f71a5cb 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,29 +1,20 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: script from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session class SpecialNumber: @@ -217,7 +208,6 @@ class WindowRealmInfo: type: str = field(default="window", init=False) context: Any | None = None - user_context: Any | None = None sandbox: str | None = None @@ -226,7 +216,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any] = field(default_factory=list) + owners: list[Any | None] | None = None @dataclass @@ -470,7 +460,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any] = field(default_factory=list) + children: list[Any | None] | None = None local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -509,7 +499,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any] = field(default_factory=list) + call_frames: list[Any | None] | None = None @dataclass @@ -518,7 +508,6 @@ class Source: realm: Any | None = None context: Any | None = None - user_context: Any | None = None @dataclass @@ -541,9 +530,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None sandbox: str | None = None @@ -558,7 +547,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any] = field(default_factory=list) + handles: list[Any | None] | None = None target: Any | None = None @@ -569,7 +558,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any] = field(default_factory=list) + arguments: list[Any | None] | None = None result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -600,7 +589,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any] = field(default_factory=list) + realms: list[Any | None] | None = None @dataclass @@ -632,29 +621,170 @@ class RealmDestroyedParameters: "realm_destroyed": "script.realmDestroyed", } +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + class Script: """WebDriver BiDi script module.""" - EVENT_CONFIGS: dict[str, EventConfig] = {} - + EVENT_CONFIGS = {} def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script( - self, - function_declaration: Any | None = None, - arguments: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - sandbox: Any | None = None, - ): + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" - if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: 'function_declaration'") - params = { "functionDeclaration": function_declaration, "arguments": arguments, @@ -667,13 +797,8 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: list[Any] | None = None, target: Any | None = None): + def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" - if handles is None: - raise TypeError("disown() missing required argument: 'handles'") - if target is None: - raise TypeError("disown() missing required argument: 'target'") - params = { "handles": handles, "target": target, @@ -683,25 +808,8 @@ def disown(self, handles: list[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function( - self, - function_declaration: Any | None = None, - await_promise: bool | None = None, - target: Any | None = None, - arguments: list[Any] | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - this: Any | None = None, - user_activation: bool | None = None, - ): + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" - if function_declaration is None: - raise TypeError("call_function() missing required argument: 'function_declaration'") - if await_promise is None: - raise TypeError("call_function() missing required argument: 'await_promise'") - if target is None: - raise TypeError("call_function() missing required argument: 'target'") - params = { "functionDeclaration": function_declaration, "awaitPromise": await_promise, @@ -717,23 +825,8 @@ def call_function( result = self._conn.execute(cmd) return result - def evaluate( - self, - expression: Any | None = None, - target: Any | None = None, - await_promise: bool | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - user_activation: bool | None = None, - ): + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" - if expression is None: - raise TypeError("evaluate() missing required argument: 'expression'") - if target is None: - raise TypeError("evaluate() missing required argument: 'target'") - if await_promise is None: - raise TypeError("evaluate() missing required argument: 'await_promise'") - params = { "expression": expression, "target": target, @@ -760,9 +853,6 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" - if script is None: - raise TypeError("remove_preload_script() missing required argument: 'script'") - params = { "script": script, } @@ -773,13 +863,6 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" - if channel is None: - raise TypeError("message() missing required argument: 'channel'") - if data is None: - raise TypeError("message() missing required argument: 'data'") - if source is None: - raise TypeError("message() missing required argument: 'source'") - params = { "channel": channel, "data": data, @@ -806,9 +889,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -859,15 +941,7 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw - - def _add_preload_script( - self, - function_declaration, - arguments=None, - contexts=None, - user_contexts=None, - sandbox=None, - ): + def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): """Add a preload script with validation. Args: @@ -895,7 +969,6 @@ def _add_preload_script( if isinstance(result, dict): return result.get("script") return result - def _remove_preload_script(self, script_id): """Remove a preload script by ID. @@ -903,7 +976,6 @@ def _remove_preload_script(self, script_id): script_id: The ID of the preload script to remove. """ return self.remove_preload_script(script=script_id) - def pin(self, function_declaration): """Pin (add) a preload script that runs on every page load. @@ -914,7 +986,6 @@ def pin(self, function_declaration): script_id: The ID of the pinned script (str). """ return self._add_preload_script(function_declaration) - def unpin(self, script_id): """Unpin (remove) a previously pinned preload script. @@ -922,16 +993,7 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) - - def _evaluate( - self, - expression, - target, - await_promise, - result_ownership=None, - serialization_options=None, - user_activation=None, - ): + def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): """Evaluate a script expression and return a structured result. Args: @@ -945,7 +1007,6 @@ def _evaluate( Returns: An object with .realm, .result (dict or None), and .exception_details (or None). """ - class _EvalResult: def __init__(self2, realm, result, exception_details): self2.realm = realm @@ -967,18 +1028,7 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) - - def _call_function( - self, - function_declaration, - await_promise, - target, - arguments=None, - result_ownership=None, - this=None, - user_activation=None, - serialization_options=None, - ): + def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): """Call a function and return a structured result. Args: @@ -994,7 +1044,6 @@ def _call_function( Returns: An object with .result (dict or None) and .exception_details (or None). """ - class _CallResult: def __init__(self2, result, exception_details): self2.result = result @@ -1017,7 +1066,6 @@ def __init__(self2, result, exception_details): if raw.get("type") == "success": return _CallResult(result=raw.get("result"), exception_details=None) return _CallResult(result=raw, exception_details=None) - def _get_realms(self, context=None, type=None): """Get all realms, optionally filtered by context and type. @@ -1028,7 +1076,6 @@ def _get_realms(self, context=None, type=None): Returns: List of realm info objects with .realm, .origin, .type, .context attributes. """ - class _RealmInfo: def __init__(self2, realm, origin, type_, context): self2.realm = realm @@ -1041,16 +1088,13 @@ def __init__(self2, realm, origin, type_, context): result = [] for r in realms_list: if isinstance(r, dict): - result.append( - _RealmInfo( - realm=r.get("realm"), - origin=r.get("origin"), - type_=r.get("type"), - context=r.get("context"), - ) - ) + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) return result - def _disown(self, handles, target): """Disown handles in a browsing context. @@ -1059,13 +1103,11 @@ def _disown(self, handles, target): target: A dict like {"context": }. """ return self.disown(handles=handles, target=target) - def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" @@ -1096,7 +1138,9 @@ def _wrapped(raw): if entry_type_filter is None: callback(entry) else: - t = getattr(entry, "type_", None) or (entry.get("type") if isinstance(entry, dict) else None) + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) if t == entry_type_filter: callback(entry) @@ -1112,14 +1156,15 @@ def from_json(self2, p): if bidi_event not in self._log_subscriptions: session = _Session(self._conn) result = session.subscribe([bidi_event]) - sub_id = result.get("subscription") if isinstance(result, dict) else None + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) self._log_subscriptions[bidi_event] = { "callbacks": [], "subscription_id": sub_id, } self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) return callback_id - def _unsubscribe_log_entry(self, callback_id): """Unsubscribe a log entry callback by ID.""" from selenium.webdriver.common.bidi.session import Session as _Session @@ -1148,7 +1193,6 @@ def from_json(self2, p): else: session.unsubscribe(events=[bidi_event]) del self._log_subscriptions[bidi_event] - def add_console_message_handler(self, callback: Callable) -> int: """Add a handler for console log messages (log.entryAdded type=console). @@ -1159,11 +1203,9 @@ def add_console_message_handler(self, callback: Callable) -> int: callback_id for use with remove_console_message_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="console") - def remove_console_message_handler(self, callback_id: int) -> None: """Remove a console message handler by callback ID.""" self._unsubscribe_log_entry(callback_id) - def add_javascript_error_handler(self, callback: Callable) -> int: """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). @@ -1174,7 +1216,6 @@ def add_javascript_error_handler(self, callback: Callable) -> int: callback_id for use with remove_javascript_error_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="javascript") - def remove_javascript_error_handler(self, callback_id: int) -> None: """Remove a JavaScript error handler by callback ID.""" self._unsubscribe_log_entry(callback_id) @@ -1205,26 +1246,17 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: script.realmCreated -RealmCreated = globals().get("RealmInfo", dict) # Fallback to dict if type not defined +RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined # Event: script.realmDestroyed -RealmDestroyed = globals().get("RealmDestroyedParameters", dict) # Fallback to dict if type not defined +RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": EventConfig( - "realm_created", - "script.realmCreated", - _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, - ), - "realm_destroyed": EventConfig( - "realm_destroyed", - "script.realmDestroyed", - _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, - ), + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index b00544d286546..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,27 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: @@ -37,7 +26,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any] = field(default_factory=list) + first_match: list[Any | None] | None = None @dataclass @@ -73,7 +62,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any] = field(default_factory=list) + no_proxy: list[Any | None] | None = None @dataclass @@ -103,23 +92,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any] = field(default_factory=list) + subscriptions: list[Any | None] | None = None @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str] = field(default_factory=list) + events: list[str | None] | None = None @dataclass @@ -188,11 +177,6 @@ def to_bidi_dict(self) -> dict: result["prompt"] = self.prompt return result - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - class Session: """WebDriver BiDi session module.""" @@ -201,7 +185,8 @@ def __init__(self, conn) -> None: def status(self): """Execute session.status.""" - params = {} + params = { + } params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.status", params) result = self._conn.execute(cmd) @@ -209,9 +194,6 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" - if capabilities is None: - raise TypeError("new() missing required argument: 'capabilities'") - params = { "capabilities": capabilities, } @@ -222,22 +204,15 @@ def new(self, capabilities: Any | None = None): def end(self): """Execute session.end.""" - params = {} + params = { + } params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.end", params) result = self._conn.execute(cmd) return result - def subscribe( - self, - events: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" - if events is None: - raise TypeError("subscribe() missing required argument: 'events'") - params = { "events": events, "contexts": contexts, @@ -248,7 +223,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, @@ -258,3 +233,4 @@ def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] cmd = command_builder("session.unsubscribe", params) result = self._conn.execute(cmd) return result + diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 90e65ac3d5ffb..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,27 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass @@ -44,7 +33,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any] = field(default_factory=list) + cookies: list[Any | None] | None = None partition_key: Any | None = None @@ -88,18 +77,13 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: Any | None, value: Any | None) -> None: + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - class SameSite: """SameSite cookie attribute values.""" @@ -108,7 +92,6 @@ class SameSite: NONE = "none" DEFAULT = "default" - @dataclass class StorageCookie: """A cookie object returned by storage.getCookies.""" @@ -124,11 +107,11 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) + value = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( @@ -143,7 +126,6 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: expiry=raw.get("expiry"), ) - @dataclass class CookieFilter: """CookieFilter.""" @@ -181,11 +163,6 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - @dataclass class PartialCookie: """PartialCookie.""" @@ -220,11 +197,6 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - class BrowsingContextPartitionDescriptor: """BrowsingContextPartitionDescriptor. @@ -240,11 +212,6 @@ def __init__(self, context: Any = None, type: str = "context") -> None: def to_bidi_dict(self) -> dict: return {"type": "context", "context": self.context} - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - @dataclass class StorageKeyPartitionDescriptor: """StorageKeyPartitionDescriptor.""" @@ -262,17 +229,45 @@ def to_bidi_dict(self) -> dict: result["sourceOrigin"] = self.source_origin return result - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - class Storage: """WebDriver BiDi storage module.""" def __init__(self, conn) -> None: self._conn = conn + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result + + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result + def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -287,7 +282,11 @@ def get_cookies(self, filter=None, partition=None): cmd = command_builder("storage.getCookies", params) result = self._conn.execute(cmd) if result and "cookies" in result: - cookies = [StorageCookie.from_bidi_dict(c) for c in result.get("cookies", []) if isinstance(c, dict)] + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] pk_raw = result.get("partitionKey") pk = ( PartitionKey( @@ -299,7 +298,6 @@ def get_cookies(self, filter=None, partition=None): ) return GetCookiesResult(cookies=cookies, partition_key=pk) return GetCookiesResult(cookies=[], partition_key=None) - def set_cookie(self, cookie=None, partition=None): """Execute storage.setCookie.""" if cookie and hasattr(cookie, "to_bidi_dict"): @@ -313,19 +311,7 @@ def set_cookie(self, cookie=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return SetCookieResult(partition_key=pk) return result - def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -339,15 +325,4 @@ def delete_cookies(self, filter=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return DeleteCookiesResult(partition_key=pk) return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 62f2dec130308..8a737efeeafde 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,27 +1,16 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass @@ -75,12 +64,7 @@ class WebExtension: def __init__(self, conn) -> None: self._conn = conn - def install( - self, - path: str | None = None, - archive_path: str | None = None, - base64_value: str | None = None, - ): + def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -98,57 +82,31 @@ def install( Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [ - k - for k, v in { - "path": path, - "archive_path": archive_path, - "base64_value": base64_value, - }.items() - if v is not None - ] + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] if len(provided) != 1: - raise ValueError(f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}") + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: - assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - try: - return self._conn.execute(cmd) - except Exception as e: - if "Method not available" in str(e): - raise RuntimeError( - "webExtension.install failed with 'Method not available'. " - "This likely means that web extension support is disabled. " - "Enable unsafe extension debugging and/or set options.enable_webextensions " - "in your WebDriver configuration." - ) from e - raise - - def uninstall(self, extension: str | dict): + return self._conn.execute(cmd) + def uninstall(self, extension: Any | None = None): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). - - Raises: - ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension_id: Any = extension.get("extension") - else: - extension_id = extension - - if extension_id is None: - raise ValueError("extension parameter is required") - - params = {"extension": extension_id} + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index eadf1d069709f..28de19afa5742 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -35,13 +35,23 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make( + 0, "DIRECT" + ) # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make( + 1, "MANUAL" + ) # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make( + 4, "AUTODETECT" + ) # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make( + 5, "SYSTEM" + ) # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make( + 6, "UNSPECIFIED" + ) # Not initialized (for internal use). @classmethod def load(cls, value): @@ -50,7 +60,11 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: + if ( + isinstance(attr_value, dict) + and "string" in attr_value + and attr_value["string"] == value + ): return attr_value raise Exception(f"No proxy type is found for {value}") @@ -205,13 +219,17 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] + result["noProxy"] = [ + host.strip() for host in self.noProxy.split(",") if host.strip() + ] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError("no_proxy must be a comma-separated string or a list of strings") + raise TypeError( + "no_proxy must be a comma-separated string or a list of strings" + ) elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 4e426090883d4..2c41897878075 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -116,7 +116,9 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config = client_config or ClientConfig( + remote_server_addr=command_executor + ) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -398,9 +400,13 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) + return self.create_web_element( + value["element-6066-11e4-a52e-4f735466cecf"] + ) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) + return self._shadowroot_cls( + self, value["shadow-6066-11e4-a52e-4f735466cecf"] + ) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -426,7 +432,9 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ + "value" + ] def execute( self, @@ -462,7 +470,9 @@ def execute( elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute( + driver_command, params + ) if response: self.error_handler.check_response(response) @@ -518,7 +528,9 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None + raise KeyError( + f"No script with key: {script_key} existed in {self.pinned_scripts}" + ) from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -551,7 +563,9 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -570,7 +584,9 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] @property def current_url(self) -> str: @@ -747,7 +763,9 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} + ) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -776,7 +794,9 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} + ) except WebDriverException: self.execute( Command.SET_TIMEOUTS, @@ -817,7 +837,9 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: + def find_element( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -838,12 +860,18 @@ def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) - if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") + raise NoSuchElementException( + f"Cannot locate relative element with: {by.root}" + ) return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ + "value" + ] - def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: + def find_elements( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -865,14 +893,21 @@ def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raise FileNotFoundError( + f"Could not find findElements.js in package {_pkg}" + ) raw_function = raw_data.decode("utf8") - find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" + find_element_js = ( + f"/* findElements */return ({raw_function}).apply(null, arguments);" + ) return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] + return ( + self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] + or [] + ) @property def capabilities(self) -> dict: @@ -969,7 +1004,9 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position( + self, x: float, y: float, windowHandle: str = "current" + ) -> dict: """Sets the x,y position of the current window. Args: @@ -1028,7 +1065,9 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return self.execute( + Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} + )["value"] @property def file_detector(self) -> FileDetector: @@ -1073,7 +1112,9 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") + raise WebDriverException( + "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" + ) def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1088,7 +1129,9 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1097,20 +1140,28 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") + raise RuntimeError( + "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for CDP support" + ) self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute(self._devtools.target.get_targets()) + targets = self._websocket_connection.execute( + self._devtools.target.get_targets() + ) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) + session = self._websocket_connection.execute( + self._devtools.target.attach_to_target(target_id, True) + ) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1125,7 +1176,9 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1151,10 +1204,14 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for BiDi support" + ) self._websocket_connection = WebSocketConnection( ws_url, @@ -1370,9 +1427,13 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get( + "debuggerAddress" + ) elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get( + "debuggerAddress" + ) except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1400,7 +1461,9 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] + self._authenticator_id = self.execute( + Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() + )["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1440,8 +1503,12 @@ def add_credential(self, credential: Credential) -> None: @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) - return [Credential.from_dict(credential) for credential in credential_data["value"]] + credential_data = self.execute( + Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) + return [ + Credential.from_dict(credential) for credential in credential_data["value"] + ] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1463,7 +1530,9 @@ def remove_credential(self, credential_id: str | bytearray) -> None: @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1484,7 +1553,9 @@ def set_user_verified(self, verified: bool) -> None: def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1499,12 +1570,16 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ + "contents" + ] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1517,7 +1592,9 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) self.execute(Command.DELETE_DOWNLOADABLE_FILES) From a0d47350c40ef986337dd0e492ee2434137a0053 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Feb 2026 14:07:04 +0000 Subject: [PATCH 02/42] fixup --- py/generate_bidi.py | 204 ++++--- py/private/bidi_enhancements_manifest.py | 77 ++- py/selenium/webdriver/common/bidi/__init__.py | 23 + py/selenium/webdriver/common/bidi/browser.py | 43 +- .../webdriver/common/bidi/browsing_context.py | 171 ++++-- py/selenium/webdriver/common/bidi/cdp.py | 515 ------------------ py/selenium/webdriver/common/bidi/common.py | 7 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 187 +++---- py/selenium/webdriver/common/bidi/input.py | 36 +- py/selenium/webdriver/common/bidi/log.py | 223 +++++++- py/selenium/webdriver/common/bidi/network.py | 115 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 113 +++- py/selenium/webdriver/common/bidi/session.py | 30 +- py/selenium/webdriver/common/bidi/storage.py | 44 +- .../webdriver/common/bidi/webextension.py | 20 +- 17 files changed, 873 insertions(+), 945 deletions(-) delete mode 100644 py/selenium/webdriver/common/bidi/cdp.py mode change 100755 => 100644 py/selenium/webdriver/common/bidi/console.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 1770cf436bef1..2db595ff37cd0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -43,8 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder +from typing import Any """ @@ -53,7 +51,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -124,10 +122,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -139,11 +137,11 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -174,8 +172,15 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str else: param_list = "self" - # Build method body - body = f" def {method_name}({param_list}):\n" + # Build method body - wrap long signatures over multiple lines if needed + sig_line = f" def {method_name}({param_list}):" + if len(sig_line) > 120 and param_strs: + body = f" def {method_name}(\n self,\n" + for p in param_strs: + body += f" {p},\n" + body += " ):\n" + else: + body = sig_line + "\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -237,7 +242,6 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -264,45 +268,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +355,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +370,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -386,7 +390,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "List[" in python_type: + elif "list[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -453,7 +457,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -530,10 +534,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -548,7 +552,33 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: + """Check if any type definition in this module requires the 'field' import. + + Respects the same type exclusions applied during code generation. + """ + enhancements = enhancements or {} + extra_cls_names: set[str] = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + for type_def in self.types: + if type_def.name in exclude_types: + continue + for field_type in type_def.fields.values(): + # Literal string discriminants use field(default=..., init=False) + if re.match(r'^"', field_type.strip()): + return True + # List-typed fields use field(default_factory=list) + python_type = CddlTypeDefinition._get_python_type(field_type) + if python_type.startswith("list["): + return True + return False + + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -558,17 +588,21 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.types: - code += "from dataclasses import field\n" + if self.commands: + code += "from .common import command_builder\n" + dataclass_imported = False if self.commands or self.types: - code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" + dataclass_imported = True + if self.types and self._needs_field_import(enhancements): + code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: code += "import threading\n" code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if not dataclass_imported: + code += "from dataclasses import dataclass\n" code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -660,7 +694,13 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) + # Also auto-exclude types whose names appear in extra_dataclasses + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -680,13 +720,16 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" + # Collect event keys from extra_events so we skip CDDL duplicates + extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - code += f' "{event_name}": "{event_def.method}",\n' + if event_name not in extra_event_keys: + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -923,7 +966,13 @@ def clear_event_handlers(self) -> None: code += "\n" # Generate command methods - exclude_methods = enhancements.get("exclude_methods", []) + # Auto-exclude methods whose names appear in extra_methods to prevent duplicates + extra_method_names = set() + for extra_meth in enhancements.get("extra_methods", []): + m = re.search(r"def\s+(\w+)\s*\(", extra_meth) + if m: + extra_method_names.add(m.group(1)) + exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names if self.commands: for command in self.commands: # Get method-specific enhancements @@ -981,24 +1030,44 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" + # Collect extra event keys to skip CDDL duplicates + extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) + if event_name in extra_event_keys_cfg: + continue # The event class is the event name (e.g., ContextCreated) # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + code += ( + f' "{event_name}": (\n' + f' EventConfig("{event_name}", "{event_def.method}",\n' + f' _globals.get("{event_def.name}", dict))\n' + f' if _globals.get("{event_def.name}")\n' + f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' + f' ),\n' + ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] be = extra_evt["bidi_event"] ec = extra_evt["event_class"] - code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' + if len(single) > 120: + code += ( + f' "{ek}": EventConfig(\n' + f' "{ek}", "{be}",\n' + f' _globals.get("{ec}", dict),\n' + f' ),\n' + ) + else: + code += single + "\n" code += "}\n" return code @@ -1011,9 +1080,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1021,12 +1090,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1090,9 +1159,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1175,7 +1241,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1225,7 +1291,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1352,8 +1418,8 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), @@ -1466,7 +1532,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1481,7 +1547,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1703,7 +1769,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index ae7229f6ddebd..39af67d4c635b 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -85,7 +85,12 @@ # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). "extra_methods": [ - ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + ''' def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: @@ -272,8 +277,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -325,8 +330,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -349,8 +354,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -373,8 +378,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -396,8 +401,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -433,8 +438,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. @@ -534,7 +539,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw''', - ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + ''' def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -586,7 +598,15 @@ def _serialize_arg(value): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id)''', - ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + ''' def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -621,7 +641,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None)''', - ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + ''' def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1256,7 +1286,12 @@ def to_bidi_dict(self) -> dict: # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], "extra_methods": [ - ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + ''' def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -1274,7 +1309,11 @@ def to_bidi_dict(self) -> dict: Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" @@ -1502,6 +1541,7 @@ def _add_event_handler( - 'history_updated' Args: + self: The module instance this handler is bound to. event_name: The name of the event to subscribe to callback: Callback function to invoke when event occurs contexts: Optional list of context IDs to limit event subscription @@ -1538,6 +1578,7 @@ def _remove_event_handler( """Remove an event handler by its callback ID. Args: + self: The module instance this handler is bound to. callback_id: The callback ID returned from add_event_handler """ if not hasattr(self, "_event_handlers"): diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index ab96f2d81e292..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -5,3 +5,26 @@ from __future__ import annotations +from .browser import Browser +from .browsing_context import BrowsingContext +from .emulation import Emulation +from .input import Input +from .log import Log +from .network import Network +from .script import Script +from .session import Session +from .storage import Storage +from .webextension import WebExtension + +__all__ = [ + "Browser", + "BrowsingContext", + "Emulation", + "Input", + "Log", + "Network", + "Script", + "Session", + "Storage", + "WebExtension", +] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ed6a4d8f33bc5..acda63f71953e 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( @@ -131,14 +130,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,7 +170,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -204,7 +203,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -285,23 +289,12 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..5f128635df29d 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -220,14 +219,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any | None] | None = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any | None] | None = field(default_factory=list) @dataclass @@ -300,7 +299,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -328,20 +327,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -390,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -414,8 +399,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,7 +613,13 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -657,7 +646,13 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" params = { "type": type, @@ -711,7 +706,14 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -740,7 +742,15 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" params = { "context": context, @@ -770,7 +780,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -868,20 +884,81 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "context_created": ( + EventConfig("context_created", "browsingContext.contextCreated", + _globals.get("ContextCreated", dict)) + if _globals.get("ContextCreated") + else EventConfig("context_created", "browsingContext.contextCreated", dict) + ), + "context_destroyed": ( + EventConfig("context_destroyed", "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict)) + if _globals.get("ContextDestroyed") + else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) + ), + "navigation_started": ( + EventConfig("navigation_started", "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict)) + if _globals.get("NavigationStarted") + else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) + ), + "fragment_navigated": ( + EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict)) + if _globals.get("FragmentNavigated") + else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) + ), + "history_updated": ( + EventConfig("history_updated", "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict)) + if _globals.get("HistoryUpdated") + else EventConfig("history_updated", "browsingContext.historyUpdated", dict) + ), + "dom_content_loaded": ( + EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict)) + if _globals.get("DomContentLoaded") + else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) + ), + "load": ( + EventConfig("load", "browsingContext.load", + _globals.get("Load", dict)) + if _globals.get("Load") + else EventConfig("load", "browsingContext.load", dict) + ), + "navigation_aborted": ( + EventConfig("navigation_aborted", "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict)) + if _globals.get("NavigationAborted") + else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) + ), + "navigation_committed": ( + EventConfig("navigation_committed", "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict)) + if _globals.get("NavigationCommitted") + else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) + ), + "navigation_failed": ( + EventConfig("navigation_failed", "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict)) + if _globals.get("NavigationFailed") + else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) + ), + "user_prompt_closed": ( + EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict)) + if _globals.get("UserPromptClosed") + else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) + ), + "user_prompt_opened": ( + EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict)) + if _globals.get("UserPromptOpened") + else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) + ), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBeginParams", dict), + ), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py deleted file mode 100644 index 38dcf8d803ea3..0000000000000 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ /dev/null @@ -1,515 +0,0 @@ -# The MIT License(MIT) -# -# Copyright(c) 2018 Hyperion Gray -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp - -import contextvars -import importlib -import itertools -import json -import logging -import pathlib -from collections import defaultdict -from collections.abc import AsyncGenerator, AsyncIterator, Generator -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from typing import Any, TypeVar - -import trio -from trio_websocket import ConnectionClosed as WsConnectionClosed -from trio_websocket import connect_websocket_url - -logger = logging.getLogger("trio_cdp") -T = TypeVar("T") -MAX_WS_MESSAGE_SIZE = 2**24 - -devtools = None -version = None - - -def import_devtools(ver): - """Attempt to load the current latest available devtools into the module cache for use later.""" - global devtools - global version - version = ver - base = "selenium.webdriver.common.devtools.v" - try: - devtools = importlib.import_module(f"{base}{ver}") - return devtools - except ModuleNotFoundError: - # Attempt to parse and load the 'most recent' devtools module. This is likely - # because cdp has been updated but selenium python has not been released yet. - devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") - versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest") - latest = max(int(x[1:]) for x in versions) - selenium_logger = logging.getLogger(__name__) - selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) - devtools = importlib.import_module(f"{base}{latest}") - return devtools - - -_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") -_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") - - -def get_connection_context(fn_name): - """Look up the current connection. - - If there is no current connection, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _connection_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a connection context.") - - -def get_session_context(fn_name): - """Look up the current session. - - If there is no current session, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _session_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a session context.") - - -@contextmanager -def connection_context(connection): - """Context manager installs ``connection`` as the session context for the current Trio task.""" - token = _connection_context.set(connection) - try: - yield - finally: - _connection_context.reset(token) - - -@contextmanager -def session_context(session): - """Context manager installs ``session`` as the session context for the current Trio task.""" - token = _session_context.set(session) - try: - yield - finally: - _session_context.reset(token) - - -def set_global_connection(connection): - """Install ``connection`` in the root context so that it will become the default connection for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _connection_context - _connection_context = contextvars.ContextVar("_connection_context", default=connection) - - -def set_global_session(session): - """Install ``session`` in the root context so that it will become the default session for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _session_context - _session_context = contextvars.ContextVar("_session_context", default=session) - - -class BrowserError(Exception): - """This exception is raised when the browser's response to a command indicates that an error occurred.""" - - def __init__(self, obj): - self.code = obj.get("code") - self.message = obj.get("message") - self.detail = obj.get("data") - - def __str__(self): - return f"BrowserError {self.detail}" - - -class CdpConnectionClosed(WsConnectionClosed): - """Raised when a public method is called on a closed CDP connection.""" - - def __init__(self, reason): - """Constructor. - - Args: - reason: wsproto.frame_protocol.CloseReason - """ - self.reason = reason - - def __repr__(self): - """Return representation.""" - return f"{self.__class__.__name__}<{self.reason}>" - - -class InternalError(Exception): - """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" - - pass - - -@dataclass -class CmEventProxy: - """A proxy object returned by :meth:`CdpBase.wait_for()``. - - After the context manager executes, this proxy object will have a - value set that contains the returned event. - """ - - value: Any = None - - -class CdpBase: - def __init__(self, ws, session_id, target_id): - self.ws = ws - self.session_id = session_id - self.target_id = target_id - self.channels = defaultdict(set) - self.id_iter = itertools.count() - self.inflight_cmd = {} - self.inflight_result = {} - - async def execute(self, cmd: Generator[dict, T, Any]) -> T: - """Execute a command on the server and wait for the result. - - Args: - cmd: any CDP command - - Returns: - a CDP result - """ - cmd_id = next(self.id_iter) - cmd_event = trio.Event() - self.inflight_cmd[cmd_id] = cmd, cmd_event - request = next(cmd) - request["id"] = cmd_id - if self.session_id: - request["sessionId"] = self.session_id - request_str = json.dumps(request) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") - try: - await self.ws.send_message(request_str) - except WsConnectionClosed as wcc: - raise CdpConnectionClosed(wcc.reason) from None - await cmd_event.wait() - response = self.inflight_result.pop(cmd_id) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Received CDP message: {response}") - if isinstance(response, Exception): - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") - raise response - return response - - def listen(self, *event_types, buffer_size=10): - """Listen for events. - - Returns: - An async iterator that iterates over events matching the indicated types. - """ - sender, receiver = trio.open_memory_channel(buffer_size) - for event_type in event_types: - self.channels[event_type].add(sender) - return receiver - - @asynccontextmanager - async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: - """Wait for an event of the given type and return it. - - This is an async context manager, so you should open it inside - an async with block. The block will not exit until the indicated - event is received. - """ - sender: trio.MemorySendChannel - receiver: trio.MemoryReceiveChannel - sender, receiver = trio.open_memory_channel(buffer_size) - self.channels[event_type].add(sender) - proxy = CmEventProxy() - yield proxy - async with receiver: - event = await receiver.receive() - proxy.value = event - - def _handle_data(self, data): - """Handle incoming WebSocket data. - - Args: - data: a JSON dictionary - """ - if "id" in data: - self._handle_cmd_response(data) - else: - self._handle_event(data) - - def _handle_cmd_response(self, data: dict): - """Handle a response to a command. - - This will set an event flag that will return control to the - task that called the command. - - Args: - data: response as a JSON dictionary - """ - cmd_id = data["id"] - try: - cmd, event = self.inflight_cmd.pop(cmd_id) - except KeyError: - logger.warning("Got a message with a command ID that does not exist: %s", data) - return - if "error" in data: - # If the server reported an error, convert it to an exception and do - # not process the response any further. - self.inflight_result[cmd_id] = BrowserError(data["error"]) - else: - # Otherwise, continue the generator to parse the JSON result - # into a CDP object. - try: - _ = cmd.send(data["result"]) - raise InternalError("The command's generator function did not exit when expected!") - except StopIteration as exit: - return_ = exit.value - self.inflight_result[cmd_id] = return_ - event.set() - - def _handle_event(self, data: dict): - """Handle an event. - - Args: - data: event as a JSON dictionary - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - event = devtools.util.parse_json_event(data) - logger.debug("Received event: %s", event) - to_remove = set() - for sender in self.channels[type(event)]: - try: - sender.send_nowait(event) - except trio.WouldBlock: - logger.error('Unable to send event "%r" due to full channel %s', event, sender) - except trio.BrokenResourceError: - to_remove.add(sender) - if to_remove: - self.channels[type(event)] -= to_remove - - -class CdpSession(CdpBase): - """Contains the state for a CDP session. - - Generally you should not instantiate this object yourself; you should call - :meth:`CdpConnection.open_session`. - """ - - def __init__(self, ws, session_id, target_id): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - session_id: devtools.target.SessionID - target_id: devtools.target.TargetID - """ - super().__init__(ws, session_id, target_id) - - self._dom_enable_count = 0 - self._dom_enable_lock = trio.Lock() - self._page_enable_count = 0 - self._page_enable_lock = trio.Lock() - - @asynccontextmanager - async def dom_enable(self): - """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. - - This keeps track of concurrent callers and only disables DOM - events when all callers have exited. - """ - global devtools - async with self._dom_enable_lock: - self._dom_enable_count += 1 - if self._dom_enable_count == 1: - await self.execute(devtools.dom.enable()) - - yield - - async with self._dom_enable_lock: - self._dom_enable_count -= 1 - if self._dom_enable_count == 0: - await self.execute(devtools.dom.disable()) - - @asynccontextmanager - async def page_enable(self): - """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. - - This keeps track of concurrent callers and only disables page - events when all callers have exited. - """ - global devtools - async with self._page_enable_lock: - self._page_enable_count += 1 - if self._page_enable_count == 1: - await self.execute(devtools.page.enable()) - - yield - - async with self._page_enable_lock: - self._page_enable_count -= 1 - if self._page_enable_count == 0: - await self.execute(devtools.page.disable()) - - -class CdpConnection(CdpBase, trio.abc.AsyncResource): - """Contains the connection state for a Chrome DevTools Protocol server. - - CDP can multiplex multiple "sessions" over a single connection. This - class corresponds to the "root" session, i.e. the implicitly created - session that has no session ID. This class is responsible for - reading incoming WebSocket messages and forwarding them to the - corresponding session, as well as handling messages targeted at the - root session itself. You should generally call the - :func:`open_cdp()` instead of instantiating this class directly. - """ - - def __init__(self, ws): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - """ - super().__init__(ws, session_id=None, target_id=None) - self.sessions = {} - - async def aclose(self): - """Close the underlying WebSocket connection. - - This will cause the reader task to gracefully exit when it tries - to read the next message from the WebSocket. All of the public - APIs (``execute()``, ``listen()``, etc.) will raise - ``CdpConnectionClosed`` after the CDP connection is closed. It - is safe to call this multiple times. - """ - await self.ws.aclose() - - @asynccontextmanager - async def open_session(self, target_id) -> AsyncIterator[CdpSession]: - """Context manager opens a session and enables the "simple" style of calling CDP APIs. - - For example, inside a session context, you can call ``await - dom.get_document()`` and it will execute on the current session - automatically. - """ - session = await self.connect_session(target_id) - with session_context(session): - yield session - - async def connect_session(self, target_id) -> "CdpSession": - """Returns a new :class:`CdpSession` connected to the specified target.""" - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) - session = CdpSession(self.ws, session_id, target_id) - self.sessions[session_id] = session - return session - - async def _reader_task(self): - """Runs in the background and handles incoming messages. - - Dispatches responses to commands and events to listeners. - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - while True: - try: - message = await self.ws.get_message() - except WsConnectionClosed: - # If the WebSocket is closed, we don't want to throw an - # exception from the reader task. Instead we will throw - # exceptions from the public API methods, and we can quietly - # exit the reader task here. - break - try: - data = json.loads(message) - except json.JSONDecodeError: - raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) - logger.debug("Received message %r", data) - if "sessionId" in data: - session_id = devtools.target.SessionID(data["sessionId"]) - try: - session = self.sessions[session_id] - except KeyError: - raise BrowserError( - { - "code": -32700, - "message": "Browser sent a message for an invalid session", - "data": f"{session_id!r}", - } - ) - session._handle_data(data) - else: - self._handle_data(data) - - for _, session in self.sessions.items(): - for _, senders in session.channels.items(): - for sender in senders: - sender.close() - - -@asynccontextmanager -async def open_cdp(url) -> AsyncIterator[CdpConnection]: - """Async context manager opens a connection to the browser then closes the connection when the block exits. - - The context manager also sets the connection as the default - connection for the current task, so that commands like ``await - target.get_targets()`` will run on this connection automatically. If - you want to use multiple connections concurrently, it is recommended - to open each on in a separate task. - """ - async with trio.open_nursery() as nursery: - conn = await connect_cdp(nursery, url) - try: - with connection_context(conn): - yield conn - finally: - await conn.aclose() - - -async def connect_cdp(nursery, url) -> CdpConnection: - """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. - - The ``open_cdp()`` context manager is preferred in most situations. - You should only use this function if you need to specify a custom - nursery. This connection is not automatically closed! You can either - use the connection object as a context manager (``async with - conn:``) or else call ``await conn.aclose()`` on it when you are - done with it. If ``set_context`` is True, then the returned - connection will be installed as the default connection for the - current task. This argument is for unusual use cases, such as - running inside of a notebook. - """ - ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) - cdp_conn = CdpConnection(ws) - nursery.start_soon(cdp_conn._reader_task) - return cdp_conn diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100755 new mode 100644 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 4cd6ae2e3c712..cb575bbdc54dd 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -41,16 +40,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -78,8 +77,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -87,8 +86,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -111,8 +110,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -128,8 +127,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -137,8 +136,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -146,8 +145,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -155,8 +154,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -164,8 +163,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -173,16 +172,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) class Emulation: @@ -191,7 +190,12 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -203,18 +207,12 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -226,19 +224,12 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -250,31 +241,12 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -286,19 +258,12 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -310,19 +275,7 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, @@ -337,8 +290,8 @@ def set_geolocation_override( self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -390,8 +343,8 @@ def set_geolocation_override( def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -414,8 +367,8 @@ def set_timezone_override( def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -438,8 +391,8 @@ def set_scripting_enabled( def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -461,8 +414,8 @@ def set_user_agent_override( def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -498,8 +451,8 @@ def set_network_conditions( self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..13f43361293f2 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -45,7 +44,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -54,7 +53,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -63,7 +62,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -73,7 +72,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -89,7 +88,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -163,7 +162,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any | None] | None = field(default_factory=list) @dataclass @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -389,7 +388,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -454,5 +453,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": ( + EventConfig("file_dialog_opened", "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict)) + if _globals.get("FileDialogOpened") + else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index faf6c85ae2b6c..7971b807e94a1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,12 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator +import threading +from collections.abc import Callable from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session class Level: @@ -56,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -81,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -92,18 +93,212 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "entry_added": "log.entryAdded", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + class Log: """WebDriver BiDi log module.""" + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + pass + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: log.entryAdded +EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined - def entry_added(self): - """Execute log.entryAdded.""" - params = { - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("log.entryAdded", params) - result = self._conn.execute(cmd) - return result +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Log.EVENT_CONFIGS = { + "entry_added": ( + EventConfig("entry_added", "log.entryAdded", + _globals.get("EntryAdded", dict)) + if _globals.get("EntryAdded") + else EventConfig("entry_added", "log.entryAdded", dict) + ), +} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 4f44e309bffbb..6e02eeabc4ed7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -75,7 +74,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,13 +170,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any | None] | None = field(default_factory=list) @dataclass @@ -219,11 +218,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -237,9 +236,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + url_patterns: list[Any | None] | None = field(default_factory=list) @dataclass @@ -254,9 +253,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +314,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +339,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -562,7 +561,14 @@ def __init__(self, conn) -> None: self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: list[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -576,7 +582,12 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, + ): """Execute network.addIntercept.""" params = { "phases": phases, @@ -588,7 +599,15 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" params = { "request": request, @@ -603,7 +622,15 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: list[Any] | None = None, + credentials: Any | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" params = { "request": request, @@ -650,7 +677,13 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" params = { "dataType": data_type, @@ -663,7 +696,15 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" params = { "request": request, @@ -698,7 +739,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -709,7 +750,12 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -918,6 +964,11 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": ( + EventConfig("auth_required", "network.authRequired", + _globals.get("AuthRequired", dict)) + if _globals.get("AuthRequired") + else EventConfig("auth_required", "network.authRequired", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e13c11f71a5cb..b29721db88503 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -216,7 +215,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any | None] | None = field(default_factory=list) @dataclass @@ -460,7 +459,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any | None] | None = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +498,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any | None] | None = field(default_factory=list) @dataclass @@ -530,9 +529,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) sandbox: str | None = None @@ -547,7 +546,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -558,7 +557,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +588,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any | None] | None = field(default_factory=list) @dataclass @@ -783,7 +782,14 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -797,7 +803,7 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -808,7 +814,17 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: list[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -825,7 +841,15 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" params = { "expression": expression, @@ -889,8 +913,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -941,7 +966,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw - def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -993,7 +1025,15 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) - def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -1028,7 +1068,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) - def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1106,8 +1156,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" @@ -1257,6 +1308,16 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": ( + EventConfig("realm_created", "script.realmCreated", + _globals.get("RealmCreated", dict)) + if _globals.get("RealmCreated") + else EventConfig("realm_created", "script.realmCreated", dict) + ), + "realm_destroyed": ( + EventConfig("realm_destroyed", "script.realmDestroyed", + _globals.get("RealmDestroyed", dict)) + if _globals.get("RealmDestroyed") + else EventConfig("realm_destroyed", "script.realmDestroyed", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..c1b5be09ca024 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -26,7 +25,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any | None] | None = field(default_factory=list) @dataclass @@ -62,7 +61,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any | None] | None = field(default_factory=list) @dataclass @@ -92,23 +91,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str | None] | None = field(default_factory=list) @dataclass @@ -211,7 +210,12 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute session.subscribe.""" params = { "events": events, @@ -223,7 +227,7 @@ def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..3f29b85d13a23 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -33,7 +32,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -235,39 +234,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 8a737efeeafde..ebbe6729499b2 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -64,7 +63,12 @@ class WebExtension: def __init__(self, conn) -> None: self._conn = conn - def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -82,7 +86,11 @@ def install(self, path: str | None = None, archive_path: str | None = None, base Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" From 06b33cfa5813c2a91fff63b2f5d240356f9ecc84 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:54:27 +0000 Subject: [PATCH 03/42] fixup --- py/generate_bidi.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 2db595ff37cd0..4bf0d8b64514e 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -721,7 +721,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" @@ -972,7 +974,9 @@ def clear_event_handlers(self) -> None: m = re.search(r"def\s+(\w+)\s*\(", extra_meth) if m: extra_method_names.add(m.group(1)) - exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names + exclude_methods = ( + set(enhancements.get("exclude_methods", [])) | extra_method_names + ) if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1035,7 +1039,9 @@ def clear_event_handlers(self) -> None: code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys_cfg = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") @@ -1051,7 +1057,7 @@ def clear_event_handlers(self) -> None: f' _globals.get("{event_def.name}", dict))\n' f' if _globals.get("{event_def.name}")\n' f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f' ),\n' + f" ),\n" ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): @@ -1064,7 +1070,7 @@ def clear_event_handlers(self) -> None: f' "{ek}": EventConfig(\n' f' "{ek}", "{be}",\n' f' _globals.get("{ec}", dict),\n' - f' ),\n' + f" ),\n" ) else: code += single + "\n" From 12d7ad7baa8ebf57442c774cfed82594b731b78b Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:55:57 +0000 Subject: [PATCH 04/42] fixup --- py/selenium/webdriver/common/bidi/cdp.py | 515 +++++++++++++++++++++++ 1 file changed, 515 insertions(+) create mode 100644 py/selenium/webdriver/common/bidi/cdp.py diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn From 2eb19316f1e39dfddff6908f90436a432a019c4b Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 11:26:58 +0000 Subject: [PATCH 05/42] handle comments --- py/generate_bidi.py | 64 +++--- py/private/bidi_enhancements_manifest.py | 36 +++- py/selenium/webdriver/common/bidi/browser.py | 38 ++-- .../webdriver/common/bidi/browsing_context.py | 68 +++--- py/selenium/webdriver/common/bidi/common.py | 6 +- .../webdriver/common/bidi/emulation.py | 36 ++-- py/selenium/webdriver/common/bidi/input.py | 30 +-- py/selenium/webdriver/common/bidi/log.py | 4 +- py/selenium/webdriver/common/bidi/network.py | 194 ++++++++++-------- py/selenium/webdriver/common/bidi/script.py | 154 +++++++------- py/selenium/webdriver/common/bidi/session.py | 30 +-- py/selenium/webdriver/common/bidi/storage.py | 14 +- .../webdriver/common/bidi/webextension.py | 12 +- 13 files changed, 380 insertions(+), 306 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 4bf0d8b64514e..5d7f39e53abfc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -368,11 +368,14 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name (keep it as-is, don't split on underscores) - class_name = self.name + # Generate class name from type name. + # CDDL type names that start with a lowercase letter (e.g. camelCase + # command-parameter types like "setNetworkConditionsParameters") are + # capitalised so that the resulting Python class follows PascalCase. + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = "@dataclass\n" code += f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name} type definition."""\n\n' if not self.fields: code += " pass\n" @@ -466,9 +469,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -684,8 +687,19 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Collect names of extra_dataclasses so we can skip CDDL-generated + # enums and types that are overridden by manual definitions. + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + # Generate enums first, skipping any that are overridden via extra_dataclasses for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -694,13 +708,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - # Also auto-exclude types whose names appear in extra_dataclasses - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -1146,8 +1153,12 @@ def _remove_comments(self, content: str) -> str: def _extract_definitions(self, content: str) -> None: """Extract CDDL definitions (type definitions, commands, etc.).""" # Match pattern: Name = Definition - # Handles multiline definitions properly - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + # Handles multiline definitions properly. + # The \s* after \n in the lookahead allows definitions that start with + # leading whitespace (e.g. " network.BeforeRequestSent = (") to be + # recognised as separate definitions instead of being swallowed into + # the body of the preceding definition. + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" for match in re.finditer(pattern, content, re.DOTALL): name = match.group(1).strip() @@ -1589,12 +1600,15 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from collections.abc import Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any] | None = None\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1607,6 +1621,8 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' + " if params is None:\n" + " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1680,8 +1696,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from __future__ import annotations\n" + "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1724,10 +1742,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 39af67d4c635b..adf0a17128af3 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -81,6 +81,20 @@ "result_param": "download_behavior", }, }, + # Replace the auto-generated ClientWindowNamedState so we can add the + # convenience NORMAL constant. In the BiDi spec "normal" is the state + # represented by ClientWindowRectState, but exposing it here keeps the + # Python API consistent with the old ClientWindowState enum. + "exclude_types": ["ClientWindowNamedState"], + "extra_dataclasses": [ + '''class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal"''', + ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). @@ -845,8 +859,11 @@ def from_json(self2, p): ], }, "network": { - # Initialize intercepts tracking list in __init__ - "extra_init_code": ["self.intercepts = []"], + # Initialize intercepts tracking list and per-handler intercept map + "extra_init_code": [ + "self.intercepts = []", + "self._handler_intercepts: dict = {}", + ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ '''class BytesValue: @@ -940,7 +957,8 @@ def continue_request(self, **kwargs): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -951,15 +969,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback)''', + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ - self.remove_event_handler(event, callback_id)''', + self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ''' def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index acda63f71953e..71f917634304d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -60,17 +60,9 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: - """ClientWindowInfo.""" + """ClientWindowInfo type definition.""" active: bool | None = None client_window: Any | None = None @@ -112,14 +104,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo.""" + """UserContextInfo type definition.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters.""" + """CreateUserContextParameters type definition.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -128,35 +120,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult.""" + """GetClientWindowsResult type definition.""" client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: - """GetUserContextsResult.""" + """GetUserContextsResult type definition.""" user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters.""" + """RemoveUserContextParameters type definition.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters.""" + """SetClientWindowStateParameters type definition.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState.""" + """ClientWindowRectState type definition.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -167,7 +159,7 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters.""" + """SetDownloadBehaviorParameters type definition.""" download_behavior: Any | None = None user_contexts: list[Any | None] | None = field(default_factory=list) @@ -175,7 +167,7 @@ class SetDownloadBehaviorParameters: @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed.""" + """DownloadBehaviorAllowed type definition.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -183,11 +175,19 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied.""" + """DownloadBehaviorDenied type definition.""" type: str = field(default="denied", init=False) +class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal" + class Browser: """WebDriver BiDi browser module.""" diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5f128635df29d..ede96071778c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -48,7 +48,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info.""" + """Info type definition.""" children: Any | None = None client_window: Any | None = None @@ -61,7 +61,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator.""" + """AccessibilityLocator type definition.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -70,7 +70,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator.""" + """CssLocator type definition.""" type: str = field(default="css", init=False) value: str | None = None @@ -78,7 +78,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator.""" + """ContextLocator type definition.""" type: str = field(default="context", init=False) context: Any | None = None @@ -86,7 +86,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator.""" + """InnerTextLocator type definition.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -97,7 +97,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator.""" + """XPathLocator type definition.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -105,7 +105,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo.""" + """BaseNavigationInfo type definition.""" context: Any | None = None navigation: Any | None = None @@ -115,14 +115,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters.""" + """ActivateParameters type definition.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters.""" + """CaptureScreenshotParameters type definition.""" context: Any | None = None format: Any | None = None @@ -131,7 +131,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat.""" + """ImageFormat type definition.""" type: str | None = None quality: Any | None = None @@ -139,7 +139,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle.""" + """ElementClipRectangle type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -147,7 +147,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle.""" + """BoxClipRectangle type definition.""" type: str = field(default="box", init=False) x: Any | None = None @@ -158,14 +158,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult.""" + """CaptureScreenshotResult type definition.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters.""" + """CloseParameters type definition.""" context: Any | None = None prompt_unload: bool | None = None @@ -173,7 +173,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters.""" + """CreateParameters type definition.""" type: Any | None = None reference_context: Any | None = None @@ -183,14 +183,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult.""" + """CreateResult type definition.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters.""" + """GetTreeParameters type definition.""" max_depth: Any | None = None root: Any | None = None @@ -198,14 +198,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult.""" + """GetTreeResult type definition.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters.""" + """HandleUserPromptParameters type definition.""" context: Any | None = None accept: bool | None = None @@ -214,7 +214,7 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters.""" + """LocateNodesParameters type definition.""" context: Any | None = None locator: Any | None = None @@ -224,14 +224,14 @@ class LocateNodesParameters: @dataclass class LocateNodesResult: - """LocateNodesResult.""" + """LocateNodesResult type definition.""" nodes: list[Any | None] | None = field(default_factory=list) @dataclass class NavigateParameters: - """NavigateParameters.""" + """NavigateParameters type definition.""" context: Any | None = None url: str | None = None @@ -240,7 +240,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult.""" + """NavigateResult type definition.""" navigation: Any | None = None url: str | None = None @@ -248,7 +248,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters.""" + """PrintParameters type definition.""" context: Any | None = None background: bool | None = None @@ -260,7 +260,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters.""" + """PrintMarginParameters type definition.""" bottom: Any | None = None left: Any | None = None @@ -270,7 +270,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters.""" + """PrintPageParameters type definition.""" height: Any | None = None width: Any | None = None @@ -278,14 +278,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult.""" + """PrintResult type definition.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters.""" + """ReloadParameters type definition.""" context: Any | None = None ignore_cache: bool | None = None @@ -294,7 +294,7 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters.""" + """SetViewportParameters type definition.""" context: Any | None = None viewport: Any | None = None @@ -304,7 +304,7 @@ class SetViewportParameters: @dataclass class Viewport: - """Viewport.""" + """Viewport type definition.""" width: Any | None = None height: Any | None = None @@ -312,7 +312,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters.""" + """TraverseHistoryParameters type definition.""" context: Any | None = None delta: Any | None = None @@ -320,7 +320,7 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters.""" + """HistoryUpdatedParameters type definition.""" context: Any | None = None timestamp: Any | None = None @@ -329,7 +329,7 @@ class HistoryUpdatedParameters: @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters.""" + """UserPromptClosedParameters type definition.""" context: Any | None = None accepted: bool | None = None @@ -339,7 +339,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters.""" + """UserPromptOpenedParameters type definition.""" context: Any | None = None handler: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..dae051876833e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,14 @@ """Common utilities for BiDi command construction.""" +from __future__ import annotations + from collections.abc import Generator from typing import Any def command_builder( - method: str, params: dict[str, Any] + method: str, params: dict[str, Any] | None = None ) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. @@ -36,5 +38,7 @@ def command_builder( Returns: The result from the BiDi command execution """ + if params is None: + params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index cb575bbdc54dd..fbbe0966d8b3a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -37,7 +37,7 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters.""" + """SetForcedColorsModeThemeOverrideParameters type definition.""" theme: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -46,7 +46,7 @@ class SetForcedColorsModeThemeOverrideParameters: @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters.""" + """SetGeolocationOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) @@ -54,7 +54,7 @@ class SetGeolocationOverrideParameters: @dataclass class GeolocationCoordinates: - """GeolocationCoordinates.""" + """GeolocationCoordinates type definition.""" latitude: Any | None = None longitude: Any | None = None @@ -67,14 +67,14 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError.""" + """GeolocationPositionError type definition.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters.""" + """SetLocaleOverrideParameters type definition.""" locale: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -82,8 +82,8 @@ class SetLocaleOverrideParameters: @dataclass -class setNetworkConditionsParameters: - """setNetworkConditionsParameters.""" +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters type definition.""" network_conditions: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -92,14 +92,14 @@ class setNetworkConditionsParameters: @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline.""" + """NetworkConditionsOffline type definition.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea.""" + """ScreenArea type definition.""" width: Any | None = None height: Any | None = None @@ -107,7 +107,7 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters.""" + """SetScreenSettingsOverrideParameters type definition.""" screen_area: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -116,7 +116,7 @@ class SetScreenSettingsOverrideParameters: @dataclass class ScreenOrientation: - """ScreenOrientation.""" + """ScreenOrientation type definition.""" natural: Any | None = None type: Any | None = None @@ -124,7 +124,7 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters.""" + """SetScreenOrientationOverrideParameters type definition.""" screen_orientation: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -133,7 +133,7 @@ class SetScreenOrientationOverrideParameters: @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters.""" + """SetUserAgentOverrideParameters type definition.""" user_agent: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -142,7 +142,7 @@ class SetUserAgentOverrideParameters: @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters.""" + """SetViewportMetaOverrideParameters type definition.""" viewport_meta: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -151,7 +151,7 @@ class SetViewportMetaOverrideParameters: @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters.""" + """SetScriptingEnabledParameters type definition.""" enabled: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -160,7 +160,7 @@ class SetScriptingEnabledParameters: @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters.""" + """SetScrollbarTypeOverrideParameters type definition.""" scrollbar_type: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -169,7 +169,7 @@ class SetScrollbarTypeOverrideParameters: @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters.""" + """SetTimezoneOverrideParameters type definition.""" timezone: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -178,7 +178,7 @@ class SetTimezoneOverrideParameters: @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters.""" + """SetTouchOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 13f43361293f2..c8e58181b343e 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -33,7 +33,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin.""" + """ElementOrigin type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -41,7 +41,7 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters.""" + """PerformActionsParameters type definition.""" context: Any | None = None actions: list[Any | None] | None = field(default_factory=list) @@ -49,7 +49,7 @@ class PerformActionsParameters: @dataclass class NoneSourceActions: - """NoneSourceActions.""" + """NoneSourceActions type definition.""" type: str = field(default="none", init=False) id: str | None = None @@ -58,7 +58,7 @@ class NoneSourceActions: @dataclass class KeySourceActions: - """KeySourceActions.""" + """KeySourceActions type definition.""" type: str = field(default="key", init=False) id: str | None = None @@ -67,7 +67,7 @@ class KeySourceActions: @dataclass class PointerSourceActions: - """PointerSourceActions.""" + """PointerSourceActions type definition.""" type: str = field(default="pointer", init=False) id: str | None = None @@ -77,14 +77,14 @@ class PointerSourceActions: @dataclass class PointerParameters: - """PointerParameters.""" + """PointerParameters type definition.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions.""" + """WheelSourceActions type definition.""" type: str = field(default="wheel", init=False) id: str | None = None @@ -93,7 +93,7 @@ class WheelSourceActions: @dataclass class PauseAction: - """PauseAction.""" + """PauseAction type definition.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -101,7 +101,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction.""" + """KeyDownAction type definition.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -109,7 +109,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction.""" + """KeyUpAction type definition.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -117,7 +117,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction.""" + """PointerUpAction type definition.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -125,7 +125,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction.""" + """WheelScrollAction type definition.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -138,7 +138,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties.""" + """PointerCommonProperties type definition.""" width: Any | None = None height: Any | None = None @@ -151,14 +151,14 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters.""" + """ReleaseActionsParameters type definition.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters.""" + """SetFilesParameters type definition.""" context: Any | None = None element: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7971b807e94a1..eaf52a2ec08c2 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -27,7 +27,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry.""" + """BaseLogEntry type definition.""" level: Any | None = None source: Any | None = None @@ -38,7 +38,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry.""" + """GenericLogEntry type definition.""" type: str | None = None diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6e02eeabc4ed7..c9737ac9131d0 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -49,7 +49,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge.""" + """AuthChallenge type definition.""" scheme: str | None = None realm: str | None = None @@ -57,7 +57,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials.""" + """AuthCredentials type definition.""" type: str = field(default="password", init=False) username: str | None = None @@ -66,7 +66,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters.""" + """BaseParameters type definition.""" context: Any | None = None is_blocked: bool | None = None @@ -79,7 +79,7 @@ class BaseParameters: @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -87,7 +87,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value.""" + """Base64Value type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -95,7 +95,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie.""" + """Cookie type definition.""" name: str | None = None value: Any | None = None @@ -110,7 +110,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader.""" + """CookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -118,7 +118,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo.""" + """FetchTimingInfo type definition.""" time_origin: Any | None = None request_time: Any | None = None @@ -137,7 +137,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header.""" + """Header type definition.""" name: str | None = None value: Any | None = None @@ -145,7 +145,7 @@ class Header: @dataclass class Initiator: - """Initiator.""" + """Initiator type definition.""" column_number: Any | None = None line_number: Any | None = None @@ -156,14 +156,14 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent.""" + """ResponseContent type definition.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData.""" + """ResponseData type definition.""" url: str | None = None protocol: str | None = None @@ -181,7 +181,7 @@ class ResponseData: @dataclass class SetCookieHeader: - """SetCookieHeader.""" + """SetCookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -196,7 +196,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern.""" + """UrlPatternPattern type definition.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -208,7 +208,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString.""" + """UrlPatternString type definition.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -216,7 +216,7 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters.""" + """AddDataCollectorParameters type definition.""" data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None @@ -227,14 +227,14 @@ class AddDataCollectorParameters: @dataclass class AddDataCollectorResult: - """AddDataCollectorResult.""" + """AddDataCollectorResult type definition.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters.""" + """AddInterceptParameters type definition.""" phases: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -243,14 +243,14 @@ class AddInterceptParameters: @dataclass class AddInterceptResult: - """AddInterceptResult.""" + """AddInterceptResult type definition.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters.""" + """ContinueResponseParameters type definition.""" request: Any | None = None cookies: list[Any | None] | None = field(default_factory=list) @@ -262,22 +262,22 @@ class ContinueResponseParameters: @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters.""" + """ContinueWithAuthParameters type definition.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials.""" + """ContinueWithAuthCredentials type definition.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class disownDataParameters: - """disownDataParameters.""" +class DisownDataParameters: + """DisownDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -286,14 +286,14 @@ class disownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters.""" + """FailRequestParameters type definition.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters.""" + """GetDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +303,14 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult.""" + """GetDataResult type definition.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters.""" + """ProvideResponseParameters type definition.""" request: Any | None = None body: Any | None = None @@ -322,21 +322,21 @@ class ProvideResponseParameters: @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters.""" + """RemoveDataCollectorParameters type definition.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters.""" + """RemoveInterceptParameters type definition.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters.""" + """SetCacheBehaviorParameters type definition.""" cache_behavior: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -344,16 +344,44 @@ class SetCacheBehaviorParameters: @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters.""" + """SetExtraHeadersParameters type definition.""" headers: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) +@dataclass +class AuthRequiredParameters: + """AuthRequiredParameters type definition.""" + + response: Any | None = None + + +@dataclass +class BeforeRequestSentParameters: + """BeforeRequestSentParameters type definition.""" + + initiator: Any | None = None + + +@dataclass +class FetchErrorParameters: + """FetchErrorParameters type definition.""" + + error_text: str | None = None + + +@dataclass +class ResponseCompletedParameters: + """ResponseCompletedParameters type definition.""" + + response: Any | None = None + + @dataclass class ResponseStartedParameters: - """ResponseStartedParameters.""" + """ResponseStartedParameters type definition.""" response: Any | None = None @@ -396,6 +424,10 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", + "before_request_sent": "network.beforeRequestSent", + "fetch_error": "network.fetchError", + "response_completed": "network.responseCompleted", + "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -560,6 +592,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] + self._handler_intercepts: dict = {} def add_data_collector( self, @@ -767,52 +800,6 @@ def set_extra_headers( result = self._conn.execute(cmd) return result - def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.beforeRequestSent.""" - params = { - "initiator": initiator, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.beforeRequestSent", params) - result = self._conn.execute(cmd) - return result - - def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.fetchError.""" - params = { - "errorText": error_text, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.fetchError", params) - result = self._conn.execute(cmd) - return result - - def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.responseCompleted.""" - params = { - "response": response, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseCompleted", params) - result = self._conn.execute(cmd) - return result - - def response_started(self, response: Any | None = None): - """Execute network.responseStarted.""" - params = { - "response": response, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseStarted", params) - result = self._conn.execute(cmd) - return result - def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -861,7 +848,8 @@ def add_request_handler(self, event, callback, url_patterns=None): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -872,15 +860,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback) + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() @@ -960,6 +954,18 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined +# Event: network.beforeRequestSent +BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined + +# Event: network.fetchError +FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseCompleted +ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseStarted +ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined + # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() @@ -970,5 +976,29 @@ def clear_event_handlers(self) -> None: if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict) ), + "before_request_sent": ( + EventConfig("before_request_sent", "network.beforeRequestSent", + _globals.get("BeforeRequestSent", dict)) + if _globals.get("BeforeRequestSent") + else EventConfig("before_request_sent", "network.beforeRequestSent", dict) + ), + "fetch_error": ( + EventConfig("fetch_error", "network.fetchError", + _globals.get("FetchError", dict)) + if _globals.get("FetchError") + else EventConfig("fetch_error", "network.fetchError", dict) + ), + "response_completed": ( + EventConfig("response_completed", "network.responseCompleted", + _globals.get("ResponseCompleted", dict)) + if _globals.get("ResponseCompleted") + else EventConfig("response_completed", "network.responseCompleted", dict) + ), + "response_started": ( + EventConfig("response_started", "network.responseStarted", + _globals.get("ResponseStarted", dict)) + if _globals.get("ResponseStarted") + else EventConfig("response_started", "network.responseStarted", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index b29721db88503..061bb17b0deec 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -47,7 +47,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue.""" + """ChannelValue type definition.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -55,7 +55,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties.""" + """ChannelProperties type definition.""" channel: Any | None = None serialization_options: Any | None = None @@ -64,7 +64,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess.""" + """EvaluateResultSuccess type definition.""" type: str = field(default="success", init=False) result: Any | None = None @@ -73,7 +73,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException.""" + """EvaluateResultException type definition.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -82,7 +82,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails.""" + """ExceptionDetails type definition.""" column_number: Any | None = None exception: Any | None = None @@ -93,7 +93,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue.""" + """ArrayLocalValue type definition.""" type: str = field(default="array", init=False) value: Any | None = None @@ -101,7 +101,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue.""" + """DateLocalValue type definition.""" type: str = field(default="date", init=False) value: str | None = None @@ -109,7 +109,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue.""" + """MapLocalValue type definition.""" type: str = field(default="map", init=False) value: Any | None = None @@ -117,7 +117,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue.""" + """ObjectLocalValue type definition.""" type: str = field(default="object", init=False) value: Any | None = None @@ -125,7 +125,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue.""" + """RegExpValue type definition.""" pattern: str | None = None flags: str | None = None @@ -133,7 +133,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue.""" + """RegExpLocalValue type definition.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -141,7 +141,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue.""" + """SetLocalValue type definition.""" type: str = field(default="set", init=False) value: Any | None = None @@ -149,21 +149,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue.""" + """UndefinedValue type definition.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue.""" + """NullValue type definition.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -171,7 +171,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue.""" + """NumberValue type definition.""" type: str = field(default="number", init=False) value: Any | None = None @@ -179,7 +179,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue.""" + """BooleanValue type definition.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -187,7 +187,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue.""" + """BigIntValue type definition.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -195,7 +195,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo.""" + """BaseRealmInfo type definition.""" realm: Any | None = None origin: str | None = None @@ -203,7 +203,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo.""" + """WindowRealmInfo type definition.""" type: str = field(default="window", init=False) context: Any | None = None @@ -212,7 +212,7 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo.""" + """DedicatedWorkerRealmInfo type definition.""" type: str = field(default="dedicated-worker", init=False) owners: list[Any | None] | None = field(default_factory=list) @@ -220,49 +220,49 @@ class DedicatedWorkerRealmInfo: @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo.""" + """SharedWorkerRealmInfo type definition.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo.""" + """ServiceWorkerRealmInfo type definition.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo.""" + """WorkerRealmInfo type definition.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo.""" + """PaintWorkletRealmInfo type definition.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo.""" + """AudioWorkletRealmInfo type definition.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo.""" + """WorkletRealmInfo type definition.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference.""" + """SharedReference type definition.""" shared_id: Any | None = None handle: Any | None = None @@ -270,7 +270,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference.""" + """RemoteObjectReference type definition.""" handle: Any | None = None shared_id: Any | None = None @@ -278,7 +278,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue.""" + """SymbolRemoteValue type definition.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -287,7 +287,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue.""" + """ArrayRemoteValue type definition.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -297,7 +297,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue.""" + """ObjectRemoteValue type definition.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -307,7 +307,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue.""" + """FunctionRemoteValue type definition.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -316,7 +316,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue.""" + """RegExpRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -324,7 +324,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue.""" + """DateRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -332,7 +332,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue.""" + """MapRemoteValue type definition.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -342,7 +342,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue.""" + """SetRemoteValue type definition.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -352,7 +352,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue.""" + """WeakMapRemoteValue type definition.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -361,7 +361,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue.""" + """WeakSetRemoteValue type definition.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -370,7 +370,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue.""" + """GeneratorRemoteValue type definition.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -379,7 +379,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue.""" + """ErrorRemoteValue type definition.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -388,7 +388,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue.""" + """ProxyRemoteValue type definition.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -397,7 +397,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue.""" + """PromiseRemoteValue type definition.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -406,7 +406,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue.""" + """TypedArrayRemoteValue type definition.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -415,7 +415,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue.""" + """ArrayBufferRemoteValue type definition.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -424,7 +424,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue.""" + """NodeListRemoteValue type definition.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -434,7 +434,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue.""" + """HTMLCollectionRemoteValue type definition.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -444,7 +444,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue.""" + """NodeRemoteValue type definition.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -455,7 +455,7 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties.""" + """NodeProperties type definition.""" node_type: Any | None = None child_node_count: Any | None = None @@ -469,7 +469,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue.""" + """WindowProxyRemoteValue type definition.""" type: str = field(default="window", init=False) value: Any | None = None @@ -479,14 +479,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties.""" + """WindowProxyProperties type definition.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame.""" + """StackFrame type definition.""" column_number: Any | None = None function_name: str | None = None @@ -496,14 +496,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace.""" + """StackTrace type definition.""" call_frames: list[Any | None] | None = field(default_factory=list) @dataclass class Source: - """Source.""" + """Source type definition.""" realm: Any | None = None context: Any | None = None @@ -511,14 +511,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget.""" + """RealmTarget type definition.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget.""" + """ContextTarget type definition.""" context: Any | None = None sandbox: str | None = None @@ -526,7 +526,7 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters.""" + """AddPreloadScriptParameters type definition.""" function_declaration: str | None = None arguments: list[Any | None] | None = field(default_factory=list) @@ -537,14 +537,14 @@ class AddPreloadScriptParameters: @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult.""" + """AddPreloadScriptResult type definition.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters.""" + """DisownParameters type definition.""" handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -552,7 +552,7 @@ class DisownParameters: @dataclass class CallFunctionParameters: - """CallFunctionParameters.""" + """CallFunctionParameters type definition.""" function_declaration: str | None = None await_promise: bool | None = None @@ -566,7 +566,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters.""" + """EvaluateParameters type definition.""" expression: str | None = None target: Any | None = None @@ -578,7 +578,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters.""" + """GetRealmsParameters type definition.""" context: Any | None = None type: Any | None = None @@ -586,21 +586,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult.""" + """GetRealmsResult type definition.""" realms: list[Any | None] | None = field(default_factory=list) @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters.""" + """RemovePreloadScriptParameters type definition.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters.""" + """MessageParameters type definition.""" channel: Any | None = None data: Any | None = None @@ -609,13 +609,14 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters.""" + """RealmDestroyedParameters type definition.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { + "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -885,18 +886,6 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result - def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): - """Execute script.message.""" - params = { - "channel": channel, - "data": data, - "source": source, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.message", params) - result = self._conn.execute(cmd) - return result - def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -1298,6 +1287,9 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases +# Event: script.message +Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined + # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1308,6 +1300,12 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { + "message": ( + EventConfig("message", "script.message", + _globals.get("Message", dict)) + if _globals.get("Message") + else EventConfig("message", "script.message", dict) + ), "realm_created": ( EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c1b5be09ca024..da12c1cd49792 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -22,7 +22,7 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest.""" + """CapabilitiesRequest type definition.""" always_match: Any | None = None first_match: list[Any | None] | None = field(default_factory=list) @@ -30,7 +30,7 @@ class CapabilitiesRequest: @dataclass class CapabilityRequest: - """CapabilityRequest.""" + """CapabilityRequest type definition.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -42,21 +42,21 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration.""" + """AutodetectProxyConfiguration type definition.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration.""" + """DirectProxyConfiguration type definition.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration.""" + """ManualProxyConfiguration type definition.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None @@ -66,7 +66,7 @@ class ManualProxyConfiguration: @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration.""" + """SocksProxyConfiguration type definition.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -74,7 +74,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration.""" + """PacProxyConfiguration type definition.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -82,14 +82,14 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration.""" + """SystemProxyConfiguration type definition.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters.""" + """SubscribeParameters type definition.""" events: list[str | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -98,21 +98,21 @@ class SubscribeParameters: @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest.""" + """UnsubscribeByIDRequest type definition.""" subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest.""" + """UnsubscribeByAttributesRequest type definition.""" events: list[str | None] | None = field(default_factory=list) @dataclass class StatusResult: - """StatusResult.""" + """StatusResult type definition.""" ready: bool | None = None message: str | None = None @@ -120,14 +120,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters.""" + """NewParameters type definition.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult.""" + """NewResult type definition.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -143,7 +143,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult.""" + """SubscribeResult type definition.""" subscription: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 3f29b85d13a23..c5a4666ebaf07 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -14,7 +14,7 @@ @dataclass class PartitionKey: - """PartitionKey.""" + """PartitionKey type definition.""" user_context: str | None = None source_origin: str | None = None @@ -22,7 +22,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters.""" + """GetCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -30,7 +30,7 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult.""" + """GetCookiesResult type definition.""" cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -38,7 +38,7 @@ class GetCookiesResult: @dataclass class SetCookieParameters: - """SetCookieParameters.""" + """SetCookieParameters type definition.""" cookie: Any | None = None partition: Any | None = None @@ -46,14 +46,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult.""" + """SetCookieResult type definition.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters.""" + """DeleteCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -61,7 +61,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult.""" + """DeleteCookiesResult type definition.""" partition_key: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index ebbe6729499b2..0a3998a611125 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -14,14 +14,14 @@ @dataclass class InstallParameters: - """InstallParameters.""" + """InstallParameters type definition.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath.""" + """ExtensionPath type definition.""" type: str = field(default="path", init=False) path: str | None = None @@ -29,7 +29,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath.""" + """ExtensionArchivePath type definition.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -37,7 +37,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded.""" + """ExtensionBase64Encoded type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -45,14 +45,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult.""" + """InstallResult type definition.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters.""" + """UninstallParameters type definition.""" extension: Any | None = None From 5307196a30bd6a36aa0c01129ecbb487d85308fe Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 13:55:49 +0000 Subject: [PATCH 06/42] [py] Fix Copilot review: license headers, _BiDiEncoder nested types, revert unrelated requirements changes --- py/generate_bidi.py | 19 ++++++++++++++++++- py/requirements_lock.txt | 5 ++++- py/selenium/webdriver/common/bidi/__init__.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 17 +++++++++++++++++ .../webdriver/common/bidi/browsing_context.py | 17 +++++++++++++++++ .../webdriver/common/bidi/emulation.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/input.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/log.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/network.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/script.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/session.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/storage.py | 17 +++++++++++++++++ .../webdriver/common/bidi/webextension.py | 17 +++++++++++++++++ 13 files changed, 209 insertions(+), 2 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 5d7f39e53abfc..412494517772a 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -32,7 +32,24 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index c58f4b1c76fe6..68f8d858bb6f4 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -461,6 +461,7 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt + # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -1037,7 +1038,9 @@ rich==14.3.3 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via -r py/requirements.txt + # via + # -r py/requirements.txt + # keyring sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 7be7bd4f73856..bb129d5f6a195 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 71f917634304d..ff0c2d59b8cf2 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index ede96071778c3..7a0f8faf8687e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index fbbe0966d8b3a..c58f6d5f78d6c 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index c8e58181b343e..e9c3f8345f05d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index eaf52a2ec08c2..94f511d7185f8 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index c9737ac9131d0..9dc5fb94d8488 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 061bb17b0deec..0b2ec04101933 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index da12c1cd49792..771a5327151bf 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index c5a4666ebaf07..7623381706040 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a3998a611125..99250afca4c68 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make From 22445dccbb14117bd6cb12a2b69bdbe237ad7bd8 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 11:48:05 +0000 Subject: [PATCH 07/42] fixup --- py/generate_bidi.py | 1134 +---------------- .../webdriver/remote/websocket_connection.py | 12 +- 2 files changed, 12 insertions(+), 1134 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 412494517772a..8103cafe40684 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -619,11 +619,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add imports for event handling if needed if self.events: - code += "import threading\n" - code += "from collections.abc import Callable\n" - if not dataclass_imported: - code += "from dataclasses import dataclass\n" - code += "from selenium.webdriver.common.bidi.session import Session\n" + code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" code += "\n\n" @@ -801,1131 +797,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # Generate EventConfig and _EventManager for modules with events - if self.events: - # Generate EventConfig dataclass - code += """@dataclass -class EventConfig: - \"\"\"Configuration for a BiDi event.\"\"\" - event_key: str - bidi_event: str - event_class: type - - -""" - - # Generate _EventManager class - code += """class _EventWrapper: - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - \"\"\"Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - \"\"\" - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, \"from_json\") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend([\"_\", char.lower()]) - else: - result.append(char) - return \"\".join(result) - - -class _EventManager: - \"\"\"Manages event subscriptions and callbacks.\"\"\" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get(\"subscription\") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - \"callbacks\": [], - \"subscription_id\": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry[\"callbacks\"]: - session = Session(self.conn) - sub_id = entry.get(\"subscription_id\") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry[\"callbacks\"]: - entry[\"callbacks\"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get(\"subscription_id\") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - -""" - code += "\n\n" + # EventConfig, _EventWrapper, and _EventManager are imported from + # ._event_manager (see the import block above); nothing to emit here. # Generate class - # Convert module name (camelCase or snake_case) to proper class name (PascalCase) - class_name = module_name_to_class_name(self.name) - code += f"class {class_name}:\n" - code += f' """WebDriver BiDi {self.name} module."""\n\n' - - # Add EVENT_CONFIGS dict if there are events - if self.events: - code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined - ) - - if self.name == "script": - code += " def __init__(self, conn, driver=None) -> None:\n" - code += " self._conn = conn\n" - code += " self._driver = driver\n" - else: - code += " def __init__(self, conn) -> None:\n" - code += " self._conn = conn\n" - - # Initialize _event_manager if there are events - if self.events: - code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" - - # Append extra init code from enhancements (e.g. self.intercepts = []) - for init_line in enhancements.get("extra_init_code", []): - code += f" {init_line}\n" - - code += "\n" - - # Generate command methods - # Auto-exclude methods whose names appear in extra_methods to prevent duplicates - extra_method_names = set() - for extra_meth in enhancements.get("extra_methods", []): - m = re.search(r"def\s+(\w+)\s*\(", extra_meth) - if m: - extra_method_names.add(m.group(1)) - exclude_methods = ( - set(enhancements.get("exclude_methods", [])) | extra_method_names - ) - if self.commands: - for command in self.commands: - # Get method-specific enhancements - # Convert command name to snake_case to match enhancement manifest keys - method_name_snake = command._camel_to_snake(command.name) - if method_name_snake in exclude_methods: - continue - method_enhancements = enhancements.get(method_name_snake, {}) - code += command.to_python_method(method_enhancements) - code += "\n" - else: - code += " pass\n" - - # Emit extra methods from enhancement manifest - for extra_method in enhancements.get("extra_methods", []): - code += extra_method - code += "\n" - - # Add delegating event handler methods if events are present - if self.events: - code += """ - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - \"\"\"Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - \"\"\" - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - \"\"\"Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - \"\"\" - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - return self._event_manager.clear_event_handlers() -""" - - # Generate event info type aliases AFTER the class definition - # This ensures all types are available when we create the aliases - if self.events: - code += "\n# Event Info Type Aliases\n" - for event_def in self.events: - code += event_def.to_python_dataclass() - code += "\n" - - # Now populate EVENT_CONFIGS after the aliases are defined - code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" - # Use globals() to look up types dynamically to handle missing types gracefully - code += "_globals = globals()\n" - code += f"{class_name}.EVENT_CONFIGS = {{\n" - # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } - for event_def in self.events: - # Convert method name to user-friendly event name - method_parts = event_def.method.split(".") - if len(method_parts) == 2: - event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name in extra_event_keys_cfg: - continue - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += ( - f' "{event_name}": (\n' - f' EventConfig("{event_name}", "{event_def.method}",\n' - f' _globals.get("{event_def.name}", dict))\n' - f' if _globals.get("{event_def.name}")\n' - f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f" ),\n" - ) - # Extra events not in the CDDL spec - for extra_evt in enhancements.get("extra_events", []): - ek = extra_evt["event_key"] - be = extra_evt["bidi_event"] - ec = extra_evt["event_class"] - single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' - if len(single) > 120: - code += ( - f' "{ek}": EventConfig(\n' - f' "{ek}", "{be}",\n' - f' _globals.get("{ec}", dict),\n' - f" ),\n" - ) - else: - code += single + "\n" - code += "}\n" - - return code - - -class CddlParser: - """Parse CDDL specification files.""" - - def __init__(self, cddl_path: str): - """Initialize parser with CDDL file path.""" - self.cddl_path = Path(cddl_path) - self.content = "" - self.modules: dict[str, CddlModule] = {} - self.definitions: dict[str, str] = {} - self.event_names: set[str] = set() # Names of definitions that are events - self._read_file() - - def _read_file(self) -> None: - """Read and preprocess CDDL file.""" - if not self.cddl_path.exists(): - raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - - with open(self.cddl_path, encoding="utf-8") as f: - self.content = f.read() - - logger.info(f"Loaded CDDL file: {self.cddl_path}") - - def parse(self) -> dict[str, CddlModule]: - """Parse CDDL content and return modules.""" - # Remove comments - content = self._remove_comments(self.content) - - # Extract all definitions - self._extract_definitions(content) - - # Extract event names from event union definitions - self._extract_event_names() - - # Extract type definitions by module - self._extract_types() - - # Extract event definitions by module - self._extract_events() - - # Extract command definitions by module - self._extract_commands() - - # If no modules found, create a default one from the filename - if not self.modules: - module_name = self.cddl_path.stem - default_module = CddlModule(name=module_name) - self.modules[module_name] = default_module - logger.warning(f"No modules found in CDDL, creating default: {module_name}") - - return self.modules - - def _remove_comments(self, content: str) -> str: - """Remove comments from CDDL content.""" - # CDDL uses ; for comments to end of line - lines = content.split("\n") - cleaned = [] - for line in lines: - if ";" in line and not line.strip().startswith(";"): - line = line[: line.index(";")] - elif line.strip().startswith(";"): - continue - cleaned.append(line) - return "\n".join(cleaned) - - def _extract_definitions(self, content: str) -> None: - """Extract CDDL definitions (type definitions, commands, etc.).""" - # Match pattern: Name = Definition - # Handles multiline definitions properly. - # The \s* after \n in the lookahead allows definitions that start with - # leading whitespace (e.g. " network.BeforeRequestSent = (") to be - # recognised as separate definitions instead of being swallowed into - # the body of the preceding definition. - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" - - for match in re.finditer(pattern, content, re.DOTALL): - name = match.group(1).strip() - definition = match.group(2).strip() - self.definitions[name] = definition - logger.debug(f"Extracted definition: {name}") - - def _extract_event_names(self) -> None: - """Extract event names from event union definitions. - - Event union definitions follow pattern: - module.ModuleEvent = ( - module.EventName1 // - module.EventName2 // - ... - ) - """ - for def_name, def_content in self.definitions.items(): - # Check if this looks like an event union (name ends with "Event") and - # contains a module-qualified reference like "module.EventName". - # Handles both single-item (no //) and multi-item (// separated) unions. - if "Event" in def_name and re.search(r"\w+\.\w+", def_content): - # Extract event names from the union (works for single and multi-item) - event_refs = re.findall(r"(\w+\.\w+)", def_content) - for event_ref in event_refs: - self.event_names.add(event_ref) - logger.debug(f"Identified event: {event_ref} (from {def_name})") - - def _extract_types(self) -> None: - """Extract type definitions from parsed definitions.""" - # Type definitions follow pattern: module.TypeName = { field: type, ... } - # They have dots in the name and curly braces in the content - # But they DON'T have method: "..." pattern (which means it's not a command) - # Enums follow pattern: module.EnumName = "value1" / "value2" / ... - - for def_name, def_content in self.definitions.items(): - # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") - if "." not in def_name: - continue - - # Skip if it's a command (contains method: pattern) - if "method:" in def_content: - continue - - # Extract module.TypeName - if "." in def_name: - module_name, type_name = def_name.rsplit(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Check if this is an enum (string union with /) - if self._is_enum_definition(def_content): - # Extract enum values - values = self._extract_enum_values(def_content) - if values: - enum_def = CddlEnum( - module=module_name, - name=type_name, - values=values, - description=f"{type_name}", - ) - self.modules[module_name].enums.append(enum_def) - logger.debug( - f"Found enum: {def_name} with {len(values)} values" - ) - else: - # Extract fields from type definition - fields = self._extract_type_fields(def_content) - - if fields: # Only create type if it has fields - type_def = CddlTypeDefinition( - module=module_name, - name=type_name, - fields=fields, - description=f"{type_name}", - ) - self.modules[module_name].types.append(type_def) - logger.debug( - f"Found type: {def_name} with {len(fields)} fields" - ) - - def _is_enum_definition(self, definition: str) -> bool: - """Check if a definition is an enum (string union with /). - - Enums are defined as: "value1" / "value2" / "value3" - """ - # Clean whitespace - clean_def = definition.strip() - - # Must not have curly braces (that would be a type definition) - if "{" in clean_def or "}" in clean_def: - return False - - # Must contain the union operator / surrounded by quotes - # Pattern: "something" / "something_else" - return " / " in clean_def and '"' in clean_def - - def _extract_enum_values(self, enum_definition: str) -> list[str]: - """Extract individual values from an enum definition. - - Enums are defined as: "value1" / "value2" / "value3" - Can span multiple lines. - """ - values = [] - - # Clean the definition and extract quoted strings - # Split by / and extract quoted values - parts = enum_definition.split("/") - - for part in parts: - part = part.strip() - - # Extract quoted string - use search instead of match to find quotes anywhere - match = re.search(r'"([^"]*)"', part) - if match: - value = match.group(1) - values.append(value) - logger.debug(f"Extracted enum value: {value}") - - return values - - @staticmethod - def _normalize_cddl_type(field_type: str) -> str: - """Normalize a CDDL type expression to a simple Python-compatible form. - - Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and - replaces interval/constraint expressions with their base types so that - the caller can safely check for nested struct syntax. - - Examples: - '(float .ge 0.0) .default 1.0' -> 'float' - '(float .ge 0.0) / null' -> 'float / null' - '(0.0...360.0) / null' -> 'float / null' - '-90.0..90.0' -> 'float' - 'float / null .default null' -> 'float / null' - """ - result = field_type - # Remove trailing .default annotations - result = re.sub(r"\s*\.default\s+\S+", "", result) - # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType - result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) - # Replace parenthesised numeric interval types: (0.0...360.0) -> float - result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) - # Replace bare numeric interval types: -90.0..90.0 -> float - result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) - return result.strip() - - def _extract_type_fields(self, type_definition: str) -> dict[str, str]: - """Extract fields from a type definition block.""" - fields = {} - - # Remove outer braces - clean_def = type_definition.strip() - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Parse each line for field: type patterns - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line or line.startswith("//"): - continue - - # Match pattern: [?] fieldName: type - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - field_name = match.group(1).strip() - field_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(field_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - fields[field_name] = normalized_type - logger.debug(f"Extracted field {field_name}: {normalized_type}") - - return fields - - def _extract_events(self) -> None: - """Extract event definitions from parsed definitions. - - Events are definitions that: - 1. Are listed in an event union (e.g., BrowsingContextEvent) - 2. Have method: "..." and params: ... fields - - Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) - """ - # Find definitions that are in the event_names set - event_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip if not identified as an event - if def_name not in self.event_names: - continue - - # Extract method and params - match = event_pattern.search(def_content) - if match: - method = match.group(1) # e.g., "browsingContext.contextCreated" - params_type = match.group(2) # e.g., "browsingContext.Info" - - # Extract module name from method - if "." in method: - module_name, _ = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract event name from definition name (e.g., browsingContext.ContextCreated) - _, event_name = def_name.rsplit(".", 1) - - # Create event - event = CddlEvent( - module=module_name, - name=event_name, - method=method, - params_type=params_type, - description=f"Event: {method}", - ) - - self.modules[module_name].events.append(event) - logger.debug( - f"Found event: {def_name} (method={method}, params={params_type})" - ) - - def _extract_commands(self) -> None: - """Extract command definitions from parsed definitions.""" - # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip definitions that are events (they share the same pattern) - if def_name in self.event_names: - continue - matches = list(command_pattern.finditer(def_content)) - if matches: - for match in matches: - method = match.group(1) # e.g., "session.new" - params_type = match.group(2) # e.g., "session.NewParameters" - - # Extract module name from method - if "." in method: - module_name, command_name = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract parameters - params = self._extract_parameters(params_type) - - # Create command - cmd = CddlCommand( - module=module_name, - name=command_name, - params=params, - description=f"Execute {method}", - ) - - self.modules[module_name].commands.append(cmd) - logger.debug( - f"Found command: {method} with params {params_type}" - ) - - def _extract_parameters( - self, params_type: str, _seen: set[str] | None = None - ) -> dict[str, str]: - """Extract parameters from a parameter type definition. - - Handles both struct types ({...}) and top-level union types (TypeA / TypeB), - merging all fields from each alternative as optional parameters. - """ - params = {} - - if _seen is None: - _seen = set() - if params_type in _seen: - return params - _seen.add(params_type) - - if params_type not in self.definitions: - logger.debug(f"Parameter type not found: {params_type}") - return params - - definition = self.definitions[params_type] - - # Handle top-level type alias that is a union of other named types: - # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest - # These definitions contain a single line with "/" separating type names - # (not the double-slash "//" used for command unions). - stripped = definition.strip() - if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: - # Each token separated by "/" should be a named type reference - alternatives = [a.strip() for a in stripped.split("/") if a.strip()] - all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) - if all_named: - for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) - params.update(alt_params) - return params - - # Remove the outer curly braces and split by comma - # Then parse each line for key: type patterns - clean_def = stripped - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Split by newlines and process each line - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line: - continue - - # Match pattern: [?] name: type - # Using a simple pattern that handles optional prefix - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - param_name = match.group(1).strip() - param_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(param_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - params[param_name] = normalized_type - logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" - ) - - return params - - -def module_name_to_class_name(module_name: str) -> str: - """Convert module name to class name (PascalCase). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - """ - if "_" in module_name: - # Snake_case: browsing_context -> BrowsingContext - return "".join(word.capitalize() for word in module_name.split("_")) - else: - # CamelCase: browsingContext -> BrowsingContext - return module_name[0].upper() + module_name[1:] if module_name else "" - - -def module_name_to_filename(module_name: str) -> str: - """Convert module name to Python filename (snake_case). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - Special cases: - - browsingContext -> browsing_context - - webExtension -> webextension - """ - # Handle explicit mappings for known camelCase names - camel_to_snake_map = { - "browsingContext": "browsing_context", - "webExtension": "webextension", - } - - if module_name in camel_to_snake_map: - return camel_to_snake_map[module_name] - - if "_" in module_name: - # Already snake_case - return module_name - else: - # Convert camelCase to snake_case for other cases - # This handles cases like "myModuleName" -> "my_module_name" - import re - - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: - """Generate __init__.py file for the module.""" - init_path = output_path / "__init__.py" - - code = f"""{SHARED_HEADER} - -from __future__ import annotations - -""" - - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - filename = module_name_to_filename(module_name) - code += f"from .{filename} import {class_name}\n" - - code += "\n__all__ = [\n" - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - code += f' "{class_name}",\n' - code += "]\n" - - with open(init_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {init_path}") - - -def generate_common_file(output_path: Path) -> None: - """Generate common.py file with shared utilities.""" - common_path = output_path / "common.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""Common utilities for BiDi command construction."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from collections.abc import Generator\n" - "from typing import Any\n" - "\n" - "\n" - "def command_builder(\n" - " method: str, params: dict[str, Any] | None = None\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" - ' """Build a BiDi command generator.\n' - "\n" - " Args:\n" - ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command\n" - "\n" - " Yields:\n" - " A dictionary representing the BiDi command\n" - "\n" - " Returns:\n" - " The result from the BiDi command execution\n" - ' """\n' - " if params is None:\n" - " params = {}\n" - ' result = yield {"method": method, "params": params}\n' - " return result\n" - ) - - with open(common_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {common_path}") - - -def generate_console_file(output_path: Path) -> None: - """Generate console.py file with Console enum helper.""" - console_path = output_path / "console.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - "from enum import Enum\n" - "\n" - "\n" - "class Console(Enum):\n" - ' ALL = "all"\n' - ' LOG = "log"\n' - ' ERROR = "error"\n' - ) - - with open(console_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {console_path}") - - -def generate_permissions_file(output_path: Path) -> None: - """Generate permissions.py file with permission-related classes.""" - permissions_path = output_path / "permissions.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""WebDriver BiDi Permissions module."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from __future__ import annotations\n" - "\n" - "from enum import Enum\n" - "from typing import Any\n" - "\n" - "from .common import command_builder\n" - "\n" - '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' - "\n" - "\n" - "class PermissionState(str, Enum):\n" - ' """Permission state enumeration."""\n' - "\n" - ' GRANTED = "granted"\n' - ' DENIED = "denied"\n' - ' PROMPT = "prompt"\n' - "\n" - "\n" - "class PermissionDescriptor:\n" - ' """Descriptor for a permission."""\n' - "\n" - " def __init__(self, name: str) -> None:\n" - ' """Initialize a PermissionDescriptor.\n' - "\n" - " Args:\n" - " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" - ' """\n' - " self.name = name\n" - "\n" - " def __repr__(self) -> str:\n" - " return f\"PermissionDescriptor('{self.name}')\"\n" - "\n" - "\n" - "class Permissions:\n" - ' """WebDriver BiDi Permissions module."""\n' - "\n" - " def __init__(self, websocket_connection: Any) -> None:\n" - ' """Initialize the Permissions module.\n' - "\n" - " Args:\n" - " websocket_connection: The WebSocket connection for sending BiDi commands\n" - ' """\n' - " self._conn = websocket_connection\n" - "\n" - " def set_permission(\n" - " self,\n" - " descriptor: PermissionDescriptor | str,\n" - " state: PermissionState | str,\n" - " origin: str | None = None,\n" - " user_context: str | None = None,\n" - " ) -> None:\n" - ' """Set a permission for a given origin.\n' - "\n" - " Args:\n" - " descriptor: The permission descriptor or permission name as a string\n" - " state: The desired permission state\n" - " origin: The origin for which to set the permission\n" - " user_context: Optional user context ID to scope the permission\n" - "\n" - " Raises:\n" - " ValueError: If the state is not a valid permission state\n" - ' """\n' - " state_value = state.value if isinstance(state, PermissionState) else state\n" - " if state_value not in _VALID_PERMISSION_STATES:\n" - " raise ValueError(\n" - ' f"Invalid permission state: {state_value!r}. "\n' - ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' - " )\n" - "\n" - " if isinstance(descriptor, str):\n" - ' descriptor_dict = {"name": descriptor}\n' - " else:\n" - ' descriptor_dict = {"name": descriptor.name}\n' - "\n" - " params: dict[str, Any] = {\n" - ' "descriptor": descriptor_dict,\n' - ' "state": state_value,\n' - " }\n" - " if origin is not None:\n" - ' params["origin"] = origin\n' - " if user_context is not None:\n" - ' params["userContext"] = user_context\n' - "\n" - ' cmd = command_builder("permissions.setPermission", params)\n' - " self._conn.execute(cmd)\n" - ) - - with open(permissions_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {permissions_path}") - - -def main( - cddl_file: str, - output_dir: str, - spec_version: str = "1.0", - enhancements_manifest: str | None = None, -) -> None: - """Main entry point. - - Args: - cddl_file: Path to CDDL specification file - output_dir: Output directory for generated modules - spec_version: BiDi spec version - enhancements_manifest: Path to enhancement manifest Python file - """ - output_path = Path(output_dir).resolve() - output_path.mkdir(parents=True, exist_ok=True) - - logger.info(f"WebDriver BiDi Code Generator v{__version__}") - logger.info(f"Input CDDL: {cddl_file}") - logger.info(f"Output directory: {output_path}") - logger.info(f"Spec version: {spec_version}") - - # Load enhancement manifest - manifest = load_enhancements_manifest(enhancements_manifest) - if manifest: - logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") - - # Parse CDDL - parser = CddlParser(cddl_file) - modules = parser.parse() - - logger.info(f"Parsed {len(modules)} modules") - - # Clean up existing generated files - for file_path in output_path.glob("*.py"): - if file_path.name != "py.typed" and not file_path.name.startswith("_"): - file_path.unlink() - logger.debug(f"Removed: {file_path}") - - # Generate module files using snake_case filenames - for module_name, module in sorted(modules.items()): - filename = module_name_to_filename(module_name) - module_path = output_path / f"{filename}.py" - - # Get module-specific enhancements (merge with dataclass templates) - module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) - - # Add dataclass methods and docstrings to the enhancement data for this module - full_module_enhancements = { - **module_enhancements, - "dataclass_methods": manifest.get("dataclass_methods", {}), - "method_docstrings": manifest.get("method_docstrings", {}), - } - - with open(module_path, "w", encoding="utf-8") as f: - f.write(module.generate_code(full_module_enhancements)) - logger.info(f"Generated: {module_path}") - - # Generate __init__.py - generate_init_file(output_path, modules) - - # Generate common.py - generate_common_file(output_path) - - # Generate permissions.py - generate_permissions_file(output_path) - - # Generate console.py - generate_console_file(output_path) - - # Create py.typed marker - py_typed_path = output_path / "py.typed" - py_typed_path.touch() - logger.info(f"Generated type marker: {py_typed_path}") - - logger.info("Code generation complete!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate Python WebDriver BiDi modules from CDDL specification" - ) - parser.add_argument( - "cddl_file", - help="Path to CDDL specification file", - ) - parser.add_argument( - "output_dir", - help="Output directory for generated Python modules", - ) - parser.add_argument( - "--version", - default="1.0", - help="BiDi spec version (default: 1.0)", - ) - parser.add_argument( - "--enhancements-manifest", - default=None, - help="Path to enhancement manifest Python file (optional)", - ) - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Enable verbose logging", - ) - - args = parser.parse_args() - - if args.verbose: - logging.getLogger("generate_bidi").setLevel(logging.DEBUG) - - try: - main( - args.cddl_file, - args.output_dir, - args.version, - args.enhancements_manifest, - ) - sys.exit(0) - except Exception as e: - logger.error(f"Generation failed: {e}", exc_info=True) - sys.exit(1) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index cd34c35db3696..44cb2adef7a0b 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -158,7 +158,9 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException("The command's generator function did not exit when expected!") + raise WebDriverException( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return exit.value @@ -175,11 +177,15 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + self._ws.run_forever( + sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True + ) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws = WebSocketApp( + self.url, on_open=on_open, on_message=on_message, on_error=on_error + ) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() From 7e051f82705cb5c78fdf9e59b1228f9fd51c9432 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:12:02 +0000 Subject: [PATCH 08/42] remove --version call --- py/private/generate_bidi.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index c11b6efe4735f..e072279f85e94 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -53,7 +53,6 @@ def _generate_bidi_impl(ctx): args = [ cddl_file.path, output_base, - "--version", spec_version, ] From bd13e213325a7dd0d8e8197b641e54468bf54a87 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:37:46 +0000 Subject: [PATCH 09/42] correct web extensions --- py/generate_bidi.py | 1274 +++++++++++++++-- py/private/bidi_enhancements_manifest.py | 10 +- py/selenium/webdriver/common/bidi/__init__.py | 17 - py/selenium/webdriver/common/bidi/browser.py | 83 +- .../webdriver/common/bidi/browsing_context.py | 256 ++-- py/selenium/webdriver/common/bidi/common.py | 11 +- .../webdriver/common/bidi/emulation.py | 216 +-- py/selenium/webdriver/common/bidi/input.py | 83 +- py/selenium/webdriver/common/bidi/log.py | 39 +- py/selenium/webdriver/common/bidi/network.py | 312 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 253 ++-- py/selenium/webdriver/common/bidi/session.py | 77 +- py/selenium/webdriver/common/bidi/storage.py | 75 +- .../webdriver/common/bidi/webextension.py | 57 +- 15 files changed, 1764 insertions(+), 1009 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8103cafe40684..d14e2575c8bfd 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,11 +18,12 @@ import logging import re import sys +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import indent as tw_indent -from typing import Any +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple __version__ = "1.0.0" @@ -32,24 +33,7 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" @@ -59,7 +43,8 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any +from typing import Any, Dict, List, Optional, Union +from .common import command_builder """ @@ -68,7 +53,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,10 +124,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"list[{inner_type}]" + return f"List[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "dict[str, Any]" + return "Dict[str, Any]" # Default to Any for unknown types return "Any" @@ -154,11 +139,11 @@ class CddlCommand: module: str name: str - params: dict[str, str] = field(default_factory=dict) - result: str | None = None + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None description: str = "" - def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python method code for this command. Args: @@ -189,15 +174,8 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: else: param_list = "self" - # Build method body - wrap long signatures over multiple lines if needed - sig_line = f" def {method_name}({param_list}):" - if len(sig_line) > 120 and param_strs: - body = f" def {method_name}(\n self,\n" - for p in param_strs: - body += f" {p},\n" - body += " ):\n" - else: - body = sig_line + "\n" + # Build method body + body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -259,6 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -285,45 +264,45 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f' item.get("{extract_property}")\n' - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += " return extracted\n" - body += " return result\n" + body += f" return extracted\n" + body += f" return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -372,10 +351,10 @@ class CddlTypeDefinition: module: str name: str - fields: dict[str, str] = field(default_factory=dict) + fields: Dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python dataclass code for this type. Args: @@ -385,14 +364,11 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name. - # CDDL type names that start with a lowercase letter (e.g. camelCase - # command-parameter types like "setNetworkConditionsParameters") are - # capitalised so that the resulting Python class follows PascalCase. - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name - code = "@dataclass\n" + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = f"@dataclass\n" code += f"class {class_name}:\n" - code += f' """{class_name} type definition."""\n\n' + code += f' """{self.description or self.name}."""\n\n' if not self.fields: code += " pass\n" @@ -410,7 +386,7 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "list[" in python_type: + elif "List[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -477,7 +453,7 @@ class CddlEnum: module: str name: str - values: list[str] = field(default_factory=list) + values: List[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -486,9 +462,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name + class_name = self.name code = f"class {class_name}:\n" - code += f' """{class_name}."""\n\n' + code += f' """{self.description or self.name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -554,10 +530,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: list[CddlCommand] = field(default_factory=list) - types: list[CddlTypeDefinition] = field(default_factory=list) - enums: list[CddlEnum] = field(default_factory=list) - events: list[CddlEvent] = field(default_factory=list) + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -572,33 +548,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: - """Check if any type definition in this module requires the 'field' import. - - Respects the same type exclusions applied during code generation. - """ - enhancements = enhancements or {} - extra_cls_names: set[str] = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - for type_def in self.types: - if type_def.name in exclude_types: - continue - for field_type in type_def.fields.values(): - # Literal string discriminants use field(default=..., init=False) - if re.match(r'^"', field_type.strip()): - return True - # List-typed fields use field(default_factory=list) - python_type = CddlTypeDefinition._get_python_type(field_type) - if python_type.startswith("list["): - return True - return False - - def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python code for this module. Args: @@ -608,18 +558,18 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.commands: - code += "from .common import command_builder\n" - dataclass_imported = False + if self.types: + code += "from dataclasses import field\n" if self.commands or self.types: + code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" - dataclass_imported = True - if self.types and self._needs_field_import(enhancements): - code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: - code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -700,19 +650,8 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Collect names of extra_dataclasses so we can skip CDDL-generated - # enums and types that are overridden by manual definitions. - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - # Generate enums first, skipping any that are overridden via extra_dataclasses + # Generate enums first for enum_def in self.enums: - if enum_def.name in exclude_types: - continue code += enum_def.to_python_class() code += "\n\n" @@ -721,6 +660,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -740,18 +680,13 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" - # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name not in extra_event_keys: - code += f' "{event_name}": "{event_def.method}",\n' + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -797,7 +732,1094 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # EventConfig, _EventWrapper, and _EventManager are imported from - # ._event_manager (see the import block above); nothing to emit here. + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + else: + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + for event_def in self.events: + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += f"_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, "r", encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> Dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> List[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters + params = self._extract_parameters(params_type) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params = {} + + if _seen is None: + _seen = set() + if params_type in _seen: + return params + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + for alt_type in alternatives: + alt_params = self._extract_parameters(alt_type, _seen) + params.update(alt_params) + return params + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + logger.debug( + f"Extracted param {param_name}: {normalized_type} from {params_type}" + ) + + return params + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += f"\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from typing import Any, Dict, Generator\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any, Optional, Union\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: Optional[str] = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "spec_version", + nargs="?", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.spec_version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index adf0a17128af3..5dcce3c25ffeb 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1351,18 +1351,24 @@ def to_bidi_dict(self) -> dict: params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd)''', - ''' def uninstall(self, extension: Any | None = None): + ''' def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index bb129d5f6a195..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ff0c2d59b8cf2..7cf9678c9b007 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass def transform_download_params( @@ -77,9 +61,17 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") +class ClientWindowNamedState: + """ClientWindowNamedState.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + + @dataclass class ClientWindowInfo: - """ClientWindowInfo type definition.""" + """ClientWindowInfo.""" active: bool | None = None client_window: Any | None = None @@ -121,14 +113,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo type definition.""" + """UserContextInfo.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters type definition.""" + """CreateUserContextParameters.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -137,35 +129,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult type definition.""" + """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = field(default_factory=list) + client_windows: list[Any | None] | None = None @dataclass class GetUserContextsResult: - """GetUserContextsResult type definition.""" + """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters type definition.""" + """RemoveUserContextParameters.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters type definition.""" + """SetClientWindowStateParameters.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState type definition.""" + """ClientWindowRectState.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -176,15 +168,15 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters type definition.""" + """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed type definition.""" + """DownloadBehaviorAllowed.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -192,7 +184,7 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied type definition.""" + """DownloadBehaviorDenied.""" type: str = field(default="denied", init=False) @@ -220,12 +212,7 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Any | None = None, - unhandled_prompt_behavior: Any | None = None, - ): + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -306,6 +293,22 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 7a0f8faf8687e..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -65,7 +49,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info type definition.""" + """Info.""" children: Any | None = None client_window: Any | None = None @@ -78,7 +62,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator type definition.""" + """AccessibilityLocator.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -87,7 +71,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator type definition.""" + """CssLocator.""" type: str = field(default="css", init=False) value: str | None = None @@ -95,7 +79,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator type definition.""" + """ContextLocator.""" type: str = field(default="context", init=False) context: Any | None = None @@ -103,7 +87,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator type definition.""" + """InnerTextLocator.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -114,7 +98,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator type definition.""" + """XPathLocator.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -122,7 +106,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo type definition.""" + """BaseNavigationInfo.""" context: Any | None = None navigation: Any | None = None @@ -132,14 +116,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters type definition.""" + """ActivateParameters.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters type definition.""" + """CaptureScreenshotParameters.""" context: Any | None = None format: Any | None = None @@ -148,7 +132,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat type definition.""" + """ImageFormat.""" type: str | None = None quality: Any | None = None @@ -156,7 +140,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle type definition.""" + """ElementClipRectangle.""" type: str = field(default="element", init=False) element: Any | None = None @@ -164,7 +148,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle type definition.""" + """BoxClipRectangle.""" type: str = field(default="box", init=False) x: Any | None = None @@ -175,14 +159,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult type definition.""" + """CaptureScreenshotResult.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters type definition.""" + """CloseParameters.""" context: Any | None = None prompt_unload: bool | None = None @@ -190,7 +174,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters type definition.""" + """CreateParameters.""" type: Any | None = None reference_context: Any | None = None @@ -200,14 +184,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult type definition.""" + """CreateResult.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters type definition.""" + """GetTreeParameters.""" max_depth: Any | None = None root: Any | None = None @@ -215,14 +199,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult type definition.""" + """GetTreeResult.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters type definition.""" + """HandleUserPromptParameters.""" context: Any | None = None accept: bool | None = None @@ -231,24 +215,24 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters type definition.""" + """LocateNodesParameters.""" context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = field(default_factory=list) + start_nodes: list[Any | None] | None = None @dataclass class LocateNodesResult: - """LocateNodesResult type definition.""" + """LocateNodesResult.""" - nodes: list[Any | None] | None = field(default_factory=list) + nodes: list[Any | None] | None = None @dataclass class NavigateParameters: - """NavigateParameters type definition.""" + """NavigateParameters.""" context: Any | None = None url: str | None = None @@ -257,7 +241,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult type definition.""" + """NavigateResult.""" navigation: Any | None = None url: str | None = None @@ -265,7 +249,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters type definition.""" + """PrintParameters.""" context: Any | None = None background: bool | None = None @@ -277,7 +261,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters type definition.""" + """PrintMarginParameters.""" bottom: Any | None = None left: Any | None = None @@ -287,7 +271,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters type definition.""" + """PrintPageParameters.""" height: Any | None = None width: Any | None = None @@ -295,14 +279,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult type definition.""" + """PrintResult.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters type definition.""" + """ReloadParameters.""" context: Any | None = None ignore_cache: bool | None = None @@ -311,17 +295,17 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters type definition.""" + """SetViewportParameters.""" context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class Viewport: - """Viewport type definition.""" + """Viewport.""" width: Any | None = None height: Any | None = None @@ -329,7 +313,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters type definition.""" + """TraverseHistoryParameters.""" context: Any | None = None delta: Any | None = None @@ -337,16 +321,30 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters type definition.""" + """HistoryUpdatedParameters.""" context: Any | None = None timestamp: Any | None = None url: str | None = None +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) + + @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters type definition.""" + """UserPromptClosedParameters.""" context: Any | None = None accepted: bool | None = None @@ -356,7 +354,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters type definition.""" + """UserPromptOpenedParameters.""" context: Any | None = None handler: Any | None = None @@ -392,10 +390,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -416,6 +414,8 @@ def from_json(cls, params: dict) -> DownloadEndParams: "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,13 +630,7 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot( - self, - context: str | None = None, - format: Any | None = None, - clip: Any | None = None, - origin: str | None = None, - ): + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -663,13 +657,7 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create( - self, - type: Any | None = None, - reference_context: Any | None = None, - background: bool | None = None, - user_context: Any | None = None, - ): + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" params = { "type": type, @@ -723,14 +711,7 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes( - self, - context: str | None = None, - locator: Any | None = None, - serialization_options: Any | None = None, - start_nodes: Any | None = None, - max_node_count: int | None = None, - ): + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -759,15 +740,7 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print( - self, - context: Any | None = None, - background: bool | None = None, - margin: Any | None = None, - page: Any | None = None, - scale: Any | None = None, - shrink_to_fit: bool | None = None, - ): + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" params = { "context": context, @@ -797,13 +770,7 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any | None = None, - user_contexts: Any | None = None, - device_pixel_ratio: Any | None = None, - ): + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -901,81 +868,20 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": ( - EventConfig("context_created", "browsingContext.contextCreated", - _globals.get("ContextCreated", dict)) - if _globals.get("ContextCreated") - else EventConfig("context_created", "browsingContext.contextCreated", dict) - ), - "context_destroyed": ( - EventConfig("context_destroyed", "browsingContext.contextDestroyed", - _globals.get("ContextDestroyed", dict)) - if _globals.get("ContextDestroyed") - else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) - ), - "navigation_started": ( - EventConfig("navigation_started", "browsingContext.navigationStarted", - _globals.get("NavigationStarted", dict)) - if _globals.get("NavigationStarted") - else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) - ), - "fragment_navigated": ( - EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", - _globals.get("FragmentNavigated", dict)) - if _globals.get("FragmentNavigated") - else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) - ), - "history_updated": ( - EventConfig("history_updated", "browsingContext.historyUpdated", - _globals.get("HistoryUpdated", dict)) - if _globals.get("HistoryUpdated") - else EventConfig("history_updated", "browsingContext.historyUpdated", dict) - ), - "dom_content_loaded": ( - EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", - _globals.get("DomContentLoaded", dict)) - if _globals.get("DomContentLoaded") - else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) - ), - "load": ( - EventConfig("load", "browsingContext.load", - _globals.get("Load", dict)) - if _globals.get("Load") - else EventConfig("load", "browsingContext.load", dict) - ), - "navigation_aborted": ( - EventConfig("navigation_aborted", "browsingContext.navigationAborted", - _globals.get("NavigationAborted", dict)) - if _globals.get("NavigationAborted") - else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) - ), - "navigation_committed": ( - EventConfig("navigation_committed", "browsingContext.navigationCommitted", - _globals.get("NavigationCommitted", dict)) - if _globals.get("NavigationCommitted") - else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) - ), - "navigation_failed": ( - EventConfig("navigation_failed", "browsingContext.navigationFailed", - _globals.get("NavigationFailed", dict)) - if _globals.get("NavigationFailed") - else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) - ), - "user_prompt_closed": ( - EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", - _globals.get("UserPromptClosed", dict)) - if _globals.get("UserPromptClosed") - else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) - ), - "user_prompt_opened": ( - EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", - _globals.get("UserPromptOpened", dict)) - if _globals.get("UserPromptOpened") - else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) - ), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", - _globals.get("DownloadWillBeginParams", dict), - ), + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index dae051876833e..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,15 +17,12 @@ """Common utilities for BiDi command construction.""" -from __future__ import annotations - -from collections.abc import Generator -from typing import Any +from typing import Any, Dict, Generator def command_builder( - method: str, params: dict[str, Any] | None = None -) -> Generator[dict[str, Any], Any, Any]: + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: @@ -38,7 +35,5 @@ def command_builder( Returns: The result from the BiDi command execution """ - if params is None: - params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index c58f6d5f78d6c..a85eaad3e223a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class ForcedColorsModeTheme: @@ -54,24 +38,24 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters type definition.""" + """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters type definition.""" + """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class GeolocationCoordinates: - """GeolocationCoordinates type definition.""" + """GeolocationCoordinates.""" latitude: Any | None = None longitude: Any | None = None @@ -84,39 +68,39 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError type definition.""" + """GeolocationPositionError.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters type definition.""" + """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters type definition.""" +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline type definition.""" + """NetworkConditionsOffline.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea type definition.""" + """ScreenArea.""" width: Any | None = None height: Any | None = None @@ -124,16 +108,16 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters type definition.""" + """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ScreenOrientation: - """ScreenOrientation type definition.""" + """ScreenOrientation.""" natural: Any | None = None type: Any | None = None @@ -141,64 +125,64 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters type definition.""" + """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters type definition.""" + """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters type definition.""" + """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters type definition.""" + """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters type definition.""" + """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters type definition.""" + """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters type definition.""" + """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None class Emulation: @@ -207,12 +191,7 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override( - self, - theme: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -224,12 +203,18 @@ def set_forced_colors_mode_theme_override( result = self._conn.execute(cmd) return result - def set_locale_override( - self, - locale: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -241,12 +226,19 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_screen_settings_override( - self, - screen_area: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result + + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -258,12 +250,31 @@ def set_screen_settings_override( result = self._conn.execute(cmd) return result - def set_viewport_meta_override( - self, - viewport_meta: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result + + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -275,12 +286,19 @@ def set_viewport_meta_override( result = self._conn.execute(cmd) return result - def set_scrollbar_type_override( - self, - scrollbar_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result + + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -292,7 +310,19 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index e9c3f8345f05d..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -50,7 +34,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin type definition.""" + """ElementOrigin.""" type: str = field(default="element", init=False) element: Any | None = None @@ -58,59 +42,59 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters type definition.""" + """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class NoneSourceActions: - """NoneSourceActions type definition.""" + """NoneSourceActions.""" type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class KeySourceActions: - """KeySourceActions type definition.""" + """KeySourceActions.""" type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerSourceActions: - """PointerSourceActions type definition.""" + """PointerSourceActions.""" type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerParameters: - """PointerParameters type definition.""" + """PointerParameters.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions type definition.""" + """WheelSourceActions.""" type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PauseAction: - """PauseAction type definition.""" + """PauseAction.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -118,7 +102,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction type definition.""" + """KeyDownAction.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -126,7 +110,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction type definition.""" + """KeyUpAction.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -134,7 +118,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction type definition.""" + """PointerUpAction.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -142,7 +126,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction type definition.""" + """WheelScrollAction.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -155,7 +139,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties type definition.""" + """PointerCommonProperties.""" width: Any | None = None height: Any | None = None @@ -168,18 +152,18 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters type definition.""" + """ReleaseActionsParameters.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters type definition.""" + """SetFilesParameters.""" context: Any | None = None element: Any | None = None - files: list[Any | None] | None = field(default_factory=list) + files: list[Any | None] | None = None @dataclass @@ -191,7 +175,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -384,7 +368,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -405,7 +389,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -470,10 +454,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": ( - EventConfig("file_dialog_opened", "input.fileDialogOpened", - _globals.get("FileDialogOpened", dict)) - if _globals.get("FileDialogOpened") - else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) - ), + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 94f511d7185f8..7aa7fbf7a3171 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,11 +6,14 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -44,7 +30,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry type definition.""" + """BaseLogEntry.""" level: Any | None = None source: Any | None = None @@ -55,7 +41,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry type definition.""" + """GenericLogEntry.""" type: str | None = None @@ -74,7 +60,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -99,7 +85,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -312,10 +298,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": ( - EventConfig("entry_added", "log.entryAdded", - _globals.get("EntryAdded", dict)) - if _globals.get("EntryAdded") - else EventConfig("entry_added", "log.entryAdded", dict) - ), + "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 9dc5fb94d8488..2290c9fec12d3 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" @@ -66,7 +50,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge type definition.""" + """AuthChallenge.""" scheme: str | None = None realm: str | None = None @@ -74,7 +58,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials type definition.""" + """AuthCredentials.""" type: str = field(default="password", init=False) username: str | None = None @@ -83,7 +67,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters type definition.""" + """BaseParameters.""" context: Any | None = None is_blocked: bool | None = None @@ -91,12 +75,12 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = field(default_factory=list) + intercepts: list[Any | None] | None = None @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -104,7 +88,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value type definition.""" + """Base64Value.""" type: str = field(default="base64", init=False) value: str | None = None @@ -112,7 +96,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie type definition.""" + """Cookie.""" name: str | None = None value: Any | None = None @@ -127,7 +111,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader type definition.""" + """CookieHeader.""" name: str | None = None value: Any | None = None @@ -135,7 +119,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo type definition.""" + """FetchTimingInfo.""" time_origin: Any | None = None request_time: Any | None = None @@ -154,7 +138,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header type definition.""" + """Header.""" name: str | None = None value: Any | None = None @@ -162,7 +146,7 @@ class Header: @dataclass class Initiator: - """Initiator type definition.""" + """Initiator.""" column_number: Any | None = None line_number: Any | None = None @@ -173,32 +157,32 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent type definition.""" + """ResponseContent.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData type definition.""" + """ResponseData.""" url: str | None = None protocol: str | None = None status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = field(default_factory=list) + auth_challenges: list[Any | None] | None = None @dataclass class SetCookieHeader: - """SetCookieHeader type definition.""" + """SetCookieHeader.""" name: str | None = None value: Any | None = None @@ -213,7 +197,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern type definition.""" + """UrlPatternPattern.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -225,7 +209,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString type definition.""" + """UrlPatternString.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -233,68 +217,68 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters type definition.""" + """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = field(default_factory=list) + data_types: list[Any | None] | None = None max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class AddDataCollectorResult: - """AddDataCollectorResult type definition.""" + """AddDataCollectorResult.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters type definition.""" + """AddInterceptParameters.""" - phases: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - url_patterns: list[Any | None] | None = field(default_factory=list) + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None @dataclass class AddInterceptResult: - """AddInterceptResult type definition.""" + """AddInterceptResult.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters type definition.""" + """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None credentials: Any | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters type definition.""" + """ContinueWithAuthParameters.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials type definition.""" + """ContinueWithAuthCredentials.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class DisownDataParameters: - """DisownDataParameters type definition.""" +class disownDataParameters: + """disownDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +287,14 @@ class DisownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters type definition.""" + """FailRequestParameters.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters type definition.""" + """GetDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -320,85 +304,57 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult type definition.""" + """GetDataResult.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters type definition.""" + """ProvideResponseParameters.""" request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) - headers: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters type definition.""" + """RemoveDataCollectorParameters.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters type definition.""" + """RemoveInterceptParameters.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters type definition.""" + """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters type definition.""" - - headers: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) - - -@dataclass -class AuthRequiredParameters: - """AuthRequiredParameters type definition.""" - - response: Any | None = None - - -@dataclass -class BeforeRequestSentParameters: - """BeforeRequestSentParameters type definition.""" - - initiator: Any | None = None - - -@dataclass -class FetchErrorParameters: - """FetchErrorParameters type definition.""" - - error_text: str | None = None + """SetExtraHeadersParameters.""" - -@dataclass -class ResponseCompletedParameters: - """ResponseCompletedParameters type definition.""" - - response: Any | None = None + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ResponseStartedParameters: - """ResponseStartedParameters type definition.""" + """ResponseStartedParameters.""" response: Any | None = None @@ -441,10 +397,6 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", - "before_request_sent": "network.beforeRequestSent", - "fetch_error": "network.fetchError", - "response_completed": "network.responseCompleted", - "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -611,14 +563,7 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector( - self, - data_types: list[Any] | None = None, - max_encoded_data_size: Any | None = None, - collector_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -632,12 +577,7 @@ def add_data_collector( result = self._conn.execute(cmd) return result - def add_intercept( - self, - phases: list[Any] | None = None, - contexts: list[Any] | None = None, - url_patterns: list[Any] | None = None, - ): + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" params = { "phases": phases, @@ -649,15 +589,7 @@ def add_intercept( result = self._conn.execute(cmd) return result - def continue_request( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - method: Any | None = None, - url: Any | None = None, - ): + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" params = { "request": request, @@ -672,15 +604,7 @@ def continue_request( result = self._conn.execute(cmd) return result - def continue_response( - self, - request: Any | None = None, - cookies: list[Any] | None = None, - credentials: Any | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" params = { "request": request, @@ -727,13 +651,7 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data( - self, - data_type: Any | None = None, - collector: Any | None = None, - disown: bool | None = None, - request: Any | None = None, - ): + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" params = { "dataType": data_type, @@ -746,15 +664,7 @@ def get_data( result = self._conn.execute(cmd) return result - def provide_response( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" params = { "request": request, @@ -789,7 +699,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -800,12 +710,7 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[A result = self._conn.execute(cmd) return result - def set_extra_headers( - self, - headers: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -817,6 +722,52 @@ def set_extra_headers( result = self._conn.execute(cmd) return result + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -971,51 +922,10 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined -# Event: network.beforeRequestSent -BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined - -# Event: network.fetchError -FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseCompleted -ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseStarted -ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined - # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": ( - EventConfig("auth_required", "network.authRequired", - _globals.get("AuthRequired", dict)) - if _globals.get("AuthRequired") - else EventConfig("auth_required", "network.authRequired", dict) - ), - "before_request_sent": ( - EventConfig("before_request_sent", "network.beforeRequestSent", - _globals.get("BeforeRequestSent", dict)) - if _globals.get("BeforeRequestSent") - else EventConfig("before_request_sent", "network.beforeRequestSent", dict) - ), - "fetch_error": ( - EventConfig("fetch_error", "network.fetchError", - _globals.get("FetchError", dict)) - if _globals.get("FetchError") - else EventConfig("fetch_error", "network.fetchError", dict) - ), - "response_completed": ( - EventConfig("response_completed", "network.responseCompleted", - _globals.get("ResponseCompleted", dict)) - if _globals.get("ResponseCompleted") - else EventConfig("response_completed", "network.responseCompleted", dict) - ), - "response_started": ( - EventConfig("response_started", "network.responseStarted", - _globals.get("ResponseStarted", dict)) - if _globals.get("ResponseStarted") - else EventConfig("response_started", "network.responseStarted", dict) - ), + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 6dd138da17309..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Optional, Union from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: PermissionDescriptor | str, - state: PermissionState | str, - origin: str | None = None, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0b2ec04101933..c7bfcb3774dff 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -64,7 +48,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue type definition.""" + """ChannelValue.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -72,7 +56,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties type definition.""" + """ChannelProperties.""" channel: Any | None = None serialization_options: Any | None = None @@ -81,7 +65,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess type definition.""" + """EvaluateResultSuccess.""" type: str = field(default="success", init=False) result: Any | None = None @@ -90,7 +74,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException type definition.""" + """EvaluateResultException.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -99,7 +83,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails type definition.""" + """ExceptionDetails.""" column_number: Any | None = None exception: Any | None = None @@ -110,7 +94,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue type definition.""" + """ArrayLocalValue.""" type: str = field(default="array", init=False) value: Any | None = None @@ -118,7 +102,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue type definition.""" + """DateLocalValue.""" type: str = field(default="date", init=False) value: str | None = None @@ -126,7 +110,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue type definition.""" + """MapLocalValue.""" type: str = field(default="map", init=False) value: Any | None = None @@ -134,7 +118,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue type definition.""" + """ObjectLocalValue.""" type: str = field(default="object", init=False) value: Any | None = None @@ -142,7 +126,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue type definition.""" + """RegExpValue.""" pattern: str | None = None flags: str | None = None @@ -150,7 +134,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue type definition.""" + """RegExpLocalValue.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -158,7 +142,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue type definition.""" + """SetLocalValue.""" type: str = field(default="set", init=False) value: Any | None = None @@ -166,21 +150,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue type definition.""" + """UndefinedValue.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue type definition.""" + """NullValue.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -188,7 +172,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue type definition.""" + """NumberValue.""" type: str = field(default="number", init=False) value: Any | None = None @@ -196,7 +180,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue type definition.""" + """BooleanValue.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -204,7 +188,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue type definition.""" + """BigIntValue.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -212,7 +196,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo type definition.""" + """BaseRealmInfo.""" realm: Any | None = None origin: str | None = None @@ -220,7 +204,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo type definition.""" + """WindowRealmInfo.""" type: str = field(default="window", init=False) context: Any | None = None @@ -229,57 +213,57 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo type definition.""" + """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = field(default_factory=list) + owners: list[Any | None] | None = None @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo type definition.""" + """SharedWorkerRealmInfo.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo type definition.""" + """ServiceWorkerRealmInfo.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo type definition.""" + """WorkerRealmInfo.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo type definition.""" + """PaintWorkletRealmInfo.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo type definition.""" + """AudioWorkletRealmInfo.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo type definition.""" + """WorkletRealmInfo.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference type definition.""" + """SharedReference.""" shared_id: Any | None = None handle: Any | None = None @@ -287,7 +271,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference type definition.""" + """RemoteObjectReference.""" handle: Any | None = None shared_id: Any | None = None @@ -295,7 +279,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue type definition.""" + """SymbolRemoteValue.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -304,7 +288,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue type definition.""" + """ArrayRemoteValue.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -314,7 +298,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue type definition.""" + """ObjectRemoteValue.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -324,7 +308,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue type definition.""" + """FunctionRemoteValue.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -333,7 +317,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue type definition.""" + """RegExpRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -341,7 +325,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue type definition.""" + """DateRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -349,7 +333,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue type definition.""" + """MapRemoteValue.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -359,7 +343,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue type definition.""" + """SetRemoteValue.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -369,7 +353,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue type definition.""" + """WeakMapRemoteValue.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -378,7 +362,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue type definition.""" + """WeakSetRemoteValue.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -387,7 +371,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue type definition.""" + """GeneratorRemoteValue.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -396,7 +380,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue type definition.""" + """ErrorRemoteValue.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -405,7 +389,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue type definition.""" + """ProxyRemoteValue.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -414,7 +398,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue type definition.""" + """PromiseRemoteValue.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -423,7 +407,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue type definition.""" + """TypedArrayRemoteValue.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -432,7 +416,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue type definition.""" + """ArrayBufferRemoteValue.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -441,7 +425,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue type definition.""" + """NodeListRemoteValue.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -451,7 +435,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue type definition.""" + """HTMLCollectionRemoteValue.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -461,7 +445,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue type definition.""" + """NodeRemoteValue.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -472,11 +456,11 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties type definition.""" + """NodeProperties.""" node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = field(default_factory=list) + children: list[Any | None] | None = None local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -486,7 +470,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue type definition.""" + """WindowProxyRemoteValue.""" type: str = field(default="window", init=False) value: Any | None = None @@ -496,14 +480,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties type definition.""" + """WindowProxyProperties.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame type definition.""" + """StackFrame.""" column_number: Any | None = None function_name: str | None = None @@ -513,14 +497,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace type definition.""" + """StackTrace.""" - call_frames: list[Any | None] | None = field(default_factory=list) + call_frames: list[Any | None] | None = None @dataclass class Source: - """Source type definition.""" + """Source.""" realm: Any | None = None context: Any | None = None @@ -528,14 +512,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget type definition.""" + """RealmTarget.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget type definition.""" + """ContextTarget.""" context: Any | None = None sandbox: str | None = None @@ -543,38 +527,38 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters type definition.""" + """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None sandbox: str | None = None @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult type definition.""" + """AddPreloadScriptResult.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters type definition.""" + """DisownParameters.""" - handles: list[Any | None] | None = field(default_factory=list) + handles: list[Any | None] | None = None target: Any | None = None @dataclass class CallFunctionParameters: - """CallFunctionParameters type definition.""" + """CallFunctionParameters.""" function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -583,7 +567,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters type definition.""" + """EvaluateParameters.""" expression: str | None = None target: Any | None = None @@ -595,7 +579,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters type definition.""" + """GetRealmsParameters.""" context: Any | None = None type: Any | None = None @@ -603,21 +587,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult type definition.""" + """GetRealmsResult.""" - realms: list[Any | None] | None = field(default_factory=list) + realms: list[Any | None] | None = None @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters type definition.""" + """RemovePreloadScriptParameters.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters type definition.""" + """MessageParameters.""" channel: Any | None = None data: Any | None = None @@ -626,14 +610,13 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters type definition.""" + """RealmDestroyedParameters.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { - "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -800,14 +783,7 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script( - self, - function_declaration: Any | None = None, - arguments: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - sandbox: Any | None = None, - ): + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -821,7 +797,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: list[Any] | None = None, target: Any | None = None): + def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -832,17 +808,7 @@ def disown(self, handles: list[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function( - self, - function_declaration: Any | None = None, - await_promise: bool | None = None, - target: Any | None = None, - arguments: list[Any] | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - this: Any | None = None, - user_activation: bool | None = None, - ): + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -859,15 +825,7 @@ def call_function( result = self._conn.execute(cmd) return result - def evaluate( - self, - expression: Any | None = None, - target: Any | None = None, - await_promise: bool | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - user_activation: bool | None = None, - ): + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" params = { "expression": expression, @@ -903,6 +861,18 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -919,9 +889,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1162,9 +1131,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" @@ -1304,9 +1272,6 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases -# Event: script.message -Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined - # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1317,22 +1282,6 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "message": ( - EventConfig("message", "script.message", - _globals.get("Message", dict)) - if _globals.get("Message") - else EventConfig("message", "script.message", dict) - ), - "realm_created": ( - EventConfig("realm_created", "script.realmCreated", - _globals.get("RealmCreated", dict)) - if _globals.get("RealmCreated") - else EventConfig("realm_created", "script.realmCreated", dict) - ), - "realm_destroyed": ( - EventConfig("realm_destroyed", "script.realmDestroyed", - _globals.get("RealmDestroyed", dict)) - if _globals.get("RealmDestroyed") - else EventConfig("realm_destroyed", "script.realmDestroyed", dict) - ), + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 771a5327151bf..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: @@ -39,15 +23,15 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest type definition.""" + """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = field(default_factory=list) + first_match: list[Any | None] | None = None @dataclass class CapabilityRequest: - """CapabilityRequest type definition.""" + """CapabilityRequest.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -59,31 +43,31 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration type definition.""" + """AutodetectProxyConfiguration.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration type definition.""" + """DirectProxyConfiguration.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration type definition.""" + """ManualProxyConfiguration.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = field(default_factory=list) + no_proxy: list[Any | None] | None = None @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration type definition.""" + """SocksProxyConfiguration.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -91,7 +75,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration type definition.""" + """PacProxyConfiguration.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -99,37 +83,37 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration type definition.""" + """SystemProxyConfiguration.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters type definition.""" + """SubscribeParameters.""" - events: list[str | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest type definition.""" + """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = field(default_factory=list) + subscriptions: list[Any | None] | None = None @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest type definition.""" + """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = field(default_factory=list) + events: list[str | None] | None = None @dataclass class StatusResult: - """StatusResult type definition.""" + """StatusResult.""" ready: bool | None = None message: str | None = None @@ -137,14 +121,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters type definition.""" + """NewParameters.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult type definition.""" + """NewResult.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -160,7 +144,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult type definition.""" + """SubscribeResult.""" subscription: Any | None = None @@ -227,12 +211,7 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe( - self, - events: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" params = { "events": events, @@ -244,7 +223,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7623381706040..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class PartitionKey: - """PartitionKey type definition.""" + """PartitionKey.""" user_context: str | None = None source_origin: str | None = None @@ -39,7 +23,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters type definition.""" + """GetCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -47,15 +31,15 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult type definition.""" + """GetCookiesResult.""" - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None partition_key: Any | None = None @dataclass class SetCookieParameters: - """SetCookieParameters type definition.""" + """SetCookieParameters.""" cookie: Any | None = None partition: Any | None = None @@ -63,14 +47,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult type definition.""" + """SetCookieResult.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters type definition.""" + """DeleteCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -78,7 +62,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult type definition.""" + """DeleteCookiesResult.""" partition_key: Any | None = None @@ -123,7 +107,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -251,6 +235,39 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result + + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result + def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 99250afca4c68..98d852512f591 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,22 +6,23 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class InstallParameters: - """InstallParameters type definition.""" + """InstallParameters.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath type definition.""" + """ExtensionPath.""" type: str = field(default="path", init=False) path: str | None = None @@ -46,7 +30,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath type definition.""" + """ExtensionArchivePath.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -54,7 +38,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded type definition.""" + """ExtensionBase64Encoded.""" type: str = field(default="base64", init=False) value: str | None = None @@ -62,14 +46,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult type definition.""" + """InstallResult.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters type definition.""" + """UninstallParameters.""" extension: Any | None = None @@ -104,9 +88,13 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None + k + for k, v in { + "path": path, + "archive_path": archive_path, + "base64_value": base64_value, + }.items() + if v is not None ] if len(provided) != 1: raise ValueError( @@ -121,17 +109,24 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: Any | None = None): + + def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 3b61280320adcb7414a2676966d926c0ea22a462 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:57:06 +0000 Subject: [PATCH 10/42] Fix webextension and log from comments --- py/generate_bidi.py | 21 ++++++++++++++++++- py/private/bidi_enhancements_manifest.py | 7 +++++++ py/selenium/webdriver/common/bidi/log.py | 4 +++- .../webdriver/common/bidi/webextension.py | 11 +++------- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index d14e2575c8bfd..53eb3a9e52fcc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -672,6 +672,11 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += extra_cls code += "\n\n" + # Emit extra type aliases from enhancement manifest (e.g., union types for events) + for extra_alias in enhancements.get("extra_type_aliases", []): + code += extra_alias + code += "\n\n" + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet # They will be generated after the class definition instead @@ -976,8 +981,22 @@ def clear_event_handlers(self) -> None: # This ensures all types are available when we create the aliases if self.events: code += "\n# Event Info Type Aliases\n" + # Check for explicit event_type_aliases in the enhancement manifest + event_type_aliases = enhancements.get("event_type_aliases", {}) for event_def in self.events: - code += event_def.to_python_dataclass() + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # Check if there's an explicit alias defined in the enhancement manifest + if event_name in event_type_aliases: + # Use the alias directly + type_name = event_type_aliases[event_name] + code += f"# Event: {event_def.method}\n" + code += f"{event_def.name} = {type_name}\n" + else: + # Fall back to the original behavior + code += event_def.to_python_dataclass() code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 5dcce3c25ffeb..f06a4119625e6 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -284,6 +284,13 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), )''', ], + # Define Entry union type for log.entryAdded event deserialization + "extra_type_aliases": [ + "Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry", + ], + "event_type_aliases": { + "entry_added": "Entry", + }, }, "emulation": { "extra_methods": [ diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7aa7fbf7a3171..c58018e8a947a 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -96,6 +96,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "entry_added": "log.entryAdded", @@ -292,7 +294,7 @@ def clear_event_handlers(self) -> None: # Event Info Type Aliases # Event: log.entryAdded -EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined +EntryAdded = Entry # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 98d852512f591..e007f8e4792a6 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -88,13 +88,9 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k - for k, v in { - "path": path, - "archive_path": archive_path, - "base64_value": base64_value, - }.items() - if v is not None + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None ] if len(provided) != 1: raise ValueError( @@ -109,7 +105,6 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: str | dict): """Uninstall a web extension. From 95485561bded56835181ae590a7d9057d550ac97 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:06:48 +0000 Subject: [PATCH 11/42] Correct usage of dafault_factory --- py/generate_bidi.py | 13 +++-- py/selenium/webdriver/common/bidi/browser.py | 6 +-- .../webdriver/common/bidi/browsing_context.py | 6 +-- .../webdriver/common/bidi/emulation.py | 48 +++++++++---------- py/selenium/webdriver/common/bidi/input.py | 12 ++--- py/selenium/webdriver/common/bidi/network.py | 34 ++++++------- py/selenium/webdriver/common/bidi/script.py | 18 +++---- py/selenium/webdriver/common/bidi/session.py | 14 +++--- py/selenium/webdriver/common/bidi/storage.py | 2 +- 9 files changed, 80 insertions(+), 73 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 53eb3a9e52fcc..f4915aa1ad123 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -385,9 +385,16 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> if literal_match: literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' - # Check if this field is a list type - elif "List[" in python_type: - code += f" {snake_name}: {python_type} = field(default_factory=list)\n" + # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) + elif python_type.startswith("list["): + # Remove the trailing ' | None' from list types since default_factory=list ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" + # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) + elif python_type.startswith("dict["): + # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" else: code += f" {snake_name}: {python_type} = None\n" diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7cf9678c9b007..0618beb14ddef 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -131,14 +131,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any] = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -171,7 +171,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..d17829709c0c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -220,14 +220,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any] = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any] = field(default_factory=list) @dataclass @@ -300,7 +300,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a85eaad3e223a..7edb7a9dacd06 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -41,16 +41,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -78,8 +78,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -87,8 +87,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -111,8 +111,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -128,8 +128,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -137,8 +137,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -146,8 +146,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -155,8 +155,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -164,8 +164,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -173,16 +173,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) class Emulation: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..a294bde307b89 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -45,7 +45,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -54,7 +54,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -63,7 +63,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -73,7 +73,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -89,7 +89,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -163,7 +163,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 2290c9fec12d3..af079f421546c 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -75,7 +75,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any] = field(default_factory=list) @dataclass @@ -171,13 +171,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any] = field(default_factory=list) @dataclass @@ -219,11 +219,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any] = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -237,9 +237,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + url_patterns: list[Any] = field(default_factory=list) @dataclass @@ -254,9 +254,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +315,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +340,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index c7bfcb3774dff..492d1fe431680 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -216,7 +216,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any] = field(default_factory=list) @dataclass @@ -460,7 +460,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any] = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +499,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any] = field(default_factory=list) @dataclass @@ -530,9 +530,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) sandbox: str | None = None @@ -547,7 +547,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any] = field(default_factory=list) target: Any | None = None @@ -558,7 +558,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +589,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..f1430cb6e59d3 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -26,7 +26,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any] = field(default_factory=list) @dataclass @@ -62,7 +62,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any] = field(default_factory=list) @dataclass @@ -92,23 +92,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..833e9cdc74f2a 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -33,7 +33,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) partition_key: Any | None = None From 829808811fc577d9aab0c58c887e7cc8f66ac174 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:16:41 +0000 Subject: [PATCH 12/42] fixing generating extra pass --- py/generate_bidi.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index f4915aa1ad123..a53ea96db7481 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -946,7 +946,7 @@ def clear_event_handlers(self) -> None: method_enhancements = enhancements.get(method_name_snake, {}) code += command.to_python_method(method_enhancements) code += "\n" - else: + elif not self.events and not enhancements.get("extra_methods", []): code += " pass\n" # Emit extra methods from enhancement manifest diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index c58018e8a947a..1f16849b8e03d 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -264,7 +264,6 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - pass def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From a7f716d87d29d0ecc852f4296a0d54a4f80692d9 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:24 +0000 Subject: [PATCH 13/42] fix window tests --- py/private/bidi_enhancements_manifest.py | 37 ++++++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 37 ++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index f06a4119625e6..e33a11d5f2b79 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -130,6 +130,43 @@ if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ''' def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) return self._conn.execute(cmd)''', ], }, diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 0618beb14ddef..7c1958fd435f0 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -341,3 +341,40 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) + def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) + return self._conn.execute(cmd) From dfd366d933c24f8ec306ccc9cc83290965316dc5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:35 +0000 Subject: [PATCH 14/42] fix window tests --- py/private/bidi_enhancements_manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index e33a11d5f2b79..2b93f36f1a5dc 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -158,7 +158,7 @@ if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } From 2c9c14feae3548d34ec3418ef1824ecc8855ffd6 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 11:27:05 +0000 Subject: [PATCH 15/42] correct checks on method arguments --- py/generate_bidi.py | 151 +++++++++++------- py/selenium/webdriver/common/bidi/browser.py | 9 +- .../webdriver/common/bidi/browsing_context.py | 36 +++++ .../webdriver/common/bidi/emulation.py | 30 ++++ py/selenium/webdriver/common/bidi/input.py | 15 ++ py/selenium/webdriver/common/bidi/network.py | 69 ++++++++ py/selenium/webdriver/common/bidi/script.py | 32 ++++ py/selenium/webdriver/common/bidi/session.py | 6 + py/selenium/webdriver/common/bidi/storage.py | 3 + 9 files changed, 292 insertions(+), 59 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index a53ea96db7481..78a7603b929c0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -53,7 +52,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,11 +138,12 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + required_params: set[str] = field(default_factory=set) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -178,7 +178,17 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' - # Add validation if specified + # Add automatic validation for required parameters + # (This is used unless there's no required_params, in which case all params are optional) + if self.required_params: + for param_name, snake_param in param_names: + if param_name in self.required_params: + method_snake = self._camel_to_snake(self.name) + body += f" if {snake_param} is None:\n" + body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + body += "\n" + + # Add validation if specified in enhancements (for additional business logic validation) if "validate" in enhancements: validate_func = enhancements["validate"] # Build parameter list for validation function @@ -264,45 +274,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +361,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +376,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -460,7 +470,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -537,10 +547,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -555,7 +565,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -1007,9 +1017,9 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" for event_def in self.events: # Convert method name to user-friendly event name @@ -1037,9 +1047,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1047,12 +1057,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1201,7 +1211,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1251,7 +1261,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1361,14 +1371,17 @@ def _extract_commands(self) -> None: if module_name not in self.modules: self.modules[module_name] = CddlModule(name=module_name) - # Extract parameters - params = self._extract_parameters(params_type) + # Extract parameters and required parameters + params, required_params = self._extract_parameters_and_required( + params_type + ) # Create command cmd = CddlCommand( module=module_name, name=command_name, params=params, + required_params=required_params, description=f"Execute {method}", ) @@ -1378,24 +1391,36 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), merging all fields from each alternative as optional parameters. """ + params, _ = self._extract_parameters_and_required(params_type, _seen) + return params + + def _extract_parameters_and_required( + self, params_type: str, _seen: set[str] | None = None + ) -> tuple[dict[str, str], set[str]]: + """Extract parameters and track which are required from a parameter type definition. + + Returns: + Tuple of (params dict, required_params set) + """ params = {} + required = set() if _seen is None: _seen = set() if params_type in _seen: - return params + return params, required _seen.add(params_type) if params_type not in self.definitions: logger.debug(f"Parameter type not found: {params_type}") - return params + return params, required definition = self.definitions[params_type] @@ -1409,10 +1434,15 @@ def _extract_parameters( alternatives = [a.strip() for a in stripped.split("/") if a.strip()] all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) if all_named: + # For union types, collect parameters from all alternatives + # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) + alt_params, _ = self._extract_parameters_and_required( + alt_type, _seen + ) params.update(alt_params) - return params + # Note: We intentionally DON'T add to required, since these are union alternatives + return params, required # Remove the outer curly braces and split by comma # Then parse each line for key: type patterns @@ -1429,6 +1459,9 @@ def _extract_parameters( continue # Match pattern: [?] name: type + # Check if parameter has optional marker (?) + is_optional = line.startswith("?") + # Using a simple pattern that handles optional prefix match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) if not match: @@ -1443,11 +1476,13 @@ def _extract_parameters( # Skip lines that are part of nested definitions if "{" not in normalized_type and "(" not in normalized_type: params[param_name] = normalized_type + if not is_optional: + required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" + f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" ) - return params + return params, required def module_name_to_class_name(module_name: str) -> str: @@ -1492,7 +1527,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1507,7 +1542,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1729,7 +1764,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7c1958fd435f0..c4017265ac757 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -275,6 +275,9 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" + if user_context is None: + raise TypeError("remove_user_context() missing required argument: 'user_context'") + params = { "userContext": user_context, } @@ -285,6 +288,9 @@ def remove_user_context(self, user_context: Any | None = None): def set_client_window_state(self, client_window: Any | None = None): """Execute browser.setClientWindowState.""" + if client_window is None: + raise TypeError("set_client_window_state() missing required argument: 'client_window'") + params = { "clientWindow": client_window, } @@ -295,6 +301,7 @@ def set_client_window_state(self, client_window: Any | None = None): def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) download_behavior = None @@ -368,7 +375,7 @@ def set_client_window_state( if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index d17829709c0c3..775bcdb8f9dbb 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -622,6 +622,9 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" + if context is None: + raise TypeError("activate() missing required argument: 'context'") + params = { "context": context, } @@ -632,6 +635,9 @@ def activate(self, context: Any | None = None): def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" + if context is None: + raise TypeError("capture_screenshot() missing required argument: 'context'") + params = { "context": context, "format": format, @@ -648,6 +654,9 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" + if context is None: + raise TypeError("close() missing required argument: 'context'") + params = { "context": context, "promptUnload": prompt_unload, @@ -659,6 +668,9 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" + if type is None: + raise TypeError("create() missing required argument: 'type'") + params = { "type": type, "referenceContext": reference_context, @@ -701,6 +713,9 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" + if context is None: + raise TypeError("handle_user_prompt() missing required argument: 'context'") + params = { "context": context, "accept": accept, @@ -713,6 +728,11 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" + if context is None: + raise TypeError("locate_nodes() missing required argument: 'context'") + if locator is None: + raise TypeError("locate_nodes() missing required argument: 'locator'") + params = { "context": context, "locator": locator, @@ -730,6 +750,11 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" + if context is None: + raise TypeError("navigate() missing required argument: 'context'") + if url is None: + raise TypeError("navigate() missing required argument: 'url'") + params = { "context": context, "url": url, @@ -742,6 +767,9 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" + if context is None: + raise TypeError("print() missing required argument: 'context'") + params = { "context": context, "background": background, @@ -760,6 +788,9 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" + if context is None: + raise TypeError("reload() missing required argument: 'context'") + params = { "context": context, "ignoreCache": ignore_cache, @@ -785,6 +816,11 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" + if context is None: + raise TypeError("traverse_history() missing required argument: 'context'") + if delta is None: + raise TypeError("traverse_history() missing required argument: 'delta'") + params = { "context": context, "delta": delta, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 7edb7a9dacd06..8428c233682b8 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -193,6 +193,9 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" + if theme is None: + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + params = { "theme": theme, "contexts": contexts, @@ -216,6 +219,9 @@ def set_geolocation_override(self, contexts: List[Any] | None = None, user_conte def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" + if locale is None: + raise TypeError("set_locale_override() missing required argument: 'locale'") + params = { "locale": locale, "contexts": contexts, @@ -228,6 +234,9 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setNetworkConditions.""" + if network_conditions is None: + raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") + params = { "networkConditions": network_conditions, "contexts": contexts, @@ -240,6 +249,9 @@ def set_network_conditions(self, network_conditions: Any | None = None, contexts def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" + if screen_area is None: + raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + params = { "screenArea": screen_area, "contexts": contexts, @@ -252,6 +264,9 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenOrientationOverride.""" + if screen_orientation is None: + raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") + params = { "screenOrientation": screen_orientation, "contexts": contexts, @@ -264,6 +279,9 @@ def set_screen_orientation_override(self, screen_orientation: Any | None = None, def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setUserAgentOverride.""" + if user_agent is None: + raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") + params = { "userAgent": user_agent, "contexts": contexts, @@ -276,6 +294,9 @@ def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" + if viewport_meta is None: + raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + params = { "viewportMeta": viewport_meta, "contexts": contexts, @@ -288,6 +309,9 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScriptingEnabled.""" + if enabled is None: + raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") + params = { "enabled": enabled, "contexts": contexts, @@ -300,6 +324,9 @@ def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" + if scrollbar_type is None: + raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + params = { "scrollbarType": scrollbar_type, "contexts": contexts, @@ -312,6 +339,9 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTimezoneOverride.""" + if timezone is None: + raise TypeError("set_timezone_override() missing required argument: 'timezone'") + params = { "timezone": timezone, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index a294bde307b89..2a19d8072781a 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -370,6 +370,11 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" + if context is None: + raise TypeError("perform_actions() missing required argument: 'context'") + if actions is None: + raise TypeError("perform_actions() missing required argument: 'actions'") + params = { "context": context, "actions": actions, @@ -381,6 +386,9 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" + if context is None: + raise TypeError("release_actions() missing required argument: 'context'") + params = { "context": context, } @@ -391,6 +399,13 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" + if context is None: + raise TypeError("set_files() missing required argument: 'context'") + if element is None: + raise TypeError("set_files() missing required argument: 'element'") + if files is None: + raise TypeError("set_files() missing required argument: 'files'") + params = { "context": context, "element": element, diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index af079f421546c..1f6b0471f2414 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -565,6 +565,11 @@ def __init__(self, conn) -> None: def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" + if data_types is None: + raise TypeError("add_data_collector() missing required argument: 'data_types'") + if max_encoded_data_size is None: + raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + params = { "dataTypes": data_types, "maxEncodedDataSize": max_encoded_data_size, @@ -579,6 +584,9 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" + if phases is None: + raise TypeError("add_intercept() missing required argument: 'phases'") + params = { "phases": phases, "contexts": contexts, @@ -591,6 +599,9 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" + if request is None: + raise TypeError("continue_request() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -606,6 +617,9 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" + if request is None: + raise TypeError("continue_response() missing required argument: 'request'") + params = { "request": request, "cookies": cookies, @@ -621,6 +635,9 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" + if request is None: + raise TypeError("continue_with_auth() missing required argument: 'request'") + params = { "request": request, } @@ -631,6 +648,13 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" + if data_type is None: + raise TypeError("disown_data() missing required argument: 'data_type'") + if collector is None: + raise TypeError("disown_data() missing required argument: 'collector'") + if request is None: + raise TypeError("disown_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -643,6 +667,9 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" + if request is None: + raise TypeError("fail_request() missing required argument: 'request'") + params = { "request": request, } @@ -653,6 +680,11 @@ def fail_request(self, request: Any | None = None): def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" + if data_type is None: + raise TypeError("get_data() missing required argument: 'data_type'") + if request is None: + raise TypeError("get_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -666,6 +698,9 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" + if request is None: + raise TypeError("provide_response() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -681,6 +716,9 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" + if collector is None: + raise TypeError("remove_data_collector() missing required argument: 'collector'") + params = { "collector": collector, } @@ -691,6 +729,9 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" + if intercept is None: + raise TypeError("remove_intercept() missing required argument: 'intercept'") + params = { "intercept": intercept, } @@ -701,6 +742,9 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" + if cache_behavior is None: + raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + params = { "cacheBehavior": cache_behavior, "contexts": contexts, @@ -712,6 +756,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" + if headers is None: + raise TypeError("set_extra_headers() missing required argument: 'headers'") + params = { "headers": headers, "contexts": contexts, @@ -724,6 +771,11 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" + if method is None: + raise TypeError("before_request_sent() missing required argument: 'method'") + if params is None: + raise TypeError("before_request_sent() missing required argument: 'params'") + params = { "initiator": initiator, "method": method, @@ -736,6 +788,13 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" + if error_text is None: + raise TypeError("fetch_error() missing required argument: 'error_text'") + if method is None: + raise TypeError("fetch_error() missing required argument: 'method'") + if params is None: + raise TypeError("fetch_error() missing required argument: 'params'") + params = { "errorText": error_text, "method": method, @@ -748,6 +807,13 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" + if response is None: + raise TypeError("response_completed() missing required argument: 'response'") + if method is None: + raise TypeError("response_completed() missing required argument: 'method'") + if params is None: + raise TypeError("response_completed() missing required argument: 'params'") + params = { "response": response, "method": method, @@ -760,6 +826,9 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" + if response is None: + raise TypeError("response_started() missing required argument: 'response'") + params = { "response": response, } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 492d1fe431680..0f59c400a38c2 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -785,6 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" + if function_declaration is None: + raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + params = { "functionDeclaration": function_declaration, "arguments": arguments, @@ -799,6 +802,11 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" + if handles is None: + raise TypeError("disown() missing required argument: 'handles'") + if target is None: + raise TypeError("disown() missing required argument: 'target'") + params = { "handles": handles, "target": target, @@ -810,6 +818,13 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" + if function_declaration is None: + raise TypeError("call_function() missing required argument: 'function_declaration'") + if await_promise is None: + raise TypeError("call_function() missing required argument: 'await_promise'") + if target is None: + raise TypeError("call_function() missing required argument: 'target'") + params = { "functionDeclaration": function_declaration, "awaitPromise": await_promise, @@ -827,6 +842,13 @@ def call_function(self, function_declaration: Any | None = None, await_promise: def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" + if expression is None: + raise TypeError("evaluate() missing required argument: 'expression'") + if target is None: + raise TypeError("evaluate() missing required argument: 'target'") + if await_promise is None: + raise TypeError("evaluate() missing required argument: 'await_promise'") + params = { "expression": expression, "target": target, @@ -853,6 +875,9 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" + if script is None: + raise TypeError("remove_preload_script() missing required argument: 'script'") + params = { "script": script, } @@ -863,6 +888,13 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" + if channel is None: + raise TypeError("message() missing required argument: 'channel'") + if data is None: + raise TypeError("message() missing required argument: 'data'") + if source is None: + raise TypeError("message() missing required argument: 'source'") + params = { "channel": channel, "data": data, diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index f1430cb6e59d3..374375a62f2ec 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -194,6 +194,9 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" + if capabilities is None: + raise TypeError("new() missing required argument: 'capabilities'") + params = { "capabilities": capabilities, } @@ -213,6 +216,9 @@ def end(self): def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" + if events is None: + raise TypeError("subscribe() missing required argument: 'events'") + params = { "events": events, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 833e9cdc74f2a..8742dc61ebccf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -248,6 +248,9 @@ def get_cookies(self, filter: Any | None = None, partition: Any | None = None): def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): """Execute storage.setCookie.""" + if cookie is None: + raise TypeError("set_cookie() missing required argument: 'cookie'") + params = { "cookie": cookie, "partition": partition, From 75929331d7d9a4d3918bcda3f35027b289b08445 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:32:27 +0000 Subject: [PATCH 16/42] improve generation so we don't need to run ruffs over it --- py/generate_bidi.py | 81 +++++++-- py/private/bidi_enhancements_manifest.py | 14 +- py/selenium/webdriver/common/bidi/browser.py | 47 +---- .../webdriver/common/bidi/browsing_context.py | 167 ++++++++++++------ .../webdriver/common/bidi/emulation.py | 131 ++++---------- py/selenium/webdriver/common/bidi/input.py | 18 +- py/selenium/webdriver/common/bidi/log.py | 6 +- py/selenium/webdriver/common/bidi/network.py | 119 +++++++++---- py/selenium/webdriver/common/bidi/script.py | 69 ++++++-- py/selenium/webdriver/common/bidi/session.py | 11 +- py/selenium/webdriver/common/bidi/storage.py | 36 ---- 11 files changed, 384 insertions(+), 315 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 78a7603b929c0..affd0a63a750c 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -170,22 +170,34 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: param_strs.append(f"{snake_param}: {python_type} | None = None") if param_strs: - param_list = "self, " + ", ".join(param_strs) + # Check if full signature would exceed line length limit (120 chars) + single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + if len(single_line_signature) > 120: + # Format parameters on multiple lines + body = f" def {method_name}(\n" + body += " self,\n" + for i, param_str in enumerate(param_strs): + if i < len(param_strs) - 1: + body += f" {param_str},\n" + else: + body += f" {param_str},\n" + body += " ):\n" + else: + param_list = "self, " + ", ".join(param_strs) + body = f" def {method_name}({param_list}):\n" else: - param_list = "self" - - # Build method body - body = f" def {method_name}({param_list}):\n" + body = f" def {method_name}(self):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add automatic validation for required parameters # (This is used unless there's no required_params, in which case all params are optional) if self.required_params: + method_snake = self._camel_to_snake(self.name) for param_name, snake_param in param_names: if param_name in self.required_params: - method_snake = self._camel_to_snake(self.name) body += f" if {snake_param} is None:\n" - body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + msg = f"{method_snake}() missing required argument:" + body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -247,7 +259,6 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -667,8 +678,20 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Generate enums first (excluding those in exclude_types) + exclude_types = set(enhancements.get("exclude_types", [])) + + # Also exclude any types that have extra_dataclasses overrides + # Extract class names from extra_dataclasses strings + for extra_cls in enhancements.get("extra_dataclasses", []): + # Match "class ClassName:" pattern + match = re.search(r"class\s+(\w+)\s*:", extra_cls) + if match: + exclude_types.add(match.group(1)) + for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -677,7 +700,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -946,6 +968,16 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) + + # Automatically exclude methods that are defined in extra_methods + # to prevent generating duplicates + if "extra_methods" in enhancements: + for extra_method in enhancements["extra_methods"]: + # Extract method name from "def method_name(" + match = re.search(r"def\s+(\w+)\s*\(", extra_method) + if match: + exclude_methods = list(exclude_methods) + [match.group(1)] + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1026,9 +1058,26 @@ def clear_event_handlers(self) -> None: method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Try to get event class from globals, default to dict if not found + getter = f'_globals.get("{event_def.name}", dict)' + condition = f'_globals.get("{event_def.name}")' + event_class = f'{getter} if {condition} else dict' + + # Build the entry line and check if it exceeds 120 chars + single_line = ( + f' "{event_name}": ' + f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + ) + + if len(single_line) > 120: + # Break into multiple lines + code += f' "{event_name}": EventConfig(\n' + code += f' "{event_name}",\n' + code += f' "{event_def.method}",\n' + code += f' {event_class},\n' + code += ' ),\n' + else: + code += single_line + '\n' # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] @@ -1126,9 +1175,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1479,7 +1525,8 @@ def _extract_parameters_and_required( if not is_optional: required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" + f"Extracted param {param_name}: {normalized_type} " + f"(required={not is_optional}) from {params_type}" ) return params, required diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 2b93f36f1a5dc..40647157f8535 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -252,19 +252,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ) return cls(download_params=dp)''', ], - # Non-CDDL download events (Chromium-specific, not in the BiDi spec) - "extra_events": [ - { - "event_key": "download_will_begin", - "bidi_event": "browsingContext.downloadWillBegin", - "event_class": "DownloadWillBeginParams", - }, - { - "event_key": "download_end", - "bidi_event": "browsingContext.downloadEnd", - "event_class": "DownloadEndParams", - }, - ], + # Download events are now in the CDDL spec, so no extra_events needed }, "log": { # Make LogLevel an alias for Level so existing code using LogLevel works diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index c4017265ac757..77ae8f0696281 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -61,14 +61,6 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: """ClientWindowInfo.""" @@ -212,7 +204,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -276,7 +273,7 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" if user_context is None: - raise TypeError("remove_user_context() missing required argument: 'user_context'") + raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") params = { "userContext": user_context, @@ -286,36 +283,6 @@ def remove_user_context(self, user_context: Any | None = None): result = self._conn.execute(cmd) return result - def set_client_window_state(self, client_window: Any | None = None): - """Execute browser.setClientWindowState.""" - if client_window is None: - raise TypeError("set_client_window_state() missing required argument: 'client_window'") - - params = { - "clientWindow": client_window, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setClientWindowState", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 775bcdb8f9dbb..3f877b06b00ab 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -328,20 +328,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -421,8 +407,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "navigation_failed": "browsingContext.navigationFailed", "user_prompt_closed": "browsingContext.userPromptClosed", "user_prompt_opened": "browsingContext.userPromptOpened", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", } def _deserialize_info_list(items: list) -> list | None: @@ -623,7 +607,7 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" if context is None: - raise TypeError("activate() missing required argument: 'context'") + raise TypeError("activate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -633,10 +617,16 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" if context is None: - raise TypeError("capture_screenshot() missing required argument: 'context'") + raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -655,7 +645,7 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" if context is None: - raise TypeError("close() missing required argument: 'context'") + raise TypeError("close() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -666,10 +656,16 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" if type is None: - raise TypeError("create() missing required argument: 'type'") + raise TypeError("create() missing required argument: {{snake_param!r}}") params = { "type": type, @@ -714,7 +710,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" if context is None: - raise TypeError("handle_user_prompt() missing required argument: 'context'") + raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -726,12 +722,19 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" if context is None: - raise TypeError("locate_nodes() missing required argument: 'context'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") if locator is None: - raise TypeError("locate_nodes() missing required argument: 'locator'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -751,9 +754,9 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" if context is None: - raise TypeError("navigate() missing required argument: 'context'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") if url is None: - raise TypeError("navigate() missing required argument: 'url'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -765,10 +768,18 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" if context is None: - raise TypeError("print() missing required argument: 'context'") + raise TypeError("print() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -789,7 +800,7 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" if context is None: - raise TypeError("reload() missing required argument: 'context'") + raise TypeError("reload() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -801,7 +812,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -817,9 +834,9 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: - raise TypeError("traverse_history() missing required argument: 'context'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") if delta is None: - raise TypeError("traverse_history() missing required argument: 'delta'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -904,20 +921,70 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), + "context_created": EventConfig( + "context_created", + "browsingContext.contextCreated", + _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, + ), + "context_destroyed": EventConfig( + "context_destroyed", + "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, + ), + "navigation_started": EventConfig( + "navigation_started", + "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, + ), + "fragment_navigated": EventConfig( + "fragment_navigated", + "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, + ), + "history_updated": EventConfig( + "history_updated", + "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, + ), + "dom_content_loaded": EventConfig( + "dom_content_loaded", + "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, + ), + "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), + "download_will_begin": EventConfig( + "download_will_begin", + "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, + ), + "download_end": EventConfig( + "download_end", + "browsingContext.downloadEnd", + _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, + ), + "navigation_aborted": EventConfig( + "navigation_aborted", + "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, + ), + "navigation_committed": EventConfig( + "navigation_committed", + "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, + ), + "navigation_failed": EventConfig( + "navigation_failed", + "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, + ), + "user_prompt_closed": EventConfig( + "user_prompt_closed", + "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, + ), + "user_prompt_opened": EventConfig( + "user_prompt_opened", + "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 8428c233682b8..d482fecc755cb 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -191,10 +191,15 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") params = { "theme": theme, @@ -206,21 +211,15 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" if locale is None: - raise TypeError("set_locale_override() missing required argument: 'locale'") + raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") params = { "locale": locale, @@ -232,25 +231,15 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - if network_conditions is None: - raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") - - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") params = { "screenArea": screen_area, @@ -262,40 +251,15 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - if screen_orientation is None: - raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") - - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - if user_agent is None: - raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") - - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") params = { "viewportMeta": viewport_meta, @@ -307,25 +271,15 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - if enabled is None: - raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") - - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") params = { "scrollbarType": scrollbar_type, @@ -337,21 +291,6 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - if timezone is None: - raise TypeError("set_timezone_override() missing required argument: 'timezone'") - - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 2a19d8072781a..0990dacc39363 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -371,9 +371,9 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" if context is None: - raise TypeError("perform_actions() missing required argument: 'context'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") if actions is None: - raise TypeError("perform_actions() missing required argument: 'actions'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -387,7 +387,7 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" if context is None: - raise TypeError("release_actions() missing required argument: 'context'") + raise TypeError("release_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -400,11 +400,11 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" if context is None: - raise TypeError("set_files() missing required argument: 'context'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if element is None: - raise TypeError("set_files() missing required argument: 'element'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if files is None: - raise TypeError("set_files() missing required argument: 'files'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -469,5 +469,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": EventConfig( + "file_dialog_opened", + "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 1f16849b8e03d..07121242348ea 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -299,5 +299,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), + "entry_added": EventConfig( + "entry_added", + "log.entryAdded", + _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1f6b0471f2414..d7baeb07040ce 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -563,12 +563,19 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: List[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.addDataCollector.""" if data_types is None: - raise TypeError("add_data_collector() missing required argument: 'data_types'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") params = { "dataTypes": data_types, @@ -582,10 +589,15 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: List[Any] | None = None, + contexts: List[Any] | None = None, + url_patterns: List[Any] | None = None, + ): """Execute network.addIntercept.""" if phases is None: - raise TypeError("add_intercept() missing required argument: 'phases'") + raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") params = { "phases": phases, @@ -597,10 +609,18 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" if request is None: - raise TypeError("continue_request() missing required argument: 'request'") + raise TypeError("continue_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -615,10 +635,18 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: List[Any] | None = None, + credentials: Any | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" if request is None: - raise TypeError("continue_response() missing required argument: 'request'") + raise TypeError("continue_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -636,7 +664,7 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" if request is None: - raise TypeError("continue_with_auth() missing required argument: 'request'") + raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -649,11 +677,11 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" if data_type is None: - raise TypeError("disown_data() missing required argument: 'data_type'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if collector is None: - raise TypeError("disown_data() missing required argument: 'collector'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("disown_data() missing required argument: 'request'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -668,7 +696,7 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" if request is None: - raise TypeError("fail_request() missing required argument: 'request'") + raise TypeError("fail_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -678,12 +706,18 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" if data_type is None: - raise TypeError("get_data() missing required argument: 'data_type'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("get_data() missing required argument: 'request'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -696,10 +730,18 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" if request is None: - raise TypeError("provide_response() missing required argument: 'request'") + raise TypeError("provide_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -717,7 +759,7 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" if collector is None: - raise TypeError("remove_data_collector() missing required argument: 'collector'") + raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") params = { "collector": collector, @@ -730,7 +772,7 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" if intercept is None: - raise TypeError("remove_intercept() missing required argument: 'intercept'") + raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") params = { "intercept": intercept, @@ -743,7 +785,7 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") params = { "cacheBehavior": cache_behavior, @@ -754,10 +796,15 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.setExtraHeaders.""" if headers is None: - raise TypeError("set_extra_headers() missing required argument: 'headers'") + raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") params = { "headers": headers, @@ -772,9 +819,9 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" if method is None: - raise TypeError("before_request_sent() missing required argument: 'method'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("before_request_sent() missing required argument: 'params'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") params = { "initiator": initiator, @@ -789,11 +836,11 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" if error_text is None: - raise TypeError("fetch_error() missing required argument: 'error_text'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("fetch_error() missing required argument: 'method'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("fetch_error() missing required argument: 'params'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") params = { "errorText": error_text, @@ -808,11 +855,11 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" if response is None: - raise TypeError("response_completed() missing required argument: 'response'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("response_completed() missing required argument: 'method'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("response_completed() missing required argument: 'params'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -827,7 +874,7 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" if response is None: - raise TypeError("response_started() missing required argument: 'response'") + raise TypeError("response_started() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -995,6 +1042,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": EventConfig( + "auth_required", + "network.authRequired", + _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0f59c400a38c2..8e832f4a9cae9 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -783,10 +783,17 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -803,9 +810,9 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: - raise TypeError("disown() missing required argument: 'handles'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("disown() missing required argument: 'target'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") params = { "handles": handles, @@ -816,14 +823,24 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: List[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" if function_declaration is None: - raise TypeError("call_function() missing required argument: 'function_declaration'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("call_function() missing required argument: 'await_promise'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("call_function() missing required argument: 'target'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -840,14 +857,22 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" if expression is None: - raise TypeError("evaluate() missing required argument: 'expression'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("evaluate() missing required argument: 'target'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("evaluate() missing required argument: 'await_promise'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") params = { "expression": expression, @@ -876,7 +901,7 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" if script is None: - raise TypeError("remove_preload_script() missing required argument: 'script'") + raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") params = { "script": script, @@ -889,11 +914,11 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" if channel is None: - raise TypeError("message() missing required argument: 'channel'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if data is None: - raise TypeError("message() missing required argument: 'data'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if source is None: - raise TypeError("message() missing required argument: 'source'") + raise TypeError("message() missing required argument: {{snake_param!r}}") params = { "channel": channel, @@ -1314,6 +1339,14 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": EventConfig( + "realm_created", + "script.realmCreated", + _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, + ), + "realm_destroyed": EventConfig( + "realm_destroyed", + "script.realmDestroyed", + _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 374375a62f2ec..c7dd45ec824b8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -195,7 +195,7 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" if capabilities is None: - raise TypeError("new() missing required argument: 'capabilities'") + raise TypeError("new() missing required argument: {{snake_param!r}}") params = { "capabilities": capabilities, @@ -214,10 +214,15 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute session.subscribe.""" if events is None: - raise TypeError("subscribe() missing required argument: 'events'") + raise TypeError("subscribe() missing required argument: {{snake_param!r}}") params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 8742dc61ebccf..267569f782289 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -235,42 +235,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - if cookie is None: - raise TypeError("set_cookie() missing required argument: 'cookie'") - - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): From d71f7b2f5e9ba2b0f5b9390e67f5564f8f5a4c12 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:44:03 +0000 Subject: [PATCH 17/42] make sure not to generate F401 ruff errors --- py/generate_bidi.py | 55 +++++++++++-------- py/selenium/webdriver/common/bidi/browser.py | 7 +-- .../webdriver/common/bidi/browsing_context.py | 15 +++-- py/selenium/webdriver/common/bidi/common.py | 7 ++- .../webdriver/common/bidi/emulation.py | 29 +++++----- py/selenium/webdriver/common/bidi/input.py | 17 +++--- py/selenium/webdriver/common/bidi/log.py | 11 ++-- py/selenium/webdriver/common/bidi/network.py | 43 +++++++-------- .../webdriver/common/bidi/permissions.py | 10 ++-- py/selenium/webdriver/common/bidi/script.py | 27 ++++----- py/selenium/webdriver/common/bidi/session.py | 15 +++-- py/selenium/webdriver/common/bidi/storage.py | 9 ++- .../webdriver/common/bidi/webextension.py | 7 +-- 13 files changed, 126 insertions(+), 126 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index affd0a63a750c..8372d25743c08 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -42,7 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any from .common import command_builder """ @@ -123,10 +123,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -171,7 +171,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_strs: # Check if full signature would exceed line length limit (120 chars) - single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + single_line_signature = ( + f" def {method_name}(self, {', '.join(param_strs)}):" + ) if len(single_line_signature) > 120: # Format parameters on multiple lines body = f" def {method_name}(\n" @@ -197,7 +199,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_name in self.required_params: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" - body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + body += ( + f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + ) body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -585,18 +589,23 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: enhancements = enhancements or {} code = MODULE_HEADER.format(self.name) + # Collect needed imports to avoid duplicates + needs_dataclass = self.commands or self.types or self.events + needs_field = self.types + needs_threading = self.events + needs_callable = self.events + needs_session = self.events + # Add imports if needed - if self.types: - code += "from dataclasses import field\n" - if self.commands or self.types: - code += "from typing import Generator\n" + if needs_dataclass: code += "from dataclasses import dataclass\n" - - # Add imports for event handling if needed - if self.events: + if needs_field: + code += "from dataclasses import field\n" + if needs_threading: code += "import threading\n" + if needs_callable: code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if needs_session: code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -680,7 +689,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate enums first (excluding those in exclude_types) exclude_types = set(enhancements.get("exclude_types", [])) - + # Also exclude any types that have extra_dataclasses overrides # Extract class names from extra_dataclasses strings for extra_cls in enhancements.get("extra_dataclasses", []): @@ -688,7 +697,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: match = re.search(r"class\s+(\w+)\s*:", extra_cls) if match: exclude_types.add(match.group(1)) - + for enum_def in self.enums: if enum_def.name in exclude_types: continue @@ -968,7 +977,7 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) - + # Automatically exclude methods that are defined in extra_methods # to prevent generating duplicates if "extra_methods" in enhancements: @@ -977,7 +986,7 @@ def clear_event_handlers(self) -> None: match = re.search(r"def\s+(\w+)\s*\(", extra_method) if match: exclude_methods = list(exclude_methods) + [match.group(1)] - + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1061,23 +1070,23 @@ def clear_event_handlers(self) -> None: # Try to get event class from globals, default to dict if not found getter = f'_globals.get("{event_def.name}", dict)' condition = f'_globals.get("{event_def.name}")' - event_class = f'{getter} if {condition} else dict' - + event_class = f"{getter} if {condition} else dict" + # Build the entry line and check if it exceeds 120 chars single_line = ( f' "{event_name}": ' f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' ) - + if len(single_line) > 120: # Break into multiple lines code += f' "{event_name}": EventConfig(\n' code += f' "{event_name}",\n' code += f' "{event_def.method}",\n' - code += f' {event_class},\n' - code += ' ),\n' + code += f" {event_class},\n" + code += " ),\n" else: - code += single_line + '\n' + code += single_line + "\n" # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 77ae8f0696281..a8fb60c98178d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 3f877b06b00ab..777005e0ce4e5 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -376,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index d482fecc755cb..0356372c48f03 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -194,8 +193,8 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override( self, theme: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: @@ -214,8 +213,8 @@ def set_forced_colors_mode_theme_override( def set_locale_override( self, locale: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setLocaleOverride.""" if locale is None: @@ -234,8 +233,8 @@ def set_locale_override( def set_screen_settings_override( self, screen_area: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: @@ -254,8 +253,8 @@ def set_screen_settings_override( def set_viewport_meta_override( self, viewport_meta: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: @@ -274,8 +273,8 @@ def set_viewport_meta_override( def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: @@ -291,7 +290,7 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 0990dacc39363..7e76cb831543f 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" if context is None: raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") @@ -397,7 +396,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" if context is None: raise TypeError("set_files() missing required argument: {{snake_param!r}}") diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 07121242348ea..fd712b7c9a8ab 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,14 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass +from typing import Any + from selenium.webdriver.common.bidi.session import Session @@ -60,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -85,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index d7baeb07040ce..74951031c597f 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -565,11 +564,11 @@ def __init__(self, conn) -> None: def add_data_collector( self, - data_types: List[Any] | None = None, + data_types: list[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.addDataCollector.""" if data_types is None: @@ -591,9 +590,9 @@ def add_data_collector( def add_intercept( self, - phases: List[Any] | None = None, - contexts: List[Any] | None = None, - url_patterns: List[Any] | None = None, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, ): """Execute network.addIntercept.""" if phases is None: @@ -613,8 +612,8 @@ def continue_request( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, method: Any | None = None, url: Any | None = None, ): @@ -638,9 +637,9 @@ def continue_request( def continue_response( self, request: Any | None = None, - cookies: List[Any] | None = None, + cookies: list[Any] | None = None, credentials: Any | None = None, - headers: List[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -734,8 +733,8 @@ def provide_response( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -782,7 +781,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") @@ -798,9 +797,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers( self, - headers: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.setExtraHeaders.""" if headers is None: diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 8e832f4a9cae9..6c2e4298a2dce 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -786,9 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script( self, function_declaration: Any | None = None, - arguments: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, sandbox: Any | None = None, ): """Execute script.addPreloadScript.""" @@ -807,7 +806,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: raise TypeError("disown() missing required argument: {{snake_param!r}}") @@ -828,7 +827,7 @@ def call_function( function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, - arguments: List[Any] | None = None, + arguments: list[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, @@ -946,8 +945,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,8 +1188,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c7dd45ec824b8..fcb42a4ad86fc 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -216,9 +215,9 @@ def end(self): def subscribe( self, - events: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute session.subscribe.""" if events is None: @@ -234,7 +233,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 267569f782289..089cee2c4fbdf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index e007f8e4792a6..b1bc09452bc63 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass From 29fb79fa2d7043e4a8d3cbcee2ba34f52dd14a5f Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:36:16 +0000 Subject: [PATCH 18/42] ruffs and mypy fixes --- py/generate_bidi.py | 51 +++++++++++++------ py/private/bidi_enhancements_manifest.py | 33 ++++++------ .../webdriver/common/bidi/browsing_context.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 5 +- .../webdriver/common/bidi/emulation.py | 12 ++--- py/selenium/webdriver/common/bidi/input.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 2 +- py/selenium/webdriver/common/bidi/network.py | 8 +-- py/selenium/webdriver/common/bidi/script.py | 2 +- py/selenium/webdriver/common/bidi/storage.py | 4 +- .../webdriver/common/bidi/webextension.py | 11 ++-- 11 files changed, 80 insertions(+), 52 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8372d25743c08..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3.10 """ Generate Python WebDriver BiDi command modules from CDDL specification. @@ -43,7 +43,6 @@ from __future__ import annotations from typing import Any -from .common import command_builder """ @@ -590,17 +589,17 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Collect needed imports to avoid duplicates + needs_command_builder = bool(self.commands) needs_dataclass = self.commands or self.types or self.events - needs_field = self.types needs_threading = self.events needs_callable = self.events needs_session = self.events - # Add imports if needed + # Add imports (field import will be added conditionally after code generation) + if needs_command_builder: + code += "from .common import command_builder\n" if needs_dataclass: code += "from dataclasses import dataclass\n" - if needs_field: - code += "from dataclasses import field\n" if needs_threading: code += "import threading\n" if needs_callable: @@ -954,7 +953,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined ) if self.name == "script": @@ -1095,6 +1094,26 @@ def clear_event_handlers(self) -> None: code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' code += "}\n" + # Check if field() is actually used in the generated code + # If so, add the field import after the dataclass import + if "field(" in code: + # Find where to insert the field import + # It should go after "from dataclasses import dataclass" line + dataclass_import_pattern = r"from dataclasses import dataclass\n" + if re.search(dataclass_import_pattern, code): + code = re.sub( + dataclass_import_pattern, + "from dataclasses import dataclass\nfrom dataclasses import field\n", + code, + count=1 + ) + elif "from dataclasses import" not in code: + # If there's no dataclasses import yet, add field import after typing + code = code.replace( + "from typing import Any\n", + "from typing import Any\nfrom dataclasses import field\n" + ) + return code @@ -1634,12 +1653,14 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any]\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1726,7 +1747,7 @@ def generate_permissions_file(output_path: Path) -> None: "from __future__ import annotations\n" "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1769,10 +1790,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 40647157f8535..647dd7bcfd892 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -338,7 +338,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -390,7 +390,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -414,7 +414,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -437,7 +437,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -473,7 +473,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -506,7 +506,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -893,8 +893,8 @@ def from_json(self2, p): "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ - "self.intercepts = []", - "self._handler_intercepts: dict = {}", + "self.intercepts: list[Any] = []", + "self._handler_intercepts: dict[str, Any] = {}", ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ @@ -908,7 +908,7 @@ def from_json(self2, p): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1089,7 +1089,7 @@ def _auth_callback(params): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1122,7 +1122,7 @@ def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( @@ -1379,6 +1379,7 @@ def to_bidi_dict(self) -> dict: elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -1395,12 +1396,14 @@ def to_bidi_dict(self) -> dict: ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 777005e0ce4e5..5b1a67ce93f11 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -598,7 +598,7 @@ def clear_event_handlers(self) -> None: class BrowsingContext: """WebDriver BiDi browsingContext module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..168f748d5501b 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,13 +17,14 @@ """Common utilities for BiDi command construction.""" -from collections.abc import Generator +from __future__ import annotations + from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> Generator[dict[str, Any], Any, Any]: +) -> dict[str, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 0356372c48f03..3dcf8e58881e4 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -320,7 +320,7 @@ def set_geolocation_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -372,7 +372,7 @@ def set_timezone_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -396,7 +396,7 @@ def set_scripting_enabled( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -419,7 +419,7 @@ def set_user_agent_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -455,7 +455,7 @@ def set_screen_orientation_override( "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -488,7 +488,7 @@ def set_network_conditions( nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 7e76cb831543f..1d4730534f16d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -362,7 +362,7 @@ def clear_event_handlers(self) -> None: class Input: """WebDriver BiDi input module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index fd712b7c9a8ab..488f0740a40b5 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -256,7 +256,7 @@ def clear_event_handlers(self) -> None: class Log: """WebDriver BiDi log module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 74951031c597f..30de3306ff001 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -368,7 +368,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -555,12 +555,12 @@ def clear_event_handlers(self) -> None: class Network: """WebDriver BiDi network module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - self.intercepts = [] - self._handler_intercepts: dict = {} + self.intercepts: list[Any] = [] + self._handler_intercepts: dict[str, Any] = {} def add_data_collector( self, diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 6c2e4298a2dce..221b5963e8ec1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -776,7 +776,7 @@ def clear_event_handlers(self) -> None: class Script: """WebDriver BiDi script module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 089cee2c4fbdf..a2606526f3856 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -76,7 +76,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -110,7 +110,7 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b1bc09452bc63..70a21d7fd5e5e 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -100,6 +100,7 @@ def install( elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -116,11 +117,13 @@ def uninstall(self, extension: str | dict): ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 645f5bed1f2c8d41b03192618ce0982f14d56f54 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:43:50 +0000 Subject: [PATCH 19/42] fix linting --- py/generate_bidi.py | 12 +++++------- py/private/bidi_enhancements_manifest.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 3 ++- py/selenium/webdriver/common/bidi/webextension.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..de41855954651 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,9 +952,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += ( - " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined - ) + code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1105,13 +1103,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1 + count=1, ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n" + "from typing import Any\nfrom dataclasses import field\n", ) return code @@ -1655,12 +1653,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any\n" + "from typing import Any, Generator\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 647dd7bcfd892..d9923531b0293 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1402,7 +1402,7 @@ def to_bidi_dict(self) -> dict: if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 168f748d5501b..59e8afd93ab2e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -19,12 +19,13 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> dict[str, Any]: +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 70a21d7fd5e5e..b5881d01e0bea 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -123,7 +123,7 @@ def uninstall(self, extension: str | dict): if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 00e45408470755e3dd6f3984596e4faf1de1c9af Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:12:19 +0000 Subject: [PATCH 20/42] Fix auth tests --- py/generate_bidi.py | 12 +++++++----- py/private/bidi_enhancements_manifest.py | 20 +++++++++++++++++--- py/selenium/webdriver/common/bidi/network.py | 18 ++++++++++++++++-- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index de41855954651..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,7 +952,9 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + code += ( + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + ) if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1103,13 +1105,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1, + count=1 ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n", + "from typing import Any\nfrom dataclasses import field\n" ) return code @@ -1653,12 +1655,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any, Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d9923531b0293..d617f7468c034 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1033,6 +1033,10 @@ def _request_callback(params): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1060,10 +1064,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback)''', + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" - self.remove_event_handler("auth_required", callback_id)''', + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ + self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ], }, "storage": { diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 30de3306ff001..1dd2f5a476049 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -975,6 +975,10 @@ def add_auth_handler(self, username, password): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1002,10 +1006,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback) + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From 2a2ff5ea98cd13549ccd062f86ad052bef7408c2 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:30:15 +0000 Subject: [PATCH 21/42] sort spacing --- py/generate_bidi.py | 3 ++- py/private/bidi_enhancements_manifest.py | 10 ++++++++++ py/selenium/webdriver/common/bidi/browser.py | 4 ++-- .../webdriver/common/bidi/browsing_context.py | 13 ++++++------- py/selenium/webdriver/common/bidi/emulation.py | 4 ++-- py/selenium/webdriver/common/bidi/input.py | 11 +++++------ py/selenium/webdriver/common/bidi/log.py | 9 ++++----- py/selenium/webdriver/common/bidi/network.py | 9 ++++----- py/selenium/webdriver/common/bidi/script.py | 15 ++++++--------- py/selenium/webdriver/common/bidi/session.py | 4 ++-- py/selenium/webdriver/common/bidi/storage.py | 6 +++--- py/selenium/webdriver/common/bidi/webextension.py | 4 ++-- 12 files changed, 48 insertions(+), 44 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..32d19ec83cec9 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1655,12 +1655,13 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from collections.abc import Generator\n" "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d617f7468c034..dcf464f425e9d 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -37,6 +37,7 @@ # ============================================================================ ENHANCEMENTS: dict[str, dict[str, Any]] = { + "browser": { # Dataclass custom methods "__dataclass_methods__": { @@ -170,6 +171,7 @@ return self._conn.execute(cmd)''', ], }, + "browsingContext": { # Method enhancements "create": { @@ -254,6 +256,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ], # Download events are now in the CDDL spec, so no extra_events needed }, + "log": { # Make LogLevel an alias for Level so existing code using LogLevel works "aliases": {"LogLevel": "Level"}, @@ -317,6 +320,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "entry_added": "Entry", }, }, + "emulation": { "extra_methods": [ ''' def set_geolocation_override( @@ -515,6 +519,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": return self._conn.execute(cmd)''', ], }, + "script": { "extra_methods": [ ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: @@ -890,6 +895,7 @@ def from_json(self2, p): self._unsubscribe_log_entry(callback_id)''', ], }, + "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ @@ -1080,6 +1086,7 @@ def _auth_callback(params): self._remove_intercept(intercept_id)''', ], }, + "storage": { # Exclude auto-generated dataclasses that need custom to_bidi_dict() # for JSON-over-WebSocket serialization, or custom constructors. @@ -1319,6 +1326,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "session": { # Override UserPromptHandler to add to_bidi_dict() for JSON serialization "exclude_types": ["UserPromptHandler"], @@ -1352,6 +1360,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], @@ -1422,6 +1431,7 @@ def to_bidi_dict(self) -> dict: return self._conn.execute(cmd)''', ], }, + "input": { # FileDialogInfo needs from_json for event deserialization "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a8fb60c98178d..a4ec770fbb135 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5b1a67ce93f11..c5489ce865180 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -375,10 +374,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 3dcf8e58881e4..03347a0a85c04 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class ForcedColorsModeTheme: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 1d4730534f16d..44fd3c82c3407 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -174,7 +173,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 488f0740a40b5..3c6a95d74f6d1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -57,7 +56,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -82,7 +81,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1dd2f5a476049..6a0edf0b2b5e7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 221b5963e8ec1..5a7d2792a1221 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -945,9 +944,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,9 +1186,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index fcb42a4ad86fc..177421eca5ee8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class UserPromptHandlerType: diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index a2606526f3856..fef35106c33b0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass @@ -106,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b5881d01e0bea..1c5b342c070d5 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass From 94ab2b498cbb995e97c2f3baedbd76b457edeed5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 17 Mar 2026 09:12:20 +0000 Subject: [PATCH 22/42] linting --- py/selenium/webdriver/common/bidi/browser.py | 4 ++-- .../webdriver/common/bidi/browsing_context.py | 13 +++++++------ py/selenium/webdriver/common/bidi/emulation.py | 4 ++-- py/selenium/webdriver/common/bidi/input.py | 11 ++++++----- py/selenium/webdriver/common/bidi/log.py | 9 +++++---- py/selenium/webdriver/common/bidi/network.py | 9 +++++---- py/selenium/webdriver/common/bidi/script.py | 15 +++++++++------ py/selenium/webdriver/common/bidi/session.py | 4 ++-- py/selenium/webdriver/common/bidi/storage.py | 6 +++--- py/selenium/webdriver/common/bidi/webextension.py | 4 ++-- 10 files changed, 43 insertions(+), 36 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a4ec770fbb135..a8fb60c98178d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index c5489ce865180..5b1a67ce93f11 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -374,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 03347a0a85c04..3dcf8e58881e4 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field class ForcedColorsModeTheme: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 44fd3c82c3407..1d4730534f16d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -173,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 3c6a95d74f6d1..488f0740a40b5 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,10 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any -from dataclasses import dataclass import threading from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + from selenium.webdriver.common.bidi.session import Session @@ -56,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -81,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6a0edf0b2b5e7..1dd2f5a476049 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 5a7d2792a1221..221b5963e8ec1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -944,8 +945,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1186,8 +1188,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 177421eca5ee8..fcb42a4ad86fc 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field class UserPromptHandlerType: diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index fef35106c33b0..a2606526f3856 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field @dataclass @@ -106,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 1c5b342c070d5..b5881d01e0bea 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field @dataclass From cd6fbaa75ebafd56315d9051919e468e2a0312df Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Mar 2026 09:44:29 +0000 Subject: [PATCH 23/42] handle comments --- py/generate_bidi.py | 220 +++-------------- py/private/bidi_enhancements_manifest.py | 124 ++++++++-- py/selenium/webdriver/common/bidi/__init__.py | 20 +- py/selenium/webdriver/common/bidi/browser.py | 24 +- .../webdriver/common/bidi/browsing_context.py | 185 ++------------ py/selenium/webdriver/common/bidi/common.py | 7 +- .../webdriver/common/bidi/emulation.py | 33 +-- py/selenium/webdriver/common/bidi/input.py | 171 +------------ py/selenium/webdriver/common/bidi/log.py | 156 +----------- py/selenium/webdriver/common/bidi/network.py | 232 +++--------------- .../webdriver/common/bidi/permissions.py | 2 +- py/selenium/webdriver/common/bidi/py.typed | 0 py/selenium/webdriver/common/bidi/script.py | 185 ++------------ py/selenium/webdriver/common/bidi/session.py | 10 +- py/selenium/webdriver/common/bidi/storage.py | 44 +++- .../webdriver/common/bidi/webextension.py | 14 +- 16 files changed, 340 insertions(+), 1087 deletions(-) mode change 100755 => 100644 py/selenium/webdriver/common/bidi/py.typed diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 32d19ec83cec9..745c0f00ed890 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -42,7 +42,6 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any """ @@ -198,8 +197,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_name in self.required_params: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" + error_message = f"{msg} {snake_param!r}" body += ( - f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + f" raise TypeError({error_message!r})\n" ) body += "\n" @@ -591,23 +591,32 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Collect needed imports to avoid duplicates needs_command_builder = bool(self.commands) needs_dataclass = self.commands or self.types or self.events - needs_threading = self.events needs_callable = self.events - needs_session = self.events + + stdlib_imports = [] + local_imports = [] # Add imports (field import will be added conditionally after code generation) - if needs_command_builder: - code += "from .common import command_builder\n" - if needs_dataclass: - code += "from dataclasses import dataclass\n" - if needs_threading: - code += "import threading\n" if needs_callable: - code += "from collections.abc import Callable\n" - if needs_session: - code += "from selenium.webdriver.common.bidi.session import Session\n" + stdlib_imports.append("from collections.abc import Callable") + if needs_dataclass: + stdlib_imports.append("from dataclasses import dataclass") + stdlib_imports.append("from typing import Any") + + if needs_command_builder: + local_imports.append( + "from selenium.webdriver.common.bidi.common import command_builder" + ) + if self.events: + local_imports.append( + "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" + ) + + code += "\n".join(stdlib_imports) + "\n" + if local_imports: + code += "\n" + "\n".join(local_imports) + "\n" - code += "\n\n" + code += "\n" # Add helper function definitions from enhancements # Collect all referenced helper functions (validate, transform) @@ -784,165 +793,11 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # Generate EventConfig and _EventManager for modules with events - if self.events: - # Generate EventConfig dataclass - code += """@dataclass -class EventConfig: - \"\"\"Configuration for a BiDi event.\"\"\" - event_key: str - bidi_event: str - event_class: type - - -""" - - # Generate _EventManager class - code += """class _EventWrapper: - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - \"\"\"Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - \"\"\" - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, \"from_json\") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend([\"_\", char.lower()]) - else: - result.append(char) - return \"\".join(result) - - -class _EventManager: - \"\"\"Manages event subscriptions and callbacks.\"\"\" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get(\"subscription\") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - \"callbacks\": [], - \"subscription_id\": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry[\"callbacks\"]: - session = Session(self.conn) - sub_id = entry.get(\"subscription_id\") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry[\"callbacks\"]: - entry[\"callbacks\"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get(\"subscription_id\") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - -""" - code += "\n\n" + # EventConfig, _EventWrapper, and _EventManager are imported from + # selenium.webdriver.common.bidi._event_manager (see import section above) + # rather than being duplicated inline in every generated module. + if False: # placeholder to preserve indentation structure + pass # Generate class # Convert module name (camelCase or snake_case) to proper class name (PascalCase) @@ -1103,15 +958,15 @@ def clear_event_handlers(self) -> None: if re.search(dataclass_import_pattern, code): code = re.sub( dataclass_import_pattern, - "from dataclasses import dataclass\nfrom dataclasses import field\n", + "from dataclasses import dataclass, field\n", code, - count=1 + count=1, ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n" + "from dataclasses import field\nfrom typing import Any\n", ) return code @@ -1615,7 +1470,9 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) filename = module_name_to_filename(module_name) - code += f"from .{filename} import {class_name}\n" + code += ( + f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" + ) code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): @@ -1660,13 +1517,14 @@ def generate_common_file(output_path: Path) -> None: "\n" "\n" "def command_builder(\n" - " method: str, params: dict[str, Any]\n" + " method: str, params: dict[str, Any] | None = None\n" ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command\n" + " params: The parameters for the command. If omitted, an empty\n" + " dictionary is sent.\n" "\n" " Yields:\n" " A dictionary representing the BiDi command\n" @@ -1674,6 +1532,8 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' + " if params is None:\n" + " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1750,7 +1610,7 @@ def generate_permissions_file(output_path: Path) -> None: "from enum import Enum\n" "from typing import Any\n" "\n" - "from .common import command_builder\n" + "from selenium.webdriver.common.bidi.common import command_builder\n" "\n" '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index dcf464f425e9d..f8ade8b9b3ad8 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -86,7 +86,7 @@ # convenience NORMAL constant. In the BiDi spec "normal" is the state # represented by ClientWindowRectState, but exposing it here keeps the # Python API consistent with the old ClientWindowState enum. - "exclude_types": ["ClientWindowNamedState"], + "exclude_types": ["ClientWindowNamedState", "SetClientWindowStateParameters"], "extra_dataclasses": [ '''class ClientWindowNamedState: """Named states for a browser client window.""" @@ -95,6 +95,18 @@ MAXIMIZED = "maximized" MINIMIZED = "minimized" NORMAL = "normal"''', + '''@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters. + + The ``state`` field is required and must be either a named-state string + (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a + :class:`ClientWindowRectState` instance. ``client_window`` is the ID of + the window to affect. + """ + + client_window: Any | None = None + state: Any | None = None''', ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. @@ -239,10 +251,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -277,7 +289,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -301,7 +313,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -322,6 +334,20 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": }, "emulation": { + "exclude_types": ["setNetworkConditionsParameters"], + "extra_dataclasses": [ + '''@dataclass +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters.""" + + network_conditions: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +# Backward-compatible alias for existing imports +setNetworkConditionsParameters = SetNetworkConditionsParameters''', + ], "extra_methods": [ ''' def set_geolocation_override( self, @@ -897,6 +923,7 @@ def from_json(self2, p): }, "network": { + "exclude_types": ["disownDataParameters"], # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ "self.intercepts: list[Any] = []", @@ -904,6 +931,17 @@ def from_json(self2, p): ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ + '''@dataclass +class DisownDataParameters: + """DisownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +# Backward-compatible alias for existing imports +disownDataParameters = DisownDataParameters''', '''class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -1115,7 +1153,11 @@ def __init__(self, type: Any | None, value: Any | None) -> None: self.value = value def to_bidi_dict(self) -> dict: - return {"type": self.type, "value": self.value}''', + return {"type": self.type, "value": self.value} + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', '''class SameSite: """SameSite cookie attribute values.""" @@ -1139,7 +1181,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -1193,7 +1235,11 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', # Custom PartialCookie with camelCase serialization '''@dataclass class PartialCookie: @@ -1227,7 +1273,11 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', # BrowsingContextPartitionDescriptor: first positional arg is *context* # (the auto-generated dataclass had `type` first, breaking positional # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) @@ -1244,7 +1294,12 @@ def __init__(self, context: Any = None, type: str = "context") -> None: self.type = type def to_bidi_dict(self) -> dict: - return {"type": "context", "context": self.context}''', + return {"type": "context", "context": self.context} + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', + # StorageKeyPartitionDescriptor with camelCase serialization '''@dataclass class StorageKeyPartitionDescriptor: @@ -1261,7 +1316,11 @@ def to_bidi_dict(self) -> dict: result["userContext"] = self.user_context if self.source_origin is not None: result["sourceOrigin"] = self.source_origin - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', ], # Override the generated Storage class methods (Python's last-definition- # wins semantics means these extra_methods shadow the generated ones). @@ -1309,7 +1368,19 @@ def to_bidi_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return SetCookieResult(partition_key=pk) return result''', + ''' def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -1323,6 +1394,17 @@ def to_bidi_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return DeleteCookiesResult(partition_key=pk) return result''', ], }, @@ -1357,7 +1439,11 @@ def to_bidi_dict(self) -> dict: result["file"] = self.file if self.prompt is not None: result["prompt"] = self.prompt - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', ], }, @@ -1406,7 +1492,17 @@ def to_bidi_dict(self) -> dict: extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - return self._conn.execute(cmd)''', + try: + return self._conn.execute(cmd) + except Exception as e: + if "Method not available" in str(e): + raise RuntimeError( + "webExtension.install failed with 'Method not available'. " + "This likely means that web extension support is disabled. " + "Enable unsafe extension debugging and/or set options.enable_webextensions " + "in your WebDriver configuration." + ) from e + raise''', ''' def uninstall(self, extension: str | dict): """Uninstall a web extension. @@ -1445,7 +1541,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 7be7bd4f73856..79ba3dbf2f86f 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -5,16 +5,16 @@ from __future__ import annotations -from .browser import Browser -from .browsing_context import BrowsingContext -from .emulation import Emulation -from .input import Input -from .log import Log -from .network import Network -from .script import Script -from .session import Session -from .storage import Storage -from .webextension import WebExtension +from selenium.webdriver.common.bidi.browser import Browser +from selenium.webdriver.common.bidi.browsing_context import BrowsingContext +from selenium.webdriver.common.bidi.emulation import Emulation +from selenium.webdriver.common.bidi.input import Input +from selenium.webdriver.common.bidi.log import Log +from selenium.webdriver.common.bidi.network import Network +from selenium.webdriver.common.bidi.script import Script +from selenium.webdriver.common.bidi.session import Session +from selenium.webdriver.common.bidi.storage import Storage +from selenium.webdriver.common.bidi.webextension import WebExtension __all__ = [ "Browser", diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a8fb60c98178d..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder def transform_download_params( @@ -139,13 +139,6 @@ class RemoveUserContextParameters: user_context: Any | None = None -@dataclass -class SetClientWindowStateParameters: - """SetClientWindowStateParameters.""" - - client_window: Any | None = None - - @dataclass class ClientWindowRectState: """ClientWindowRectState.""" @@ -188,6 +181,19 @@ class ClientWindowNamedState: MINIMIZED = "minimized" NORMAL = "normal" +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters. + + The ``state`` field is required and must be either a named-state string + (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a + :class:`ClientWindowRectState` instance. ``client_window`` is the ID of + the window to affect. + """ + + client_window: Any | None = None + state: Any | None = None + class Browser: """WebDriver BiDi browser module.""" @@ -272,7 +278,7 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" if user_context is None: - raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") + raise TypeError("remove_user_context() missing required argument: 'user_context'") params = { "userContext": user_context, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5b1a67ce93f11..fcee27df8488e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class ReadinessState: @@ -442,159 +440,6 @@ def _deserialize_info_list(items: list) -> list | None: -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" @@ -606,7 +451,7 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" if context is None: - raise TypeError("activate() missing required argument: {{snake_param!r}}") + raise TypeError("activate() missing required argument: 'context'") params = { "context": context, @@ -625,7 +470,7 @@ def capture_screenshot( ): """Execute browsingContext.captureScreenshot.""" if context is None: - raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") + raise TypeError("capture_screenshot() missing required argument: 'context'") params = { "context": context, @@ -644,7 +489,7 @@ def capture_screenshot( def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" if context is None: - raise TypeError("close() missing required argument: {{snake_param!r}}") + raise TypeError("close() missing required argument: 'context'") params = { "context": context, @@ -664,7 +509,7 @@ def create( ): """Execute browsingContext.create.""" if type is None: - raise TypeError("create() missing required argument: {{snake_param!r}}") + raise TypeError("create() missing required argument: 'type'") params = { "type": type, @@ -709,7 +554,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" if context is None: - raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") + raise TypeError("handle_user_prompt() missing required argument: 'context'") params = { "context": context, @@ -731,9 +576,9 @@ def locate_nodes( ): """Execute browsingContext.locateNodes.""" if context is None: - raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") + raise TypeError("locate_nodes() missing required argument: 'context'") if locator is None: - raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") + raise TypeError("locate_nodes() missing required argument: 'locator'") params = { "context": context, @@ -753,9 +598,9 @@ def locate_nodes( def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" if context is None: - raise TypeError("navigate() missing required argument: {{snake_param!r}}") + raise TypeError("navigate() missing required argument: 'context'") if url is None: - raise TypeError("navigate() missing required argument: {{snake_param!r}}") + raise TypeError("navigate() missing required argument: 'url'") params = { "context": context, @@ -778,7 +623,7 @@ def print( ): """Execute browsingContext.print.""" if context is None: - raise TypeError("print() missing required argument: {{snake_param!r}}") + raise TypeError("print() missing required argument: 'context'") params = { "context": context, @@ -799,7 +644,7 @@ def print( def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" if context is None: - raise TypeError("reload() missing required argument: {{snake_param!r}}") + raise TypeError("reload() missing required argument: 'context'") params = { "context": context, @@ -833,9 +678,9 @@ def set_viewport( def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: - raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") + raise TypeError("traverse_history() missing required argument: 'context'") if delta is None: - raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") + raise TypeError("traverse_history() missing required argument: 'delta'") params = { "context": context, diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 59e8afd93ab2e..fc75caa282a45 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -24,13 +24,14 @@ def command_builder( - method: str, params: dict[str, Any] + method: str, params: dict[str, Any] | None = None ) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: method: The BiDi method name (e.g., "session.status", "browser.close") - params: The parameters for the command + params: The parameters for the command. If omitted, an empty + dictionary is sent. Yields: A dictionary representing the BiDi command @@ -38,5 +39,7 @@ def command_builder( Returns: The result from the BiDi command execution """ + if params is None: + params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 3dcf8e58881e4..44babb6777616 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder class ForcedColorsModeTheme: @@ -81,15 +81,6 @@ class SetLocaleOverrideParameters: user_contexts: list[Any] = field(default_factory=list) -@dataclass -class setNetworkConditionsParameters: - """setNetworkConditionsParameters.""" - - network_conditions: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - @dataclass class NetworkConditionsOffline: """NetworkConditionsOffline.""" @@ -184,6 +175,18 @@ class SetTouchOverrideParameters: user_contexts: list[Any] = field(default_factory=list) +@dataclass +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters.""" + + network_conditions: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +# Backward-compatible alias for existing imports +setNetworkConditionsParameters = SetNetworkConditionsParameters + class Emulation: """WebDriver BiDi emulation module.""" @@ -198,7 +201,7 @@ def set_forced_colors_mode_theme_override( ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") params = { "theme": theme, @@ -218,7 +221,7 @@ def set_locale_override( ): """Execute emulation.setLocaleOverride.""" if locale is None: - raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_locale_override() missing required argument: 'locale'") params = { "locale": locale, @@ -238,7 +241,7 @@ def set_screen_settings_override( ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") params = { "screenArea": screen_area, @@ -258,7 +261,7 @@ def set_viewport_meta_override( ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") params = { "viewportMeta": viewport_meta, @@ -278,7 +281,7 @@ def set_scrollbar_type_override( ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") params = { "scrollbarType": scrollbar_type, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 1d4730534f16d..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: input from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class PointerType: @@ -206,159 +204,6 @@ class PointerDownAction: "file_dialog_opened": "input.fileDialogOpened", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Input: """WebDriver BiDi input module.""" @@ -370,9 +215,9 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" if context is None: - raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + raise TypeError("perform_actions() missing required argument: 'context'") if actions is None: - raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + raise TypeError("perform_actions() missing required argument: 'actions'") params = { "context": context, @@ -386,7 +231,7 @@ def perform_actions(self, context: Any | None = None, actions: list[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" if context is None: - raise TypeError("release_actions() missing required argument: {{snake_param!r}}") + raise TypeError("release_actions() missing required argument: 'context'") params = { "context": context, @@ -399,11 +244,11 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" if context is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'context'") if element is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'element'") if files is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'files'") params = { "context": context, diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 488f0740a40b5..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,12 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi.session import Session +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager class Level: @@ -100,159 +99,6 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "log.entryAdded", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Log: """WebDriver BiDi log module.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1dd2f5a476049..343b6d960c017 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: network from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class SameSite: @@ -275,15 +273,6 @@ class ContinueWithAuthCredentials: credentials: Any | None = None -@dataclass -class disownDataParameters: - """disownDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - request: Any | None = None - - @dataclass class FailRequestParameters: """FailRequestParameters.""" @@ -358,6 +347,18 @@ class ResponseStartedParameters: response: Any | None = None +@dataclass +class DisownDataParameters: + """DisownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +# Backward-compatible alias for existing imports +disownDataParameters = DisownDataParameters + class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -399,159 +400,6 @@ def continue_request(self, **kwargs): "before_request": "network.beforeRequestSent", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Network: """WebDriver BiDi network module.""" @@ -572,9 +420,9 @@ def add_data_collector( ): """Execute network.addDataCollector.""" if data_types is None: - raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("add_data_collector() missing required argument: 'data_types'") if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") params = { "dataTypes": data_types, @@ -596,7 +444,7 @@ def add_intercept( ): """Execute network.addIntercept.""" if phases is None: - raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") + raise TypeError("add_intercept() missing required argument: 'phases'") params = { "phases": phases, @@ -619,7 +467,7 @@ def continue_request( ): """Execute network.continueRequest.""" if request is None: - raise TypeError("continue_request() missing required argument: {{snake_param!r}}") + raise TypeError("continue_request() missing required argument: 'request'") params = { "request": request, @@ -645,7 +493,7 @@ def continue_response( ): """Execute network.continueResponse.""" if request is None: - raise TypeError("continue_response() missing required argument: {{snake_param!r}}") + raise TypeError("continue_response() missing required argument: 'request'") params = { "request": request, @@ -663,7 +511,7 @@ def continue_response( def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" if request is None: - raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") + raise TypeError("continue_with_auth() missing required argument: 'request'") params = { "request": request, @@ -676,11 +524,11 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" if data_type is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'data_type'") if collector is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'collector'") if request is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'request'") params = { "dataType": data_type, @@ -695,7 +543,7 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" if request is None: - raise TypeError("fail_request() missing required argument: {{snake_param!r}}") + raise TypeError("fail_request() missing required argument: 'request'") params = { "request": request, @@ -714,9 +562,9 @@ def get_data( ): """Execute network.getData.""" if data_type is None: - raise TypeError("get_data() missing required argument: {{snake_param!r}}") + raise TypeError("get_data() missing required argument: 'data_type'") if request is None: - raise TypeError("get_data() missing required argument: {{snake_param!r}}") + raise TypeError("get_data() missing required argument: 'request'") params = { "dataType": data_type, @@ -740,7 +588,7 @@ def provide_response( ): """Execute network.provideResponse.""" if request is None: - raise TypeError("provide_response() missing required argument: {{snake_param!r}}") + raise TypeError("provide_response() missing required argument: 'request'") params = { "request": request, @@ -758,7 +606,7 @@ def provide_response( def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" if collector is None: - raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("remove_data_collector() missing required argument: 'collector'") params = { "collector": collector, @@ -771,7 +619,7 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" if intercept is None: - raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") + raise TypeError("remove_intercept() missing required argument: 'intercept'") params = { "intercept": intercept, @@ -784,7 +632,7 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") + raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") params = { "cacheBehavior": cache_behavior, @@ -803,7 +651,7 @@ def set_extra_headers( ): """Execute network.setExtraHeaders.""" if headers is None: - raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") + raise TypeError("set_extra_headers() missing required argument: 'headers'") params = { "headers": headers, @@ -818,9 +666,9 @@ def set_extra_headers( def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" if method is None: - raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + raise TypeError("before_request_sent() missing required argument: 'method'") if params is None: - raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + raise TypeError("before_request_sent() missing required argument: 'params'") params = { "initiator": initiator, @@ -835,11 +683,11 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" if error_text is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'error_text'") if method is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'method'") if params is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'params'") params = { "errorText": error_text, @@ -854,11 +702,11 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" if response is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'response'") if method is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'method'") if params is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'params'") params = { "response": response, @@ -873,7 +721,7 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" if response is None: - raise TypeError("response_started() missing required argument: {{snake_param!r}}") + raise TypeError("response_started() missing required argument: 'response'") params = { "response": response, diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 6dd138da17309..acb8bdf65f0ef 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -22,7 +22,7 @@ from enum import Enum from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder _VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed old mode 100755 new mode 100644 diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 221b5963e8ec1..d6877de623d14 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: script from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class SpecialNumber: @@ -620,159 +618,6 @@ class RealmDestroyedParameters: "realm_destroyed": "script.realmDestroyed", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Script: """WebDriver BiDi script module.""" @@ -792,7 +637,7 @@ def add_preload_script( ): """Execute script.addPreloadScript.""" if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") + raise TypeError("add_preload_script() missing required argument: 'function_declaration'") params = { "functionDeclaration": function_declaration, @@ -809,9 +654,9 @@ def add_preload_script( def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: - raise TypeError("disown() missing required argument: {{snake_param!r}}") + raise TypeError("disown() missing required argument: 'handles'") if target is None: - raise TypeError("disown() missing required argument: {{snake_param!r}}") + raise TypeError("disown() missing required argument: 'target'") params = { "handles": handles, @@ -835,11 +680,11 @@ def call_function( ): """Execute script.callFunction.""" if function_declaration is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'function_declaration'") if await_promise is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'await_promise'") if target is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'target'") params = { "functionDeclaration": function_declaration, @@ -867,11 +712,11 @@ def evaluate( ): """Execute script.evaluate.""" if expression is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'expression'") if target is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'target'") if await_promise is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'await_promise'") params = { "expression": expression, @@ -900,7 +745,7 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" if script is None: - raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") + raise TypeError("remove_preload_script() missing required argument: 'script'") params = { "script": script, @@ -913,11 +758,11 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" if channel is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'channel'") if data is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'data'") if source is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'source'") params = { "channel": channel, diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index fcb42a4ad86fc..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder class UserPromptHandlerType: @@ -176,6 +176,10 @@ def to_bidi_dict(self) -> dict: result["prompt"] = self.prompt return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class Session: """WebDriver BiDi session module.""" @@ -194,7 +198,7 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" if capabilities is None: - raise TypeError("new() missing required argument: {{snake_param!r}}") + raise TypeError("new() missing required argument: 'capabilities'") params = { "capabilities": capabilities, @@ -221,7 +225,7 @@ def subscribe( ): """Execute session.subscribe.""" if events is None: - raise TypeError("subscribe() missing required argument: {{snake_param!r}}") + raise TypeError("subscribe() missing required argument: 'events'") params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index a2606526f3856..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder @dataclass @@ -83,6 +83,10 @@ def __init__(self, type: Any | None, value: Any | None) -> None: def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class SameSite: """SameSite cookie attribute values.""" @@ -162,6 +166,10 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + @dataclass class PartialCookie: """PartialCookie.""" @@ -196,6 +204,10 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class BrowsingContextPartitionDescriptor: """BrowsingContextPartitionDescriptor. @@ -211,6 +223,10 @@ def __init__(self, context: Any = None, type: str = "context") -> None: def to_bidi_dict(self) -> dict: return {"type": "context", "context": self.context} + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + @dataclass class StorageKeyPartitionDescriptor: """StorageKeyPartitionDescriptor.""" @@ -228,6 +244,10 @@ def to_bidi_dict(self) -> dict: result["sourceOrigin"] = self.source_origin return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class Storage: """WebDriver BiDi storage module.""" @@ -277,6 +297,17 @@ def set_cookie(self, cookie=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return SetCookieResult(partition_key=pk) return result def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" @@ -291,4 +322,15 @@ def delete_cookies(self, filter=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return DeleteCookiesResult(partition_key=pk) return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b5881d01e0bea..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder @dataclass @@ -104,7 +104,17 @@ def install( extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - return self._conn.execute(cmd) + try: + return self._conn.execute(cmd) + except Exception as e: + if "Method not available" in str(e): + raise RuntimeError( + "webExtension.install failed with 'Method not available'. " + "This likely means that web extension support is disabled. " + "Enable unsafe extension debugging and/or set options.enable_webextensions " + "in your WebDriver configuration." + ) from e + raise def uninstall(self, extension: str | dict): """Uninstall a web extension. From 3250f1babd8cb69e1f6b004d4cf12f70d6697df5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 1 Apr 2026 13:36:02 +0100 Subject: [PATCH 24/42] fix tests --- py/private/BUILD.bazel | 1 + py/private/cdp.py | 8 +++++++- py/selenium/webdriver/common/bidi/_event_manager.py | 2 +- py/selenium/webdriver/common/bidi/browser.py | 1 - py/selenium/webdriver/common/bidi/browsing_context.py | 3 +-- py/selenium/webdriver/common/bidi/cdp.py | 8 +++++++- py/selenium/webdriver/common/bidi/emulation.py | 1 - py/selenium/webdriver/common/bidi/input.py | 3 +-- py/selenium/webdriver/common/bidi/log.py | 3 +-- py/selenium/webdriver/common/bidi/network.py | 3 +-- py/selenium/webdriver/common/bidi/script.py | 9 +++------ py/selenium/webdriver/common/bidi/session.py | 1 - py/selenium/webdriver/common/bidi/storage.py | 1 - py/selenium/webdriver/common/bidi/webextension.py | 1 - 14 files changed, 23 insertions(+), 22 deletions(-) diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index 88acc9d2aba11..d2ea587fd8101 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,6 +1,7 @@ load("@rules_python//python:defs.bzl", "py_binary") exports_files([ + "_event_manager.py", "bidi_enhancements_manifest.py", "cdp.py", ]) diff --git a/py/private/cdp.py b/py/private/cdp.py index b097762fe50cd..ba4a73298ee0a 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -60,7 +60,13 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - latest = max(int(x[1:]) for x in versions) + available_versions = tuple( + x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) + numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) + if not numeric_versions: + raise + latest = max(numeric_versions, key=int) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py index 1dcc8288ce683..3fb3a6a1ceb6b 100644 --- a/py/selenium/webdriver/common/bidi/_event_manager.py +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -177,4 +177,4 @@ def clear_event_handlers(self) -> None: session.unsubscribe(subscriptions=[sub_id]) else: session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() + self.subscriptions.clear() \ No newline at end of file diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..94dd0094e9173 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index fcee27df8488e..86075dc166256 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index b097762fe50cd..ba4a73298ee0a 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -60,7 +60,13 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - latest = max(int(x[1:]) for x in versions) + available_versions = tuple( + x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) + numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) + if not numeric_versions: + raise + latest = max(numeric_versions, key=int) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 44babb6777616..9791aba5e08a6 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..d2508fea5ca64 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..04c5a53c04510 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,8 +10,7 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 343b6d960c017..c0302bdec186b 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d6877de623d14..d5b15ff4d983c 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SpecialNumber: """SpecialNumber.""" @@ -790,9 +789,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,9 +1031,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..a54e196aa86d9 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..d922390a08699 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..3520219e26c53 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class InstallParameters: """InstallParameters.""" From 9b3a2cf6874bf44272a7859adf8b46936724ff9c Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:14:08 +0100 Subject: [PATCH 25/42] ruffs updates --- py/selenium/webdriver/common/bidi/browser.py | 1 + py/selenium/webdriver/common/bidi/browsing_context.py | 3 ++- py/selenium/webdriver/common/bidi/emulation.py | 1 + py/selenium/webdriver/common/bidi/input.py | 3 ++- py/selenium/webdriver/common/bidi/log.py | 3 ++- py/selenium/webdriver/common/bidi/network.py | 3 ++- py/selenium/webdriver/common/bidi/script.py | 9 ++++++--- py/selenium/webdriver/common/bidi/session.py | 1 + py/selenium/webdriver/common/bidi/storage.py | 1 + py/selenium/webdriver/common/bidi/webextension.py | 1 + 10 files changed, 19 insertions(+), 7 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 94dd0094e9173..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 86075dc166256..fcee27df8488e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 9791aba5e08a6..44babb6777616 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index d2508fea5ca64..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 04c5a53c04510..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,7 +10,8 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager + class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index c0302bdec186b..343b6d960c017 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d5b15ff4d983c..d6877de623d14 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SpecialNumber: """SpecialNumber.""" @@ -789,8 +790,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1031,8 +1033,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index a54e196aa86d9..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index d922390a08699..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 3520219e26c53..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class InstallParameters: """InstallParameters.""" From ad1fd00938f390f3b8b742530a0e8a8355e31c23 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:34:15 +0100 Subject: [PATCH 26/42] Update CDDL files and regenerate python files --- common/bidi/spec/all.cddl | 71 +++++++++++-------- common/bidi/spec/local.cddl | 29 ++++++-- common/bidi/spec/remote.cddl | 65 +++++++---------- py/selenium/webdriver/common/bidi/browser.py | 1 - .../webdriver/common/bidi/browsing_context.py | 37 +++++++++- .../webdriver/common/bidi/emulation.py | 30 -------- py/selenium/webdriver/common/bidi/input.py | 3 +- py/selenium/webdriver/common/bidi/log.py | 3 +- py/selenium/webdriver/common/bidi/network.py | 4 +- py/selenium/webdriver/common/bidi/script.py | 11 ++- py/selenium/webdriver/common/bidi/session.py | 1 - py/selenium/webdriver/common/bidi/storage.py | 1 - .../webdriver/common/bidi/webextension.py | 1 - 13 files changed, 130 insertions(+), 127 deletions(-) diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl index 85c4536a2cd10..e10b42723b0f5 100644 --- a/common/bidi/spec/all.cddl +++ b/common/bidi/spec/all.cddl @@ -420,6 +420,7 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // + browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -435,6 +436,7 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / + browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -518,6 +520,7 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, + ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -605,7 +608,8 @@ browsingContext.CreateParameters = { } browsingContext.CreateResult = { - context: browsingContext.BrowsingContext + context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } browsingContext.GetTree = ( @@ -715,6 +719,19 @@ browsingContext.ReloadParameters = { browsingContext.ReloadResult = browsingContext.NavigateResult +browsingContext.SetBypassCSP = ( + method: "browsingContext.setBypassCSP", + params: browsingContext.SetBypassCSPParameters +) + +browsingContext.SetBypassCSPParameters = { + bypass: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +browsingContext.SetBypassCSPResult = EmptyResult + browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -774,7 +791,8 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text + url: text, + ? userContext: browser.UserContext } browsingContext.DomContentLoaded = ( @@ -844,6 +862,7 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? userText: text } @@ -857,6 +876,7 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? defaultValue: text } @@ -871,8 +891,7 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride // - emulation.SetViewportMetaOverride + emulation.SetUserAgentOverride ) @@ -885,8 +904,7 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult / - emulation.SetViewportMetaOverrideResult + emulation.SetUserAgentOverrideResult ) emulation.SetForcedColorsModeThemeOverride = ( @@ -949,10 +967,10 @@ emulation.SetLocaleOverrideResult = EmptyResult emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.setNetworkConditionsParameters + params: emulation.SetNetworkConditionsParameters ) -emulation.setNetworkConditionsParameters = { +emulation.SetNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -1018,19 +1036,6 @@ emulation.SetUserAgentOverrideParameters = { emulation.SetUserAgentOverrideResult = EmptyResult -emulation.SetViewportMetaOverride = ( - method: "emulation.setViewportMetaOverride", - params: emulation.SetViewportMetaOverrideParameters -) - -emulation.SetViewportMetaOverrideParameters = { - viewportMeta: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - -emulation.SetViewportMetaOverrideResult = EmptyResult - emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -1145,6 +1150,7 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, + ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -1379,10 +1385,10 @@ network.ContinueWithAuthResult = EmptyResult network.DisownData = ( method: "network.disownData", - params: network.disownDataParameters + params: network.DisownDataParameters ) -network.disownDataParameters = { +network.DisownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1710,6 +1716,7 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? sandbox: text } @@ -1969,7 +1976,8 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext + ? context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } script.RealmTarget = { @@ -2381,15 +2389,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint .default 1, - ? height: js-uint .default 1, - ? pressure: float .default 0.0, - ? tangentialPressure: float .default 0.0, - ? twist: (0..359) .default 0, + ? width: js-uint, + ? height: js-uint, + ? pressure: (0.0..1.0), + ? tangentialPressure: (-1.0..1.0), + ? twist: (0..359), ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ? altitudeAngle: (0.0..1.5707963267948966), ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, + ? azimuthAngle: (0.0..6.283185307179586), ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -2427,6 +2435,7 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl index d43af0ae11b03..1bb2ce612e2c2 100644 --- a/common/bidi/spec/local.cddl +++ b/common/bidi/spec/local.cddl @@ -251,6 +251,7 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / + browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -334,6 +335,7 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, + ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -351,7 +353,8 @@ browsingContext.CaptureScreenshotResult = { browsingContext.CloseResult = EmptyResult browsingContext.CreateResult = { - context: browsingContext.BrowsingContext + context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } browsingContext.GetTreeResult = { @@ -375,6 +378,8 @@ browsingContext.PrintResult = { browsingContext.ReloadResult = browsingContext.NavigateResult +browsingContext.SetBypassCSPResult = EmptyResult + browsingContext.SetViewportResult = EmptyResult browsingContext.TraverseHistoryResult = EmptyResult @@ -407,7 +412,8 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text + url: text, + ? userContext: browser.UserContext } browsingContext.DomContentLoaded = ( @@ -477,6 +483,7 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? userText: text } @@ -490,6 +497,7 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? defaultValue: text } @@ -502,8 +510,7 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult / - emulation.SetViewportMetaOverrideResult + emulation.SetUserAgentOverrideResult ) emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult @@ -520,8 +527,6 @@ emulation.SetScreenOrientationOverrideResult = EmptyResult emulation.SetUserAgentOverrideResult = EmptyResult -emulation.SetViewportMetaOverrideResult = EmptyResult - emulation.SetScriptingEnabledResult = EmptyResult emulation.SetScrollbarTypeOverrideResult = EmptyResult @@ -568,6 +573,7 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, + ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -926,6 +932,7 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? sandbox: text } @@ -1185,7 +1192,8 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext + ? context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } script.AddPreloadScriptResult = { @@ -1295,6 +1303,12 @@ log.EntryAdded = ( params: log.Entry, ) +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + InputEvent = ( input.FileDialogOpened @@ -1313,6 +1327,7 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl index a98859a021e12..7490df1b44bc7 100644 --- a/common/bidi/spec/remote.cddl +++ b/common/bidi/spec/remote.cddl @@ -273,6 +273,7 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // + browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -480,6 +481,17 @@ browsingContext.ReloadParameters = { ? wait: browsingContext.ReadinessState, } +browsingContext.SetBypassCSP = ( + method: "browsingContext.setBypassCSP", + params: browsingContext.SetBypassCSPParameters +) + +browsingContext.SetBypassCSPParameters = { + bypass: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -518,8 +530,7 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride // - emulation.SetViewportMetaOverride + emulation.SetUserAgentOverride ) @@ -577,10 +588,10 @@ emulation.SetLocaleOverrideParameters = { emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.setNetworkConditionsParameters + params: emulation.SetNetworkConditionsParameters ) -emulation.setNetworkConditionsParameters = { +emulation.SetNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -638,17 +649,6 @@ emulation.SetUserAgentOverrideParameters = { ? userContexts: [+browser.UserContext], } -emulation.SetViewportMetaOverride = ( - method: "emulation.setViewportMetaOverride", - params: emulation.SetViewportMetaOverrideParameters -) - -emulation.SetViewportMetaOverrideParameters = { - viewportMeta: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -876,10 +876,10 @@ network.ContinueWithAuthNoCredentials = ( network.DisownData = ( method: "network.disownData", - params: network.disownDataParameters + params: network.DisownDataParameters ) -network.disownDataParameters = { +network.DisownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1500,12 +1500,6 @@ InputCommand = ( input.SetFiles ) -InputResult = ( - input.PerformActionsResult / - input.ReleaseActionsResult / - input.SetFilesResult -) - input.ElementOrigin = { type: "element", element: script.SharedReference @@ -1625,15 +1619,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint .default 1, - ? height: js-uint .default 1, - ? pressure: float .default 0.0, - ? tangentialPressure: float .default 0.0, - ? twist: (0..359) .default 0, + ? width: js-uint, + ? height: js-uint, + ? pressure: (0.0..1.0), + ? tangentialPressure: (-1.0..1.0), + ? twist: (0..359), ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ? altitudeAngle: (0.0..1.5707963267948966), ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, + ? azimuthAngle: (0.0..6.283185307179586), ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -1658,17 +1652,6 @@ input.SetFilesParameters = { files: [*text] } -input.FileDialogOpened = ( - method: "input.fileDialogOpened", - params: input.FileDialogInfo -) - -input.FileDialogInfo = { - context: browsingContext.BrowsingContext, - ? element: script.SharedReference, - multiple: bool, -} - WebExtensionCommand = ( webExtension.Install // webExtension.Uninstall diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..94dd0094e9173 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index fcee27df8488e..50f9d61d487f6 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" @@ -109,6 +108,7 @@ class BaseNavigationInfo: navigation: Any | None = None timestamp: Any | None = None url: str | None = None + user_context: Any | None = None @dataclass @@ -184,6 +184,7 @@ class CreateResult: """CreateResult.""" context: Any | None = None + user_context: Any | None = None @dataclass @@ -290,6 +291,15 @@ class ReloadParameters: wait: Any | None = None +@dataclass +class SetBypassCSPParameters: + """SetBypassCSPParameters.""" + + bypass: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + @dataclass class SetViewportParameters: """SetViewportParameters.""" @@ -323,6 +333,7 @@ class HistoryUpdatedParameters: context: Any | None = None timestamp: Any | None = None url: str | None = None + user_context: Any | None = None @dataclass @@ -332,6 +343,7 @@ class UserPromptClosedParameters: context: Any | None = None accepted: bool | None = None type: Any | None = None + user_context: Any | None = None user_text: str | None = None @@ -343,6 +355,7 @@ class UserPromptOpenedParameters: handler: Any | None = None message: str | None = None type: Any | None = None + user_context: Any | None = None default_value: str | None = None @@ -656,6 +669,26 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result + def set_bypass_csp( + self, + bypass: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute browsingContext.setBypassCSP.""" + if bypass is None: + raise TypeError("set_bypass_csp() missing required argument: 'bypass'") + + params = { + "bypass": bypass, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.setBypassCSP", params) + result = self._conn.execute(cmd) + return result + def set_viewport( self, context: str | None = None, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 44babb6777616..67f95b933aa16 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" @@ -131,15 +130,6 @@ class SetUserAgentOverrideParameters: user_contexts: list[Any] = field(default_factory=list) -@dataclass -class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters.""" - - viewport_meta: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - @dataclass class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" @@ -253,26 +243,6 @@ def set_screen_settings_override( result = self._conn.execute(cmd) return result - def set_viewport_meta_override( - self, - viewport_meta: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setViewportMetaOverride.""" - if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") - - params = { - "viewportMeta": viewport_meta, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setViewportMetaOverride", params) - result = self._conn.execute(cmd) - return result - def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..d2508fea5ca64 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..04c5a53c04510 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,8 +10,7 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 343b6d960c017..e13befc582f08 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SameSite: """SameSite.""" @@ -72,6 +71,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None + user_context: Any | None = None intercepts: list[Any] = field(default_factory=list) diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d6877de623d14..572f620a6d63b 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SpecialNumber: """SpecialNumber.""" @@ -205,6 +204,7 @@ class WindowRealmInfo: type: str = field(default="window", init=False) context: Any | None = None + user_context: Any | None = None sandbox: str | None = None @@ -505,6 +505,7 @@ class Source: realm: Any | None = None context: Any | None = None + user_context: Any | None = None @dataclass @@ -790,9 +791,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,9 +1033,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..a54e196aa86d9 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..d922390a08699 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..3520219e26c53 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class InstallParameters: """InstallParameters.""" From bca421b8a6d803c2cf46bf8ca7e99f0a6a22b94a Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:34:54 +0100 Subject: [PATCH 27/42] ruffs updates --- py/selenium/webdriver/common/bidi/browser.py | 1 + py/selenium/webdriver/common/bidi/browsing_context.py | 3 ++- py/selenium/webdriver/common/bidi/emulation.py | 1 + py/selenium/webdriver/common/bidi/input.py | 3 ++- py/selenium/webdriver/common/bidi/log.py | 3 ++- py/selenium/webdriver/common/bidi/network.py | 3 ++- py/selenium/webdriver/common/bidi/script.py | 9 ++++++--- py/selenium/webdriver/common/bidi/session.py | 1 + py/selenium/webdriver/common/bidi/storage.py | 1 + py/selenium/webdriver/common/bidi/webextension.py | 1 + 10 files changed, 19 insertions(+), 7 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 94dd0094e9173..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 50f9d61d487f6..175c511393098 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 67f95b933aa16..1c48100cc343b 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index d2508fea5ca64..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 04c5a53c04510..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,7 +10,8 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager + class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index e13befc582f08..d6875e14fa58a 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 572f620a6d63b..ecc2a75e0922d 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SpecialNumber: """SpecialNumber.""" @@ -791,8 +792,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,8 +1035,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index a54e196aa86d9..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index d922390a08699..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 3520219e26c53..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class InstallParameters: """InstallParameters.""" From 5322480f62aa84b208b941b23c5ef26c36668c59 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 11:15:10 +0100 Subject: [PATCH 28/42] fix test --- py/test/selenium/webdriver/common/bidi_browsing_context_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 86e3d11af0341..51f745d1c98af 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -436,7 +436,6 @@ def test_set_viewport_back_to_default(driver, pages): # Allow some tolerance since some window managers might not put it to the exact value assert abs(viewport_size[0] - default_viewport_size[0]) <= 5 assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 - assert device_pixel_ratio == default_device_pixel_ratio finally: driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) From 93dffd9cf1f8edb46c71e4f0eb3099e76d591a51 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 11:46:07 +0100 Subject: [PATCH 29/42] add assert and add failure to chrome --- .../common/bidi_browsing_context_tests.py | 248 +++++++++++++----- 1 file changed, 186 insertions(+), 62 deletions(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 51f745d1c98af..8038ee826aa74 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -60,7 +60,9 @@ def test_create_window(driver): def test_create_window_with_reference_context(driver): """Test creating a window with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create(type=WindowTypes.WINDOW, reference_context=reference_context) + context_id = driver.browsing_context.create( + type=WindowTypes.WINDOW, reference_context=reference_context + ) assert context_id is not None # Clean up @@ -79,7 +81,9 @@ def test_create_tab(driver): def test_create_tab_with_reference_context(driver): """Test creating a tab with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create(type=WindowTypes.TAB, reference_context=reference_context) + context_id = driver.browsing_context.create( + type=WindowTypes.TAB, reference_context=reference_context + ) assert context_id is not None # Clean up @@ -124,7 +128,9 @@ def test_navigate_to_url_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - result = driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + result = driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert context_id is not None assert "/bidi/logEntryAdded.html" in result["url"] @@ -138,7 +144,9 @@ def test_get_tree_with_child(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=reference_context, url=url, wait=ReadinessState.COMPLETE + ) context_info_list = driver.browsing_context.get_tree(root=reference_context) @@ -154,9 +162,13 @@ def test_get_tree_with_depth(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=reference_context, url=url, wait=ReadinessState.COMPLETE + ) - context_info_list = driver.browsing_context.get_tree(root=reference_context, max_depth=0) + context_info_list = driver.browsing_context.get_tree( + root=reference_context, max_depth=0 + ) assert len(context_info_list) == 1 info = context_info_list[0] @@ -227,7 +239,9 @@ def test_reload_browsing_context(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) reload_info = driver.browsing_context.reload(context=context_id) @@ -242,9 +256,13 @@ def test_reload_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) - reload_info = driver.browsing_context.reload(context=context_id, wait=ReadinessState.COMPLETE) + reload_info = driver.browsing_context.reload( + context=context_id, wait=ReadinessState.COMPLETE + ) assert reload_info["navigation"] is not None assert "/bidi/logEntryAdded.html" in reload_info["url"] @@ -341,7 +359,9 @@ def test_capture_screenshot_with_parameters(driver, pages): clip = {"type": "box", "x": rect["x"], "y": rect["y"], "width": 5, "height": 5} - screenshot = driver.browsing_context.capture_screenshot(context=context_id, origin="document", clip=clip) + screenshot = driver.browsing_context.capture_screenshot( + context=context_id, origin="document", clip=clip + ) assert len(screenshot) > 0 @@ -352,14 +372,20 @@ def test_set_viewport(driver, pages): driver.get(pages.url("formPage.html")) try: - driver.browsing_context.set_viewport(context=context_id, viewport={"width": 251, "height": 301}) + driver.browsing_context.set_viewport( + context=context_id, viewport={"width": 251, "height": 301} + ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 251 assert viewport_size[1] == 301 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_set_viewport_with_device_pixel_ratio(driver, pages): @@ -374,7 +400,9 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): device_pixel_ratio=5, ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 252 assert viewport_size[1] == 302 @@ -383,7 +411,9 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): assert device_pixel_ratio == 5 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): @@ -400,7 +430,9 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): driver.browsing_context.set_viewport(context=context_id) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 253 assert viewport_size[1] == 303 @@ -409,7 +441,9 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): assert device_pixel_ratio == 6 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) @pytest.mark.xfail_chrome @@ -418,7 +452,9 @@ def test_set_viewport_back_to_default(driver, pages): context_id = driver.current_window_handle driver.get(pages.url("formPage.html")) - default_viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + default_viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) default_device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") try: @@ -428,16 +464,23 @@ def test_set_viewport_back_to_default(driver, pages): device_pixel_ratio=10, ) - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") # Allow some tolerance since some window managers might not put it to the exact value assert abs(viewport_size[0] - default_viewport_size[0]) <= 5 assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 + assert device_pixel_ratio == default_device_pixel_ratio finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_print_page(driver, pages): @@ -456,7 +499,9 @@ def test_print_page(driver, pages): def test_navigate_back_in_browser_history(driver, pages): """Test navigating back in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE + ) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -469,7 +514,9 @@ def test_navigate_back_in_browser_history(driver, pages): def test_navigate_forward_in_browser_history(driver, pages): """Test navigating forward in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE + ) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -491,7 +538,9 @@ def test_locate_nodes(driver, pages): driver.get(pages.url("xhtmlTest.html")) - elements = driver.browsing_context.locate_nodes(context=context_id, locator={"type": "css", "value": "div"}) + elements = driver.browsing_context.locate_nodes( + context=context_id, locator={"type": "css", "value": "div"} + ) assert len(elements) > 0 @@ -611,7 +660,9 @@ def test_add_event_handler_context_created(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created + ) assert callback_id is not None # Create a new context to trigger the event @@ -619,7 +670,10 @@ def on_context_created(info): # Verify the event was received (might be > 1 since default context is also included) assert len(events_received) >= 1 - assert events_received[0].context == context_id or events_received[1].context == context_id + assert ( + events_received[0].context == context_id + or events_received[1].context == context_id + ) driver.browsing_context.close(context_id) driver.browsing_context.remove_event_handler("context_created", callback_id) @@ -632,7 +686,9 @@ def test_add_event_handler_context_destroyed(driver): def on_context_destroyed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) + callback_id = driver.browsing_context.add_event_handler( + "context_destroyed", on_context_destroyed + ) assert callback_id is not None # Create and then close a context to trigger the event @@ -652,13 +708,17 @@ def test_add_event_handler_navigation_committed(driver, pages): def on_navigation_committed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) + callback_id = driver.browsing_context.add_event_handler( + "navigation_committed", on_navigation_committed + ) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) >= 1 assert any(url in event.url for event in events_received) @@ -673,13 +733,17 @@ def test_add_event_handler_dom_content_loaded(driver, pages): def on_dom_content_loaded(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) + callback_id = driver.browsing_context.add_event_handler( + "dom_content_loaded", on_dom_content_loaded + ) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -699,7 +763,9 @@ def on_load(info): context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -714,12 +780,16 @@ def test_add_event_handler_navigation_started(driver, pages): def on_navigation_started(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) + callback_id = driver.browsing_context.add_event_handler( + "navigation_started", on_navigation_started + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -734,17 +804,23 @@ def test_add_event_handler_fragment_navigated(driver, pages): def on_fragment_navigated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) + callback_id = driver.browsing_context.add_event_handler( + "fragment_navigated", on_fragment_navigated + ) assert callback_id is not None # First navigate to a page context_id = driver.current_window_handle url = pages.url("linked_image.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) # Then navigate to the same page with a fragment to trigger the event fragment_url = url + "#link" - driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("link" in event.url for event in events_received) @@ -760,13 +836,17 @@ def test_add_event_handler_navigation_failed(driver): def on_navigation_failed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) + callback_id = driver.browsing_context.add_event_handler( + "navigation_failed", on_navigation_failed + ) assert callback_id is not None # Navigate to an invalid URL to trigger the event context_id = driver.current_window_handle try: - driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") + driver.browsing_context.navigate( + context=context_id, url="http://invalid-domain-that-does-not-exist.test/" + ) except Exception: # Expect an exception due to navigation failure pass @@ -785,7 +865,9 @@ def test_add_event_handler_user_prompt_opened(driver, pages): def on_user_prompt_opened(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) + callback_id = driver.browsing_context.add_event_handler( + "user_prompt_opened", on_user_prompt_opened + ) assert callback_id is not None # Create an alert to trigger the event @@ -810,7 +892,9 @@ def test_add_event_handler_user_prompt_closed(driver, pages): def on_user_prompt_closed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) + callback_id = driver.browsing_context.add_event_handler( + "user_prompt_closed", on_user_prompt_closed + ) assert callback_id is not None create_prompt_page(driver, pages) @@ -835,12 +919,16 @@ def test_add_event_handler_history_updated(driver, pages): def on_history_updated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) + callback_id = driver.browsing_context.add_event_handler( + "history_updated", on_history_updated + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) # Use history.pushState to trigger history updated event driver.script.execute("() => { history.pushState({}, '', '/new-path'); }") @@ -860,13 +948,17 @@ def test_add_event_handler_download_will_begin(driver, pages): def on_download_will_begin(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) + callback_id = driver.browsing_context.add_event_handler( + "download_will_begin", on_download_will_begin + ) assert callback_id is not None # click on a download link to trigger the event context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) download_xpath_file_1_txt = '//*[@id="file-1"]' driver.find_element(By.XPATH, download_xpath_file_1_txt).click() @@ -886,12 +978,16 @@ def test_add_event_handler_download_end(driver, pages): def on_download_end(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("download_end", on_download_end) + callback_id = driver.browsing_context.add_event_handler( + "download_end", on_download_end + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) driver.find_element(By.ID, "file-1").click() @@ -909,12 +1005,14 @@ def on_download_end(info): # we assert that atleast "file_1" is present in the downloaded file since multiple downloads # will have numbered suffix like file_1 (1) assert any( - "downloads/file_1.txt" in ev.download_params.url and "file_1" in ev.download_params.filepath + "downloads/file_1.txt" in ev.download_params.url + and "file_1" in ev.download_params.filepath for ev in events_received ) assert any( - "downloads/file_2.jpg" in ev.download_params.url and "file_2" in ev.download_params.filepath + "downloads/file_2.jpg" in ev.download_params.url + and "file_2" in ev.download_params.filepath for ev in events_received ) @@ -953,7 +1051,9 @@ def test_remove_event_handler(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created + ) # Create a context to trigger the event context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -985,8 +1085,12 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers for the same event - callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) - callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + callback_id_1 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_1 + ) + callback_id_2 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_2 + ) # Create a context to trigger both handlers context_id = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1015,8 +1119,12 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers - callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) - callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + callback_id_1 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_1 + ) + callback_id_2 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_2 + ) # Create a context to trigger both handlers context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1098,7 +1206,9 @@ def callback(info): def register_handler(self, thread_id): try: callback = self.make_callback() - callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) + callback_id = self.driver.browsing_context.add_event_handler( + "context_created", callback + ) with self.data_lock: self.callback_ids.append(callback_id) if len(self.callback_ids) == 5: @@ -1106,12 +1216,16 @@ def register_handler(self, thread_id): return callback_id except Exception as e: with self.data_lock: - self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + self.thread_errors.append( + f"Thread {thread_id}: Registration failed: {e}" + ) return None def remove_handler(self, callback_id, thread_id): try: - self.driver.browsing_context.remove_event_handler("context_created", callback_id) + self.driver.browsing_context.remove_event_handler( + "context_created", callback_id + ) except Exception as e: with self.data_lock: self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") @@ -1121,13 +1235,19 @@ def test_concurrent_event_handler_registration(driver): helper = _EventHandlerTestHelper(driver) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] + futures = [ + executor.submit(helper.register_handler, f"reg-{i}") for i in range(5) + ] for future in futures: future.result(timeout=15) helper.registration_complete.wait(timeout=5) - assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" - assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) + assert ( + len(helper.callback_ids) == 5 + ), f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join( + helper.thread_errors + ) def test_event_callback_data_consistency(driver): @@ -1145,7 +1265,9 @@ def test_event_callback_data_consistency(driver): driver.browsing_context.close(ctx) with helper.data_lock: - assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) + assert not helper.consistency_errors, "Consistency errors: " + str( + helper.consistency_errors + ) assert len(helper.events_received) > 0, "No events received" assert len(helper.events_received) == sum(helper.context_counts.values()) assert len(helper.events_received) == sum(helper.event_type_counts.values()) @@ -1166,7 +1288,9 @@ def test_concurrent_event_handler_removal(driver): for future in futures: future.result(timeout=15) - assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join( + helper.thread_errors + ) def test_no_event_after_handler_removal(driver): From 6638195fd77f981589b273b1659a4cdc572ac4fd Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 17:55:35 +0100 Subject: [PATCH 30/42] do ruff format --- py/generate_bidi.py | 99 ++----- py/private/bidi_enhancements_manifest.py | 16 +- py/private/cdp.py | 4 +- py/selenium/common/exceptions.py | 12 +- .../webdriver/common/bidi/_event_manager.py | 2 +- py/selenium/webdriver/common/bidi/browser.py | 32 +-- .../webdriver/common/bidi/browsing_context.py | 40 +-- py/selenium/webdriver/common/bidi/cdp.py | 4 +- py/selenium/webdriver/common/bidi/common.py | 4 +- .../webdriver/common/bidi/emulation.py | 10 +- py/selenium/webdriver/common/bidi/input.py | 8 +- py/selenium/webdriver/common/bidi/log.py | 7 +- py/selenium/webdriver/common/bidi/network.py | 32 +-- .../webdriver/common/bidi/permissions.py | 3 +- py/selenium/webdriver/common/bidi/script.py | 46 +++- py/selenium/webdriver/common/bidi/session.py | 8 +- py/selenium/webdriver/common/bidi/storage.py | 15 +- .../webdriver/common/bidi/webextension.py | 15 +- py/selenium/webdriver/common/proxy.py | 34 +-- py/selenium/webdriver/remote/webdriver.py | 153 +++-------- .../webdriver/remote/websocket_connection.py | 12 +- .../common/bidi_browsing_context_tests.py | 247 +++++------------- 22 files changed, 268 insertions(+), 535 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 745c0f00ed890..194d94ba12d04 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -68,9 +68,7 @@ def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: return {} try: - spec = importlib.util.spec_from_file_location( - "bidi_enhancements", manifest_file - ) + spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file) if spec is None or spec.loader is None: logger.warning(f"Could not load manifest: {manifest_path}") return {} @@ -169,9 +167,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_strs: # Check if full signature would exceed line length limit (120 chars) - single_line_signature = ( - f" def {method_name}(self, {', '.join(param_strs)}):" - ) + single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" if len(single_line_signature) > 120: # Format parameters on multiple lines body = f" def {method_name}(\n" @@ -198,9 +194,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" error_message = f"{msg} {snake_param!r}" - body += ( - f" raise TypeError({error_message!r})\n" - ) + body += f" raise TypeError({error_message!r})\n" body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -220,9 +214,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: transform_func = transform_spec.get("func") result_param = transform_spec.get("result_param", "params") input_params = [ - transform_spec.get(k) - for k in ["allowed", "destination_folder"] - if transform_spec.get(k) + transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k) ] if transform_func and result_param: @@ -245,9 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: snake_param = self._camel_to_snake(param_name) if preprocess_type == "check_serialize_method": body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" - body += ( - f" {snake_param} = {snake_param}.to_bidi_dict()\n" - ) + body += f" {snake_param} = {snake_param}.to_bidi_dict()\n" body += "\n" # Build params dict @@ -538,11 +528,7 @@ def to_python_dataclass(self) -> str: # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") # The params_type comes from the CDDL and includes module prefix - type_name = ( - self.params_type.split(".")[-1] - if "." in self.params_type - else self.params_type - ) + type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly # (NavigationInfo will be created as an alias to it) @@ -604,9 +590,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: stdlib_imports.append("from typing import Any") if needs_command_builder: - local_imports.append( - "from selenium.webdriver.common.bidi.common import command_builder" - ) + local_imports.append("from selenium.webdriver.common.bidi.common import command_builder") if self.events: local_imports.append( "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" @@ -626,9 +610,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: method_enhancements = enhancements.get(method_name_snake, {}) if "validate" in method_enhancements: helper_funcs_to_add.add(("validate", method_enhancements["validate"])) - if "transform" in method_enhancements and isinstance( - method_enhancements["transform"], dict - ): + if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict): transform_spec = method_enhancements["transform"] if "func" in transform_spec: helper_funcs_to_add.add(("transform", transform_spec["func"])) @@ -636,10 +618,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate helper functions if needed if helper_funcs_to_add: for func_type, func_name in sorted(helper_funcs_to_add): - if ( - func_type == "validate" - and func_name == "validate_download_behavior" - ): + if func_type == "validate" and func_name == "validate_download_behavior": code += """def validate_download_behavior( allowed: bool | None, destination_folder: str | None, @@ -662,10 +641,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - elif ( - func_type == "transform" - and func_name == "transform_download_params" - ): + elif func_type == "transform" and func_name == "transform_download_params": code += """def transform_download_params( allowed: bool | None, destination_folder: str | None, @@ -750,9 +726,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): - code += ( - f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' - ) + code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' code += "}\n\n" # Add custom method function definitions before the class (for browsingContext) @@ -807,9 +781,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add EVENT_CONFIGS dict if there are events if self.events: - code += ( - " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined - ) + code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -928,8 +900,7 @@ def clear_event_handlers(self) -> None: # Build the entry line and check if it exceeds 120 chars single_line = ( - f' "{event_name}": ' - f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),' ) if len(single_line) > 120: @@ -1105,9 +1076,7 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].enums.append(enum_def) - logger.debug( - f"Found enum: {def_name} with {len(values)} values" - ) + logger.debug(f"Found enum: {def_name} with {len(values)} values") else: # Extract fields from type definition fields = self._extract_type_fields(def_content) @@ -1120,9 +1089,7 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].types.append(type_def) - logger.debug( - f"Found type: {def_name} with {len(fields)} fields" - ) + logger.debug(f"Found type: {def_name} with {len(fields)} fields") def _is_enum_definition(self, definition: str) -> bool: """Check if a definition is an enum (string union with /). @@ -1235,9 +1202,7 @@ def _extract_events(self) -> None: Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) """ # Find definitions that are in the event_names set - event_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) + event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") for def_name, def_content in self.definitions.items(): # Skip if not identified as an event @@ -1271,16 +1236,12 @@ def _extract_events(self) -> None: ) self.modules[module_name].events.append(event) - logger.debug( - f"Found event: {def_name} (method={method}, params={params_type})" - ) + logger.debug(f"Found event: {def_name} (method={method}, params={params_type})") def _extract_commands(self) -> None: """Extract command definitions from parsed definitions.""" # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) + command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") for def_name, def_content in self.definitions.items(): # Skip definitions that are events (they share the same pattern) @@ -1301,9 +1262,7 @@ def _extract_commands(self) -> None: self.modules[module_name] = CddlModule(name=module_name) # Extract parameters and required parameters - params, required_params = self._extract_parameters_and_required( - params_type - ) + params, required_params = self._extract_parameters_and_required(params_type) # Create command cmd = CddlCommand( @@ -1315,13 +1274,9 @@ def _extract_commands(self) -> None: ) self.modules[module_name].commands.append(cmd) - logger.debug( - f"Found command: {method} with params {params_type}" - ) + logger.debug(f"Found command: {method} with params {params_type}") - def _extract_parameters( - self, params_type: str, _seen: set[str] | None = None - ) -> dict[str, str]: + def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), @@ -1366,9 +1321,7 @@ def _extract_parameters_and_required( # For union types, collect parameters from all alternatives # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params, _ = self._extract_parameters_and_required( - alt_type, _seen - ) + alt_params, _ = self._extract_parameters_and_required(alt_type, _seen) params.update(alt_params) # Note: We intentionally DON'T add to required, since these are union alternatives return params, required @@ -1470,9 +1423,7 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) filename = module_name_to_filename(module_name) - code += ( - f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" - ) + code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): @@ -1777,9 +1728,7 @@ def main( if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate Python WebDriver BiDi modules from CDDL specification" - ) + parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification") parser.add_argument( "cddl_file", help="Path to CDDL specification file", diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index f8ade8b9b3ad8..06c0573db9083 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -37,7 +37,6 @@ # ============================================================================ ENHANCEMENTS: dict[str, dict[str, Any]] = { - "browser": { # Dataclass custom methods "__dataclass_methods__": { @@ -183,7 +182,6 @@ class SetClientWindowStateParameters: return self._conn.execute(cmd)''', ], }, - "browsingContext": { # Method enhancements "create": { @@ -268,7 +266,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: ], # Download events are now in the CDDL spec, so no extra_events needed }, - "log": { # Make LogLevel an alias for Level so existing code using LogLevel works "aliases": {"LogLevel": "Level"}, @@ -332,7 +329,6 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "Entry", }, }, - "emulation": { "exclude_types": ["setNetworkConditionsParameters"], "extra_dataclasses": [ @@ -545,7 +541,6 @@ class SetNetworkConditionsParameters: return self._conn.execute(cmd)''', ], }, - "script": { "extra_methods": [ ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: @@ -921,7 +916,6 @@ def from_json(self2, p): self._unsubscribe_log_entry(callback_id)''', ], }, - "network": { "exclude_types": ["disownDataParameters"], # Initialize intercepts tracking list and per-handler intercept map @@ -1124,7 +1118,6 @@ def _auth_callback(params): self._remove_intercept(intercept_id)''', ], }, - "storage": { # Exclude auto-generated dataclasses that need custom to_bidi_dict() # for JSON-over-WebSocket serialization, or custom constructors. @@ -1299,7 +1292,6 @@ def to_bidi_dict(self) -> dict: def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict()''', - # StorageKeyPartitionDescriptor with camelCase serialization '''@dataclass class StorageKeyPartitionDescriptor: @@ -1380,7 +1372,6 @@ def to_dict(self) -> dict: ) return SetCookieResult(partition_key=pk) return result''', - ''' def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -1408,7 +1399,6 @@ def to_dict(self) -> dict: return result''', ], }, - "session": { # Override UserPromptHandler to add to_bidi_dict() for JSON serialization "exclude_types": ["UserPromptHandler"], @@ -1446,7 +1436,6 @@ def to_dict(self) -> dict: return self.to_bidi_dict()''', ], }, - "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], @@ -1527,7 +1516,6 @@ def to_dict(self) -> dict: return self._conn.execute(cmd)''', ], }, - "input": { # FileDialogInfo needs from_json for event deserialization "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], @@ -1651,9 +1639,7 @@ def transform_download_params( "type": "allowed", # Convert pathlib.Path (or any path-like) to str so the BiDi # protocol always receives a plain JSON string. - "destinationFolder": ( - str(destination_folder) if destination_folder is not None else None - ), + "destinationFolder": (str(destination_folder) if destination_folder is not None else None), } elif allowed is False: return {"type": "denied"} diff --git a/py/private/cdp.py b/py/private/cdp.py index ba4a73298ee0a..bac00765f43ca 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -60,9 +60,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) + available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index 7ec809eb20b18..92526c3a701be 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -122,9 +122,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" super().__init__(with_support, screen, stacktrace) @@ -191,9 +189,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" super().__init__(with_support, screen, stacktrace) @@ -279,9 +275,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" super().__init__(with_support, screen, stacktrace) diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py index 3fb3a6a1ceb6b..1dcc8288ce683 100644 --- a/py/selenium/webdriver/common/bidi/_event_manager.py +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -177,4 +177,4 @@ def clear_event_handlers(self) -> None: session.unsubscribe(subscriptions=[sub_id]) else: session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() \ No newline at end of file + self.subscriptions.clear() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..6310f2e18c2ce 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -101,7 +101,6 @@ def get_y(self): return self.y - @dataclass class UserContextInfo: """UserContextInfo.""" @@ -181,6 +180,7 @@ class ClientWindowNamedState: MINIMIZED = "minimized" NORMAL = "normal" + @dataclass class SetClientWindowStateParameters: """SetClientWindowStateParameters. @@ -194,6 +194,7 @@ class SetClientWindowStateParameters: client_window: Any | None = None state: Any | None = None + class Browser: """WebDriver BiDi browser module.""" @@ -202,8 +203,7 @@ def __init__(self, conn) -> None: def close(self): """Execute browser.close.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.close", params) result = self._conn.execute(cmd) @@ -216,10 +216,10 @@ def create_user_context( unhandled_prompt_behavior: Any | None = None, ): """Execute browser.createUserContext.""" - if proxy and hasattr(proxy, 'to_bidi_dict'): + if proxy and hasattr(proxy, "to_bidi_dict"): proxy = proxy.to_bidi_dict() - if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, "to_bidi_dict"): unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() params = { @@ -237,8 +237,7 @@ def create_user_context( def get_client_windows(self): """Execute browser.getClientWindows.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getClientWindows", params) result = self._conn.execute(cmd) @@ -252,7 +251,7 @@ def get_client_windows(self): state=item.get("state"), width=item.get("width"), x=item.get("x"), - y=item.get("y") + y=item.get("y"), ) for item in items if isinstance(item, dict) @@ -261,18 +260,13 @@ def get_client_windows(self): def get_user_contexts(self): """Execute browser.getUserContexts.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getUserContexts", params) result = self._conn.execute(cmd) if result and "userContexts" in result: items = result.get("userContexts", []) - return [ - item.get("userContext") - for item in items - if isinstance(item, dict) - ] + return [item.get("userContext") for item in items if isinstance(item, dict)] return [] def remove_user_context(self, user_context: Any | None = None): @@ -320,6 +314,7 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) + def set_client_window_state( self, client_window: Any | None = None, @@ -344,12 +339,9 @@ def set_client_window_state( # Serialize ClientWindowRectState if needed state_param = state - if hasattr(state, '__dataclass_fields__'): + if hasattr(state, "__dataclass_fields__"): # It's a dataclass, convert to dict - state_param = { - k: v for k, v in state.__dict__.items() - if v is not None - } + state_param = {k: v for k, v in state.__dict__.items() if v is not None} params = { "clientWindow": client_window, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 175c511393098..59a9813e58124 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -366,12 +366,14 @@ class DownloadWillBeginParams: suggested_filename: str | None = None + @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None + @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -383,6 +385,7 @@ class DownloadParams: url: str | None = None filepath: str | None = None + @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -402,6 +405,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -420,6 +424,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } + def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -452,12 +457,11 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -558,7 +562,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent") + parent=item.get("parent"), ) for item in items if isinstance(item, dict) @@ -725,7 +729,6 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None) result = self._conn.execute(cmd) return result - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -752,48 +755,49 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined +ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined +DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index ba4a73298ee0a..bac00765f43ca 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -60,9 +60,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) + available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index fc75caa282a45..ff67b56622c35 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -23,9 +23,7 @@ from typing import Any -def command_builder( - method: str, params: dict[str, Any] | None = None -) -> Generator[dict[str, Any], Any, Any]: +def command_builder(method: str, params: dict[str, Any] | None = None) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 1c48100cc343b..0860890abf41b 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -178,6 +178,7 @@ class SetNetworkConditionsParameters: # Backward-compatible alias for existing imports setNetworkConditionsParameters = SetNetworkConditionsParameters + class Emulation: """WebDriver BiDi emulation module.""" @@ -319,9 +320,7 @@ def set_geolocation_override( if isinstance(error, dict): params["error"] = error else: - params["error"] = { - "type": error.type if error.type is not None else "positionUnavailable" - } + params["error"] = {"type": error.type if error.type is not None else "positionUnavailable"} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -329,6 +328,7 @@ def set_geolocation_override( cmd = command_builder("emulation.setGeolocationOverride", params) result = self._conn.execute(cmd) return result + def set_timezone_override( self, timezone=None, @@ -353,6 +353,7 @@ def set_timezone_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setTimezoneOverride", params) return self._conn.execute(cmd) + def set_scripting_enabled( self, enabled=None, @@ -377,6 +378,7 @@ def set_scripting_enabled( params["userContexts"] = user_contexts cmd = command_builder("emulation.setScriptingEnabled", params) return self._conn.execute(cmd) + def set_user_agent_override( self, user_agent=None, @@ -400,6 +402,7 @@ def set_user_agent_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setUserAgentOverride", params) return self._conn.execute(cmd) + def set_screen_orientation_override( self, screen_orientation=None, @@ -436,6 +439,7 @@ def set_screen_orientation_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setScreenOrientationOverride", params) return self._conn.execute(cmd) + def set_network_conditions( self, network_conditions=None, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..5d4c670490089 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -180,6 +180,7 @@ def from_json(cls, params: dict) -> FileDialogInfo: multiple=params.get("multiple"), ) + @dataclass class PointerMoveAction: """PointerMoveAction.""" @@ -191,6 +192,7 @@ class PointerMoveAction: origin: Any | None = None properties: Any | None = None + @dataclass class PointerDownAction: """PointerDownAction.""" @@ -199,15 +201,18 @@ class PointerDownAction: button: Any | None = None properties: Any | None = None + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "file_dialog_opened": "input.fileDialogOpened", } + class Input: """WebDriver BiDi input module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -305,9 +310,10 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: input.fileDialogOpened -FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined +FileDialogOpened = globals().get("FileDialogInfo", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..856d8561e706f 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -24,6 +24,7 @@ class Level: LogLevel = Level + @dataclass class BaseLogEntry: """BaseLogEntry.""" @@ -69,6 +70,7 @@ def from_json(cls, params: dict) -> ConsoleLogEntry: stack_trace=params.get("stackTrace"), ) + @dataclass class JavascriptLogEntry: """JavascriptLogEntry - a JavaScript error log entry from the browser.""" @@ -92,6 +94,7 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: stacktrace=params.get("stackTrace"), ) + Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry # BiDi Event Name to Parameter Type Mapping @@ -99,15 +102,16 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "log.entryAdded", } + class Log: """WebDriver BiDi log module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -134,6 +138,7 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: log.entryAdded EntryAdded = Entry diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index d6875e14fa58a..e13fbe0f7a20b 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -360,6 +360,7 @@ class DisownDataParameters: # Backward-compatible alias for existing imports disownDataParameters = DisownDataParameters + class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -377,6 +378,7 @@ def __init__(self, type: Any | None, value: Any | None) -> None: def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} + class Request: """Wraps a BiDi network request event params and provides request action methods.""" @@ -395,16 +397,19 @@ def continue_request(self, **kwargs): params.update(kwargs) self._conn.execute(_cb("network.continueRequest", params)) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", "before_request": "network.beforeRequestSent", } + class Network: """WebDriver BiDi network module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -755,6 +760,7 @@ def _add_intercept(self, phases=None, url_patterns=None): if intercept_id and intercept_id not in self.intercepts: self.intercepts.append(intercept_id) return result + def _remove_intercept(self, intercept_id): """Remove a low-level network intercept.""" from selenium.webdriver.common.bidi.common import command_builder as _cb @@ -762,6 +768,7 @@ def _remove_intercept(self, intercept_id): self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) if intercept_id in self.intercepts: self.intercepts.remove(intercept_id) + def add_request_handler(self, event, callback, url_patterns=None): """Add a handler for network requests at the specified phase. @@ -784,11 +791,7 @@ def add_request_handler(self, event, callback, url_patterns=None): intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): - raw = ( - params - if isinstance(params, dict) - else (params.__dict__ if hasattr(params, "__dict__") else {}) - ) + raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) request = Request(self._conn, raw) callback(request) @@ -796,6 +799,7 @@ def _request_callback(params): if intercept_id: self._handler_intercepts[callback_id] = intercept_id return callback_id + def remove_request_handler(self, event, callback_id): """Remove a network request handler and its associated network intercept. @@ -807,11 +811,13 @@ def remove_request_handler(self, event, callback_id): intercept_id = self._handler_intercepts.pop(callback_id, None) if intercept_id: self._remove_intercept(intercept_id) + def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() for intercept_id in list(self.intercepts): self._remove_intercept(intercept_id) + def add_auth_handler(self, username, password): """Add an auth handler that automatically provides credentials. @@ -829,16 +835,8 @@ def add_auth_handler(self, username, password): intercept_id = intercept_result.get("intercept") if intercept_result else None def _auth_callback(params): - raw = ( - params - if isinstance(params, dict) - else (params.__dict__ if hasattr(params, "__dict__") else {}) - ) - request_id = ( - raw.get("request", {}).get("request") - if isinstance(raw, dict) - else None - ) + raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) + request_id = raw.get("request", {}).get("request") if isinstance(raw, dict) else None if request_id: self._conn.execute( _cb( @@ -859,6 +857,7 @@ def _auth_callback(params): if intercept_id: self._handler_intercepts[callback_id] = intercept_id return callback_id + def remove_auth_handler(self, callback_id): """Remove an auth handler by callback ID and its associated network intercept. @@ -896,9 +895,10 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: network.authRequired -AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined +AuthRequired = globals().get("AuthRequiredParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index acb8bdf65f0ef..98e25a1d2f856 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -82,8 +82,7 @@ def set_permission( state_value = state.value if isinstance(state, PermissionState) else state if state_value not in _VALID_PERMISSION_STATES: raise ValueError( - f"Invalid permission state: {state_value!r}. " - f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" + f"Invalid permission state: {state_value!r}. Must be one of {sorted(_VALID_PERMISSION_STATES)}" ) if isinstance(descriptor, str): diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index ecc2a75e0922d..38e43a6677470 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -620,10 +620,12 @@ class RealmDestroyedParameters: "realm_destroyed": "script.realmDestroyed", } + class Script: """WebDriver BiDi script module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver @@ -845,6 +847,7 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw + def _add_preload_script( self, function_declaration, @@ -880,6 +883,7 @@ def _add_preload_script( if isinstance(result, dict): return result.get("script") return result + def _remove_preload_script(self, script_id): """Remove a preload script by ID. @@ -887,6 +891,7 @@ def _remove_preload_script(self, script_id): script_id: The ID of the preload script to remove. """ return self.remove_preload_script(script=script_id) + def pin(self, function_declaration): """Pin (add) a preload script that runs on every page load. @@ -897,6 +902,7 @@ def pin(self, function_declaration): script_id: The ID of the pinned script (str). """ return self._add_preload_script(function_declaration) + def unpin(self, script_id): """Unpin (remove) a previously pinned preload script. @@ -904,6 +910,7 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) + def _evaluate( self, expression, @@ -926,6 +933,7 @@ def _evaluate( Returns: An object with .realm, .result (dict or None), and .exception_details (or None). """ + class _EvalResult: def __init__(self2, realm, result, exception_details): self2.realm = realm @@ -947,6 +955,7 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) + def _call_function( self, function_declaration, @@ -973,6 +982,7 @@ def _call_function( Returns: An object with .result (dict or None) and .exception_details (or None). """ + class _CallResult: def __init__(self2, result, exception_details): self2.result = result @@ -995,6 +1005,7 @@ def __init__(self2, result, exception_details): if raw.get("type") == "success": return _CallResult(result=raw.get("result"), exception_details=None) return _CallResult(result=raw, exception_details=None) + def _get_realms(self, context=None, type=None): """Get all realms, optionally filtered by context and type. @@ -1005,6 +1016,7 @@ def _get_realms(self, context=None, type=None): Returns: List of realm info objects with .realm, .origin, .type, .context attributes. """ + class _RealmInfo: def __init__(self2, realm, origin, type_, context): self2.realm = realm @@ -1017,13 +1029,16 @@ def __init__(self2, realm, origin, type_, context): result = [] for r in realms_list: if isinstance(r, dict): - result.append(_RealmInfo( - realm=r.get("realm"), - origin=r.get("origin"), - type_=r.get("type"), - context=r.get("context"), - )) + result.append( + _RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + ) + ) return result + def _disown(self, handles, target): """Disown handles in a browsing context. @@ -1032,6 +1047,7 @@ def _disown(self, handles, target): target: A dict like {"context": }. """ return self.disown(handles=handles, target=target) + def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading @@ -1068,9 +1084,7 @@ def _wrapped(raw): if entry_type_filter is None: callback(entry) else: - t = getattr(entry, "type_", None) or ( - entry.get("type") if isinstance(entry, dict) else None - ) + t = getattr(entry, "type_", None) or (entry.get("type") if isinstance(entry, dict) else None) if t == entry_type_filter: callback(entry) @@ -1086,15 +1100,14 @@ def from_json(self2, p): if bidi_event not in self._log_subscriptions: session = _Session(self._conn) result = session.subscribe([bidi_event]) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) + sub_id = result.get("subscription") if isinstance(result, dict) else None self._log_subscriptions[bidi_event] = { "callbacks": [], "subscription_id": sub_id, } self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) return callback_id + def _unsubscribe_log_entry(self, callback_id): """Unsubscribe a log entry callback by ID.""" from selenium.webdriver.common.bidi.session import Session as _Session @@ -1123,6 +1136,7 @@ def from_json(self2, p): else: session.unsubscribe(events=[bidi_event]) del self._log_subscriptions[bidi_event] + def add_console_message_handler(self, callback: Callable) -> int: """Add a handler for console log messages (log.entryAdded type=console). @@ -1133,9 +1147,11 @@ def add_console_message_handler(self, callback: Callable) -> int: callback_id for use with remove_console_message_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="console") + def remove_console_message_handler(self, callback_id: int) -> None: """Remove a console message handler by callback ID.""" self._unsubscribe_log_entry(callback_id) + def add_javascript_error_handler(self, callback: Callable) -> int: """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). @@ -1146,6 +1162,7 @@ def add_javascript_error_handler(self, callback: Callable) -> int: callback_id for use with remove_javascript_error_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="javascript") + def remove_javascript_error_handler(self, callback_id: int) -> None: """Remove a JavaScript error handler by callback ID.""" self._unsubscribe_log_entry(callback_id) @@ -1176,12 +1193,13 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: script.realmCreated -RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined +RealmCreated = globals().get("RealmInfo", dict) # Fallback to dict if type not defined # Event: script.realmDestroyed -RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined +RealmDestroyed = globals().get("RealmDestroyedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..741faeb42bc43 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -180,6 +180,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class Session: """WebDriver BiDi session module.""" @@ -188,8 +189,7 @@ def __init__(self, conn) -> None: def status(self): """Execute session.status.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.status", params) result = self._conn.execute(cmd) @@ -210,8 +210,7 @@ def new(self, capabilities: Any | None = None): def end(self): """Execute session.end.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.end", params) result = self._conn.execute(cmd) @@ -247,4 +246,3 @@ def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] cmd = command_builder("session.unsubscribe", params) result = self._conn.execute(cmd) return result - diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..9825407c2eaf8 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -87,6 +87,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class SameSite: """SameSite cookie attribute values.""" @@ -95,6 +96,7 @@ class SameSite: NONE = "none" DEFAULT = "default" + @dataclass class StorageCookie: """A cookie object returned by storage.getCookies.""" @@ -129,6 +131,7 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: expiry=raw.get("expiry"), ) + @dataclass class CookieFilter: """CookieFilter.""" @@ -170,6 +173,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + @dataclass class PartialCookie: """PartialCookie.""" @@ -208,6 +212,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class BrowsingContextPartitionDescriptor: """BrowsingContextPartitionDescriptor. @@ -227,6 +232,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + @dataclass class StorageKeyPartitionDescriptor: """StorageKeyPartitionDescriptor.""" @@ -248,6 +254,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class Storage: """WebDriver BiDi storage module.""" @@ -268,11 +275,7 @@ def get_cookies(self, filter=None, partition=None): cmd = command_builder("storage.getCookies", params) result = self._conn.execute(cmd) if result and "cookies" in result: - cookies = [ - StorageCookie.from_bidi_dict(c) - for c in result.get("cookies", []) - if isinstance(c, dict) - ] + cookies = [StorageCookie.from_bidi_dict(c) for c in result.get("cookies", []) if isinstance(c, dict)] pk_raw = result.get("partitionKey") pk = ( PartitionKey( @@ -284,6 +287,7 @@ def get_cookies(self, filter=None, partition=None): ) return GetCookiesResult(cookies=cookies, partition_key=pk) return GetCookiesResult(cookies=[], partition_key=None) + def set_cookie(self, cookie=None, partition=None): """Execute storage.setCookie.""" if cookie and hasattr(cookie, "to_bidi_dict"): @@ -309,6 +313,7 @@ def set_cookie(self, cookie=None, partition=None): ) return SetCookieResult(partition_key=pk) return result + def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..03fedab62e174 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -87,14 +87,16 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None + k + for k, v in { + "path": path, + "archive_path": archive_path, + "base64_value": base64_value, + }.items() + if v is not None ] if len(provided) != 1: - raise ValueError( - f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" - ) + raise ValueError(f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}") if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: @@ -115,6 +117,7 @@ def install( "in your WebDriver configuration." ) from e raise + def uninstall(self, extension: str | dict): """Uninstall a web extension. diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index 28de19afa5742..eadf1d069709f 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -35,23 +35,13 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make( - 0, "DIRECT" - ) # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make( - 1, "MANUAL" - ) # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make( - 4, "AUTODETECT" - ) # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make( - 5, "SYSTEM" - ) # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make( - 6, "UNSPECIFIED" - ) # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). @classmethod def load(cls, value): @@ -60,11 +50,7 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if ( - isinstance(attr_value, dict) - and "string" in attr_value - and attr_value["string"] == value - ): + if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: return attr_value raise Exception(f"No proxy type is found for {value}") @@ -219,17 +205,13 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [ - host.strip() for host in self.noProxy.split(",") if host.strip() - ] + result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError( - "no_proxy must be a comma-separated string or a list of strings" - ) + raise TypeError("no_proxy must be a comma-separated string or a list of strings") elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 2c41897878075..4e426090883d4 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -116,9 +116,7 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig( - remote_server_addr=command_executor - ) + client_config = client_config or ClientConfig(remote_server_addr=command_executor) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -400,13 +398,9 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element( - value["element-6066-11e4-a52e-4f735466cecf"] - ) + return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls( - self, value["shadow-6066-11e4-a52e-4f735466cecf"] - ) + return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -432,9 +426,7 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ - "value" - ] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] def execute( self, @@ -470,9 +462,7 @@ def execute( elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute( - driver_command, params - ) + response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) if response: self.error_handler.check_response(response) @@ -528,9 +518,7 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError( - f"No script with key: {script_key} existed in {self.pinned_scripts}" - ) from None + raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -563,9 +551,7 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})[ - "value" - ] + return self.execute(command, {"script": script, "args": converted_args})["value"] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -584,9 +570,7 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})[ - "value" - ] + return self.execute(command, {"script": script, "args": converted_args})["value"] @property def current_url(self) -> str: @@ -763,9 +747,7 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute( - Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} - ) + self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -794,9 +776,7 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute( - Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} - ) + self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) except WebDriverException: self.execute( Command.SET_TIMEOUTS, @@ -837,9 +817,7 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element( - self, by: str | RelativeBy = By.ID, value: str | None = None - ) -> WebElement: + def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -860,18 +838,12 @@ def find_element( if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException( - f"Cannot locate relative element with: {by.root}" - ) + raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ - "value" - ] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] - def find_elements( - self, by: str | RelativeBy = By.ID, value: str | None = None - ) -> list[WebElement]: + def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -893,21 +865,14 @@ def find_elements( _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError( - f"Could not find findElements.js in package {_pkg}" - ) + raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") raw_function = raw_data.decode("utf8") - find_element_js = ( - f"/* findElements */return ({raw_function}).apply(null, arguments);" - ) + find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return ( - self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] - or [] - ) + return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] @property def capabilities(self) -> dict: @@ -1004,9 +969,7 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position( - self, x: float, y: float, windowHandle: str = "current" - ) -> dict: + def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: """Sets the x,y position of the current window. Args: @@ -1065,9 +1028,7 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute( - Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} - )["value"] + return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] @property def file_detector(self) -> FileDetector: @@ -1112,9 +1073,7 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException( - "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" - ) + raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1129,9 +1088,7 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1140,28 +1097,20 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError( - "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." - ) + raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException( - "command_executor must be a RemoteConnection instance for CDP support" - ) + raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute( - self._devtools.target.get_targets() - ) + targets = self._websocket_connection.execute(self._devtools.target.get_targets()) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute( - self._devtools.target.attach_to_target(target_id, True) - ) + session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1176,9 +1125,7 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1204,14 +1151,10 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException( - "command_executor must be a RemoteConnection instance for BiDi support" - ) + raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") self._websocket_connection = WebSocketConnection( ws_url, @@ -1427,13 +1370,9 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get( - "debuggerAddress" - ) + debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get( - "debuggerAddress" - ) + debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1461,9 +1400,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute( - Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() - )["value"] + self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1503,12 +1440,8 @@ def add_credential(self, credential: Credential) -> None: @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute( - Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} - ) - return [ - Credential.from_dict(credential) for credential in credential_data["value"] - ] + credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + return [Credential.from_dict(credential) for credential in credential_data["value"]] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1530,9 +1463,7 @@ def remove_credential(self, credential_id: str | bytearray) -> None: @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute( - Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} - ) + self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1553,9 +1484,7 @@ def set_user_verified(self, verified: bool) -> None: def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1570,16 +1499,12 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ - "contents" - ] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1592,9 +1517,7 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") self.execute(Command.DELETE_DOWNLOADABLE_FILES) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 44cb2adef7a0b..cd34c35db3696 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -158,9 +158,7 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException( - "The command's generator function did not exit when expected!" - ) + raise WebDriverException("The command's generator function did not exit when expected!") except StopIteration as exit: return exit.value @@ -177,15 +175,11 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever( - sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True - ) + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp( - self.url, on_open=on_open, on_message=on_message, on_error=on_error - ) + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 8038ee826aa74..86e3d11af0341 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -60,9 +60,7 @@ def test_create_window(driver): def test_create_window_with_reference_context(driver): """Test creating a window with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create( - type=WindowTypes.WINDOW, reference_context=reference_context - ) + context_id = driver.browsing_context.create(type=WindowTypes.WINDOW, reference_context=reference_context) assert context_id is not None # Clean up @@ -81,9 +79,7 @@ def test_create_tab(driver): def test_create_tab_with_reference_context(driver): """Test creating a tab with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create( - type=WindowTypes.TAB, reference_context=reference_context - ) + context_id = driver.browsing_context.create(type=WindowTypes.TAB, reference_context=reference_context) assert context_id is not None # Clean up @@ -128,9 +124,7 @@ def test_navigate_to_url_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - result = driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + result = driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert context_id is not None assert "/bidi/logEntryAdded.html" in result["url"] @@ -144,9 +138,7 @@ def test_get_tree_with_child(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate( - context=reference_context, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) context_info_list = driver.browsing_context.get_tree(root=reference_context) @@ -162,13 +154,9 @@ def test_get_tree_with_depth(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate( - context=reference_context, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) - context_info_list = driver.browsing_context.get_tree( - root=reference_context, max_depth=0 - ) + context_info_list = driver.browsing_context.get_tree(root=reference_context, max_depth=0) assert len(context_info_list) == 1 info = context_info_list[0] @@ -239,9 +227,7 @@ def test_reload_browsing_context(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) reload_info = driver.browsing_context.reload(context=context_id) @@ -256,13 +242,9 @@ def test_reload_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) - reload_info = driver.browsing_context.reload( - context=context_id, wait=ReadinessState.COMPLETE - ) + reload_info = driver.browsing_context.reload(context=context_id, wait=ReadinessState.COMPLETE) assert reload_info["navigation"] is not None assert "/bidi/logEntryAdded.html" in reload_info["url"] @@ -359,9 +341,7 @@ def test_capture_screenshot_with_parameters(driver, pages): clip = {"type": "box", "x": rect["x"], "y": rect["y"], "width": 5, "height": 5} - screenshot = driver.browsing_context.capture_screenshot( - context=context_id, origin="document", clip=clip - ) + screenshot = driver.browsing_context.capture_screenshot(context=context_id, origin="document", clip=clip) assert len(screenshot) > 0 @@ -372,20 +352,14 @@ def test_set_viewport(driver, pages): driver.get(pages.url("formPage.html")) try: - driver.browsing_context.set_viewport( - context=context_id, viewport={"width": 251, "height": 301} - ) + driver.browsing_context.set_viewport(context=context_id, viewport={"width": 251, "height": 301}) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 251 assert viewport_size[1] == 301 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_set_viewport_with_device_pixel_ratio(driver, pages): @@ -400,9 +374,7 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): device_pixel_ratio=5, ) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 252 assert viewport_size[1] == 302 @@ -411,9 +383,7 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): assert device_pixel_ratio == 5 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): @@ -430,9 +400,7 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): driver.browsing_context.set_viewport(context=context_id) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 253 assert viewport_size[1] == 303 @@ -441,9 +409,7 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): assert device_pixel_ratio == 6 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) @pytest.mark.xfail_chrome @@ -452,9 +418,7 @@ def test_set_viewport_back_to_default(driver, pages): context_id = driver.current_window_handle driver.get(pages.url("formPage.html")) - default_viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + default_viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") default_device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") try: @@ -464,13 +428,9 @@ def test_set_viewport_back_to_default(driver, pages): device_pixel_ratio=10, ) - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") # Allow some tolerance since some window managers might not put it to the exact value @@ -478,9 +438,7 @@ def test_set_viewport_back_to_default(driver, pages): assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 assert device_pixel_ratio == default_device_pixel_ratio finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_print_page(driver, pages): @@ -499,9 +457,7 @@ def test_print_page(driver, pages): def test_navigate_back_in_browser_history(driver, pages): """Test navigating back in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate( - context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -514,9 +470,7 @@ def test_navigate_back_in_browser_history(driver, pages): def test_navigate_forward_in_browser_history(driver, pages): """Test navigating forward in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate( - context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -538,9 +492,7 @@ def test_locate_nodes(driver, pages): driver.get(pages.url("xhtmlTest.html")) - elements = driver.browsing_context.locate_nodes( - context=context_id, locator={"type": "css", "value": "div"} - ) + elements = driver.browsing_context.locate_nodes(context=context_id, locator={"type": "css", "value": "div"}) assert len(elements) > 0 @@ -660,9 +612,7 @@ def test_add_event_handler_context_created(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_created", on_context_created - ) + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) assert callback_id is not None # Create a new context to trigger the event @@ -670,10 +620,7 @@ def on_context_created(info): # Verify the event was received (might be > 1 since default context is also included) assert len(events_received) >= 1 - assert ( - events_received[0].context == context_id - or events_received[1].context == context_id - ) + assert events_received[0].context == context_id or events_received[1].context == context_id driver.browsing_context.close(context_id) driver.browsing_context.remove_event_handler("context_created", callback_id) @@ -686,9 +633,7 @@ def test_add_event_handler_context_destroyed(driver): def on_context_destroyed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_destroyed", on_context_destroyed - ) + callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) assert callback_id is not None # Create and then close a context to trigger the event @@ -708,17 +653,13 @@ def test_add_event_handler_navigation_committed(driver, pages): def on_navigation_committed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_committed", on_navigation_committed - ) + callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) >= 1 assert any(url in event.url for event in events_received) @@ -733,17 +674,13 @@ def test_add_event_handler_dom_content_loaded(driver, pages): def on_dom_content_loaded(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "dom_content_loaded", on_dom_content_loaded - ) + callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -763,9 +700,7 @@ def on_load(info): context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -780,16 +715,12 @@ def test_add_event_handler_navigation_started(driver, pages): def on_navigation_started(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_started", on_navigation_started - ) + callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -804,23 +735,17 @@ def test_add_event_handler_fragment_navigated(driver, pages): def on_fragment_navigated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "fragment_navigated", on_fragment_navigated - ) + callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) assert callback_id is not None # First navigate to a page context_id = driver.current_window_handle url = pages.url("linked_image.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) # Then navigate to the same page with a fragment to trigger the event fragment_url = url + "#link" - driver.browsing_context.navigate( - context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("link" in event.url for event in events_received) @@ -836,17 +761,13 @@ def test_add_event_handler_navigation_failed(driver): def on_navigation_failed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_failed", on_navigation_failed - ) + callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) assert callback_id is not None # Navigate to an invalid URL to trigger the event context_id = driver.current_window_handle try: - driver.browsing_context.navigate( - context=context_id, url="http://invalid-domain-that-does-not-exist.test/" - ) + driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") except Exception: # Expect an exception due to navigation failure pass @@ -865,9 +786,7 @@ def test_add_event_handler_user_prompt_opened(driver, pages): def on_user_prompt_opened(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "user_prompt_opened", on_user_prompt_opened - ) + callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) assert callback_id is not None # Create an alert to trigger the event @@ -892,9 +811,7 @@ def test_add_event_handler_user_prompt_closed(driver, pages): def on_user_prompt_closed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "user_prompt_closed", on_user_prompt_closed - ) + callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) assert callback_id is not None create_prompt_page(driver, pages) @@ -919,16 +836,12 @@ def test_add_event_handler_history_updated(driver, pages): def on_history_updated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "history_updated", on_history_updated - ) + callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) # Use history.pushState to trigger history updated event driver.script.execute("() => { history.pushState({}, '', '/new-path'); }") @@ -948,17 +861,13 @@ def test_add_event_handler_download_will_begin(driver, pages): def on_download_will_begin(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "download_will_begin", on_download_will_begin - ) + callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) assert callback_id is not None # click on a download link to trigger the event context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) download_xpath_file_1_txt = '//*[@id="file-1"]' driver.find_element(By.XPATH, download_xpath_file_1_txt).click() @@ -978,16 +887,12 @@ def test_add_event_handler_download_end(driver, pages): def on_download_end(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "download_end", on_download_end - ) + callback_id = driver.browsing_context.add_event_handler("download_end", on_download_end) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) driver.find_element(By.ID, "file-1").click() @@ -1005,14 +910,12 @@ def on_download_end(info): # we assert that atleast "file_1" is present in the downloaded file since multiple downloads # will have numbered suffix like file_1 (1) assert any( - "downloads/file_1.txt" in ev.download_params.url - and "file_1" in ev.download_params.filepath + "downloads/file_1.txt" in ev.download_params.url and "file_1" in ev.download_params.filepath for ev in events_received ) assert any( - "downloads/file_2.jpg" in ev.download_params.url - and "file_2" in ev.download_params.filepath + "downloads/file_2.jpg" in ev.download_params.url and "file_2" in ev.download_params.filepath for ev in events_received ) @@ -1051,9 +954,7 @@ def test_remove_event_handler(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_created", on_context_created - ) + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) # Create a context to trigger the event context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1085,12 +986,8 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers for the same event - callback_id_1 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_1 - ) - callback_id_2 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_2 - ) + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) # Create a context to trigger both handlers context_id = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1119,12 +1016,8 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers - callback_id_1 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_1 - ) - callback_id_2 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_2 - ) + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) # Create a context to trigger both handlers context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1206,9 +1099,7 @@ def callback(info): def register_handler(self, thread_id): try: callback = self.make_callback() - callback_id = self.driver.browsing_context.add_event_handler( - "context_created", callback - ) + callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) with self.data_lock: self.callback_ids.append(callback_id) if len(self.callback_ids) == 5: @@ -1216,16 +1107,12 @@ def register_handler(self, thread_id): return callback_id except Exception as e: with self.data_lock: - self.thread_errors.append( - f"Thread {thread_id}: Registration failed: {e}" - ) + self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") return None def remove_handler(self, callback_id, thread_id): try: - self.driver.browsing_context.remove_event_handler( - "context_created", callback_id - ) + self.driver.browsing_context.remove_event_handler("context_created", callback_id) except Exception as e: with self.data_lock: self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") @@ -1235,19 +1122,13 @@ def test_concurrent_event_handler_registration(driver): helper = _EventHandlerTestHelper(driver) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit(helper.register_handler, f"reg-{i}") for i in range(5) - ] + futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] for future in futures: future.result(timeout=15) helper.registration_complete.wait(timeout=5) - assert ( - len(helper.callback_ids) == 5 - ), f"Expected 5 handlers, got {len(helper.callback_ids)}" - assert not helper.thread_errors, "Errors during registration: \n" + "\n".join( - helper.thread_errors - ) + assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) def test_event_callback_data_consistency(driver): @@ -1265,9 +1146,7 @@ def test_event_callback_data_consistency(driver): driver.browsing_context.close(ctx) with helper.data_lock: - assert not helper.consistency_errors, "Consistency errors: " + str( - helper.consistency_errors - ) + assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) assert len(helper.events_received) > 0, "No events received" assert len(helper.events_received) == sum(helper.context_counts.values()) assert len(helper.events_received) == sum(helper.event_type_counts.values()) @@ -1288,9 +1167,7 @@ def test_concurrent_event_handler_removal(driver): for future in futures: future.result(timeout=15) - assert not helper.thread_errors, "Errors during removal: \n" + "\n".join( - helper.thread_errors - ) + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) def test_no_event_after_handler_removal(driver): From 818daba99b6d24efe49c722a2351ed4122a5b030 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:06:25 +0100 Subject: [PATCH 31/42] Use sentinel pattern for set viewport --- py/private/bidi_enhancements_manifest.py | 28 ++++++ .../webdriver/common/bidi/browsing_context.py | 87 ++++++++++--------- 2 files changed, 72 insertions(+), 43 deletions(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 06c0573db9083..57e3d4e35f0dc 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -184,6 +184,7 @@ class SetClientWindowStateParameters: }, "browsingContext": { # Method enhancements + "exclude_methods": ["set_viewport"], "create": { "extract_field": "context", }, @@ -223,6 +224,33 @@ class SetClientWindowStateParameters: "devicePixelRatio": "float", }, }, + "extra_methods": [ + ''' def set_viewport( + self, + context: str | None = None, + viewport: Any = ..., + user_contexts: Any | None = None, + device_pixel_ratio: Any = ..., + ): + """Execute browsingContext.setViewport. + + Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, + while omitted arguments are not sent. + """ + params = {} + if context is not None: + params["context"] = context + if user_contexts is not None: + params["userContexts"] = user_contexts + if viewport is not ...: + params["viewport"] = viewport + if device_pixel_ratio is not ...: + params["devicePixelRatio"] = device_pixel_ratio + + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result''', + ], # Non-CDDL download event dataclasses (Chromium-specific) "extra_dataclasses": [ '''@dataclass diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 59a9813e58124..177d727c97949 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" @@ -366,14 +365,12 @@ class DownloadWillBeginParams: suggested_filename: str | None = None - @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None - @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -385,7 +382,6 @@ class DownloadParams: url: str | None = None filepath: str | None = None - @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -405,7 +401,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) - # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -424,7 +419,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } - def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -457,11 +451,12 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None + + class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} - def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -562,7 +557,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent"), + parent=item.get("parent") ) for item in items if isinstance(item, dict) @@ -694,25 +689,6 @@ def set_bypass_csp( result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any | None = None, - user_contexts: Any | None = None, - device_pixel_ratio: Any | None = None, - ): - """Execute browsingContext.setViewport.""" - params = { - "context": context, - "viewport": viewport, - "userContexts": user_contexts, - "devicePixelRatio": device_pixel_ratio, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.setViewport", params) - result = self._conn.execute(cmd) - return result - def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: @@ -729,6 +705,32 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None) result = self._conn.execute(cmd) return result + def set_viewport( + self, + context: str | None = None, + viewport: Any = ..., + user_contexts: Any | None = None, + device_pixel_ratio: Any = ..., + ): + """Execute browsingContext.setViewport. + + Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, + while omitted arguments are not sent. + """ + params = {} + if context is not None: + params["context"] = context + if user_contexts is not None: + params["userContexts"] = user_contexts + if viewport is not ...: + params["viewport"] = viewport + if device_pixel_ratio is not ...: + params["devicePixelRatio"] = device_pixel_ratio + + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -755,49 +757,48 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings From cddfecb833003c068a9aff289e88ca1b978c6248 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:28:11 +0100 Subject: [PATCH 32/42] formatting sigh --- .../webdriver/common/bidi/browsing_context.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 177d727c97949..5491e157a87c2 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" @@ -365,12 +366,14 @@ class DownloadWillBeginParams: suggested_filename: str | None = None + @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None + @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -382,6 +385,7 @@ class DownloadParams: url: str | None = None filepath: str | None = None + @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -401,6 +405,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -419,6 +424,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } + def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -451,12 +457,11 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -557,7 +562,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent") + parent=item.get("parent"), ) for item in items if isinstance(item, dict) @@ -757,48 +762,49 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined +ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined +DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings From 3f3b8209bb0ab2f24ce63012cb2049da11399140 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:47:58 +0100 Subject: [PATCH 33/42] more formatting because ruff format isn't enough --- py/BUILD.bazel | 1 - py/generate_bidi.py | 19 +++++++++- py/private/bidi_enhancements_manifest.py | 18 ++++++++++ py/private/cdp.py | 36 ++++++++----------- py/private/generate_bidi.bzl | 1 - py/selenium/webdriver/common/bidi/__init__.py | 19 ++++++++-- py/selenium/webdriver/common/bidi/browser.py | 20 ++++++++--- .../webdriver/common/bidi/browsing_context.py | 20 ++++++++--- .../webdriver/common/bidi/emulation.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/input.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/log.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/network.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/script.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/session.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/storage.py | 20 ++++++++--- .../webdriver/common/bidi/webextension.py | 20 ++++++++--- 16 files changed, 227 insertions(+), 67 deletions(-) diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 292cde4981d74..186324560aade 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -810,7 +810,6 @@ BROWSER_TESTS = { ] ] - test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 194d94ba12d04..5b301d3ec7e40 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,4 +1,21 @@ -#!/usr/bin/env python3.10 +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + """ Generate Python WebDriver BiDi command modules from CDDL specification. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 57e3d4e35f0dc..4b25688ed47c4 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1,3 +1,21 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + """ Enhancement manifest for BiDi code generation. diff --git a/py/private/cdp.py b/py/private/cdp.py index bac00765f43ca..d94f0dac2e32b 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -1,26 +1,20 @@ -# The MIT License(MIT) +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Copyright(c) 2018 Hyperion Gray +# http://www.apache.org/licenses/LICENSE-2.0 # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import contextvars import importlib diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index e072279f85e94..8b4cc4e3e648f 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -72,7 +72,6 @@ def _generate_bidi_impl(ctx): return [DefaultInfo(files = depset(outputs))] - generate_bidi = rule( implementation = _generate_bidi_impl, attrs = { diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 79ba3dbf2f86f..b37319da3651b 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,7 +1,20 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 6310f2e18c2ce..440f13ed00072 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: browser +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5491e157a87c2..b5e14f19c6864 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: browsingContext +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 0860890abf41b..f1bc0c9efeb0a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: emulation +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5d4c670490089..6c06fc4e7deaa 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: input +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 856d8561e706f..597936402f99c 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: log +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index e13fbe0f7a20b..6c24e399b0e54 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: network +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 38e43a6677470..ee6eb4f4a437a 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: script +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 741faeb42bc43..b00544d286546 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: session +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 9825407c2eaf8..90e65ac3d5ffb 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: storage +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 03fedab62e174..62f2dec130308 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: webExtension +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field From 86a9f45e9c9dcdb341f5825bb6b50de1a3e430b6 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 20:23:08 +0100 Subject: [PATCH 34/42] correct signature --- py/private/bidi_enhancements_manifest.py | 32 ++++++++ .../webdriver/common/bidi/emulation.py | 73 ++++++++++--------- 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 4b25688ed47c4..8cec1f9da245f 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -584,6 +584,38 @@ class SetNetworkConditionsParameters: if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd)''', + ''' def set_screen_settings_override( + self, + width: int | None = None, + height: int | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenSettingsOverride. + + Sets or clears the screen settings override for specified browsing or user + contexts. + + Args: + width: The screen width in pixels, or ``None`` to clear the override. + height: The screen height in pixels, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + screen_area = None + if width is not None or height is not None: + screen_area = {} + if width is not None: + screen_area["width"] = width + if height is not None: + screen_area["height"] = height + params: dict[str, Any] = {"screenArea": screen_area} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenSettingsOverride", params) return self._conn.execute(cmd)''', ], }, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index f1bc0c9efeb0a..c03d602f25670 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,21 +1,9 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: emulation from __future__ import annotations from dataclasses import dataclass, field @@ -237,26 +225,6 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_screen_settings_override( - self, - screen_area: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenSettingsOverride.""" - if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") - - params = { - "screenArea": screen_area, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenSettingsOverride", params) - result = self._conn.execute(cmd) - return result - def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, @@ -485,3 +453,36 @@ def set_network_conditions( params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) return self._conn.execute(cmd) + + def set_screen_settings_override( + self, + width: int | None = None, + height: int | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenSettingsOverride. + + Sets or clears the screen settings override for specified browsing or user + contexts. + + Args: + width: The screen width in pixels, or ``None`` to clear the override. + height: The screen height in pixels, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + screen_area = None + if width is not None or height is not None: + screen_area = {} + if width is not None: + screen_area["width"] = width + if height is not None: + screen_area["height"] = height + params: dict[str, Any] = {"screenArea": screen_area} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenSettingsOverride", params) + return self._conn.execute(cmd) From 8f954b30eb5e0a3c143d3f8cc61ea9bc67fbd502 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 20:39:04 +0100 Subject: [PATCH 35/42] more formatting because ruff and ./go format do different things and hate people writing code --- .../webdriver/common/bidi/emulation.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index c03d602f25670..a3e6b4b6c4ddb 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: emulation +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field From c9fe8a4bc47f35e93be6e5b55d7854fa9c05182a Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Mar 2026 13:57:23 +0000 Subject: [PATCH 36/42] [py] Use generated Bidi files instead of hand curated ones --- py/BUILD.bazel | 9 +- py/private/cdp.py | 94 ++++++--- py/selenium/webdriver/common/bidi/__init__.py | 43 ----- .../webdriver/common/bidi/_event_manager.py | 180 ------------------ py/selenium/webdriver/common/bidi/cdp.py | 58 ++++-- py/selenium/webdriver/common/bidi/common.py | 43 ----- py/selenium/webdriver/common/bidi/console.py | 24 --- .../webdriver/common/bidi/permissions.py | 103 ---------- py/selenium/webdriver/common/bidi/py.typed | 0 scripts/update_copyright.py | 1 - 10 files changed, 117 insertions(+), 438 deletions(-) delete mode 100644 py/selenium/webdriver/common/bidi/__init__.py delete mode 100644 py/selenium/webdriver/common/bidi/_event_manager.py delete mode 100644 py/selenium/webdriver/common/bidi/common.py delete mode 100644 py/selenium/webdriver/common/bidi/console.py delete mode 100644 py/selenium/webdriver/common/bidi/permissions.py delete mode 100644 py/selenium/webdriver/common/bidi/py.typed diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 186324560aade..cfd3da8ad4e78 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -262,8 +262,11 @@ py_library( # BiDi protocol support py_library( name = "bidi", - srcs = glob(["selenium/webdriver/common/bidi/**/*.py"]), - data = [":mutation-listener"], + srcs = [], + data = [ + ":create-bidi-src", + ":mutation-listener", + ], imports = ["."], visibility = ["//visibility:public"], deps = [ @@ -617,7 +620,7 @@ generate_devtools_latest( browser_versions = BROWSER_VERSIONS, ) -# Pilot BiDi code generation from CDDL specification +# Generate BiDi source files from CDDL specification generate_bidi( name = "create-bidi-src", cddl_file = "//common/bidi/spec:all.cddl", diff --git a/py/private/cdp.py b/py/private/cdp.py index d94f0dac2e32b..9ca951479f657 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -1,20 +1,26 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# The MIT License(MIT) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Copyright(c) 2018 Hyperion Gray # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp import contextvars import importlib @@ -54,7 +60,11 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) + available_versions = tuple( + x + for x in versions + if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise @@ -65,7 +75,9 @@ def import_devtools(ver): return devtools -_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_connection_context: contextvars.ContextVar = contextvars.ContextVar( + "connection_context" +) _session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") @@ -120,7 +132,9 @@ def set_global_connection(connection): certain use cases such as running inside Jupyter notebook. """ global _connection_context - _connection_context = contextvars.ContextVar("_connection_context", default=connection) + _connection_context = contextvars.ContextVar( + "_connection_context", default=connection + ) def set_global_session(session): @@ -217,7 +231,9 @@ async def execute(self, cmd: Generator[dict, T, Any]) -> T: logger.debug(f"Received CDP message: {response}") if isinstance(response, Exception): if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + logger.debug( + f"Exception raised by {cmd_event} message: {type(response).__name__}" + ) raise response return response @@ -233,7 +249,9 @@ def listen(self, *event_types, buffer_size=10): return receiver @asynccontextmanager - async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + async def wait_for( + self, event_type: type[T], buffer_size=10 + ) -> AsyncGenerator[CmEventProxy, None]: """Wait for an event of the given type and return it. This is an async context manager, so you should open it inside @@ -274,7 +292,9 @@ def _handle_cmd_response(self, data: dict): try: cmd, event = self.inflight_cmd.pop(cmd_id) except KeyError: - logger.warning("Got a message with a command ID that does not exist: %s", data) + logger.warning( + "Got a message with a command ID that does not exist: %s", data + ) return if "error" in data: # If the server reported an error, convert it to an exception and do @@ -285,7 +305,9 @@ def _handle_cmd_response(self, data: dict): # into a CDP object. try: _ = cmd.send(data["result"]) - raise InternalError("The command's generator function did not exit when expected!") + raise InternalError( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return_ = exit.value self.inflight_result[cmd_id] = return_ @@ -299,7 +321,9 @@ def _handle_event(self, data: dict): """ global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) event = devtools.util.parse_json_event(data) logger.debug("Received event: %s", event) to_remove = set() @@ -307,7 +331,9 @@ def _handle_event(self, data: dict): try: sender.send_nowait(event) except trio.WouldBlock: - logger.error('Unable to send event "%r" due to full channel %s', event, sender) + logger.error( + 'Unable to send event "%r" due to full channel %s', event, sender + ) except trio.BrokenResourceError: to_remove.add(sender) if to_remove: @@ -425,8 +451,12 @@ async def connect_session(self, target_id) -> "CdpSession": """Returns a new :class:`CdpSession` connected to the specified target.""" global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) + session_id = await self.execute( + devtools.target.attach_to_target(target_id, True) + ) session = CdpSession(self.ws, session_id, target_id) self.sessions[session_id] = session return session @@ -438,7 +468,9 @@ async def _reader_task(self): """ global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) while True: try: message = await self.ws.get_message() @@ -451,7 +483,13 @@ async def _reader_task(self): try: data = json.loads(message) except json.JSONDecodeError: - raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + raise BrowserError( + { + "code": -32700, + "message": "Client received invalid JSON", + "data": message, + } + ) logger.debug("Received message %r", data) if "sessionId" in data: session_id = devtools.target.SessionID(data["sessionId"]) diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py deleted file mode 100644 index b37319da3651b..0000000000000 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from selenium.webdriver.common.bidi.browser import Browser -from selenium.webdriver.common.bidi.browsing_context import BrowsingContext -from selenium.webdriver.common.bidi.emulation import Emulation -from selenium.webdriver.common.bidi.input import Input -from selenium.webdriver.common.bidi.log import Log -from selenium.webdriver.common.bidi.network import Network -from selenium.webdriver.common.bidi.script import Script -from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.storage import Storage -from selenium.webdriver.common.bidi.webextension import WebExtension - -__all__ = [ - "Browser", - "BrowsingContext", - "Emulation", - "Input", - "Log", - "Network", - "Script", - "Session", - "Storage", - "WebExtension", -] diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py deleted file mode 100644 index 1dcc8288ce683..0000000000000 --- a/py/selenium/webdriver/common/bidi/_event_manager.py +++ /dev/null @@ -1,180 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Shared event management helpers for generated WebDriver BiDi modules. - -``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted -identically into every generated module that exposes events. Rather than -duplicating this logic across those modules, they are defined once here and -copied into generated outputs by Bazel. -""" - -from __future__ import annotations - -import threading -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -from selenium.webdriver.common.bidi.session import Session - - -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable(self._python_class.from_json): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = result.get("subscription") if isinstance(result, dict) else None - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = entry.get("subscription_id") if isinstance(entry, dict) else None - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index bac00765f43ca..9ca951479f657 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -60,7 +60,11 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) + available_versions = tuple( + x + for x in versions + if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise @@ -71,7 +75,9 @@ def import_devtools(ver): return devtools -_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_connection_context: contextvars.ContextVar = contextvars.ContextVar( + "connection_context" +) _session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") @@ -126,7 +132,9 @@ def set_global_connection(connection): certain use cases such as running inside Jupyter notebook. """ global _connection_context - _connection_context = contextvars.ContextVar("_connection_context", default=connection) + _connection_context = contextvars.ContextVar( + "_connection_context", default=connection + ) def set_global_session(session): @@ -223,7 +231,9 @@ async def execute(self, cmd: Generator[dict, T, Any]) -> T: logger.debug(f"Received CDP message: {response}") if isinstance(response, Exception): if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + logger.debug( + f"Exception raised by {cmd_event} message: {type(response).__name__}" + ) raise response return response @@ -239,7 +249,9 @@ def listen(self, *event_types, buffer_size=10): return receiver @asynccontextmanager - async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + async def wait_for( + self, event_type: type[T], buffer_size=10 + ) -> AsyncGenerator[CmEventProxy, None]: """Wait for an event of the given type and return it. This is an async context manager, so you should open it inside @@ -280,7 +292,9 @@ def _handle_cmd_response(self, data: dict): try: cmd, event = self.inflight_cmd.pop(cmd_id) except KeyError: - logger.warning("Got a message with a command ID that does not exist: %s", data) + logger.warning( + "Got a message with a command ID that does not exist: %s", data + ) return if "error" in data: # If the server reported an error, convert it to an exception and do @@ -291,7 +305,9 @@ def _handle_cmd_response(self, data: dict): # into a CDP object. try: _ = cmd.send(data["result"]) - raise InternalError("The command's generator function did not exit when expected!") + raise InternalError( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return_ = exit.value self.inflight_result[cmd_id] = return_ @@ -305,7 +321,9 @@ def _handle_event(self, data: dict): """ global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) event = devtools.util.parse_json_event(data) logger.debug("Received event: %s", event) to_remove = set() @@ -313,7 +331,9 @@ def _handle_event(self, data: dict): try: sender.send_nowait(event) except trio.WouldBlock: - logger.error('Unable to send event "%r" due to full channel %s', event, sender) + logger.error( + 'Unable to send event "%r" due to full channel %s', event, sender + ) except trio.BrokenResourceError: to_remove.add(sender) if to_remove: @@ -431,8 +451,12 @@ async def connect_session(self, target_id) -> "CdpSession": """Returns a new :class:`CdpSession` connected to the specified target.""" global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) + session_id = await self.execute( + devtools.target.attach_to_target(target_id, True) + ) session = CdpSession(self.ws, session_id, target_id) self.sessions[session_id] = session return session @@ -444,7 +468,9 @@ async def _reader_task(self): """ global devtools if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + raise RuntimeError( + "CDP devtools module not loaded. Call import_devtools() first." + ) while True: try: message = await self.ws.get_message() @@ -457,7 +483,13 @@ async def _reader_task(self): try: data = json.loads(message) except json.JSONDecodeError: - raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + raise BrowserError( + { + "code": -32700, + "message": "Client received invalid JSON", + "data": message, + } + ) logger.debug("Received message %r", data) if "sessionId" in data: session_id = devtools.target.SessionID(data["sessionId"]) diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py deleted file mode 100644 index ff67b56622c35..0000000000000 --- a/py/selenium/webdriver/common/bidi/common.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Common utilities for BiDi command construction.""" - -from __future__ import annotations - -from collections.abc import Generator -from typing import Any - - -def command_builder(method: str, params: dict[str, Any] | None = None) -> Generator[dict[str, Any], Any, Any]: - """Build a BiDi command generator. - - Args: - method: The BiDi method name (e.g., "session.status", "browser.close") - params: The parameters for the command. If omitted, an empty - dictionary is sent. - - Yields: - A dictionary representing the BiDi command - - Returns: - The result from the BiDi command execution - """ - if params is None: - params = {} - result = yield {"method": method, "params": params} - return result diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py deleted file mode 100644 index 93fe7d80d4de0..0000000000000 --- a/py/selenium/webdriver/common/bidi/console.py +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from enum import Enum - - -class Console(Enum): - ALL = "all" - LOG = "log" - ERROR = "error" diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py deleted file mode 100644 index 98e25a1d2f856..0000000000000 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""WebDriver BiDi Permissions module.""" - -from __future__ import annotations - -from enum import Enum -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - -_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} - - -class PermissionState(str, Enum): - """Permission state enumeration.""" - - GRANTED = "granted" - DENIED = "denied" - PROMPT = "prompt" - - -class PermissionDescriptor: - """Descriptor for a permission.""" - - def __init__(self, name: str) -> None: - """Initialize a PermissionDescriptor. - - Args: - name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera') - """ - self.name = name - - def __repr__(self) -> str: - return f"PermissionDescriptor('{self.name}')" - - -class Permissions: - """WebDriver BiDi Permissions module.""" - - def __init__(self, websocket_connection: Any) -> None: - """Initialize the Permissions module. - - Args: - websocket_connection: The WebSocket connection for sending BiDi commands - """ - self._conn = websocket_connection - - def set_permission( - self, - descriptor: PermissionDescriptor | str, - state: PermissionState | str, - origin: str | None = None, - user_context: str | None = None, - ) -> None: - """Set a permission for a given origin. - - Args: - descriptor: The permission descriptor or permission name as a string - state: The desired permission state - origin: The origin for which to set the permission - user_context: Optional user context ID to scope the permission - - Raises: - ValueError: If the state is not a valid permission state - """ - state_value = state.value if isinstance(state, PermissionState) else state - if state_value not in _VALID_PERMISSION_STATES: - raise ValueError( - f"Invalid permission state: {state_value!r}. Must be one of {sorted(_VALID_PERMISSION_STATES)}" - ) - - if isinstance(descriptor, str): - descriptor_dict = {"name": descriptor} - else: - descriptor_dict = {"name": descriptor.name} - - params: dict[str, Any] = { - "descriptor": descriptor_dict, - "state": state_value, - } - if origin is not None: - params["origin"] = origin - if user_context is not None: - params["userContext"] = user_context - - cmd = command_builder("permissions.setPermission", params) - self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/scripts/update_copyright.py b/scripts/update_copyright.py index 4829dae8e07c5..32d5814d11813 100755 --- a/scripts/update_copyright.py +++ b/scripts/update_copyright.py @@ -102,7 +102,6 @@ def write_update_notice(self, file, lines): ] PY_EXCLUSIONS = [ - f"{ROOT}/py/selenium/webdriver/common/bidi/cdp.py", f"{ROOT}/py/generate.py", f"{ROOT}/py/selenium/webdriver/common/devtools/**/*", f"{ROOT}/py/venv/**/*", From c7874eabd0f70f16471461dafc218d027ece364a Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 31 Mar 2026 19:00:45 +0100 Subject: [PATCH 37/42] making sure we don't commit files to that directory again --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c520e81a96551..02dee0ffb80fb 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ __pycache__ .tox *.pyc dist/ +py/selenium/webdriver/common/bidi/ py/selenium/webdriver/common/devtools/**/* !py/selenium/webdriver/common/devtools/util.py py/selenium/webdriver/common/linux/ From cf5de9e6cef93037d21fb69314d3c0a9dd5c183c Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 15 Apr 2026 13:56:16 +0100 Subject: [PATCH 38/42] delete files again --- py/selenium/webdriver/common/bidi/browser.py | 363 ----- .../webdriver/common/bidi/browsing_context.py | 891 ------------ py/selenium/webdriver/common/bidi/cdp.py | 551 -------- .../webdriver/common/bidi/emulation.py | 500 ------- py/selenium/webdriver/common/bidi/input.py | 339 ----- py/selenium/webdriver/common/bidi/log.py | 167 --- py/selenium/webdriver/common/bidi/network.py | 925 ------------- py/selenium/webdriver/common/bidi/script.py | 1230 ----------------- py/selenium/webdriver/common/bidi/session.py | 260 ---- py/selenium/webdriver/common/bidi/storage.py | 353 ----- .../webdriver/common/bidi/webextension.py | 154 --- 11 files changed, 5733 deletions(-) delete mode 100644 py/selenium/webdriver/common/bidi/browser.py delete mode 100644 py/selenium/webdriver/common/bidi/browsing_context.py delete mode 100644 py/selenium/webdriver/common/bidi/cdp.py delete mode 100644 py/selenium/webdriver/common/bidi/emulation.py delete mode 100644 py/selenium/webdriver/common/bidi/input.py delete mode 100644 py/selenium/webdriver/common/bidi/log.py delete mode 100644 py/selenium/webdriver/common/bidi/network.py delete mode 100644 py/selenium/webdriver/common/bidi/script.py delete mode 100644 py/selenium/webdriver/common/bidi/session.py delete mode 100644 py/selenium/webdriver/common/bidi/storage.py delete mode 100644 py/selenium/webdriver/common/bidi/webextension.py diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py deleted file mode 100644 index 440f13ed00072..0000000000000 --- a/py/selenium/webdriver/common/bidi/browser.py +++ /dev/null @@ -1,363 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - - -def transform_download_params( - allowed: bool | None, - destination_folder: str | None, -) -> dict[str, Any] | None: - """Transform download parameters into download_behavior object. - - Args: - allowed: Whether downloads are allowed - destination_folder: Destination folder for downloads (accepts str or - pathlib.Path; will be coerced to str) - - Returns: - Dictionary representing the download_behavior object, or None if allowed is None - """ - if allowed is True: - return { - "type": "allowed", - # Coerce pathlib.Path (or any path-like) to str so the BiDi - # protocol always receives a plain JSON string. - "destinationFolder": str(destination_folder) if destination_folder is not None else None, - } - elif allowed is False: - return {"type": "denied"} - else: # None — reset to browser default (sent as JSON null) - return None - - -def validate_download_behavior( - allowed: bool | None, - destination_folder: str | None, - user_contexts: Any | None = None, -) -> None: - """Validate download behavior parameters. - - Args: - allowed: Whether downloads are allowed - destination_folder: Destination folder for downloads - user_contexts: Optional list of user contexts - - Raises: - ValueError: If parameters are invalid - """ - if allowed is True and not destination_folder: - raise ValueError("destination_folder is required when allowed=True") - if allowed is False and destination_folder: - raise ValueError("destination_folder should not be provided when allowed=False") - - -@dataclass -class ClientWindowInfo: - """ClientWindowInfo.""" - - active: bool | None = None - client_window: Any | None = None - height: Any | None = None - state: Any | None = None - width: Any | None = None - x: Any | None = None - y: Any | None = None - - def get_client_window(self): - """Get the client window ID.""" - return self.client_window - - def get_state(self): - """Get the client window state.""" - return self.state - - def get_width(self): - """Get the client window width.""" - return self.width - - def get_height(self): - """Get the client window height.""" - return self.height - - def is_active(self): - """Check if the client window is active.""" - return self.active - - def get_x(self): - """Get the client window X position.""" - return self.x - - def get_y(self): - """Get the client window Y position.""" - return self.y - - -@dataclass -class UserContextInfo: - """UserContextInfo.""" - - user_context: Any | None = None - - -@dataclass -class CreateUserContextParameters: - """CreateUserContextParameters.""" - - accept_insecure_certs: bool | None = None - proxy: Any | None = None - unhandled_prompt_behavior: Any | None = None - - -@dataclass -class GetClientWindowsResult: - """GetClientWindowsResult.""" - - client_windows: list[Any] = field(default_factory=list) - - -@dataclass -class GetUserContextsResult: - """GetUserContextsResult.""" - - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class RemoveUserContextParameters: - """RemoveUserContextParameters.""" - - user_context: Any | None = None - - -@dataclass -class ClientWindowRectState: - """ClientWindowRectState.""" - - state: str = field(default="normal", init=False) - width: Any | None = None - height: Any | None = None - x: Any | None = None - y: Any | None = None - - -@dataclass -class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters.""" - - download_behavior: Any | None = None - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed.""" - - type: str = field(default="allowed", init=False) - destination_folder: str | None = None - - -@dataclass -class DownloadBehaviorDenied: - """DownloadBehaviorDenied.""" - - type: str = field(default="denied", init=False) - - -class ClientWindowNamedState: - """Named states for a browser client window.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - NORMAL = "normal" - - -@dataclass -class SetClientWindowStateParameters: - """SetClientWindowStateParameters. - - The ``state`` field is required and must be either a named-state string - (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a - :class:`ClientWindowRectState` instance. ``client_window`` is the ID of - the window to affect. - """ - - client_window: Any | None = None - state: Any | None = None - - -class Browser: - """WebDriver BiDi browser module.""" - - def __init__(self, conn) -> None: - self._conn = conn - - def close(self): - """Execute browser.close.""" - params = {} - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.close", params) - result = self._conn.execute(cmd) - return result - - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Any | None = None, - unhandled_prompt_behavior: Any | None = None, - ): - """Execute browser.createUserContext.""" - if proxy and hasattr(proxy, "to_bidi_dict"): - proxy = proxy.to_bidi_dict() - - if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, "to_bidi_dict"): - unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() - - params = { - "acceptInsecureCerts": accept_insecure_certs, - "proxy": proxy, - "unhandledPromptBehavior": unhandled_prompt_behavior, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.createUserContext", params) - result = self._conn.execute(cmd) - if result and "userContext" in result: - extracted = result.get("userContext") - return extracted - return result - - def get_client_windows(self): - """Execute browser.getClientWindows.""" - params = {} - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.getClientWindows", params) - result = self._conn.execute(cmd) - if result and "clientWindows" in result: - items = result.get("clientWindows", []) - return [ - ClientWindowInfo( - active=item.get("active"), - client_window=item.get("clientWindow"), - height=item.get("height"), - state=item.get("state"), - width=item.get("width"), - x=item.get("x"), - y=item.get("y"), - ) - for item in items - if isinstance(item, dict) - ] - return [] - - def get_user_contexts(self): - """Execute browser.getUserContexts.""" - params = {} - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.getUserContexts", params) - result = self._conn.execute(cmd) - if result and "userContexts" in result: - items = result.get("userContexts", []) - return [item.get("userContext") for item in items if isinstance(item, dict)] - return [] - - def remove_user_context(self, user_context: Any | None = None): - """Execute browser.removeUserContext.""" - if user_context is None: - raise TypeError("remove_user_context() missing required argument: 'user_context'") - - params = { - "userContext": user_context, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.removeUserContext", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior( - self, - allowed: bool | None = None, - destination_folder: str | None = None, - user_contexts: list[Any] | None = None, - ): - """Set the download behavior for the browser. - - Args: - allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` - to reset to browser default (sends ``null`` to the protocol). - destination_folder: Destination folder for downloads. Required when - ``allowed=True``. Accepts a string or :class:`pathlib.Path`. - user_contexts: Optional list of user context IDs. - - Raises: - ValueError: If *allowed* is ``True`` and *destination_folder* is - omitted, or ``False`` and *destination_folder* is provided. - """ - validate_download_behavior( - allowed=allowed, - destination_folder=destination_folder, - user_contexts=user_contexts, - ) - download_behavior = transform_download_params(allowed, destination_folder) - # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but - # must be present). Do NOT use a generic None-filter on it. - params: dict = {"downloadBehavior": download_behavior} - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("browser.setDownloadBehavior", params) - return self._conn.execute(cmd) - - def set_client_window_state( - self, - client_window: Any | None = None, - state: Any | None = None, - ): - """Set the client window state. - - Args: - client_window: The client window ID to apply the state to. - state: The window state to set. Can be one of: - - A string: "fullscreen", "maximized", "minimized", "normal" - - A ClientWindowRectState object with width, height, x, y - - A dict representing the state - - Raises: - ValueError: If client_window is not provided or state is invalid. - """ - if client_window is None: - raise ValueError("client_window is required") - if state is None: - raise ValueError("state is required") - - # Serialize ClientWindowRectState if needed - state_param = state - if hasattr(state, "__dataclass_fields__"): - # It's a dataclass, convert to dict - state_param = {k: v for k, v in state.__dict__.items() if v is not None} - - params = { - "clientWindow": client_window, - "state": state_param, - } - cmd = command_builder("browser.setClientWindowState", params) - return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py deleted file mode 100644 index b5e14f19c6864..0000000000000 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ /dev/null @@ -1,891 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder - - -class ReadinessState: - """ReadinessState.""" - - NONE = "none" - INTERACTIVE = "interactive" - COMPLETE = "complete" - - -class UserPromptType: - """UserPromptType.""" - - ALERT = "alert" - BEFOREUNLOAD = "beforeunload" - CONFIRM = "confirm" - PROMPT = "prompt" - - -class CreateType: - """CreateType.""" - - TAB = "tab" - WINDOW = "window" - - -class DownloadCompleteParams: - """DownloadCompleteParams.""" - - COMPLETE = "complete" - - -@dataclass -class Info: - """Info.""" - - children: Any | None = None - client_window: Any | None = None - context: Any | None = None - original_opener: Any | None = None - url: str | None = None - user_context: Any | None = None - parent: Any | None = None - - -@dataclass -class AccessibilityLocator: - """AccessibilityLocator.""" - - type: str = field(default="accessibility", init=False) - name: str | None = None - role: str | None = None - - -@dataclass -class CssLocator: - """CssLocator.""" - - type: str = field(default="css", init=False) - value: str | None = None - - -@dataclass -class ContextLocator: - """ContextLocator.""" - - type: str = field(default="context", init=False) - context: Any | None = None - - -@dataclass -class InnerTextLocator: - """InnerTextLocator.""" - - type: str = field(default="innerText", init=False) - value: str | None = None - ignore_case: bool | None = None - match_type: Any | None = None - max_depth: Any | None = None - - -@dataclass -class XPathLocator: - """XPathLocator.""" - - type: str = field(default="xpath", init=False) - value: str | None = None - - -@dataclass -class BaseNavigationInfo: - """BaseNavigationInfo.""" - - context: Any | None = None - navigation: Any | None = None - timestamp: Any | None = None - url: str | None = None - user_context: Any | None = None - - -@dataclass -class ActivateParameters: - """ActivateParameters.""" - - context: Any | None = None - - -@dataclass -class CaptureScreenshotParameters: - """CaptureScreenshotParameters.""" - - context: Any | None = None - format: Any | None = None - clip: Any | None = None - - -@dataclass -class ImageFormat: - """ImageFormat.""" - - type: str | None = None - quality: Any | None = None - - -@dataclass -class ElementClipRectangle: - """ElementClipRectangle.""" - - type: str = field(default="element", init=False) - element: Any | None = None - - -@dataclass -class BoxClipRectangle: - """BoxClipRectangle.""" - - type: str = field(default="box", init=False) - x: Any | None = None - y: Any | None = None - width: Any | None = None - height: Any | None = None - - -@dataclass -class CaptureScreenshotResult: - """CaptureScreenshotResult.""" - - data: str | None = None - - -@dataclass -class CloseParameters: - """CloseParameters.""" - - context: Any | None = None - prompt_unload: bool | None = None - - -@dataclass -class CreateParameters: - """CreateParameters.""" - - type: Any | None = None - reference_context: Any | None = None - background: bool | None = None - user_context: Any | None = None - - -@dataclass -class CreateResult: - """CreateResult.""" - - context: Any | None = None - user_context: Any | None = None - - -@dataclass -class GetTreeParameters: - """GetTreeParameters.""" - - max_depth: Any | None = None - root: Any | None = None - - -@dataclass -class GetTreeResult: - """GetTreeResult.""" - - contexts: Any | None = None - - -@dataclass -class HandleUserPromptParameters: - """HandleUserPromptParameters.""" - - context: Any | None = None - accept: bool | None = None - user_text: str | None = None - - -@dataclass -class LocateNodesParameters: - """LocateNodesParameters.""" - - context: Any | None = None - locator: Any | None = None - serialization_options: Any | None = None - start_nodes: list[Any] = field(default_factory=list) - - -@dataclass -class LocateNodesResult: - """LocateNodesResult.""" - - nodes: list[Any] = field(default_factory=list) - - -@dataclass -class NavigateParameters: - """NavigateParameters.""" - - context: Any | None = None - url: str | None = None - wait: Any | None = None - - -@dataclass -class NavigateResult: - """NavigateResult.""" - - navigation: Any | None = None - url: str | None = None - - -@dataclass -class PrintParameters: - """PrintParameters.""" - - context: Any | None = None - background: bool | None = None - margin: Any | None = None - page: Any | None = None - scale: Any | None = None - shrink_to_fit: bool | None = None - - -@dataclass -class PrintMarginParameters: - """PrintMarginParameters.""" - - bottom: Any | None = None - left: Any | None = None - right: Any | None = None - top: Any | None = None - - -@dataclass -class PrintPageParameters: - """PrintPageParameters.""" - - height: Any | None = None - width: Any | None = None - - -@dataclass -class PrintResult: - """PrintResult.""" - - data: str | None = None - - -@dataclass -class ReloadParameters: - """ReloadParameters.""" - - context: Any | None = None - ignore_cache: bool | None = None - wait: Any | None = None - - -@dataclass -class SetBypassCSPParameters: - """SetBypassCSPParameters.""" - - bypass: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetViewportParameters: - """SetViewportParameters.""" - - context: Any | None = None - viewport: Any | None = None - device_pixel_ratio: Any | None = None - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class Viewport: - """Viewport.""" - - width: Any | None = None - height: Any | None = None - - -@dataclass -class TraverseHistoryParameters: - """TraverseHistoryParameters.""" - - context: Any | None = None - delta: Any | None = None - - -@dataclass -class HistoryUpdatedParameters: - """HistoryUpdatedParameters.""" - - context: Any | None = None - timestamp: Any | None = None - url: str | None = None - user_context: Any | None = None - - -@dataclass -class UserPromptClosedParameters: - """UserPromptClosedParameters.""" - - context: Any | None = None - accepted: bool | None = None - type: Any | None = None - user_context: Any | None = None - user_text: str | None = None - - -@dataclass -class UserPromptOpenedParameters: - """UserPromptOpenedParameters.""" - - context: Any | None = None - handler: Any | None = None - message: str | None = None - type: Any | None = None - user_context: Any | None = None - default_value: str | None = None - - -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: Any | None = None - - -@dataclass -class DownloadParams: - """DownloadParams - fields shared by all download end event variants.""" - - status: str | None = None - context: Any | None = None - navigation: Any | None = None - timestamp: Any | None = None - url: str | None = None - filepath: str | None = None - - -@dataclass -class DownloadEndParams: - """DownloadEndParams - params for browsingContext.downloadEnd event.""" - - download_params: DownloadParams | None = None - - @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: - """Deserialize from BiDi wire-level params dict.""" - dp = DownloadParams( - status=params.get("status"), - context=params.get("context"), - navigation=params.get("navigation"), - timestamp=params.get("timestamp"), - url=params.get("url"), - filepath=params.get("filepath"), - ) - return cls(download_params=dp) - - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "context_created": "browsingContext.contextCreated", - "context_destroyed": "browsingContext.contextDestroyed", - "navigation_started": "browsingContext.navigationStarted", - "fragment_navigated": "browsingContext.fragmentNavigated", - "history_updated": "browsingContext.historyUpdated", - "dom_content_loaded": "browsingContext.domContentLoaded", - "load": "browsingContext.load", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", - "navigation_aborted": "browsingContext.navigationAborted", - "navigation_committed": "browsingContext.navigationCommitted", - "navigation_failed": "browsingContext.navigationFailed", - "user_prompt_closed": "browsingContext.userPromptClosed", - "user_prompt_opened": "browsingContext.userPromptOpened", -} - - -def _deserialize_info_list(items: list) -> list | None: - """Recursively deserialize a list of dicts to Info objects. - - Args: - items: List of dicts from the API response - - Returns: - List of Info objects with properly nested children, or None if empty - """ - if not items or not isinstance(items, list): - return None - - result = [] - for item in items: - if isinstance(item, dict): - # Recursively deserialize children only if the key exists in response - children_list = None - if "children" in item: - children_list = _deserialize_info_list(item.get("children", [])) - info = Info( - children=children_list, - client_window=item.get("clientWindow"), - context=item.get("context"), - original_opener=item.get("originalOpener"), - url=item.get("url"), - user_context=item.get("userContext"), - parent=item.get("parent"), - ) - result.append(info) - return result if result else None - - -class BrowsingContext: - """WebDriver BiDi browsingContext module.""" - - EVENT_CONFIGS: dict[str, EventConfig] = {} - - def __init__(self, conn) -> None: - self._conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - def activate(self, context: Any | None = None): - """Execute browsingContext.activate.""" - if context is None: - raise TypeError("activate() missing required argument: 'context'") - - params = { - "context": context, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.activate", params) - result = self._conn.execute(cmd) - return result - - def capture_screenshot( - self, - context: str | None = None, - format: Any | None = None, - clip: Any | None = None, - origin: str | None = None, - ): - """Execute browsingContext.captureScreenshot.""" - if context is None: - raise TypeError("capture_screenshot() missing required argument: 'context'") - - params = { - "context": context, - "format": format, - "clip": clip, - "origin": origin, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.captureScreenshot", params) - result = self._conn.execute(cmd) - if result and "data" in result: - extracted = result.get("data") - return extracted - return result - - def close(self, context: Any | None = None, prompt_unload: bool | None = None): - """Execute browsingContext.close.""" - if context is None: - raise TypeError("close() missing required argument: 'context'") - - params = { - "context": context, - "promptUnload": prompt_unload, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.close", params) - result = self._conn.execute(cmd) - return result - - def create( - self, - type: Any | None = None, - reference_context: Any | None = None, - background: bool | None = None, - user_context: Any | None = None, - ): - """Execute browsingContext.create.""" - if type is None: - raise TypeError("create() missing required argument: 'type'") - - params = { - "type": type, - "referenceContext": reference_context, - "background": background, - "userContext": user_context, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.create", params) - result = self._conn.execute(cmd) - if result and "context" in result: - extracted = result.get("context") - return extracted - return result - - def get_tree(self, max_depth: Any | None = None, root: Any | None = None): - """Execute browsingContext.getTree.""" - params = { - "maxDepth": max_depth, - "root": root, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.getTree", params) - result = self._conn.execute(cmd) - if result and "contexts" in result: - items = result.get("contexts", []) - return [ - Info( - children=_deserialize_info_list(item.get("children", [])), - client_window=item.get("clientWindow"), - context=item.get("context"), - original_opener=item.get("originalOpener"), - url=item.get("url"), - user_context=item.get("userContext"), - parent=item.get("parent"), - ) - for item in items - if isinstance(item, dict) - ] - return [] - - def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): - """Execute browsingContext.handleUserPrompt.""" - if context is None: - raise TypeError("handle_user_prompt() missing required argument: 'context'") - - params = { - "context": context, - "accept": accept, - "userText": user_text, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.handleUserPrompt", params) - result = self._conn.execute(cmd) - return result - - def locate_nodes( - self, - context: str | None = None, - locator: Any | None = None, - serialization_options: Any | None = None, - start_nodes: Any | None = None, - max_node_count: int | None = None, - ): - """Execute browsingContext.locateNodes.""" - if context is None: - raise TypeError("locate_nodes() missing required argument: 'context'") - if locator is None: - raise TypeError("locate_nodes() missing required argument: 'locator'") - - params = { - "context": context, - "locator": locator, - "serializationOptions": serialization_options, - "startNodes": start_nodes, - "maxNodeCount": max_node_count, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.locateNodes", params) - result = self._conn.execute(cmd) - if result and "nodes" in result: - extracted = result.get("nodes") - return extracted - return result - - def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): - """Execute browsingContext.navigate.""" - if context is None: - raise TypeError("navigate() missing required argument: 'context'") - if url is None: - raise TypeError("navigate() missing required argument: 'url'") - - params = { - "context": context, - "url": url, - "wait": wait, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.navigate", params) - result = self._conn.execute(cmd) - return result - - def print( - self, - context: Any | None = None, - background: bool | None = None, - margin: Any | None = None, - page: Any | None = None, - scale: Any | None = None, - shrink_to_fit: bool | None = None, - ): - """Execute browsingContext.print.""" - if context is None: - raise TypeError("print() missing required argument: 'context'") - - params = { - "context": context, - "background": background, - "margin": margin, - "page": page, - "scale": scale, - "shrinkToFit": shrink_to_fit, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.print", params) - result = self._conn.execute(cmd) - if result and "data" in result: - extracted = result.get("data") - return extracted - return result - - def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): - """Execute browsingContext.reload.""" - if context is None: - raise TypeError("reload() missing required argument: 'context'") - - params = { - "context": context, - "ignoreCache": ignore_cache, - "wait": wait, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.reload", params) - result = self._conn.execute(cmd) - return result - - def set_bypass_csp( - self, - bypass: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute browsingContext.setBypassCSP.""" - if bypass is None: - raise TypeError("set_bypass_csp() missing required argument: 'bypass'") - - params = { - "bypass": bypass, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.setBypassCSP", params) - result = self._conn.execute(cmd) - return result - - def traverse_history(self, context: Any | None = None, delta: Any | None = None): - """Execute browsingContext.traverseHistory.""" - if context is None: - raise TypeError("traverse_history() missing required argument: 'context'") - if delta is None: - raise TypeError("traverse_history() missing required argument: 'delta'") - - params = { - "context": context, - "delta": delta, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.traverseHistory", params) - result = self._conn.execute(cmd) - return result - - def set_viewport( - self, - context: str | None = None, - viewport: Any = ..., - user_contexts: Any | None = None, - device_pixel_ratio: Any = ..., - ): - """Execute browsingContext.setViewport. - - Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, - while omitted arguments are not sent. - """ - params = {} - if context is not None: - params["context"] = context - if user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not ...: - params["viewport"] = viewport - if device_pixel_ratio is not ...: - params["devicePixelRatio"] = device_pixel_ratio - - cmd = command_builder("browsingContext.setViewport", params) - result = self._conn.execute(cmd) - return result - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: browsingContext.contextCreated -ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined - -# Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined - -# Event: browsingContext.navigationStarted -NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.historyUpdated -HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined - -# Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.load -Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined - -# Event: browsingContext.downloadEnd -DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined - -# Event: browsingContext.navigationAborted -NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.navigationFailed -NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined - -# Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined - -# Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined - - -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -BrowsingContext.EVENT_CONFIGS = { - "context_created": EventConfig( - "context_created", - "browsingContext.contextCreated", - _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, - ), - "context_destroyed": EventConfig( - "context_destroyed", - "browsingContext.contextDestroyed", - _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, - ), - "navigation_started": EventConfig( - "navigation_started", - "browsingContext.navigationStarted", - _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, - ), - "fragment_navigated": EventConfig( - "fragment_navigated", - "browsingContext.fragmentNavigated", - _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, - ), - "history_updated": EventConfig( - "history_updated", - "browsingContext.historyUpdated", - _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, - ), - "dom_content_loaded": EventConfig( - "dom_content_loaded", - "browsingContext.domContentLoaded", - _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, - ), - "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), - "download_will_begin": EventConfig( - "download_will_begin", - "browsingContext.downloadWillBegin", - _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, - ), - "download_end": EventConfig( - "download_end", - "browsingContext.downloadEnd", - _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, - ), - "navigation_aborted": EventConfig( - "navigation_aborted", - "browsingContext.navigationAborted", - _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, - ), - "navigation_committed": EventConfig( - "navigation_committed", - "browsingContext.navigationCommitted", - _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, - ), - "navigation_failed": EventConfig( - "navigation_failed", - "browsingContext.navigationFailed", - _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, - ), - "user_prompt_closed": EventConfig( - "user_prompt_closed", - "browsingContext.userPromptClosed", - _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, - ), - "user_prompt_opened": EventConfig( - "user_prompt_opened", - "browsingContext.userPromptOpened", - _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, - ), -} diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py deleted file mode 100644 index 9ca951479f657..0000000000000 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ /dev/null @@ -1,551 +0,0 @@ -# The MIT License(MIT) -# -# Copyright(c) 2018 Hyperion Gray -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp - -import contextvars -import importlib -import itertools -import json -import logging -import pathlib -from collections import defaultdict -from collections.abc import AsyncGenerator, AsyncIterator, Generator -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from typing import Any, TypeVar - -import trio -from trio_websocket import ConnectionClosed as WsConnectionClosed -from trio_websocket import connect_websocket_url - -logger = logging.getLogger("trio_cdp") -T = TypeVar("T") -MAX_WS_MESSAGE_SIZE = 2**24 - -devtools = None -version = None - - -def import_devtools(ver): - """Attempt to load the current latest available devtools into the module cache for use later.""" - global devtools - global version - version = ver - base = "selenium.webdriver.common.devtools.v" - try: - devtools = importlib.import_module(f"{base}{ver}") - return devtools - except ModuleNotFoundError: - # Attempt to parse and load the 'most recent' devtools module. This is likely - # because cdp has been updated but selenium python has not been released yet. - devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") - versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x - for x in versions - if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) - numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) - if not numeric_versions: - raise - latest = max(numeric_versions, key=int) - selenium_logger = logging.getLogger(__name__) - selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) - devtools = importlib.import_module(f"{base}{latest}") - return devtools - - -_connection_context: contextvars.ContextVar = contextvars.ContextVar( - "connection_context" -) -_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") - - -def get_connection_context(fn_name): - """Look up the current connection. - - If there is no current connection, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _connection_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a connection context.") - - -def get_session_context(fn_name): - """Look up the current session. - - If there is no current session, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _session_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a session context.") - - -@contextmanager -def connection_context(connection): - """Context manager installs ``connection`` as the session context for the current Trio task.""" - token = _connection_context.set(connection) - try: - yield - finally: - _connection_context.reset(token) - - -@contextmanager -def session_context(session): - """Context manager installs ``session`` as the session context for the current Trio task.""" - token = _session_context.set(session) - try: - yield - finally: - _session_context.reset(token) - - -def set_global_connection(connection): - """Install ``connection`` in the root context so that it will become the default connection for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _connection_context - _connection_context = contextvars.ContextVar( - "_connection_context", default=connection - ) - - -def set_global_session(session): - """Install ``session`` in the root context so that it will become the default session for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _session_context - _session_context = contextvars.ContextVar("_session_context", default=session) - - -class BrowserError(Exception): - """This exception is raised when the browser's response to a command indicates that an error occurred.""" - - def __init__(self, obj): - self.code = obj.get("code") - self.message = obj.get("message") - self.detail = obj.get("data") - - def __str__(self): - return f"BrowserError {self.detail}" - - -class CdpConnectionClosed(WsConnectionClosed): - """Raised when a public method is called on a closed CDP connection.""" - - def __init__(self, reason): - """Constructor. - - Args: - reason: wsproto.frame_protocol.CloseReason - """ - self.reason = reason - - def __repr__(self): - """Return representation.""" - return f"{self.__class__.__name__}<{self.reason}>" - - -class InternalError(Exception): - """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" - - pass - - -@dataclass -class CmEventProxy: - """A proxy object returned by :meth:`CdpBase.wait_for()``. - - After the context manager executes, this proxy object will have a - value set that contains the returned event. - """ - - value: Any = None - - -class CdpBase: - def __init__(self, ws, session_id, target_id): - self.ws = ws - self.session_id = session_id - self.target_id = target_id - self.channels = defaultdict(set) - self.id_iter = itertools.count() - self.inflight_cmd = {} - self.inflight_result = {} - - async def execute(self, cmd: Generator[dict, T, Any]) -> T: - """Execute a command on the server and wait for the result. - - Args: - cmd: any CDP command - - Returns: - a CDP result - """ - cmd_id = next(self.id_iter) - cmd_event = trio.Event() - self.inflight_cmd[cmd_id] = cmd, cmd_event - request = next(cmd) - request["id"] = cmd_id - if self.session_id: - request["sessionId"] = self.session_id - request_str = json.dumps(request) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") - try: - await self.ws.send_message(request_str) - except WsConnectionClosed as wcc: - raise CdpConnectionClosed(wcc.reason) from None - await cmd_event.wait() - response = self.inflight_result.pop(cmd_id) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Received CDP message: {response}") - if isinstance(response, Exception): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - f"Exception raised by {cmd_event} message: {type(response).__name__}" - ) - raise response - return response - - def listen(self, *event_types, buffer_size=10): - """Listen for events. - - Returns: - An async iterator that iterates over events matching the indicated types. - """ - sender, receiver = trio.open_memory_channel(buffer_size) - for event_type in event_types: - self.channels[event_type].add(sender) - return receiver - - @asynccontextmanager - async def wait_for( - self, event_type: type[T], buffer_size=10 - ) -> AsyncGenerator[CmEventProxy, None]: - """Wait for an event of the given type and return it. - - This is an async context manager, so you should open it inside - an async with block. The block will not exit until the indicated - event is received. - """ - sender: trio.MemorySendChannel - receiver: trio.MemoryReceiveChannel - sender, receiver = trio.open_memory_channel(buffer_size) - self.channels[event_type].add(sender) - proxy = CmEventProxy() - yield proxy - async with receiver: - event = await receiver.receive() - proxy.value = event - - def _handle_data(self, data): - """Handle incoming WebSocket data. - - Args: - data: a JSON dictionary - """ - if "id" in data: - self._handle_cmd_response(data) - else: - self._handle_event(data) - - def _handle_cmd_response(self, data: dict): - """Handle a response to a command. - - This will set an event flag that will return control to the - task that called the command. - - Args: - data: response as a JSON dictionary - """ - cmd_id = data["id"] - try: - cmd, event = self.inflight_cmd.pop(cmd_id) - except KeyError: - logger.warning( - "Got a message with a command ID that does not exist: %s", data - ) - return - if "error" in data: - # If the server reported an error, convert it to an exception and do - # not process the response any further. - self.inflight_result[cmd_id] = BrowserError(data["error"]) - else: - # Otherwise, continue the generator to parse the JSON result - # into a CDP object. - try: - _ = cmd.send(data["result"]) - raise InternalError( - "The command's generator function did not exit when expected!" - ) - except StopIteration as exit: - return_ = exit.value - self.inflight_result[cmd_id] = return_ - event.set() - - def _handle_event(self, data: dict): - """Handle an event. - - Args: - data: event as a JSON dictionary - """ - global devtools - if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) - event = devtools.util.parse_json_event(data) - logger.debug("Received event: %s", event) - to_remove = set() - for sender in self.channels[type(event)]: - try: - sender.send_nowait(event) - except trio.WouldBlock: - logger.error( - 'Unable to send event "%r" due to full channel %s', event, sender - ) - except trio.BrokenResourceError: - to_remove.add(sender) - if to_remove: - self.channels[type(event)] -= to_remove - - -class CdpSession(CdpBase): - """Contains the state for a CDP session. - - Generally you should not instantiate this object yourself; you should call - :meth:`CdpConnection.open_session`. - """ - - def __init__(self, ws, session_id, target_id): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - session_id: devtools.target.SessionID - target_id: devtools.target.TargetID - """ - super().__init__(ws, session_id, target_id) - - self._dom_enable_count = 0 - self._dom_enable_lock = trio.Lock() - self._page_enable_count = 0 - self._page_enable_lock = trio.Lock() - - @asynccontextmanager - async def dom_enable(self): - """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. - - This keeps track of concurrent callers and only disables DOM - events when all callers have exited. - """ - global devtools - async with self._dom_enable_lock: - self._dom_enable_count += 1 - if self._dom_enable_count == 1: - await self.execute(devtools.dom.enable()) - - yield - - async with self._dom_enable_lock: - self._dom_enable_count -= 1 - if self._dom_enable_count == 0: - await self.execute(devtools.dom.disable()) - - @asynccontextmanager - async def page_enable(self): - """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. - - This keeps track of concurrent callers and only disables page - events when all callers have exited. - """ - global devtools - async with self._page_enable_lock: - self._page_enable_count += 1 - if self._page_enable_count == 1: - await self.execute(devtools.page.enable()) - - yield - - async with self._page_enable_lock: - self._page_enable_count -= 1 - if self._page_enable_count == 0: - await self.execute(devtools.page.disable()) - - -class CdpConnection(CdpBase, trio.abc.AsyncResource): - """Contains the connection state for a Chrome DevTools Protocol server. - - CDP can multiplex multiple "sessions" over a single connection. This - class corresponds to the "root" session, i.e. the implicitly created - session that has no session ID. This class is responsible for - reading incoming WebSocket messages and forwarding them to the - corresponding session, as well as handling messages targeted at the - root session itself. You should generally call the - :func:`open_cdp()` instead of instantiating this class directly. - """ - - def __init__(self, ws): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - """ - super().__init__(ws, session_id=None, target_id=None) - self.sessions = {} - - async def aclose(self): - """Close the underlying WebSocket connection. - - This will cause the reader task to gracefully exit when it tries - to read the next message from the WebSocket. All of the public - APIs (``execute()``, ``listen()``, etc.) will raise - ``CdpConnectionClosed`` after the CDP connection is closed. It - is safe to call this multiple times. - """ - await self.ws.aclose() - - @asynccontextmanager - async def open_session(self, target_id) -> AsyncIterator[CdpSession]: - """Context manager opens a session and enables the "simple" style of calling CDP APIs. - - For example, inside a session context, you can call ``await - dom.get_document()`` and it will execute on the current session - automatically. - """ - session = await self.connect_session(target_id) - with session_context(session): - yield session - - async def connect_session(self, target_id) -> "CdpSession": - """Returns a new :class:`CdpSession` connected to the specified target.""" - global devtools - if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) - session_id = await self.execute( - devtools.target.attach_to_target(target_id, True) - ) - session = CdpSession(self.ws, session_id, target_id) - self.sessions[session_id] = session - return session - - async def _reader_task(self): - """Runs in the background and handles incoming messages. - - Dispatches responses to commands and events to listeners. - """ - global devtools - if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) - while True: - try: - message = await self.ws.get_message() - except WsConnectionClosed: - # If the WebSocket is closed, we don't want to throw an - # exception from the reader task. Instead we will throw - # exceptions from the public API methods, and we can quietly - # exit the reader task here. - break - try: - data = json.loads(message) - except json.JSONDecodeError: - raise BrowserError( - { - "code": -32700, - "message": "Client received invalid JSON", - "data": message, - } - ) - logger.debug("Received message %r", data) - if "sessionId" in data: - session_id = devtools.target.SessionID(data["sessionId"]) - try: - session = self.sessions[session_id] - except KeyError: - raise BrowserError( - { - "code": -32700, - "message": "Browser sent a message for an invalid session", - "data": f"{session_id!r}", - } - ) - session._handle_data(data) - else: - self._handle_data(data) - - for _, session in self.sessions.items(): - for _, senders in session.channels.items(): - for sender in senders: - sender.close() - - -@asynccontextmanager -async def open_cdp(url) -> AsyncIterator[CdpConnection]: - """Async context manager opens a connection to the browser then closes the connection when the block exits. - - The context manager also sets the connection as the default - connection for the current task, so that commands like ``await - target.get_targets()`` will run on this connection automatically. If - you want to use multiple connections concurrently, it is recommended - to open each on in a separate task. - """ - async with trio.open_nursery() as nursery: - conn = await connect_cdp(nursery, url) - try: - with connection_context(conn): - yield conn - finally: - await conn.aclose() - - -async def connect_cdp(nursery, url) -> CdpConnection: - """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. - - The ``open_cdp()`` context manager is preferred in most situations. - You should only use this function if you need to specify a custom - nursery. This connection is not automatically closed! You can either - use the connection object as a context manager (``async with - conn:``) or else call ``await conn.aclose()`` on it when you are - done with it. If ``set_context`` is True, then the returned - connection will be installed as the default connection for the - current task. This argument is for unusual use cases, such as - running inside of a notebook. - """ - ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) - cdp_conn = CdpConnection(ws) - nursery.start_soon(cdp_conn._reader_task) - return cdp_conn diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py deleted file mode 100644 index a3e6b4b6c4ddb..0000000000000 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ /dev/null @@ -1,500 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - - -class ForcedColorsModeTheme: - """ForcedColorsModeTheme.""" - - LIGHT = "light" - DARK = "dark" - - -class ScreenOrientationNatural: - """ScreenOrientationNatural.""" - - PORTRAIT = "portrait" - LANDSCAPE = "landscape" - - -class ScreenOrientationType: - """ScreenOrientationType.""" - - PORTRAIT_PRIMARY = "portrait-primary" - PORTRAIT_SECONDARY = "portrait-secondary" - LANDSCAPE_PRIMARY = "landscape-primary" - LANDSCAPE_SECONDARY = "landscape-secondary" - - -@dataclass -class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters.""" - - theme: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters.""" - - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class GeolocationCoordinates: - """GeolocationCoordinates.""" - - latitude: Any | None = None - longitude: Any | None = None - accuracy: Any | None = None - altitude: Any | None = None - altitude_accuracy: Any | None = None - heading: Any | None = None - speed: Any | None = None - - -@dataclass -class GeolocationPositionError: - """GeolocationPositionError.""" - - type: str = field(default="positionUnavailable", init=False) - - -@dataclass -class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters.""" - - locale: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class NetworkConditionsOffline: - """NetworkConditionsOffline.""" - - type: str = field(default="offline", init=False) - - -@dataclass -class ScreenArea: - """ScreenArea.""" - - width: Any | None = None - height: Any | None = None - - -@dataclass -class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters.""" - - screen_area: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class ScreenOrientation: - """ScreenOrientation.""" - - natural: Any | None = None - type: Any | None = None - - -@dataclass -class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters.""" - - screen_orientation: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters.""" - - user_agent: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters.""" - - enabled: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters.""" - - scrollbar_type: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters.""" - - timezone: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetTouchOverrideParameters: - """SetTouchOverrideParameters.""" - - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters.""" - - network_conditions: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -# Backward-compatible alias for existing imports -setNetworkConditionsParameters = SetNetworkConditionsParameters - - -class Emulation: - """WebDriver BiDi emulation module.""" - - def __init__(self, conn) -> None: - self._conn = conn - - def set_forced_colors_mode_theme_override( - self, - theme: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setForcedColorsModeThemeOverride.""" - if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") - - params = { - "theme": theme, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setForcedColorsModeThemeOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override( - self, - locale: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setLocaleOverride.""" - if locale is None: - raise TypeError("set_locale_override() missing required argument: 'locale'") - - params = { - "locale": locale, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setLocaleOverride", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override( - self, - scrollbar_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScrollbarTypeOverride.""" - if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") - - params = { - "scrollbarType": scrollbar_type, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScrollbarTypeOverride", params) - result = self._conn.execute(cmd) - return result - - def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): - """Execute emulation.setTouchOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTouchOverride", params) - result = self._conn.execute(cmd) - return result - - def set_geolocation_override( - self, - coordinates=None, - error=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setGeolocationOverride. - - Sets or clears the geolocation override for specified browsing or user contexts. - - Args: - coordinates: A GeolocationCoordinates instance (or dict) to override the - position, or ``None`` to clear a previously-set override. - error: A GeolocationPositionError instance (or dict) to simulate a - position-unavailable error. Mutually exclusive with *coordinates*. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - params: dict[str, Any] = {} - if coordinates is not None: - if isinstance(coordinates, dict): - coords_dict = coordinates - else: - coords_dict = {} - if coordinates.latitude is not None: - coords_dict["latitude"] = coordinates.latitude - if coordinates.longitude is not None: - coords_dict["longitude"] = coordinates.longitude - if coordinates.accuracy is not None: - coords_dict["accuracy"] = coordinates.accuracy - if coordinates.altitude is not None: - coords_dict["altitude"] = coordinates.altitude - if coordinates.altitude_accuracy is not None: - coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy - if coordinates.heading is not None: - coords_dict["heading"] = coordinates.heading - if coordinates.speed is not None: - coords_dict["speed"] = coordinates.speed - params["coordinates"] = coords_dict - if error is not None: - if isinstance(error, dict): - params["error"] = error - else: - params["error"] = {"type": error.type if error.type is not None else "positionUnavailable"} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_timezone_override( - self, - timezone=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setTimezoneOverride. - - Sets or clears the timezone override for specified browsing or user contexts. - Pass ``timezone=None`` (or omit it) to clear a previously-set override. - - Args: - timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` - to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - params: dict[str, Any] = {"timezone": timezone} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setTimezoneOverride", params) - return self._conn.execute(cmd) - - def set_scripting_enabled( - self, - enabled=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScriptingEnabled. - - Enables or disables scripting for specified browsing or user contexts. - Pass ``enabled=None`` to restore the default behaviour. - - Args: - enabled: ``True`` to enable scripting, ``False`` to disable it, or - ``None`` to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - params: dict[str, Any] = {"enabled": enabled} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setScriptingEnabled", params) - return self._conn.execute(cmd) - - def set_user_agent_override( - self, - user_agent=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setUserAgentOverride. - - Overrides the User-Agent string for specified browsing or user contexts. - Pass ``user_agent=None`` to clear a previously-set override. - - Args: - user_agent: Custom User-Agent string, or ``None`` to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - params: dict[str, Any] = {"userAgent": user_agent} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setUserAgentOverride", params) - return self._conn.execute(cmd) - - def set_screen_orientation_override( - self, - screen_orientation=None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenOrientationOverride. - - Sets or clears the screen orientation override for specified browsing or - user contexts. - - Args: - screen_orientation: A :class:`ScreenOrientation` instance (or dict with - ``natural`` and ``type`` keys) to lock the orientation, or ``None`` - to clear a previously-set override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - if screen_orientation is None: - so_value = None - elif isinstance(screen_orientation, dict): - so_value = screen_orientation - else: - natural = screen_orientation.natural - orientation_type = screen_orientation.type - so_value = { - "natural": natural.lower() if isinstance(natural, str) else natural, - "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, - } - params: dict[str, Any] = {"screenOrientation": so_value} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setScreenOrientationOverride", params) - return self._conn.execute(cmd) - - def set_network_conditions( - self, - network_conditions=None, - offline: bool | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setNetworkConditions. - - Sets or clears network condition emulation for specified browsing or user - contexts. - - Args: - network_conditions: A dict with the raw ``networkConditions`` value - (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. - Mutually exclusive with *offline*. - offline: Convenience bool — ``True`` sets offline conditions, - ``False`` clears them (sends ``null``). When provided, this takes - precedence over *network_conditions*. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - if offline is not None: - nc_value = {"type": "offline"} if offline else None - else: - nc_value = network_conditions - params: dict[str, Any] = {"networkConditions": nc_value} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setNetworkConditions", params) - return self._conn.execute(cmd) - - def set_screen_settings_override( - self, - width: int | None = None, - height: int | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenSettingsOverride. - - Sets or clears the screen settings override for specified browsing or user - contexts. - - Args: - width: The screen width in pixels, or ``None`` to clear the override. - height: The screen height in pixels, or ``None`` to clear the override. - contexts: List of browsing context IDs to target. - user_contexts: List of user context IDs to target. - """ - screen_area = None - if width is not None or height is not None: - screen_area = {} - if width is not None: - screen_area["width"] = width - if height is not None: - screen_area["height"] = height - params: dict[str, Any] = {"screenArea": screen_area} - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - cmd = command_builder("emulation.setScreenSettingsOverride", params) - return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py deleted file mode 100644 index 6c06fc4e7deaa..0000000000000 --- a/py/selenium/webdriver/common/bidi/input.py +++ /dev/null @@ -1,339 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder - - -class PointerType: - """PointerType.""" - - MOUSE = "mouse" - PEN = "pen" - TOUCH = "touch" - - -class Origin: - """Origin.""" - - VIEWPORT = "viewport" - POINTER = "pointer" - - -@dataclass -class ElementOrigin: - """ElementOrigin.""" - - type: str = field(default="element", init=False) - element: Any | None = None - - -@dataclass -class PerformActionsParameters: - """PerformActionsParameters.""" - - context: Any | None = None - actions: list[Any] = field(default_factory=list) - - -@dataclass -class NoneSourceActions: - """NoneSourceActions.""" - - type: str = field(default="none", init=False) - id: str | None = None - actions: list[Any] = field(default_factory=list) - - -@dataclass -class KeySourceActions: - """KeySourceActions.""" - - type: str = field(default="key", init=False) - id: str | None = None - actions: list[Any] = field(default_factory=list) - - -@dataclass -class PointerSourceActions: - """PointerSourceActions.""" - - type: str = field(default="pointer", init=False) - id: str | None = None - parameters: Any | None = None - actions: list[Any] = field(default_factory=list) - - -@dataclass -class PointerParameters: - """PointerParameters.""" - - pointer_type: Any | None = None - - -@dataclass -class WheelSourceActions: - """WheelSourceActions.""" - - type: str = field(default="wheel", init=False) - id: str | None = None - actions: list[Any] = field(default_factory=list) - - -@dataclass -class PauseAction: - """PauseAction.""" - - type: str = field(default="pause", init=False) - duration: Any | None = None - - -@dataclass -class KeyDownAction: - """KeyDownAction.""" - - type: str = field(default="keyDown", init=False) - value: str | None = None - - -@dataclass -class KeyUpAction: - """KeyUpAction.""" - - type: str = field(default="keyUp", init=False) - value: str | None = None - - -@dataclass -class PointerUpAction: - """PointerUpAction.""" - - type: str = field(default="pointerUp", init=False) - button: Any | None = None - - -@dataclass -class WheelScrollAction: - """WheelScrollAction.""" - - type: str = field(default="scroll", init=False) - x: Any | None = None - y: Any | None = None - delta_x: Any | None = None - delta_y: Any | None = None - duration: Any | None = None - origin: Any | None = None - - -@dataclass -class PointerCommonProperties: - """PointerCommonProperties.""" - - width: Any | None = None - height: Any | None = None - pressure: Any | None = None - tangential_pressure: Any | None = None - twist: Any | None = None - altitude_angle: Any | None = None - azimuth_angle: Any | None = None - - -@dataclass -class ReleaseActionsParameters: - """ReleaseActionsParameters.""" - - context: Any | None = None - - -@dataclass -class SetFilesParameters: - """SetFilesParameters.""" - - context: Any | None = None - element: Any | None = None - files: list[Any] = field(default_factory=list) - - -@dataclass -class FileDialogInfo: - """FileDialogInfo - parameters for the input.fileDialogOpened event.""" - - context: Any | None = None - element: Any | None = None - multiple: bool | None = None - - @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: - """Deserialize event params into FileDialogInfo.""" - return cls( - context=params.get("context"), - element=params.get("element"), - multiple=params.get("multiple"), - ) - - -@dataclass -class PointerMoveAction: - """PointerMoveAction.""" - - type: str = field(default="pointerMove", init=False) - x: Any | None = None - y: Any | None = None - duration: Any | None = None - origin: Any | None = None - properties: Any | None = None - - -@dataclass -class PointerDownAction: - """PointerDownAction.""" - - type: str = field(default="pointerDown", init=False) - button: Any | None = None - properties: Any | None = None - - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "file_dialog_opened": "input.fileDialogOpened", -} - - -class Input: - """WebDriver BiDi input module.""" - - EVENT_CONFIGS: dict[str, EventConfig] = {} - - def __init__(self, conn) -> None: - self._conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): - """Execute input.performActions.""" - if context is None: - raise TypeError("perform_actions() missing required argument: 'context'") - if actions is None: - raise TypeError("perform_actions() missing required argument: 'actions'") - - params = { - "context": context, - "actions": actions, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("input.performActions", params) - result = self._conn.execute(cmd) - return result - - def release_actions(self, context: Any | None = None): - """Execute input.releaseActions.""" - if context is None: - raise TypeError("release_actions() missing required argument: 'context'") - - params = { - "context": context, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("input.releaseActions", params) - result = self._conn.execute(cmd) - return result - - def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): - """Execute input.setFiles.""" - if context is None: - raise TypeError("set_files() missing required argument: 'context'") - if element is None: - raise TypeError("set_files() missing required argument: 'element'") - if files is None: - raise TypeError("set_files() missing required argument: 'files'") - - params = { - "context": context, - "element": element, - "files": files, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("input.setFiles", params) - result = self._conn.execute(cmd) - return result - - def add_file_dialog_handler(self, callback) -> int: - """Subscribe to the input.fileDialogOpened event. - - Args: - callback: Callable invoked with a FileDialogInfo when a file dialog opens. - - Returns: - A handler ID that can be passed to remove_file_dialog_handler. - """ - return self._event_manager.add_event_handler("file_dialog_opened", callback) - - def remove_file_dialog_handler(self, handler_id: int) -> None: - """Unsubscribe a previously registered file dialog event handler. - - Args: - handler_id: The handler ID returned by add_file_dialog_handler. - """ - return self._event_manager.remove_event_handler("file_dialog_opened", handler_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: input.fileDialogOpened -FileDialogOpened = globals().get("FileDialogInfo", dict) # Fallback to dict if type not defined - - -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -Input.EVENT_CONFIGS = { - "file_dialog_opened": EventConfig( - "file_dialog_opened", - "input.fileDialogOpened", - _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, - ), -} diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py deleted file mode 100644 index 597936402f99c..0000000000000 --- a/py/selenium/webdriver/common/bidi/log.py +++ /dev/null @@ -1,167 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager - - -class Level: - """Level.""" - - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" - - -LogLevel = Level - - -@dataclass -class BaseLogEntry: - """BaseLogEntry.""" - - level: Any | None = None - source: Any | None = None - text: Any | None = None - timestamp: Any | None = None - stack_trace: Any | None = None - - -@dataclass -class GenericLogEntry: - """GenericLogEntry.""" - - type: str | None = None - - -@dataclass -class ConsoleLogEntry: - """ConsoleLogEntry - a console log entry from the browser.""" - - type_: str | None = None - method: str | None = None - args: list | None = None - level: Any | None = None - text: Any | None = None - source: Any | None = None - timestamp: Any | None = None - stack_trace: Any | None = None - - @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: - """Deserialize from BiDi params dict.""" - return cls( - type_=params.get("type"), - method=params.get("method"), - args=params.get("args"), - level=params.get("level"), - text=params.get("text"), - source=params.get("source"), - timestamp=params.get("timestamp"), - stack_trace=params.get("stackTrace"), - ) - - -@dataclass -class JavascriptLogEntry: - """JavascriptLogEntry - a JavaScript error log entry from the browser.""" - - type_: str | None = None - level: Any | None = None - text: Any | None = None - source: Any | None = None - timestamp: Any | None = None - stacktrace: Any | None = None - - @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: - """Deserialize from BiDi params dict.""" - return cls( - type_=params.get("type"), - level=params.get("level"), - text=params.get("text"), - source=params.get("source"), - timestamp=params.get("timestamp"), - stacktrace=params.get("stackTrace"), - ) - - -Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "entry_added": "log.entryAdded", -} - - -class Log: - """WebDriver BiDi log module.""" - - EVENT_CONFIGS: dict[str, EventConfig] = {} - - def __init__(self, conn) -> None: - self._conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: log.entryAdded -EntryAdded = Entry - - -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -Log.EVENT_CONFIGS = { - "entry_added": EventConfig( - "entry_added", - "log.entryAdded", - _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, - ), -} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py deleted file mode 100644 index 6c24e399b0e54..0000000000000 --- a/py/selenium/webdriver/common/bidi/network.py +++ /dev/null @@ -1,925 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder - - -class SameSite: - """SameSite.""" - - STRICT = "strict" - LAX = "lax" - NONE = "none" - DEFAULT = "default" - - -class DataType: - """DataType.""" - - REQUEST = "request" - RESPONSE = "response" - - -class InterceptPhase: - """InterceptPhase.""" - - BEFOREREQUESTSENT = "beforeRequestSent" - RESPONSESTARTED = "responseStarted" - AUTHREQUIRED = "authRequired" - - -class ContinueWithAuthNoCredentials: - """ContinueWithAuthNoCredentials.""" - - DEFAULT = "default" - CANCEL = "cancel" - - -@dataclass -class AuthChallenge: - """AuthChallenge.""" - - scheme: str | None = None - realm: str | None = None - - -@dataclass -class AuthCredentials: - """AuthCredentials.""" - - type: str = field(default="password", init=False) - username: str | None = None - password: str | None = None - - -@dataclass -class BaseParameters: - """BaseParameters.""" - - context: Any | None = None - is_blocked: bool | None = None - navigation: Any | None = None - redirect_count: Any | None = None - request: Any | None = None - timestamp: Any | None = None - user_context: Any | None = None - intercepts: list[Any] = field(default_factory=list) - - -@dataclass -class StringValue: - """StringValue.""" - - type: str = field(default="string", init=False) - value: str | None = None - - -@dataclass -class Base64Value: - """Base64Value.""" - - type: str = field(default="base64", init=False) - value: str | None = None - - -@dataclass -class Cookie: - """Cookie.""" - - name: str | None = None - value: Any | None = None - domain: str | None = None - path: str | None = None - size: Any | None = None - http_only: bool | None = None - secure: bool | None = None - same_site: Any | None = None - expiry: Any | None = None - - -@dataclass -class CookieHeader: - """CookieHeader.""" - - name: str | None = None - value: Any | None = None - - -@dataclass -class FetchTimingInfo: - """FetchTimingInfo.""" - - time_origin: Any | None = None - request_time: Any | None = None - redirect_start: Any | None = None - redirect_end: Any | None = None - fetch_start: Any | None = None - dns_start: Any | None = None - dns_end: Any | None = None - connect_start: Any | None = None - connect_end: Any | None = None - tls_start: Any | None = None - request_start: Any | None = None - response_start: Any | None = None - response_end: Any | None = None - - -@dataclass -class Header: - """Header.""" - - name: str | None = None - value: Any | None = None - - -@dataclass -class Initiator: - """Initiator.""" - - column_number: Any | None = None - line_number: Any | None = None - request: Any | None = None - stack_trace: Any | None = None - type: Any | None = None - - -@dataclass -class ResponseContent: - """ResponseContent.""" - - size: Any | None = None - - -@dataclass -class ResponseData: - """ResponseData.""" - - url: str | None = None - protocol: str | None = None - status: Any | None = None - status_text: str | None = None - from_cache: bool | None = None - headers: list[Any] = field(default_factory=list) - mime_type: str | None = None - bytes_received: Any | None = None - headers_size: Any | None = None - body_size: Any | None = None - content: Any | None = None - auth_challenges: list[Any] = field(default_factory=list) - - -@dataclass -class SetCookieHeader: - """SetCookieHeader.""" - - name: str | None = None - value: Any | None = None - domain: str | None = None - http_only: bool | None = None - expiry: str | None = None - max_age: Any | None = None - path: str | None = None - same_site: Any | None = None - secure: bool | None = None - - -@dataclass -class UrlPatternPattern: - """UrlPatternPattern.""" - - type: str = field(default="pattern", init=False) - protocol: str | None = None - hostname: str | None = None - port: str | None = None - pathname: str | None = None - search: str | None = None - - -@dataclass -class UrlPatternString: - """UrlPatternString.""" - - type: str = field(default="string", init=False) - pattern: str | None = None - - -@dataclass -class AddDataCollectorParameters: - """AddDataCollectorParameters.""" - - data_types: list[Any] = field(default_factory=list) - max_encoded_data_size: Any | None = None - collector_type: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class AddDataCollectorResult: - """AddDataCollectorResult.""" - - collector: Any | None = None - - -@dataclass -class AddInterceptParameters: - """AddInterceptParameters.""" - - phases: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - url_patterns: list[Any] = field(default_factory=list) - - -@dataclass -class AddInterceptResult: - """AddInterceptResult.""" - - intercept: Any | None = None - - -@dataclass -class ContinueResponseParameters: - """ContinueResponseParameters.""" - - request: Any | None = None - cookies: list[Any] = field(default_factory=list) - credentials: Any | None = None - headers: list[Any] = field(default_factory=list) - reason_phrase: str | None = None - status_code: Any | None = None - - -@dataclass -class ContinueWithAuthParameters: - """ContinueWithAuthParameters.""" - - request: Any | None = None - - -@dataclass -class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials.""" - - action: str = field(default="provideCredentials", init=False) - credentials: Any | None = None - - -@dataclass -class FailRequestParameters: - """FailRequestParameters.""" - - request: Any | None = None - - -@dataclass -class GetDataParameters: - """GetDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - disown: bool | None = None - request: Any | None = None - - -@dataclass -class GetDataResult: - """GetDataResult.""" - - bytes: Any | None = None - - -@dataclass -class ProvideResponseParameters: - """ProvideResponseParameters.""" - - request: Any | None = None - body: Any | None = None - cookies: list[Any] = field(default_factory=list) - headers: list[Any] = field(default_factory=list) - reason_phrase: str | None = None - status_code: Any | None = None - - -@dataclass -class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters.""" - - collector: Any | None = None - - -@dataclass -class RemoveInterceptParameters: - """RemoveInterceptParameters.""" - - intercept: Any | None = None - - -@dataclass -class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters.""" - - cache_behavior: Any | None = None - contexts: list[Any] = field(default_factory=list) - - -@dataclass -class SetExtraHeadersParameters: - """SetExtraHeadersParameters.""" - - headers: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class ResponseStartedParameters: - """ResponseStartedParameters.""" - - response: Any | None = None - - -@dataclass -class DisownDataParameters: - """DisownDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - request: Any | None = None - - -# Backward-compatible alias for existing imports -disownDataParameters = DisownDataParameters - - -class BytesValue: - """A string or base64-encoded bytes value used in cookie operations. - - This corresponds to network.BytesValue in the WebDriver BiDi specification, - wrapping either a plain string or a base64-encoded binary value. - """ - - TYPE_STRING = "string" - TYPE_BASE64 = "base64" - - def __init__(self, type: Any | None, value: Any | None) -> None: - self.type = type - self.value = value - - def to_bidi_dict(self) -> dict: - return {"type": self.type, "value": self.value} - - -class Request: - """Wraps a BiDi network request event params and provides request action methods.""" - - def __init__(self, conn, params): - self._conn = conn - self._params = params if isinstance(params, dict) else {} - req = self._params.get("request", {}) or {} - self.url = req.get("url", "") - self._request_id = req.get("request") - - def continue_request(self, **kwargs): - """Continue the intercepted request.""" - from selenium.webdriver.common.bidi.common import command_builder as _cb - - params = {"request": self._request_id} - params.update(kwargs) - self._conn.execute(_cb("network.continueRequest", params)) - - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "auth_required": "network.authRequired", - "before_request": "network.beforeRequestSent", -} - - -class Network: - """WebDriver BiDi network module.""" - - EVENT_CONFIGS: dict[str, EventConfig] = {} - - def __init__(self, conn) -> None: - self._conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - self.intercepts: list[Any] = [] - self._handler_intercepts: dict[str, Any] = {} - - def add_data_collector( - self, - data_types: list[Any] | None = None, - max_encoded_data_size: Any | None = None, - collector_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute network.addDataCollector.""" - if data_types is None: - raise TypeError("add_data_collector() missing required argument: 'data_types'") - if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") - - params = { - "dataTypes": data_types, - "maxEncodedDataSize": max_encoded_data_size, - "collectorType": collector_type, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.addDataCollector", params) - result = self._conn.execute(cmd) - return result - - def add_intercept( - self, - phases: list[Any] | None = None, - contexts: list[Any] | None = None, - url_patterns: list[Any] | None = None, - ): - """Execute network.addIntercept.""" - if phases is None: - raise TypeError("add_intercept() missing required argument: 'phases'") - - params = { - "phases": phases, - "contexts": contexts, - "urlPatterns": url_patterns, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.addIntercept", params) - result = self._conn.execute(cmd) - return result - - def continue_request( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - method: Any | None = None, - url: Any | None = None, - ): - """Execute network.continueRequest.""" - if request is None: - raise TypeError("continue_request() missing required argument: 'request'") - - params = { - "request": request, - "body": body, - "cookies": cookies, - "headers": headers, - "method": method, - "url": url, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.continueRequest", params) - result = self._conn.execute(cmd) - return result - - def continue_response( - self, - request: Any | None = None, - cookies: list[Any] | None = None, - credentials: Any | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): - """Execute network.continueResponse.""" - if request is None: - raise TypeError("continue_response() missing required argument: 'request'") - - params = { - "request": request, - "cookies": cookies, - "credentials": credentials, - "headers": headers, - "reasonPhrase": reason_phrase, - "statusCode": status_code, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.continueResponse", params) - result = self._conn.execute(cmd) - return result - - def continue_with_auth(self, request: Any | None = None): - """Execute network.continueWithAuth.""" - if request is None: - raise TypeError("continue_with_auth() missing required argument: 'request'") - - params = { - "request": request, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.continueWithAuth", params) - result = self._conn.execute(cmd) - return result - - def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): - """Execute network.disownData.""" - if data_type is None: - raise TypeError("disown_data() missing required argument: 'data_type'") - if collector is None: - raise TypeError("disown_data() missing required argument: 'collector'") - if request is None: - raise TypeError("disown_data() missing required argument: 'request'") - - params = { - "dataType": data_type, - "collector": collector, - "request": request, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.disownData", params) - result = self._conn.execute(cmd) - return result - - def fail_request(self, request: Any | None = None): - """Execute network.failRequest.""" - if request is None: - raise TypeError("fail_request() missing required argument: 'request'") - - params = { - "request": request, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.failRequest", params) - result = self._conn.execute(cmd) - return result - - def get_data( - self, - data_type: Any | None = None, - collector: Any | None = None, - disown: bool | None = None, - request: Any | None = None, - ): - """Execute network.getData.""" - if data_type is None: - raise TypeError("get_data() missing required argument: 'data_type'") - if request is None: - raise TypeError("get_data() missing required argument: 'request'") - - params = { - "dataType": data_type, - "collector": collector, - "disown": disown, - "request": request, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.getData", params) - result = self._conn.execute(cmd) - return result - - def provide_response( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): - """Execute network.provideResponse.""" - if request is None: - raise TypeError("provide_response() missing required argument: 'request'") - - params = { - "request": request, - "body": body, - "cookies": cookies, - "headers": headers, - "reasonPhrase": reason_phrase, - "statusCode": status_code, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.provideResponse", params) - result = self._conn.execute(cmd) - return result - - def remove_data_collector(self, collector: Any | None = None): - """Execute network.removeDataCollector.""" - if collector is None: - raise TypeError("remove_data_collector() missing required argument: 'collector'") - - params = { - "collector": collector, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.removeDataCollector", params) - result = self._conn.execute(cmd) - return result - - def remove_intercept(self, intercept: Any | None = None): - """Execute network.removeIntercept.""" - if intercept is None: - raise TypeError("remove_intercept() missing required argument: 'intercept'") - - params = { - "intercept": intercept, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.removeIntercept", params) - result = self._conn.execute(cmd) - return result - - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): - """Execute network.setCacheBehavior.""" - if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") - - params = { - "cacheBehavior": cache_behavior, - "contexts": contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.setCacheBehavior", params) - result = self._conn.execute(cmd) - return result - - def set_extra_headers( - self, - headers: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute network.setExtraHeaders.""" - if headers is None: - raise TypeError("set_extra_headers() missing required argument: 'headers'") - - params = { - "headers": headers, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.setExtraHeaders", params) - result = self._conn.execute(cmd) - return result - - def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.beforeRequestSent.""" - if method is None: - raise TypeError("before_request_sent() missing required argument: 'method'") - if params is None: - raise TypeError("before_request_sent() missing required argument: 'params'") - - params = { - "initiator": initiator, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.beforeRequestSent", params) - result = self._conn.execute(cmd) - return result - - def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.fetchError.""" - if error_text is None: - raise TypeError("fetch_error() missing required argument: 'error_text'") - if method is None: - raise TypeError("fetch_error() missing required argument: 'method'") - if params is None: - raise TypeError("fetch_error() missing required argument: 'params'") - - params = { - "errorText": error_text, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.fetchError", params) - result = self._conn.execute(cmd) - return result - - def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.responseCompleted.""" - if response is None: - raise TypeError("response_completed() missing required argument: 'response'") - if method is None: - raise TypeError("response_completed() missing required argument: 'method'") - if params is None: - raise TypeError("response_completed() missing required argument: 'params'") - - params = { - "response": response, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseCompleted", params) - result = self._conn.execute(cmd) - return result - - def response_started(self, response: Any | None = None): - """Execute network.responseStarted.""" - if response is None: - raise TypeError("response_started() missing required argument: 'response'") - - params = { - "response": response, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseStarted", params) - result = self._conn.execute(cmd) - return result - - def _add_intercept(self, phases=None, url_patterns=None): - """Add a low-level network intercept. - - Args: - phases: list of intercept phases (default: ["beforeRequestSent"]) - url_patterns: optional URL patterns to filter - - Returns: - dict with "intercept" key containing the intercept ID - """ - from selenium.webdriver.common.bidi.common import command_builder as _cb - - if phases is None: - phases = ["beforeRequestSent"] - params = {"phases": phases} - if url_patterns: - params["urlPatterns"] = url_patterns - result = self._conn.execute(_cb("network.addIntercept", params)) - if result: - intercept_id = result.get("intercept") - if intercept_id and intercept_id not in self.intercepts: - self.intercepts.append(intercept_id) - return result - - def _remove_intercept(self, intercept_id): - """Remove a low-level network intercept.""" - from selenium.webdriver.common.bidi.common import command_builder as _cb - - self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) - if intercept_id in self.intercepts: - self.intercepts.remove(intercept_id) - - def add_request_handler(self, event, callback, url_patterns=None): - """Add a handler for network requests at the specified phase. - - Args: - event: Event name, e.g. ``"before_request"``. - callback: Callable receiving a :class:`Request` instance. - url_patterns: optional list of URL pattern dicts to filter. - - Returns: - callback_id int for later removal via remove_request_handler. - """ - phase_map = { - "before_request": "beforeRequestSent", - "before_request_sent": "beforeRequestSent", - "response_started": "responseStarted", - "auth_required": "authRequired", - } - phase = phase_map.get(event, "beforeRequestSent") - intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) - intercept_id = intercept_result.get("intercept") if intercept_result else None - - def _request_callback(params): - raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) - request = Request(self._conn, raw) - callback(request) - - callback_id = self.add_event_handler(event, _request_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id - - def remove_request_handler(self, event, callback_id): - """Remove a network request handler and its associated network intercept. - - Args: - event: The event name used when adding the handler. - callback_id: The int returned by add_request_handler. - """ - self.remove_event_handler(event, callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id) - - def clear_request_handlers(self): - """Clear all request handlers and remove all tracked intercepts.""" - self.clear_event_handlers() - for intercept_id in list(self.intercepts): - self._remove_intercept(intercept_id) - - def add_auth_handler(self, username, password): - """Add an auth handler that automatically provides credentials. - - Args: - username: The username for basic authentication. - password: The password for basic authentication. - - Returns: - callback_id int for later removal via remove_auth_handler. - """ - from selenium.webdriver.common.bidi.common import command_builder as _cb - - # Set up network intercept for authRequired phase - intercept_result = self._add_intercept(phases=["authRequired"]) - intercept_id = intercept_result.get("intercept") if intercept_result else None - - def _auth_callback(params): - raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) - request_id = raw.get("request", {}).get("request") if isinstance(raw, dict) else None - if request_id: - self._conn.execute( - _cb( - "network.continueWithAuth", - { - "request": request_id, - "action": "provideCredentials", - "credentials": { - "type": "password", - "username": username, - "password": password, - }, - }, - ) - ) - - callback_id = self.add_event_handler("auth_required", _auth_callback) - if intercept_id: - self._handler_intercepts[callback_id] = intercept_id - return callback_id - - def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID and its associated network intercept. - - Args: - callback_id: The handler ID returned by add_auth_handler. - """ - self.remove_event_handler("auth_required", callback_id) - intercept_id = self._handler_intercepts.pop(callback_id, None) - if intercept_id: - self._remove_intercept(intercept_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: network.authRequired -AuthRequired = globals().get("AuthRequiredParameters", dict) # Fallback to dict if type not defined - - -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -Network.EVENT_CONFIGS = { - "auth_required": EventConfig( - "auth_required", - "network.authRequired", - _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, - ), - "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), -} diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py deleted file mode 100644 index ee6eb4f4a437a..0000000000000 --- a/py/selenium/webdriver/common/bidi/script.py +++ /dev/null @@ -1,1230 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager -from selenium.webdriver.common.bidi.common import command_builder - - -class SpecialNumber: - """SpecialNumber.""" - - NAN = "NaN" - _0 = "-0" - INFINITY = "Infinity" - _INFINITY = "-Infinity" - - -class RealmType: - """RealmType.""" - - WINDOW = "window" - DEDICATED_WORKER = "dedicated-worker" - SHARED_WORKER = "shared-worker" - SERVICE_WORKER = "service-worker" - WORKER = "worker" - PAINT_WORKLET = "paint-worklet" - AUDIO_WORKLET = "audio-worklet" - WORKLET = "worklet" - - -class ResultOwnership: - """ResultOwnership.""" - - ROOT = "root" - NONE = "none" - - -@dataclass -class ChannelValue: - """ChannelValue.""" - - type: str = field(default="channel", init=False) - value: Any | None = None - - -@dataclass -class ChannelProperties: - """ChannelProperties.""" - - channel: Any | None = None - serialization_options: Any | None = None - ownership: Any | None = None - - -@dataclass -class EvaluateResultSuccess: - """EvaluateResultSuccess.""" - - type: str = field(default="success", init=False) - result: Any | None = None - realm: Any | None = None - - -@dataclass -class EvaluateResultException: - """EvaluateResultException.""" - - type: str = field(default="exception", init=False) - exception_details: Any | None = None - realm: Any | None = None - - -@dataclass -class ExceptionDetails: - """ExceptionDetails.""" - - column_number: Any | None = None - exception: Any | None = None - line_number: Any | None = None - stack_trace: Any | None = None - text: str | None = None - - -@dataclass -class ArrayLocalValue: - """ArrayLocalValue.""" - - type: str = field(default="array", init=False) - value: Any | None = None - - -@dataclass -class DateLocalValue: - """DateLocalValue.""" - - type: str = field(default="date", init=False) - value: str | None = None - - -@dataclass -class MapLocalValue: - """MapLocalValue.""" - - type: str = field(default="map", init=False) - value: Any | None = None - - -@dataclass -class ObjectLocalValue: - """ObjectLocalValue.""" - - type: str = field(default="object", init=False) - value: Any | None = None - - -@dataclass -class RegExpValue: - """RegExpValue.""" - - pattern: str | None = None - flags: str | None = None - - -@dataclass -class RegExpLocalValue: - """RegExpLocalValue.""" - - type: str = field(default="regexp", init=False) - value: Any | None = None - - -@dataclass -class SetLocalValue: - """SetLocalValue.""" - - type: str = field(default="set", init=False) - value: Any | None = None - - -@dataclass -class UndefinedValue: - """UndefinedValue.""" - - type: str = field(default="undefined", init=False) - - -@dataclass -class NullValue: - """NullValue.""" - - type: str = field(default="null", init=False) - - -@dataclass -class StringValue: - """StringValue.""" - - type: str = field(default="string", init=False) - value: str | None = None - - -@dataclass -class NumberValue: - """NumberValue.""" - - type: str = field(default="number", init=False) - value: Any | None = None - - -@dataclass -class BooleanValue: - """BooleanValue.""" - - type: str = field(default="boolean", init=False) - value: bool | None = None - - -@dataclass -class BigIntValue: - """BigIntValue.""" - - type: str = field(default="bigint", init=False) - value: str | None = None - - -@dataclass -class BaseRealmInfo: - """BaseRealmInfo.""" - - realm: Any | None = None - origin: str | None = None - - -@dataclass -class WindowRealmInfo: - """WindowRealmInfo.""" - - type: str = field(default="window", init=False) - context: Any | None = None - user_context: Any | None = None - sandbox: str | None = None - - -@dataclass -class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo.""" - - type: str = field(default="dedicated-worker", init=False) - owners: list[Any] = field(default_factory=list) - - -@dataclass -class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo.""" - - type: str = field(default="shared-worker", init=False) - - -@dataclass -class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo.""" - - type: str = field(default="service-worker", init=False) - - -@dataclass -class WorkerRealmInfo: - """WorkerRealmInfo.""" - - type: str = field(default="worker", init=False) - - -@dataclass -class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo.""" - - type: str = field(default="paint-worklet", init=False) - - -@dataclass -class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo.""" - - type: str = field(default="audio-worklet", init=False) - - -@dataclass -class WorkletRealmInfo: - """WorkletRealmInfo.""" - - type: str = field(default="worklet", init=False) - - -@dataclass -class SharedReference: - """SharedReference.""" - - shared_id: Any | None = None - handle: Any | None = None - - -@dataclass -class RemoteObjectReference: - """RemoteObjectReference.""" - - handle: Any | None = None - shared_id: Any | None = None - - -@dataclass -class SymbolRemoteValue: - """SymbolRemoteValue.""" - - type: str = field(default="symbol", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class ArrayRemoteValue: - """ArrayRemoteValue.""" - - type: str = field(default="array", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class ObjectRemoteValue: - """ObjectRemoteValue.""" - - type: str = field(default="object", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class FunctionRemoteValue: - """FunctionRemoteValue.""" - - type: str = field(default="function", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class RegExpRemoteValue: - """RegExpRemoteValue.""" - - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class DateRemoteValue: - """DateRemoteValue.""" - - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class MapRemoteValue: - """MapRemoteValue.""" - - type: str = field(default="map", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class SetRemoteValue: - """SetRemoteValue.""" - - type: str = field(default="set", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class WeakMapRemoteValue: - """WeakMapRemoteValue.""" - - type: str = field(default="weakmap", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class WeakSetRemoteValue: - """WeakSetRemoteValue.""" - - type: str = field(default="weakset", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class GeneratorRemoteValue: - """GeneratorRemoteValue.""" - - type: str = field(default="generator", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class ErrorRemoteValue: - """ErrorRemoteValue.""" - - type: str = field(default="error", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class ProxyRemoteValue: - """ProxyRemoteValue.""" - - type: str = field(default="proxy", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class PromiseRemoteValue: - """PromiseRemoteValue.""" - - type: str = field(default="promise", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class TypedArrayRemoteValue: - """TypedArrayRemoteValue.""" - - type: str = field(default="typedarray", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue.""" - - type: str = field(default="arraybuffer", init=False) - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class NodeListRemoteValue: - """NodeListRemoteValue.""" - - type: str = field(default="nodelist", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue.""" - - type: str = field(default="htmlcollection", init=False) - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class NodeRemoteValue: - """NodeRemoteValue.""" - - type: str = field(default="node", init=False) - shared_id: Any | None = None - handle: Any | None = None - internal_id: Any | None = None - value: Any | None = None - - -@dataclass -class NodeProperties: - """NodeProperties.""" - - node_type: Any | None = None - child_node_count: Any | None = None - children: list[Any] = field(default_factory=list) - local_name: str | None = None - mode: Any | None = None - namespace_uri: str | None = None - node_value: str | None = None - shadow_root: Any | None = None - - -@dataclass -class WindowProxyRemoteValue: - """WindowProxyRemoteValue.""" - - type: str = field(default="window", init=False) - value: Any | None = None - handle: Any | None = None - internal_id: Any | None = None - - -@dataclass -class WindowProxyProperties: - """WindowProxyProperties.""" - - context: Any | None = None - - -@dataclass -class StackFrame: - """StackFrame.""" - - column_number: Any | None = None - function_name: str | None = None - line_number: Any | None = None - url: str | None = None - - -@dataclass -class StackTrace: - """StackTrace.""" - - call_frames: list[Any] = field(default_factory=list) - - -@dataclass -class Source: - """Source.""" - - realm: Any | None = None - context: Any | None = None - user_context: Any | None = None - - -@dataclass -class RealmTarget: - """RealmTarget.""" - - realm: Any | None = None - - -@dataclass -class ContextTarget: - """ContextTarget.""" - - context: Any | None = None - sandbox: str | None = None - - -@dataclass -class AddPreloadScriptParameters: - """AddPreloadScriptParameters.""" - - function_declaration: str | None = None - arguments: list[Any] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - sandbox: str | None = None - - -@dataclass -class AddPreloadScriptResult: - """AddPreloadScriptResult.""" - - script: Any | None = None - - -@dataclass -class DisownParameters: - """DisownParameters.""" - - handles: list[Any] = field(default_factory=list) - target: Any | None = None - - -@dataclass -class CallFunctionParameters: - """CallFunctionParameters.""" - - function_declaration: str | None = None - await_promise: bool | None = None - target: Any | None = None - arguments: list[Any] = field(default_factory=list) - result_ownership: Any | None = None - serialization_options: Any | None = None - this: Any | None = None - user_activation: bool | None = None - - -@dataclass -class EvaluateParameters: - """EvaluateParameters.""" - - expression: str | None = None - target: Any | None = None - await_promise: bool | None = None - result_ownership: Any | None = None - serialization_options: Any | None = None - user_activation: bool | None = None - - -@dataclass -class GetRealmsParameters: - """GetRealmsParameters.""" - - context: Any | None = None - type: Any | None = None - - -@dataclass -class GetRealmsResult: - """GetRealmsResult.""" - - realms: list[Any] = field(default_factory=list) - - -@dataclass -class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters.""" - - script: Any | None = None - - -@dataclass -class MessageParameters: - """MessageParameters.""" - - channel: Any | None = None - data: Any | None = None - source: Any | None = None - - -@dataclass -class RealmDestroyedParameters: - """RealmDestroyedParameters.""" - - realm: Any | None = None - - -# BiDi Event Name to Parameter Type Mapping -EVENT_NAME_MAPPING = { - "realm_created": "script.realmCreated", - "realm_destroyed": "script.realmDestroyed", -} - - -class Script: - """WebDriver BiDi script module.""" - - EVENT_CONFIGS: dict[str, EventConfig] = {} - - def __init__(self, conn, driver=None) -> None: - self._conn = conn - self._driver = driver - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - def add_preload_script( - self, - function_declaration: Any | None = None, - arguments: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - sandbox: Any | None = None, - ): - """Execute script.addPreloadScript.""" - if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: 'function_declaration'") - - params = { - "functionDeclaration": function_declaration, - "arguments": arguments, - "contexts": contexts, - "userContexts": user_contexts, - "sandbox": sandbox, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.addPreloadScript", params) - result = self._conn.execute(cmd) - return result - - def disown(self, handles: list[Any] | None = None, target: Any | None = None): - """Execute script.disown.""" - if handles is None: - raise TypeError("disown() missing required argument: 'handles'") - if target is None: - raise TypeError("disown() missing required argument: 'target'") - - params = { - "handles": handles, - "target": target, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.disown", params) - result = self._conn.execute(cmd) - return result - - def call_function( - self, - function_declaration: Any | None = None, - await_promise: bool | None = None, - target: Any | None = None, - arguments: list[Any] | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - this: Any | None = None, - user_activation: bool | None = None, - ): - """Execute script.callFunction.""" - if function_declaration is None: - raise TypeError("call_function() missing required argument: 'function_declaration'") - if await_promise is None: - raise TypeError("call_function() missing required argument: 'await_promise'") - if target is None: - raise TypeError("call_function() missing required argument: 'target'") - - params = { - "functionDeclaration": function_declaration, - "awaitPromise": await_promise, - "target": target, - "arguments": arguments, - "resultOwnership": result_ownership, - "serializationOptions": serialization_options, - "this": this, - "userActivation": user_activation, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.callFunction", params) - result = self._conn.execute(cmd) - return result - - def evaluate( - self, - expression: Any | None = None, - target: Any | None = None, - await_promise: bool | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - user_activation: bool | None = None, - ): - """Execute script.evaluate.""" - if expression is None: - raise TypeError("evaluate() missing required argument: 'expression'") - if target is None: - raise TypeError("evaluate() missing required argument: 'target'") - if await_promise is None: - raise TypeError("evaluate() missing required argument: 'await_promise'") - - params = { - "expression": expression, - "target": target, - "awaitPromise": await_promise, - "resultOwnership": result_ownership, - "serializationOptions": serialization_options, - "userActivation": user_activation, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.evaluate", params) - result = self._conn.execute(cmd) - return result - - def get_realms(self, context: Any | None = None, type: Any | None = None): - """Execute script.getRealms.""" - params = { - "context": context, - "type": type, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.getRealms", params) - result = self._conn.execute(cmd) - return result - - def remove_preload_script(self, script: Any | None = None): - """Execute script.removePreloadScript.""" - if script is None: - raise TypeError("remove_preload_script() missing required argument: 'script'") - - params = { - "script": script, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.removePreloadScript", params) - result = self._conn.execute(cmd) - return result - - def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): - """Execute script.message.""" - if channel is None: - raise TypeError("message() missing required argument: 'channel'") - if data is None: - raise TypeError("message() missing required argument: 'data'") - if source is None: - raise TypeError("message() missing required argument: 'source'") - - params = { - "channel": channel, - "data": data, - "source": source, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.message", params) - result = self._conn.execute(cmd) - return result - - def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: - """Execute a function declaration in the browser context. - - Args: - function_declaration: The function as a string, e.g. ``"() => document.title"``. - *args: Optional Python values to pass as arguments to the function. - Each value is serialised to a BiDi ``LocalValue`` automatically. - Supported types: ``None``, ``bool``, ``int``, ``float`` - (including ``NaN`` and ``Infinity``), ``str``, ``list``, - ``dict``, and ``datetime.datetime``. - context_id: The browsing context ID to run in. Defaults to the - driver's current window handle when a driver was provided. - - Returns: - The inner RemoteValue result dict, or raises WebDriverException on exception. - """ - import datetime as _datetime - import math as _math - - from selenium.common.exceptions import WebDriverException as _WebDriverException - - def _serialize_arg(value): - """Serialise a Python value to a BiDi LocalValue dict.""" - if value is None: - return {"type": "null"} - if isinstance(value, bool): - return {"type": "boolean", "value": value} - if isinstance(value, _datetime.datetime): - return {"type": "date", "value": value.isoformat()} - if isinstance(value, float): - if _math.isnan(value): - return {"type": "number", "value": "NaN"} - if _math.isinf(value): - return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} - return {"type": "number", "value": value} - if isinstance(value, int): - _MAX_SAFE_INT = 9007199254740991 - if abs(value) > _MAX_SAFE_INT: - return {"type": "bigint", "value": str(value)} - return {"type": "number", "value": value} - if isinstance(value, str): - return {"type": "string", "value": value} - if isinstance(value, list): - return {"type": "array", "value": [_serialize_arg(v) for v in value]} - if isinstance(value, dict): - return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} - return value - - if context_id is None and self._driver is not None: - try: - context_id = self._driver.current_window_handle - except Exception: - pass - target = {"context": context_id} if context_id else {} - serialized_args = [_serialize_arg(a) for a in args] if args else None - raw = self.call_function( - function_declaration=function_declaration, - await_promise=True, - target=target, - arguments=serialized_args, - ) - if isinstance(raw, dict): - if raw.get("type") == "exception": - exc = raw.get("exceptionDetails", {}) - msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) - raise _WebDriverException(msg) - if raw.get("type") == "success": - return raw.get("result") - return raw - - def _add_preload_script( - self, - function_declaration, - arguments=None, - contexts=None, - user_contexts=None, - sandbox=None, - ): - """Add a preload script with validation. - - Args: - function_declaration: The JS function to run on page load. - arguments: Optional list of BiDi arguments. - contexts: Optional list of browsing context IDs. - user_contexts: Optional list of user context IDs. - sandbox: Optional sandbox name. - - Returns: - script_id: The ID of the added preload script (str). - - Raises: - ValueError: If both contexts and user_contexts are specified. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") - result = self.add_preload_script( - function_declaration=function_declaration, - arguments=arguments, - contexts=contexts, - user_contexts=user_contexts, - sandbox=sandbox, - ) - if isinstance(result, dict): - return result.get("script") - return result - - def _remove_preload_script(self, script_id): - """Remove a preload script by ID. - - Args: - script_id: The ID of the preload script to remove. - """ - return self.remove_preload_script(script=script_id) - - def pin(self, function_declaration): - """Pin (add) a preload script that runs on every page load. - - Args: - function_declaration: The JS function to execute on page load. - - Returns: - script_id: The ID of the pinned script (str). - """ - return self._add_preload_script(function_declaration) - - def unpin(self, script_id): - """Unpin (remove) a previously pinned preload script. - - Args: - script_id: The ID returned by pin(). - """ - return self._remove_preload_script(script_id=script_id) - - def _evaluate( - self, - expression, - target, - await_promise, - result_ownership=None, - serialization_options=None, - user_activation=None, - ): - """Evaluate a script expression and return a structured result. - - Args: - expression: The JavaScript expression to evaluate. - target: A dict like {"context": } or {"realm": }. - await_promise: Whether to await a returned promise. - result_ownership: Optional result ownership setting. - serialization_options: Optional serialization options dict. - user_activation: Optional user activation flag. - - Returns: - An object with .realm, .result (dict or None), and .exception_details (or None). - """ - - class _EvalResult: - def __init__(self2, realm, result, exception_details): - self2.realm = realm - self2.result = result - self2.exception_details = exception_details - - raw = self.evaluate( - expression=expression, - target=target, - await_promise=await_promise, - result_ownership=result_ownership, - serialization_options=serialization_options, - user_activation=user_activation, - ) - if isinstance(raw, dict): - realm = raw.get("realm") - if raw.get("type") == "exception": - exc = raw.get("exceptionDetails") - return _EvalResult(realm=realm, result=None, exception_details=exc) - return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) - return _EvalResult(realm=None, result=raw, exception_details=None) - - def _call_function( - self, - function_declaration, - await_promise, - target, - arguments=None, - result_ownership=None, - this=None, - user_activation=None, - serialization_options=None, - ): - """Call a function and return a structured result. - - Args: - function_declaration: The JS function string. - await_promise: Whether to await the return value. - target: A dict like {"context": }. - arguments: Optional list of BiDi arguments. - result_ownership: Optional result ownership. - this: Optional 'this' binding. - user_activation: Optional user activation flag. - serialization_options: Optional serialization options dict. - - Returns: - An object with .result (dict or None) and .exception_details (or None). - """ - - class _CallResult: - def __init__(self2, result, exception_details): - self2.result = result - self2.exception_details = exception_details - - raw = self.call_function( - function_declaration=function_declaration, - await_promise=await_promise, - target=target, - arguments=arguments, - result_ownership=result_ownership, - this=this, - user_activation=user_activation, - serialization_options=serialization_options, - ) - if isinstance(raw, dict): - if raw.get("type") == "exception": - exc = raw.get("exceptionDetails") - return _CallResult(result=None, exception_details=exc) - if raw.get("type") == "success": - return _CallResult(result=raw.get("result"), exception_details=None) - return _CallResult(result=raw, exception_details=None) - - def _get_realms(self, context=None, type=None): - """Get all realms, optionally filtered by context and type. - - Args: - context: Optional browsing context ID to filter by. - type: Optional realm type string to filter by (e.g. RealmType.WINDOW). - - Returns: - List of realm info objects with .realm, .origin, .type, .context attributes. - """ - - class _RealmInfo: - def __init__(self2, realm, origin, type_, context): - self2.realm = realm - self2.origin = origin - self2.type = type_ - self2.context = context - - raw = self.get_realms(context=context, type=type) - realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] - result = [] - for r in realms_list: - if isinstance(r, dict): - result.append( - _RealmInfo( - realm=r.get("realm"), - origin=r.get("origin"), - type_=r.get("type"), - context=r.get("context"), - ) - ) - return result - - def _disown(self, handles, target): - """Disown handles in a browsing context. - - Args: - handles: List of handle strings to disown. - target: A dict like {"context": }. - """ - return self.disown(handles=handles, target=target) - - def _subscribe_log_entry(self, callback, entry_type_filter=None): - """Subscribe to log.entryAdded BiDi events with optional type filtering.""" - import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod - from selenium.webdriver.common.bidi.session import Session as _Session - - bidi_event = "log.entryAdded" - - if not hasattr(self, "_log_subscriptions"): - self._log_subscriptions = {} - self._log_lock = _threading.Lock() - - def _deserialize(params): - t = params.get("type") if isinstance(params, dict) else None - if t == "console": - cls = getattr(_log_mod, "ConsoleLogEntry", None) - if cls is not None and hasattr(cls, "from_json"): - try: - return cls.from_json(params) - except Exception: - pass - elif t == "javascript": - cls = getattr(_log_mod, "JavascriptLogEntry", None) - if cls is not None and hasattr(cls, "from_json"): - try: - return cls.from_json(params) - except Exception: - pass - return params - - def _wrapped(raw): - entry = _deserialize(raw) - if entry_type_filter is None: - callback(entry) - else: - t = getattr(entry, "type_", None) or (entry.get("type") if isinstance(entry, dict) else None) - if t == entry_type_filter: - callback(entry) - - class _BidiRef: - event_class = bidi_event - - def from_json(self2, p): - return p - - _wrapper = _BidiRef() - callback_id = self._conn.add_callback(_wrapper, _wrapped) - with self._log_lock: - if bidi_event not in self._log_subscriptions: - session = _Session(self._conn) - result = session.subscribe([bidi_event]) - sub_id = result.get("subscription") if isinstance(result, dict) else None - self._log_subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) - return callback_id - - def _unsubscribe_log_entry(self, callback_id): - """Unsubscribe a log entry callback by ID.""" - from selenium.webdriver.common.bidi.session import Session as _Session - - bidi_event = "log.entryAdded" - if not hasattr(self, "_log_subscriptions"): - return - - class _BidiRef: - event_class = bidi_event - - def from_json(self2, p): - return p - - _wrapper = _BidiRef() - self._conn.remove_callback(_wrapper, callback_id) - with self._log_lock: - entry = self._log_subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - if entry is not None and not entry["callbacks"]: - session = _Session(self._conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self._log_subscriptions[bidi_event] - - def add_console_message_handler(self, callback: Callable) -> int: - """Add a handler for console log messages (log.entryAdded type=console). - - Args: - callback: Function called with a ConsoleLogEntry on each console message. - - Returns: - callback_id for use with remove_console_message_handler. - """ - return self._subscribe_log_entry(callback, entry_type_filter="console") - - def remove_console_message_handler(self, callback_id: int) -> None: - """Remove a console message handler by callback ID.""" - self._unsubscribe_log_entry(callback_id) - - def add_javascript_error_handler(self, callback: Callable) -> int: - """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). - - Args: - callback: Function called with a JavascriptLogEntry on each JS error. - - Returns: - callback_id for use with remove_javascript_error_handler. - """ - return self._subscribe_log_entry(callback, entry_type_filter="javascript") - - def remove_javascript_error_handler(self, callback_id: int) -> None: - """Remove a JavaScript error handler by callback ID.""" - self._unsubscribe_log_entry(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - """ - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - """ - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - return self._event_manager.clear_event_handlers() - - -# Event Info Type Aliases -# Event: script.realmCreated -RealmCreated = globals().get("RealmInfo", dict) # Fallback to dict if type not defined - -# Event: script.realmDestroyed -RealmDestroyed = globals().get("RealmDestroyedParameters", dict) # Fallback to dict if type not defined - - -# Populate EVENT_CONFIGS with event configuration mappings -_globals = globals() -Script.EVENT_CONFIGS = { - "realm_created": EventConfig( - "realm_created", - "script.realmCreated", - _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, - ), - "realm_destroyed": EventConfig( - "realm_destroyed", - "script.realmDestroyed", - _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, - ), -} diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py deleted file mode 100644 index b00544d286546..0000000000000 --- a/py/selenium/webdriver/common/bidi/session.py +++ /dev/null @@ -1,260 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - - -class UserPromptHandlerType: - """UserPromptHandlerType.""" - - ACCEPT = "accept" - DISMISS = "dismiss" - IGNORE = "ignore" - - -@dataclass -class CapabilitiesRequest: - """CapabilitiesRequest.""" - - always_match: Any | None = None - first_match: list[Any] = field(default_factory=list) - - -@dataclass -class CapabilityRequest: - """CapabilityRequest.""" - - accept_insecure_certs: bool | None = None - browser_name: str | None = None - browser_version: str | None = None - platform_name: str | None = None - proxy: Any | None = None - unhandled_prompt_behavior: Any | None = None - - -@dataclass -class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration.""" - - proxy_type: str = field(default="autodetect", init=False) - - -@dataclass -class DirectProxyConfiguration: - """DirectProxyConfiguration.""" - - proxy_type: str = field(default="direct", init=False) - - -@dataclass -class ManualProxyConfiguration: - """ManualProxyConfiguration.""" - - proxy_type: str = field(default="manual", init=False) - http_proxy: str | None = None - ssl_proxy: str | None = None - no_proxy: list[Any] = field(default_factory=list) - - -@dataclass -class SocksProxyConfiguration: - """SocksProxyConfiguration.""" - - socks_proxy: str | None = None - socks_version: Any | None = None - - -@dataclass -class PacProxyConfiguration: - """PacProxyConfiguration.""" - - proxy_type: str = field(default="pac", init=False) - proxy_autoconfig_url: str | None = None - - -@dataclass -class SystemProxyConfiguration: - """SystemProxyConfiguration.""" - - proxy_type: str = field(default="system", init=False) - - -@dataclass -class SubscribeParameters: - """SubscribeParameters.""" - - events: list[str] = field(default_factory=list) - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - -@dataclass -class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest.""" - - subscriptions: list[Any] = field(default_factory=list) - - -@dataclass -class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest.""" - - events: list[str] = field(default_factory=list) - - -@dataclass -class StatusResult: - """StatusResult.""" - - ready: bool | None = None - message: str | None = None - - -@dataclass -class NewParameters: - """NewParameters.""" - - capabilities: Any | None = None - - -@dataclass -class NewResult: - """NewResult.""" - - session_id: str | None = None - accept_insecure_certs: bool | None = None - browser_name: str | None = None - browser_version: str | None = None - platform_name: str | None = None - set_window_rect: bool | None = None - user_agent: str | None = None - proxy: Any | None = None - unhandled_prompt_behavior: Any | None = None - web_socket_url: str | None = None - - -@dataclass -class SubscribeResult: - """SubscribeResult.""" - - subscription: Any | None = None - - -@dataclass -class UserPromptHandler: - """UserPromptHandler.""" - - alert: Any | None = None - before_unload: Any | None = None - confirm: Any | None = None - default: Any | None = None - file: Any | None = None - prompt: Any | None = None - - def to_bidi_dict(self) -> dict: - """Convert to BiDi protocol dict with camelCase keys.""" - result = {} - if self.alert is not None: - result["alert"] = self.alert - if self.before_unload is not None: - result["beforeUnload"] = self.before_unload - if self.confirm is not None: - result["confirm"] = self.confirm - if self.default is not None: - result["default"] = self.default - if self.file is not None: - result["file"] = self.file - if self.prompt is not None: - result["prompt"] = self.prompt - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -class Session: - """WebDriver BiDi session module.""" - - def __init__(self, conn) -> None: - self._conn = conn - - def status(self): - """Execute session.status.""" - params = {} - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("session.status", params) - result = self._conn.execute(cmd) - return result - - def new(self, capabilities: Any | None = None): - """Execute session.new.""" - if capabilities is None: - raise TypeError("new() missing required argument: 'capabilities'") - - params = { - "capabilities": capabilities, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("session.new", params) - result = self._conn.execute(cmd) - return result - - def end(self): - """Execute session.end.""" - params = {} - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("session.end", params) - result = self._conn.execute(cmd) - return result - - def subscribe( - self, - events: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute session.subscribe.""" - if events is None: - raise TypeError("subscribe() missing required argument: 'events'") - - params = { - "events": events, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("session.subscribe", params) - result = self._conn.execute(cmd) - return result - - def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): - """Execute session.unsubscribe.""" - params = { - "events": events, - "subscriptions": subscriptions, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("session.unsubscribe", params) - result = self._conn.execute(cmd) - return result diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py deleted file mode 100644 index 90e65ac3d5ffb..0000000000000 --- a/py/selenium/webdriver/common/bidi/storage.py +++ /dev/null @@ -1,353 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - - -@dataclass -class PartitionKey: - """PartitionKey.""" - - user_context: str | None = None - source_origin: str | None = None - - -@dataclass -class GetCookiesParameters: - """GetCookiesParameters.""" - - filter: Any | None = None - partition: Any | None = None - - -@dataclass -class GetCookiesResult: - """GetCookiesResult.""" - - cookies: list[Any] = field(default_factory=list) - partition_key: Any | None = None - - -@dataclass -class SetCookieParameters: - """SetCookieParameters.""" - - cookie: Any | None = None - partition: Any | None = None - - -@dataclass -class SetCookieResult: - """SetCookieResult.""" - - partition_key: Any | None = None - - -@dataclass -class DeleteCookiesParameters: - """DeleteCookiesParameters.""" - - filter: Any | None = None - partition: Any | None = None - - -@dataclass -class DeleteCookiesResult: - """DeleteCookiesResult.""" - - partition_key: Any | None = None - - -class BytesValue: - """A string or base64-encoded bytes value used in cookie operations. - - This corresponds to network.BytesValue in the WebDriver BiDi specification, - wrapping either a plain string or a base64-encoded binary value. - """ - - TYPE_STRING = "string" - TYPE_BASE64 = "base64" - - def __init__(self, type: Any | None, value: Any | None) -> None: - self.type = type - self.value = value - - def to_bidi_dict(self) -> dict: - return {"type": self.type, "value": self.value} - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -class SameSite: - """SameSite cookie attribute values.""" - - STRICT = "strict" - LAX = "lax" - NONE = "none" - DEFAULT = "default" - - -@dataclass -class StorageCookie: - """A cookie object returned by storage.getCookies.""" - - name: str | None = None - value: Any | None = None - domain: str | None = None - path: str | None = None - size: Any | None = None - http_only: bool | None = None - secure: bool | None = None - same_site: Any | None = None - expiry: Any | None = None - - @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: - """Deserialize a wire-level cookie dict to a StorageCookie.""" - value_raw = raw.get("value") - if isinstance(value_raw, dict): - value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) - else: - value = value_raw - return cls( - name=raw.get("name"), - value=value, - domain=raw.get("domain"), - path=raw.get("path"), - size=raw.get("size"), - http_only=raw.get("httpOnly"), - secure=raw.get("secure"), - same_site=raw.get("sameSite"), - expiry=raw.get("expiry"), - ) - - -@dataclass -class CookieFilter: - """CookieFilter.""" - - name: str | None = None - value: Any | None = None - domain: str | None = None - path: str | None = None - size: Any | None = None - http_only: bool | None = None - secure: bool | None = None - same_site: Any | None = None - expiry: Any | None = None - - def to_bidi_dict(self) -> dict: - """Serialize to the BiDi wire-protocol dict.""" - result: dict = {} - if self.name is not None: - result["name"] = self.name - if self.value is not None: - result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value - if self.domain is not None: - result["domain"] = self.domain - if self.path is not None: - result["path"] = self.path - if self.size is not None: - result["size"] = self.size - if self.http_only is not None: - result["httpOnly"] = self.http_only - if self.secure is not None: - result["secure"] = self.secure - if self.same_site is not None: - result["sameSite"] = self.same_site - if self.expiry is not None: - result["expiry"] = self.expiry - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -@dataclass -class PartialCookie: - """PartialCookie.""" - - name: str | None = None - value: Any | None = None - domain: str | None = None - path: str | None = None - http_only: bool | None = None - secure: bool | None = None - same_site: Any | None = None - expiry: Any | None = None - - def to_bidi_dict(self) -> dict: - """Serialize to the BiDi wire-protocol dict.""" - result: dict = {} - if self.name is not None: - result["name"] = self.name - if self.value is not None: - result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value - if self.domain is not None: - result["domain"] = self.domain - if self.path is not None: - result["path"] = self.path - if self.http_only is not None: - result["httpOnly"] = self.http_only - if self.secure is not None: - result["secure"] = self.secure - if self.same_site is not None: - result["sameSite"] = self.same_site - if self.expiry is not None: - result["expiry"] = self.expiry - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -class BrowsingContextPartitionDescriptor: - """BrowsingContextPartitionDescriptor. - - The first positional argument is *context* (a browsing-context ID / window - handle), mirroring how the class is used throughout the test suite: - ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. - """ - - def __init__(self, context: Any = None, type: str = "context") -> None: - self.context = context - self.type = type - - def to_bidi_dict(self) -> dict: - return {"type": "context", "context": self.context} - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -@dataclass -class StorageKeyPartitionDescriptor: - """StorageKeyPartitionDescriptor.""" - - type: Any | None = "storageKey" - user_context: str | None = None - source_origin: str | None = None - - def to_bidi_dict(self) -> dict: - """Serialize to the BiDi wire-protocol dict.""" - result: dict = {"type": "storageKey"} - if self.user_context is not None: - result["userContext"] = self.user_context - if self.source_origin is not None: - result["sourceOrigin"] = self.source_origin - return result - - def to_dict(self) -> dict: - """Backward-compatible alias for to_bidi_dict().""" - return self.to_bidi_dict() - - -class Storage: - """WebDriver BiDi storage module.""" - - def __init__(self, conn) -> None: - self._conn = conn - - def get_cookies(self, filter=None, partition=None): - """Execute storage.getCookies and return a GetCookiesResult.""" - if filter and hasattr(filter, "to_bidi_dict"): - filter = filter.to_bidi_dict() - if partition and hasattr(partition, "to_bidi_dict"): - partition = partition.to_bidi_dict() - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - if result and "cookies" in result: - cookies = [StorageCookie.from_bidi_dict(c) for c in result.get("cookies", []) if isinstance(c, dict)] - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return GetCookiesResult(cookies=cookies, partition_key=pk) - return GetCookiesResult(cookies=[], partition_key=None) - - def set_cookie(self, cookie=None, partition=None): - """Execute storage.setCookie.""" - if cookie and hasattr(cookie, "to_bidi_dict"): - cookie = cookie.to_bidi_dict() - if partition and hasattr(partition, "to_bidi_dict"): - partition = partition.to_bidi_dict() - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return SetCookieResult(partition_key=pk) - return result - - def delete_cookies(self, filter=None, partition=None): - """Execute storage.deleteCookies.""" - if filter and hasattr(filter, "to_bidi_dict"): - filter = filter.to_bidi_dict() - if partition and hasattr(partition, "to_bidi_dict"): - partition = partition.to_bidi_dict() - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - if isinstance(result, dict): - pk_raw = result.get("partitionKey") - pk = ( - PartitionKey( - user_context=pk_raw.get("userContext"), - source_origin=pk_raw.get("sourceOrigin"), - ) - if isinstance(pk_raw, dict) - else None - ) - return DeleteCookiesResult(partition_key=pk) - return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py deleted file mode 100644 index 62f2dec130308..0000000000000 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ /dev/null @@ -1,154 +0,0 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder - - -@dataclass -class InstallParameters: - """InstallParameters.""" - - extension_data: Any | None = None - - -@dataclass -class ExtensionPath: - """ExtensionPath.""" - - type: str = field(default="path", init=False) - path: str | None = None - - -@dataclass -class ExtensionArchivePath: - """ExtensionArchivePath.""" - - type: str = field(default="archivePath", init=False) - path: str | None = None - - -@dataclass -class ExtensionBase64Encoded: - """ExtensionBase64Encoded.""" - - type: str = field(default="base64", init=False) - value: str | None = None - - -@dataclass -class InstallResult: - """InstallResult.""" - - extension: Any | None = None - - -@dataclass -class UninstallParameters: - """UninstallParameters.""" - - extension: Any | None = None - - -class WebExtension: - """WebDriver BiDi webExtension module.""" - - def __init__(self, conn) -> None: - self._conn = conn - - def install( - self, - path: str | None = None, - archive_path: str | None = None, - base64_value: str | None = None, - ): - """Install a web extension. - - Exactly one of the three keyword arguments must be provided. - - Args: - path: Directory path to an unpacked extension (also accepted for - signed ``.xpi`` / ``.crx`` archive files on Firefox). - archive_path: File-system path to a packed extension archive. - base64_value: Base64-encoded extension archive string. - - Returns: - The raw result dict from the BiDi ``webExtension.install`` command - (contains at least an ``"extension"`` key with the extension ID). - - Raises: - ValueError: If more than one, or none, of the arguments is provided. - """ - provided = [ - k - for k, v in { - "path": path, - "archive_path": archive_path, - "base64_value": base64_value, - }.items() - if v is not None - ] - if len(provided) != 1: - raise ValueError(f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}") - if path is not None: - extension_data = {"type": "path", "path": path} - elif archive_path is not None: - extension_data = {"type": "archivePath", "path": archive_path} - else: - assert base64_value is not None - extension_data = {"type": "base64", "value": base64_value} - params = {"extensionData": extension_data} - cmd = command_builder("webExtension.install", params) - try: - return self._conn.execute(cmd) - except Exception as e: - if "Method not available" in str(e): - raise RuntimeError( - "webExtension.install failed with 'Method not available'. " - "This likely means that web extension support is disabled. " - "Enable unsafe extension debugging and/or set options.enable_webextensions " - "in your WebDriver configuration." - ) from e - raise - - def uninstall(self, extension: str | dict): - """Uninstall a web extension. - - Args: - extension: Either the extension ID string returned by ``install``, - or the full result dict returned by ``install`` (the - ``"extension"`` value is extracted automatically). - - Raises: - ValueError: If extension is not provided or is None. - """ - if isinstance(extension, dict): - extension_id: Any = extension.get("extension") - else: - extension_id = extension - - if extension_id is None: - raise ValueError("extension parameter is required") - - params = {"extension": extension_id} - cmd = command_builder("webExtension.uninstall", params) - return self._conn.execute(cmd) From 4f30f6b440f29c17530875d0b02caba260cedb70 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 16 Apr 2026 13:55:43 +0200 Subject: [PATCH 39/42] fix formating --- py/private/cdp.py | 86 +++++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 59 deletions(-) diff --git a/py/private/cdp.py b/py/private/cdp.py index 9ca951479f657..86341b3babd71 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -1,26 +1,20 @@ -# The MIT License(MIT) +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Copyright(c) 2018 Hyperion Gray +# http://www.apache.org/licenses/LICENSE-2.0 # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import contextvars import importlib @@ -60,11 +54,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x - for x in versions - if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) + available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise @@ -75,9 +65,7 @@ def import_devtools(ver): return devtools -_connection_context: contextvars.ContextVar = contextvars.ContextVar( - "connection_context" -) +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") _session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") @@ -132,9 +120,7 @@ def set_global_connection(connection): certain use cases such as running inside Jupyter notebook. """ global _connection_context - _connection_context = contextvars.ContextVar( - "_connection_context", default=connection - ) + _connection_context = contextvars.ContextVar("_connection_context", default=connection) def set_global_session(session): @@ -231,9 +217,7 @@ async def execute(self, cmd: Generator[dict, T, Any]) -> T: logger.debug(f"Received CDP message: {response}") if isinstance(response, Exception): if logger.isEnabledFor(logging.DEBUG): - logger.debug( - f"Exception raised by {cmd_event} message: {type(response).__name__}" - ) + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") raise response return response @@ -249,9 +233,7 @@ def listen(self, *event_types, buffer_size=10): return receiver @asynccontextmanager - async def wait_for( - self, event_type: type[T], buffer_size=10 - ) -> AsyncGenerator[CmEventProxy, None]: + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: """Wait for an event of the given type and return it. This is an async context manager, so you should open it inside @@ -292,9 +274,7 @@ def _handle_cmd_response(self, data: dict): try: cmd, event = self.inflight_cmd.pop(cmd_id) except KeyError: - logger.warning( - "Got a message with a command ID that does not exist: %s", data - ) + logger.warning("Got a message with a command ID that does not exist: %s", data) return if "error" in data: # If the server reported an error, convert it to an exception and do @@ -305,9 +285,7 @@ def _handle_cmd_response(self, data: dict): # into a CDP object. try: _ = cmd.send(data["result"]) - raise InternalError( - "The command's generator function did not exit when expected!" - ) + raise InternalError("The command's generator function did not exit when expected!") except StopIteration as exit: return_ = exit.value self.inflight_result[cmd_id] = return_ @@ -321,9 +299,7 @@ def _handle_event(self, data: dict): """ global devtools if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") event = devtools.util.parse_json_event(data) logger.debug("Received event: %s", event) to_remove = set() @@ -331,9 +307,7 @@ def _handle_event(self, data: dict): try: sender.send_nowait(event) except trio.WouldBlock: - logger.error( - 'Unable to send event "%r" due to full channel %s', event, sender - ) + logger.error('Unable to send event "%r" due to full channel %s', event, sender) except trio.BrokenResourceError: to_remove.add(sender) if to_remove: @@ -451,12 +425,8 @@ async def connect_session(self, target_id) -> "CdpSession": """Returns a new :class:`CdpSession` connected to the specified target.""" global devtools if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) - session_id = await self.execute( - devtools.target.attach_to_target(target_id, True) - ) + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) session = CdpSession(self.ws, session_id, target_id) self.sessions[session_id] = session return session @@ -468,9 +438,7 @@ async def _reader_task(self): """ global devtools if devtools is None: - raise RuntimeError( - "CDP devtools module not loaded. Call import_devtools() first." - ) + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") while True: try: message = await self.ws.get_message() From 79368639d9e96c8c5c470ba24933197b4397d974 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 17 Apr 2026 15:22:03 +0100 Subject: [PATCH 40/42] Fix how CDP is picked up --- py/generate_bidi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 5b301d3ec7e40..8d5b75acfaea4 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1699,9 +1699,12 @@ def main( logger.info(f"Parsed {len(modules)} modules") - # Clean up existing generated files + # Clean up existing generated files. + # Keep static helper modules that are staged by Bazel (for example cdp.py) + # as part of create-bidi-src.extra_srcs. + preserved_python_files = {"py.typed", "cdp.py"} for file_path in output_path.glob("*.py"): - if file_path.name != "py.typed" and not file_path.name.startswith("_"): + if file_path.name not in preserved_python_files and not file_path.name.startswith("_"): file_path.unlink() logger.debug(f"Removed: {file_path}") From 686f115ec1d7f759ce6ec10418df3f36cf9a4615 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 18 Apr 2026 13:02:23 +0100 Subject: [PATCH 41/42] handle comments --- py/BUILD.bazel | 3 +- py/private/bidi_enhancements_manifest.py | 26 ++++-- .../webdriver/common/bidi_network_tests.py | 80 +++++++++++++++++++ 3 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 py/test/unit/selenium/webdriver/common/bidi_network_tests.py diff --git a/py/BUILD.bazel b/py/BUILD.bazel index cfd3da8ad4e78..4fcfddbcfe26d 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -262,9 +262,8 @@ py_library( # BiDi protocol support py_library( name = "bidi", - srcs = [], + srcs = [":create-bidi-src"], data = [ - ":create-bidi-src", ":mutation-listener", ], imports = ["."], diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 8cec1f9da245f..f8f033bcbc0fd 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1087,24 +1087,37 @@ def continue_request(self, **kwargs): self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) if intercept_id in self.intercepts: self.intercepts.remove(intercept_id)''', + ''' def _canonical_request_handler_event(self, event): + """Map public request-handler aliases to supported event keys.""" + event_aliases = { + "auth_required": "auth_required", + "before_request": "before_request", + "before_request_sent": "before_request", + } + canonical_event = event_aliases.get(event) + if canonical_event is None: + available_events = ", ".join(sorted(event_aliases)) + raise ValueError( + f"Unsupported request handler event '{event}'. Available events: {available_events}" + ) + return canonical_event''', ''' def add_request_handler(self, event, callback, url_patterns=None): """Add a handler for network requests at the specified phase. Args: - event: Event name, e.g. ``"before_request"``. + event: Event name, e.g. ``"before_request"`` or ``"before_request_sent"``. callback: Callable receiving a :class:`Request` instance. url_patterns: optional list of URL pattern dicts to filter. Returns: callback_id int for later removal via remove_request_handler. """ + canonical_event = self._canonical_request_handler_event(event) phase_map = { "before_request": "beforeRequestSent", - "before_request_sent": "beforeRequestSent", - "response_started": "responseStarted", "auth_required": "authRequired", } - phase = phase_map.get(event, "beforeRequestSent") + phase = phase_map[canonical_event] intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) intercept_id = intercept_result.get("intercept") if intercept_result else None @@ -1117,7 +1130,7 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - callback_id = self.add_event_handler(event, _request_callback) + callback_id = self.add_event_handler(canonical_event, _request_callback) if intercept_id: self._handler_intercepts[callback_id] = intercept_id return callback_id''', @@ -1128,7 +1141,8 @@ def _request_callback(params): event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ - self.remove_event_handler(event, callback_id) + canonical_event = self._canonical_request_handler_event(event) + self.remove_event_handler(canonical_event, callback_id) intercept_id = self._handler_intercepts.pop(callback_id, None) if intercept_id: self._remove_intercept(intercept_id)''', diff --git a/py/test/unit/selenium/webdriver/common/bidi_network_tests.py b/py/test/unit/selenium/webdriver/common/bidi_network_tests.py new file mode 100644 index 0000000000000..c8a4e41bc9fb5 --- /dev/null +++ b/py/test/unit/selenium/webdriver/common/bidi_network_tests.py @@ -0,0 +1,80 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from selenium.webdriver.common.bidi.network import Network + + +class FakeConnection: + def __init__(self): + self.commands = [] + self.added_callbacks = [] + self.removed_callbacks = [] + self._next_callback_id = 1 + + def add_callback(self, event_wrapper, callback): + callback_id = self._next_callback_id + self._next_callback_id += 1 + self.added_callbacks.append((callback_id, event_wrapper.event_class, callback)) + return callback_id + + def remove_callback(self, event_wrapper, callback_id): + self.removed_callbacks.append((callback_id, event_wrapper.event_class)) + + def execute(self, cmd): + payload = next(cmd) + self.commands.append(payload) + + if payload["method"] == "network.addIntercept": + response = {"intercept": "intercept-1"} + elif payload["method"] == "session.subscribe": + response = {"subscription": "subscription-1"} + else: + response = {} + + try: + cmd.send(response) + except StopIteration as exc: + return exc.value + + raise AssertionError("BiDi command generator did not finish") + + +def test_add_request_handler_accepts_before_request_sent_alias(): + conn = FakeConnection() + network = Network(conn) + + callback_id = network.add_request_handler("before_request_sent", lambda request: None) + network.remove_request_handler("before_request_sent", callback_id) + + assert callback_id == 1 + assert conn.added_callbacks[0][1] == "network.beforeRequestSent" + assert conn.removed_callbacks[0] == (1, "network.beforeRequestSent") + assert conn.commands == [ + {"method": "network.addIntercept", "params": {"phases": ["beforeRequestSent"]}}, + {"method": "session.subscribe", "params": {"events": ["network.beforeRequestSent"]}}, + {"method": "session.unsubscribe", "params": {"subscriptions": ["subscription-1"]}}, + {"method": "network.removeIntercept", "params": {"intercept": "intercept-1"}}, + ] + + +def test_add_request_handler_rejects_unsupported_alias(): + network = Network(FakeConnection()) + + with pytest.raises(ValueError, match="Unsupported request handler event 'response_started'"): + network.add_request_handler("response_started", lambda request: None) \ No newline at end of file From b71df8326f04102418e561075a094025abe56f5d Mon Sep 17 00:00:00 2001 From: Corey Goldberg <1113081+cgoldberg@users.noreply.github.com> Date: Sun, 19 Apr 2026 11:20:12 -0400 Subject: [PATCH 42/42] Add newline to make linter pass --- py/test/unit/selenium/webdriver/common/bidi_network_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/test/unit/selenium/webdriver/common/bidi_network_tests.py b/py/test/unit/selenium/webdriver/common/bidi_network_tests.py index c8a4e41bc9fb5..cd7c46380fd2f 100644 --- a/py/test/unit/selenium/webdriver/common/bidi_network_tests.py +++ b/py/test/unit/selenium/webdriver/common/bidi_network_tests.py @@ -77,4 +77,4 @@ def test_add_request_handler_rejects_unsupported_alias(): network = Network(FakeConnection()) with pytest.raises(ValueError, match="Unsupported request handler event 'response_started'"): - network.add_request_handler("response_started", lambda request: None) \ No newline at end of file + network.add_request_handler("response_started", lambda request: None)