Skip to content

Commit e2bb80b

Browse files
Prefer Sequence to List in annotations (#427)
1 parent abee287 commit e2bb80b

4 files changed

Lines changed: 48 additions & 31 deletions

File tree

httpcore/_async/http_proxy.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ssl
2-
from typing import Dict, List, Tuple, Union
2+
from typing import List, Mapping, Sequence, Tuple, Union
33

44
from .._exceptions import ProxyError
55
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
@@ -11,20 +11,20 @@
1111
from .http11 import AsyncHTTP11Connection
1212
from .interfaces import AsyncConnectionInterface
1313

14-
HeadersAsList = List[Tuple[Union[bytes, str], Union[bytes, str]]]
15-
HeadersAsDict = Dict[Union[bytes, str], Union[bytes, str]]
14+
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
15+
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
1616

1717

1818
def merge_headers(
19-
default_headers: List[Tuple[bytes, bytes]] = None,
20-
override_headers: List[Tuple[bytes, bytes]] = None,
19+
default_headers: Sequence[Tuple[bytes, bytes]] = None,
20+
override_headers: Sequence[Tuple[bytes, bytes]] = None,
2121
) -> List[Tuple[bytes, bytes]]:
2222
"""
2323
Append default_headers and override_headers, de-duplicating if a key exists
2424
in both cases.
2525
"""
26-
default_headers = [] if default_headers is None else default_headers
27-
override_headers = [] if override_headers is None else override_headers
26+
default_headers = [] if default_headers is None else list(default_headers)
27+
override_headers = [] if override_headers is None else list(override_headers)
2828
has_override = set([key.lower() for key, value in override_headers])
2929
default_headers = [
3030
(key, value)
@@ -42,7 +42,7 @@ class AsyncHTTPProxy(AsyncConnectionPool):
4242
def __init__(
4343
self,
4444
proxy_url: Union[URL, bytes, str],
45-
proxy_headers: Union[HeadersAsDict, HeadersAsList] = None,
45+
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence] = None,
4646
ssl_context: ssl.SSLContext = None,
4747
max_connections: int = 10,
4848
max_keepalive_connections: int = None,
@@ -106,6 +106,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
106106
)
107107
return AsyncTunnelHTTPConnection(
108108
proxy_origin=self._proxy_url.origin,
109+
proxy_headers=self._proxy_headers,
109110
remote_origin=origin,
110111
ssl_context=self._ssl_context,
111112
keepalive_expiry=self._keepalive_expiry,
@@ -117,7 +118,7 @@ class AsyncForwardHTTPConnection(AsyncConnectionInterface):
117118
def __init__(
118119
self,
119120
proxy_origin: Origin,
120-
proxy_headers: Union[HeadersAsDict, HeadersAsList] = None,
121+
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence] = None,
121122
keepalive_expiry: float = None,
122123
network_backend: AsyncNetworkBackend = None,
123124
) -> None:
@@ -177,7 +178,7 @@ def __init__(
177178
proxy_origin: Origin,
178179
remote_origin: Origin,
179180
ssl_context: ssl.SSLContext,
180-
proxy_headers: List[Tuple[bytes, bytes]] = None,
181+
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
181182
keepalive_expiry: float = None,
182183
network_backend: AsyncNetworkBackend = None,
183184
) -> None:
@@ -189,7 +190,7 @@ def __init__(
189190
self._proxy_origin = proxy_origin
190191
self._remote_origin = remote_origin
191192
self._ssl_context = ssl_context
192-
self._proxy_headers = [] if proxy_headers is None else proxy_headers
193+
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
193194
self._keepalive_expiry = keepalive_expiry
194195
self._connect_lock = AsyncLock()
195196
self._connected = False
@@ -208,7 +209,9 @@ async def handle_async_request(self, request: Request) -> Response:
208209
port=self._proxy_origin.port,
209210
target=target,
210211
)
211-
connect_headers = [(b"Host", target), (b"Accept", b"*/*")]
212+
connect_headers = merge_headers(
213+
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
214+
)
212215
connect_request = Request(
213216
method=b"CONNECT", url=connect_url, headers=connect_headers
214217
)

httpcore/_models.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
Iterable,
66
Iterator,
77
List,
8+
Mapping,
89
Optional,
10+
Sequence,
911
Tuple,
1012
Union,
1113
)
@@ -14,6 +16,10 @@
1416
# Functions for typechecking...
1517

1618

