Skip to content

Commit eb9b818

Browse files
committed
Fix: AsyncTunnelHTTPConnection leaks connection on handshake failure
1 parent 794bc96 commit eb9b818

2 files changed

Lines changed: 116 additions & 100 deletions

File tree

httpcore/_async/http_proxy.py

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -285,61 +285,69 @@ async def handle_async_request(self, request: Request) -> Response:
285285
headers=connect_headers,
286286
extensions=request.extensions,
287287
)
288-
connect_response = await self._connection.handle_async_request(
289-
connect_request
290-
)
291-
292-
if connect_response.status < 200 or connect_response.status > 299:
293-
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
294-
reason_str = reason_bytes.decode("ascii", errors="ignore")
295-
msg = "%d %s" % (connect_response.status, reason_str)
296-
await self._connection.aclose()
297-
raise ProxyError(msg)
298-
299-
stream = connect_response.extensions["network_stream"]
300-
301-
# Upgrade the stream to SSL
302-
ssl_context = (
303-
default_ssl_context()
304-
if self._ssl_context is None
305-
else self._ssl_context
306-
)
307-
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
308-
ssl_context.set_alpn_protocols(alpn_protocols)
309-
310-
kwargs = {
311-
"ssl_context": ssl_context,
312-
"server_hostname": self._remote_origin.host.decode("ascii"),
313-
"timeout": timeout,
314-
}
315-
async with Trace("start_tls", logger, request, kwargs) as trace:
316-
stream = await stream.start_tls(**kwargs)
317-
trace.return_value = stream
318-
319-
# Determine if we should be using HTTP/1.1 or HTTP/2
320-
ssl_object = stream.get_extra_info("ssl_object")
321-
http2_negotiated = (
322-
ssl_object is not None
323-
and ssl_object.selected_alpn_protocol() == "h2"
324-
)
325288

326-
# Create the HTTP/1.1 or HTTP/2 connection
327-
if http2_negotiated or (self._http2 and not self._http1):
328-
from .http2 import AsyncHTTP2Connection
289+
try:
290+
connect_response = await self._connection.handle_async_request(
291+
connect_request
292+
)
329293

