Skip to content

Commit a08ffff

Browse files
authored
[Core] Set kwarg explicitly in method signatures (#46633)
* [Core] Set kwarg explicitly in method signatures The `additional_allowed_query_params` kwarg should be in the signatures to improve type-hinting. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com> * Fix imports Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com> --------- Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 31fd1c3 commit a08ffff

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import logging
2828
import sys
2929
import urllib.parse
30-
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type, Mapping, Dict
30+
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type, Mapping, Dict, Iterable
3131
from types import TracebackType
3232

3333
from azure.core.pipeline import PipelineRequest, PipelineResponse
@@ -103,14 +103,19 @@ class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseTyp
103103
_RESPONSE_ID = "x-ms-request-id"
104104
_RESPONSE_ID_ATTR = "az.service_request_id"
105105

106-
def __init__(self, *, instrumentation_config: Optional[Mapping[str, Any]] = None, **kwargs: Any):
106+
def __init__(
107+
self,
108+
*,
109+
instrumentation_config: Optional[Mapping[str, Any]] = None,
110+
additional_allowed_query_params: Optional[Iterable[str]] = None,
111+
**kwargs: Any,
112+
):
107113
self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer)
108114
self._tracing_attributes = kwargs.get("tracing_attributes", {})
109115
self._instrumentation_config = instrumentation_config
110116
self.allowed_query_params: set[str] = CaseInsensitiveSet(self.__class__.DEFAULT_QUERY_PARAMS_ALLOWLIST)
111-
additional_params = kwargs.get("additional_allowed_query_params")
112-
if additional_params:
113-
self.allowed_query_params.update(additional_params)
117+
if additional_allowed_query_params:
118+
self.allowed_query_params.update(additional_allowed_query_params)
114119

115120
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
116121
"""Starts a span for the network call.

sdk/core/azure-core/azure/core/pipeline/policies/_universal.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import types
3636
import re
3737
import uuid
38-
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, MutableMapping
38+
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, MutableMapping, Iterable
3939

4040
from azure.core import __version__ as azcore_version
4141
from azure.core.exceptions import DecodeError
@@ -444,14 +444,18 @@ class HttpLoggingPolicy(
444444
MULTI_RECORD_LOG: str = "AZURE_SDK_LOGGING_MULTIRECORD"
445445

446446
def __init__(
447-
self, logger: Optional[logging.Logger] = None, *, http_logging_level: int = logging.INFO, **kwargs: Any
447+
self,
448+
logger: Optional[logging.Logger] = None,
449+
*,
450+
http_logging_level: int = logging.INFO,
451+
additional_allowed_query_params: Optional[Iterable[str]] = None,
452+
**kwargs: Any
448453
): # pylint: disable=unused-argument
449454
self.logger: logging.Logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
450455
self.http_logging_level: int = http_logging_level
451456
self.allowed_query_params: Set[str] = CaseInsensitiveSet(self.__class__.DEFAULT_QUERY_PARAMS_ALLOWLIST)
452-
additional_query_params = kwargs.get("additional_allowed_query_params")
453-
if additional_query_params:
454-
self.allowed_query_params.update(additional_query_params)
457+
if additional_allowed_query_params:
458+
self.allowed_query_params.update(additional_allowed_query_params)
455459
self.allowed_header_names: Set[str] = CaseInsensitiveSet(self.__class__.DEFAULT_HEADERS_ALLOWLIST)
456460

457461
def _redact_header(self, key: str, value: str) -> str:

0 commit comments

Comments
 (0)