11import ssl
2- from typing import Dict , List , Tuple , Union
2+ from typing import List , Mapping , Sequence , Tuple , Union
33
44from .._exceptions import ProxyError
55from .._models import URL , Origin , Request , Response , enforce_headers , enforce_url
1111from .http11 import AsyncHTTP11Connection
1212from .interfaces import AsyncConnectionInterface
1313
14- HeadersAsList = List [Tuple [Union [bytes , str ], Union [bytes , str ]]]
15- HeadersAsDict = Dict [Union [bytes , str ], Union [bytes , str ]]
14+ HeadersAsSequence = Sequence [Tuple [Union [bytes , str ], Union [bytes , str ]]]
15+ HeadersAsMapping = Mapping [Union [bytes , str ], Union [bytes , str ]]
1616
1717
1818def merge_headers (
19- default_headers : List [Tuple [bytes , bytes ]] = None ,
20- override_headers : List [Tuple [bytes , bytes ]] = None ,
19+ default_headers : Sequence [Tuple [bytes , bytes ]] = None ,
20+ override_headers : Sequence [Tuple [bytes , bytes ]] = None ,
2121) -> List [Tuple [bytes , bytes ]]:
2222 """
2323 Append default_headers and override_headers, de-duplicating if a key exists
2424 in both cases.
2525 """
26- default_headers = [] if default_headers is None else default_headers
27- override_headers = [] if override_headers is None else override_headers
26+ default_headers = [] if default_headers is None else list ( default_headers )
27+ override_headers = [] if override_headers is None else list ( override_headers )
2828 has_override = set ([key .lower () for key , value in override_headers ])
2929 default_headers = [
3030 (key , value )
@@ -42,7 +42,7 @@ class AsyncHTTPProxy(AsyncConnectionPool):
4242 def __init__ (
4343 self ,
4444 proxy_url : Union [URL , bytes , str ],
45- proxy_headers : Union [HeadersAsDict , HeadersAsList ] = None ,
45+ proxy_headers : Union [HeadersAsMapping , HeadersAsSequence ] = None ,
4646 ssl_context : ssl .SSLContext = None ,
4747 max_connections : int = 10 ,
4848 max_keepalive_connections : int = None ,
@@ -106,6 +106,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
106106 )
107107 return AsyncTunnelHTTPConnection (
108108 proxy_origin = self ._proxy_url .origin ,
109+ proxy_headers = self ._proxy_headers ,
109110 remote_origin = origin ,
110111 ssl_context = self ._ssl_context ,
111112 keepalive_expiry = self ._keepalive_expiry ,
@@ -117,7 +118,7 @@ class AsyncForwardHTTPConnection(AsyncConnectionInterface):
117118 def __init__ (
118119 self ,
119120 proxy_origin : Origin ,
120- proxy_headers : Union [HeadersAsDict , HeadersAsList ] = None ,
121+ proxy_headers : Union [HeadersAsMapping , HeadersAsSequence ] = None ,
121122 keepalive_expiry : float = None ,
122123 network_backend : AsyncNetworkBackend = None ,
123124 ) -> None :
@@ -177,7 +178,7 @@ def __init__(
177178 proxy_origin : Origin ,
178179 remote_origin : Origin ,
179180 ssl_context : ssl .SSLContext ,
180- proxy_headers : List [Tuple [bytes , bytes ]] = None ,
181+ proxy_headers : Sequence [Tuple [bytes , bytes ]] = None ,
181182 keepalive_expiry : float = None ,
182183 network_backend : AsyncNetworkBackend = None ,
183184 ) -> None :
@@ -189,7 +190,7 @@ def __init__(
189190 self ._proxy_origin = proxy_origin
190191 self ._remote_origin = remote_origin
191192 self ._ssl_context = ssl_context
192- self ._proxy_headers = [] if proxy_headers is None else proxy_headers
193+ self ._proxy_headers = enforce_headers ( proxy_headers , name = " proxy_headers" )
193194 self ._keepalive_expiry = keepalive_expiry
194195 self ._connect_lock = AsyncLock ()
195196 self ._connected = False
@@ -208,7 +209,9 @@ async def handle_async_request(self, request: Request) -> Response:
208209 port = self ._proxy_origin .port ,
209210 target = target ,
210211 )
211- connect_headers = [(b"Host" , target ), (b"Accept" , b"*/*" )]
212+ connect_headers = merge_headers (
213+ [(b"Host" , target ), (b"Accept" , b"*/*" )], self ._proxy_headers
214+ )
212215 connect_request = Request (
213216 method = b"CONNECT" , url = connect_url , headers = connect_headers
214217 )
0 commit comments