@@ -285,61 +285,69 @@ async def handle_async_request(self, request: Request) -> Response:
285285 headers = connect_headers ,
286286 extensions = request .extensions ,
287287 )
288- connect_response = await self ._connection .handle_async_request (
289- connect_request
290- )
291-
292- if connect_response .status < 200 or connect_response .status > 299 :
293- reason_bytes = connect_response .extensions .get ("reason_phrase" , b"" )
294- reason_str = reason_bytes .decode ("ascii" , errors = "ignore" )
295- msg = "%d %s" % (connect_response .status , reason_str )
296- await self ._connection .aclose ()
297- raise ProxyError (msg )
298-
299- stream = connect_response .extensions ["network_stream" ]
300-
301- # Upgrade the stream to SSL
302- ssl_context = (
303- default_ssl_context ()
304- if self ._ssl_context is None
305- else self ._ssl_context
306- )
307- alpn_protocols = ["http/1.1" , "h2" ] if self ._http2 else ["http/1.1" ]
308- ssl_context .set_alpn_protocols (alpn_protocols )
309-
310- kwargs = {
311- "ssl_context" : ssl_context ,
312- "server_hostname" : self ._remote_origin .host .decode ("ascii" ),
313- "timeout" : timeout ,
314- }
315- async with Trace ("start_tls" , logger , request , kwargs ) as trace :
316- stream = await stream .start_tls (** kwargs )
317- trace .return_value = stream
318-
319- # Determine if we should be using HTTP/1.1 or HTTP/2
320- ssl_object = stream .get_extra_info ("ssl_object" )
321- http2_negotiated = (
322- ssl_object is not None
323- and ssl_object .selected_alpn_protocol () == "h2"
324- )
325288
326- # Create the HTTP/1.1 or HTTP/2 connection
327- if http2_negotiated or (self ._http2 and not self ._http1 ):
328- from .http2 import AsyncHTTP2Connection
289+ try :
290+ connect_response = await self ._connection .handle_async_request (
291+ connect_request
292+ )
329293
330- self ._connection = AsyncHTTP2Connection (
331- origin = self ._remote_origin ,
332- stream = stream ,
333- keepalive_expiry = self ._keepalive_expiry ,
294+ if connect_response .status < 200 or connect_response .status > 299 :
295+ reason_bytes = connect_response .extensions .get (
296+ "reason_phrase" , b""
297+ )
298+ reason_str = reason_bytes .decode ("ascii" , errors = "ignore" )
299+ msg = "%d %s" % (connect_response .status , reason_str )
300+ await self ._connection .aclose ()
301+ raise ProxyError (msg )
302+
303+ stream = connect_response .extensions ["network_stream" ]
304+
305+ # Upgrade the stream to SSL
306+ ssl_context = (
307+ default_ssl_context ()
308+ if self ._ssl_context is None
309+ else self ._ssl_context
334310 )
335- else :
336- self ._connection = AsyncHTTP11Connection (
337- origin = self ._remote_origin ,
338- stream = stream ,
339- keepalive_expiry = self ._keepalive_expiry ,
311+ alpn_protocols = ["http/1.1" , "h2" ] if self ._http2 else ["http/1.1" ]
312+ ssl_context .set_alpn_protocols (alpn_protocols )
313+
314+ kwargs = {
315+ "ssl_context" : ssl_context ,
316+ "server_hostname" : self ._remote_origin .host .decode ("ascii" ),
317+ "timeout" : timeout ,
318+ }
319+ async with Trace ("start_tls" , logger , request , kwargs ) as trace :
320+ stream = await stream .start_tls (** kwargs )
321+ trace .return_value = stream
322+
323+ # Determine if we should be using HTTP/1.1 or HTTP/2
324+ ssl_object = stream .get_extra_info ("ssl_object" )
325+ http2_negotiated = (
326+ ssl_object is not None
327+ and ssl_object .selected_alpn_protocol () == "h2"
340328 )
341329
342- self ._connected = True
330+ # Create the HTTP/1.1 or HTTP/2 connection
331+ if http2_negotiated or (self ._http2 and not self ._http1 ):
332+ from .http2 import AsyncHTTP2Connection
333+
334+ self ._connection = AsyncHTTP2Connection (
335+ origin = self ._remote_origin ,
336+ stream = stream ,
337+ keepalive_expiry = self ._keepalive_expiry ,
338+ )
339+ else :
340+ self ._connection = AsyncHTTP11Connection (
341+ origin = self ._remote_origin ,
342+ stream = stream ,
343+ keepalive_expiry = self ._keepalive_expiry ,
344+ )
345+
346+ self ._connected = True
347+ except Exception :
348+ await self ._connection .aclose ()
349+ raise
350+
343351 return await self ._connection .handle_async_request (request )
344352
345353 def can_handle_request (self , origin : Origin ) -> bool :
0 commit comments