330-
self._connection = AsyncHTTP2Connection(
331-
origin=self._remote_origin,
332-
stream=stream,
333-
keepalive_expiry=self._keepalive_expiry,
294+
if connect_response.status < 200 or connect_response.status > 299:
295+
reason_bytes = connect_response.extensions.get(
296+
"reason_phrase", b""
297+
)
298+
reason_str = reason_bytes.decode("ascii", errors="ignore")
299+
msg = "%d %s" % (connect_response.status, reason_str)
300+
await self._connection.aclose()
301+
raise ProxyError(msg)
302+
303+
stream = connect_response.extensions["network_stream"]
304+
305+
# Upgrade the stream to SSL
306+
ssl_context = (
307+
default_ssl_context()
308+
if self._ssl_context is None
309+
else self._ssl_context
334310
)
335-
else:
336-
self._connection = AsyncHTTP11Connection(
337-
origin=self._remote_origin,
338-
stream=stream,
339-
keepalive_expiry=self._keepalive_expiry,
311+
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
312+
ssl_context.set_alpn_protocols(alpn_protocols)
313+
314+
kwargs = {
315+
"ssl_context": ssl_context,
316+
"server_hostname": self._remote_origin.host.decode("ascii"),
317+
"timeout": timeout,
318+
}
319+
async with Trace("start_tls", logger, request, kwargs) as trace:
320+
stream = await stream.start_tls(**kwargs)
321+
trace.return_value = stream
322+
323+
# Determine if we should be using HTTP/1.1 or HTTP/2
324+
ssl_object = stream.get_extra_info("ssl_object")
325+
http2_negotiated = (
326+
ssl_object is not None
327+
and ssl_object.selected_alpn_protocol() == "h2"
340328
)
341329

342-
self._connected = True
330+
# Create the HTTP/1.1 or HTTP/2 connection
331+
if http2_negotiated or (self._http2 and not self._http1):
332+
from .http2 import AsyncHTTP2Connection
333+
334+
self._connection = AsyncHTTP2Connection(
335+
origin=self._remote_origin,
336+
stream=stream,
337+
keepalive_expiry=self._keepalive_expiry,
338+
)
339+
else:
340+
self._connection = AsyncHTTP11Connection(
341+
origin=self._remote_origin,
342+
stream=stream,
343+
keepalive_expiry=self._keepalive_expiry,
344+
)
345+
346+
self._connected = True
347+
except Exception:
348+
await self._connection.aclose()
349+
raise
350+
343351
return await self._connection.handle_async_request(request)
344352

345353
def can_handle_request(self, origin: Origin) -> bool:

httpcore/_sync/http_proxy.py

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -285,61 +285,69 @@ def handle_request(self, request: Request) -> Response:
285285
headers=connect_headers,
286286
extensions=request.extensions,
287287
)
288-
connect_response = self._connection.handle_request(
289-
connect_request
290-
)
291-
292-
if connect_response.status < 200 or connect_response.status > 299:
293-
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
294-
reason_str = reason_bytes.decode("ascii", errors="ignore")
295-
msg = "%d %s" % (connect_response.status, reason_str)
296-
self._connection.close()
297-
raise ProxyError(msg)
298-
299-
stream = connect_response.extensions["network_stream"]
300-
301-
# Upgrade the stream to SSL
302-
ssl_context = (
303-
default_ssl_context()
304-
if self._ssl_context is None
305-
else self._ssl_context
306-
)
307-
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
308-
ssl_context.set_alpn_protocols(alpn_protocols)
309-
310-
kwargs = {
311-
"ssl_context": ssl_context,
312-
"server_hostname": self._remote_origin.host.decode("ascii"),
313-
"timeout": timeout,
314-
}
315-
with Trace("start_tls", logger, request, kwargs) as trace:
316-
stream = stream.start_tls(**kwargs)
317-
trace.return_value = stream
318-
319-
# Determine if we should be using HTTP/1.1 or HTTP/2
320-
ssl_object = stream.get_extra_info("ssl_object")
321-
http2_negotiated = (
322-
ssl_object is not None
323-
and ssl_object.selected_alpn_protocol() == "h2"
324-
)
325288

326-
# Create the HTTP/1.1 or HTTP/2 connection
327-
if http2_negotiated or (self._http2 and not self._http1):
328-
from .http2 import HTTP2Connection
289+
try:
290+
connect_response = self._connection.handle_request(
291+
connect_request
292+
)
329293

330-
self._connection = HTTP2Connection(
331-
origin=self._remote_origin,
332-
stream=stream,
333-
keepalive_expiry=self._keepalive_expiry,
294+
if connect_response.status < 200 or connect_response.status > 299:
295+
reason_bytes = connect_response.extensions.get(
296+
"reason_phrase", b""
297+
)
298+
reason_str = reason_bytes.decode("ascii", errors="ignore")
299+
msg = "%d %s" % (connect_response.status, reason_str)
300+
self._connection.close()
301+
raise ProxyError(msg)
302+
303+
stream = connect_response.extensions["network_stream"]
304+
305+
# Upgrade the stream to SSL
306+
ssl_context = (
307+
default_ssl_context()
308+
if self._ssl_context is None
309+
else self._ssl_context
334310
)
335-
else:
336-
self._connection = HTTP11Connection(
337-
origin=self._remote_origin,
338-
stream=stream,
339-
keepalive_expiry=self._keepalive_expiry,
311+
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
312+
ssl_context.set_alpn_protocols(alpn_protocols)
313+
314+
kwargs = {
315+
"ssl_context": ssl_context,
316+
"server_hostname": self._remote_origin.host.decode("ascii"),
317+
"timeout": timeout,
318+
}
319+
with Trace("start_tls", logger, request, kwargs) as trace:
320+
stream = stream.start_tls(**kwargs)
321+
trace.return_value = stream
322+
323+
# Determine if we should be using HTTP/1.1 or HTTP/2
324+
ssl_object = stream.get_extra_info("ssl_object")
325+
http2_negotiated = (
326+
ssl_object is not None
327+
and ssl_object.selected_alpn_protocol() == "h2"
340328
)
341329

342-
self._connected = True
330+
# Create the HTTP/1.1 or HTTP/2 connection
331+
if http2_negotiated or (self._http2 and not self._http1):
332+
from .http2 import HTTP2Connection
333+
334+
self._connection = HTTP2Connection(
335+
origin=self._remote_origin,
336+
stream=stream,
337+
keepalive_expiry=self._keepalive_expiry,
338+
)
339+
else:
340+
self._connection = HTTP11Connection(
341+
origin=self._remote_origin,
342+
stream=stream,
343+
keepalive_expiry=self._keepalive_expiry,
344+
)
345+
346+
self._connected = True
347+
except Exception:
348+
self._connection.close()
349+
raise
350+
343351
return self._connection.handle_request(request)
344352

345353
def can_handle_request(self, origin: Origin) -> bool:

0 commit comments

Comments
 (0)