19+
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
20+
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
21+
22+
1723
def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
1824
"""
1925
Any arguments that are ultimately represented as bytes can be specified
@@ -49,33 +55,35 @@ def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
4955

5056

5157
def enforce_headers(
52-
value: Union[dict, list] = None, *, name: str
58+
value: Union[HeadersAsMapping, HeadersAsSequence] = None, *, name: str
5359
) -> List[Tuple[bytes, bytes]]:
5460
"""
5561
Convienence function that ensure all items in request or response headers
5662
are either bytes or strings in the plain ASCII range.
5763
"""
5864
if value is None:
5965
return []
60-
elif isinstance(value, (list, tuple)):
66+
elif isinstance(value, Mapping):
6167
return [
6268
(
6369
enforce_bytes(k, name="header name"),
6470
enforce_bytes(v, name="header value"),
6571
)
66-
for k, v in value
72+
for k, v in value.items()
6773
]
68-
elif isinstance(value, dict):
74+
elif isinstance(value, Sequence):
6975
return [
7076
(
7177
enforce_bytes(k, name="header name"),
7278
enforce_bytes(v, name="header value"),
7379
)
74-
for k, v in value.items()
80+
for k, v in value
7581
]
7682

7783
seen_type = type(value).__name__
78-
raise TypeError(f"{name} must be a list, but got {seen_type}.")
84+
raise TypeError(
85+
f"{name} must be a mapping or sequence of two-tuples, but got {seen_type}."
86+
)
7987

8088

8189
def enforce_stream(

httpcore/_sync/http_proxy.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ssl
2-
from typing import Dict, List, Tuple, Union
2+
from typing import List, Mapping, Sequence, Tuple, Union
33

44
from .._exceptions import ProxyError
55
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
@@ -11,20 +11,20 @@
1111
from .http11 import HTTP11Connection
1212
from .interfaces import ConnectionInterface
1313

14-
HeadersAsList = List[Tuple[Union[bytes, str], Union[bytes, str]]]
15-
HeadersAsDict = Dict[Union[bytes, str], Union[bytes, str]]
14+
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
15+
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
1616

1717

1818
def merge_headers(
19-
default_headers: List[Tuple[bytes, bytes]] = None,
20-
override_headers: List[Tuple[bytes, bytes]] = None,
19+
default_headers: Sequence[Tuple[bytes, bytes]] = None,
20+
override_headers: Sequence[Tuple[bytes, bytes]] = None,
2121
) -> List[Tuple[bytes, bytes]]:
2222
"""
2323
Append default_headers and override_headers, de-duplicating if a key exists
2424
in both cases.
2525
"""
26-
default_headers = [] if default_headers is None else default_headers
27-
override_headers = [] if override_headers is None else override_headers
26+
default_headers = [] if default_headers is None else list(default_headers)
27+
override_headers = [] if override_headers is None else list(override_headers)
2828
has_override = set([key.lower() for key, value in override_headers])
2929
default_headers = [
3030
(key, value)
@@ -42,7 +42,7 @@ class HTTPProxy(ConnectionPool):
4242
def __init__(
4343
self,
4444
proxy_url: Union[URL, bytes, str],
45-
proxy_headers: Union[HeadersAsDict, HeadersAsList] = None,
45+
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence] = None,
4646
ssl_context: ssl.SSLContext = None,
4747
max_connections: int = 10,
4848
max_keepalive_connections: int = None,
@@ -106,6 +106,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface:
106106
)
107107
return TunnelHTTPConnection(
108108
proxy_origin=self._proxy_url.origin,
109+
proxy_headers=self._proxy_headers,
109110
remote_origin=origin,
110111
ssl_context=self._ssl_context,
111112
keepalive_expiry=self._keepalive_expiry,
@@ -117,7 +118,7 @@ class ForwardHTTPConnection(ConnectionInterface):
117118
def __init__(
118119
self,
119120
proxy_origin: Origin,
120-
proxy_headers: Union[HeadersAsDict, HeadersAsList] = None,
121+
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence] = None,
121122
keepalive_expiry: float = None,
122123
network_backend: NetworkBackend = None,
123124
) -> None:
@@ -177,7 +178,7 @@ def __init__(
177178
proxy_origin: Origin,
178179
remote_origin: Origin,
179180
ssl_context: ssl.SSLContext,
180-
proxy_headers: List[Tuple[bytes, bytes]] = None,
181+
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
181182
keepalive_expiry: float = None,
182183
network_backend: NetworkBackend = None,
183184
) -> None:
@@ -189,7 +190,7 @@ def __init__(
189190
self._proxy_origin = proxy_origin
190191
self._remote_origin = remote_origin
191192
self._ssl_context = ssl_context
192-
self._proxy_headers = [] if proxy_headers is None else proxy_headers
193+
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
193194
self._keepalive_expiry = keepalive_expiry
194195
self._connect_lock = Lock()
195196
self._connected = False
@@ -208,7 +209,9 @@ def handle_request(self, request: Request) -> Response:
208209
port=self._proxy_origin.port,
209210
target=target,
210211
)
211-
connect_headers = [(b"Host", target), (b"Accept", b"*/*")]
212+
connect_headers = merge_headers(
213+
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
214+
)
212215
connect_request = Request(
213216
method=b"CONNECT", url=connect_url, headers=connect_headers
214217
)

tests/test_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def test_request_with_invalid_url():
7373
def test_request_with_invalid_headers():
7474
with pytest.raises(TypeError) as exc_info:
7575
httpcore.Request("GET", "https://www.example.com/", headers=123) # type: ignore
76-
assert str(exc_info.value) == "headers must be a list, but got int."
76+
assert (
77+
str(exc_info.value)
78+
== "headers must be a mapping or sequence of two-tuples, but got int."
79+
)
7780

7881

7982
# Response

0 commit comments

Comments
 (0)