Skip to content

Commit a0b1904

Browse files
Trio (#138)
* Use mkdocs2 for documentation builds * Trio support * Update requirements * Clean server stream closes * Update requirements * NetworkClose is handled in the stream * Stream closing for sync client * Tests for sync/async variants * Add __init__.py * Cleaner test runs * Update tests * Test async * Test scripts for both httpx and ahttpx * Update test suite * Update test suite
1 parent a16006a commit a0b1904

36 files changed

+2219
-54
lines changed

.github/workflows/test-suite.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ jobs:
2424
- name: "Install dependencies"
2525
run: "scripts/install"
2626
- name: "Run tests"
27-
run: "scripts/test"
27+
run: "scripts/test"

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
-e .
22

3+
trio==0.33.0
4+
35
# Build...
46
build==1.2.2
57

68
# Test...
79
mypy==1.15.0
810
pytest==8.3.5
911
pytest-cov==6.1.1
12+
pytest-trio==0.8.0
1013

1114
# Sync & Async mirroring...
1215
unasync==0.6.0

scripts/test

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@ if [ -d 'venv' ] ; then
55
export PREFIX="venv/bin/"
66
fi
77

8-
${PREFIX}mypy src/httpx
98
${PREFIX}mypy src/ahttpx
10-
${PREFIX}pytest --cov src/httpx tests
9+
${PREFIX}pytest --cov src/ahttpx tests/test_ahttpx

scripts/unasync

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,51 @@ unasync.unasync_files(
2727
),
2828
]
2929
)
30+
31+
32+
unasync.unasync_files(
33+
fpath_list = [
34+
"tests/test_ahttpx/test_client.py",
35+
"tests/test_ahttpx/test_content.py",
36+
"tests/test_ahttpx/test_headers.py",
37+
"tests/test_ahttpx/test_network.py",
38+
"tests/test_ahttpx/test_parsers.py",
39+
"tests/test_ahttpx/test_pool.py",
40+
"tests/test_ahttpx/test_quickstart.py",
41+
"tests/test_ahttpx/test_request.py",
42+
"tests/test_ahttpx/test_response.py",
43+
"tests/test_ahttpx/test_streams.py",
44+
"tests/test_ahttpx/test_urlencode.py",
45+
"tests/test_ahttpx/test_urls.py",
46+
],
47+
rules = [
48+
unasync.Rule(
49+
"tests/test_ahttpx/",
50+
"tests/test_httpx/",
51+
additional_replacements={"ahttpx": "httpx"}
52+
),
53+
]
54+
)
55+
56+
57+
for path in [
58+
"tests/test_httpx/test_client.py",
59+
"tests/test_httpx/test_content.py",
60+
"tests/test_httpx/test_headers.py",
61+
"tests/test_httpx/test_network.py",
62+
"tests/test_httpx/test_parsers.py",
63+
"tests/test_httpx/test_pool.py",
64+
"tests/test_httpx/test_quickstart.py",
65+
"tests/test_httpx/test_request.py",
66+
"tests/test_httpx/test_response.py",
67+
"tests/test_httpx/test_streams.py",
68+
"tests/test_httpx/test_urlencode.py",
69+
"tests/test_httpx/test_urls.py",
70+
]:
71+
with open(path, "r") as fin:
72+
lines = fin.readlines()
73+
74+
lines = [line for line in lines if line != "@pytest.mark.trio\n"]
75+
76+
with open(path, "w") as fout:
77+
fout.writelines(lines)

src/ahttpx/_network.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import asyncio
21
import ssl
32
import types
43
import typing
54

5+
import trio
66
import certifi
77

88
from ._streams import Stream
@@ -13,39 +13,37 @@
1313

1414
class NetworkStream(Stream):
1515
def __init__(
16-
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = ''
16+
self, trio_stream: trio.abc.Stream, address: str = ''
1717
) -> None:
18-
self._reader = reader
19-
self._writer = writer
18+
self._trio_stream = trio_stream
2019
self._address = address
21-
self._tls = False
2220
self._closed = False
2321

2422
async def read(self, size: int = -1) -> bytes:
2523
if size < 0:
2624
size = 64 * 1024
27-
return await self._reader.read(size)
25+
return await self._trio_stream.receive_some(size)
2826

2927
async def write(self, buffer: bytes) -> None:
30-
self._writer.write(buffer)
31-
await self._writer.drain()
28+
await self._trio_stream.send_all(buffer)
3229

3330
async def close(self) -> None:
34-
if not self._closed:
35-
self._writer.close()
36-
await self._writer.wait_closed()
31+
# Close the NetworkStream.
32+
# If the stream is already closed this is a checkpointed no-op.
33+
try:
34+
await self._trio_stream.aclose()
35+
finally:
3736
self._closed = True
3837

3938
def __repr__(self):
4039
description = ""
41-
description += " TLS" if self._tls else ""
4240
description += " CLOSED" if self._closed else ""
43-
return f"<NetworkStream [{self._address!r}{description}]>"
41+
return f"<NetworkStream [{self._address}{description}]>"
4442

4543
def __del__(self):
4644
if not self._closed:
4745
import warnings
48-
warnings.warn("NetworkStream was garbage collected without being closed.")
46+
warnings.warn(f"{self!r} was garbage collected without being closed.")
4947

5048
# Context managed usage...
5149
async def __aenter__(self) -> "NetworkStream":
@@ -61,13 +59,17 @@ async def __aexit__(
6159

6260

6361
class NetworkServer:
64-
def __init__(self, host: str, port: int, server: asyncio.Server):
62+
def __init__(self, host: str, port: int, handler, listeners: list[trio.SocketListener]):
6563
self.host = host
6664
self.port = port
67-
self._server = server
65+
self._handler = handler
66+
self._listeners = listeners
6867

6968
# Context managed usage...
7069
async def __aenter__(self) -> "NetworkServer":
70+
self._nursery_manager = trio.open_nursery()
71+
self._nursery = await self._nursery_manager.__aenter__()
72+
self._nursery.start_soon(trio.serve_listeners, self._handler, self._listeners)
7173
return self
7274

7375
async def __aexit__(
@@ -76,8 +78,8 @@ async def __aexit__(
7678
exc_value: BaseException | None = None,
7779
traceback: types.TracebackType | None = None,
7880
):
79-
self._server.close()
80-
await self._server.wait_closed()
81+
self._nursery.cancel_scope.cancel()
82+
await self._nursery_manager.__aexit__(exc_type, exc_value, traceback)
8183

8284

8385
class NetworkBackend:
@@ -92,29 +94,42 @@ async def connect(self, host: str, port: int) -> NetworkStream:
9294
"""
9395
Connect to the given address, returning a Stream instance.
9496
"""
97+
# Create the TCP stream
9598
address = f"{host}:{port}"
96-
reader, writer = await asyncio.open_connection(host, port)
97-
return NetworkStream(reader, writer, address=address)
99+
trio_stream = await trio.open_tcp_stream(host, port)
100+
return NetworkStream(trio_stream, address=address)
98101

99102
async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
100103
"""
101104
Connect to the given address, returning a Stream instance.
102105
"""
106+
# Create the TCP stream
103107
address = f"{host}:{port}"
104-
reader, writer = await asyncio.open_connection(host, port)
105-
await writer.start_tls(self._ssl_ctx, server_hostname=hostname)
106-
return NetworkStream(reader, writer, address=address)
108+
trio_stream = await trio.open_tcp_stream(host, port)
109+
110+
# Establish SSL over TCP
111+
hostname = hostname or host
112+
ssl_stream = trio.SSLStream(trio_stream, ssl_context=self._ssl_ctx, server_hostname=hostname)
113+
await ssl_stream.do_handshake()
114+
115+
return NetworkStream(ssl_stream, address=address)
107116

108117
async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
109-
async def callback(reader, writer):
110-
stream = NetworkStream(reader, writer)
111-
await handler(stream)
118+
async def callback(trio_stream):
119+
stream = NetworkStream(trio_stream, address=f"{host}:{port}")
120+
try:
121+
await handler(stream)
122+
finally:
123+
await stream.close()
112124

113-
server = await asyncio.start_server(callback, host, port)
114-
return NetworkServer(host, port, server)
125+
listeners = await trio.open_tcp_listeners(port=port, host=host)
126+
return NetworkServer(host, port, callback, listeners)
127+
128+
def __repr__(self):
129+
return f"<NetworkBackend [trio]>"
115130

116131

117-
Semaphore = asyncio.Semaphore
118-
Lock = asyncio.Lock
119-
timeout = asyncio.timeout
120-
sleep = asyncio.sleep
132+
Semaphore = trio.Semaphore
133+
Lock = trio.Lock
134+
timeout = trio.move_on_after
135+
sleep = trio.sleep

src/ahttpx/_parsers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@ async def send_body(self, body: bytes) -> None:
224224
# Handle body close
225225
self.send_state = State.DONE
226226

227+
async def recv_close(self) -> bool:
228+
# ...
229+
if self.is_closed():
230+
return True
231+
232+
if await self.parser.read_eof():
233+
await self.close()
234+
return True
235+
return False
236+
227237
async def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
228238
"""
229239
Receive the initial request method line:
@@ -463,6 +473,18 @@ async def read(self, size: int) -> bytes:
463473
self._push_back(bytes(push_back))
464474
return bytes(buffer)
465475

476+
async def read_eof(self) -> bool:
477+
"""
478+
Attempt to read the closing EOF.
479+
Return True if the stream is EOF, or False otherwise.
480+
"""
481+
if not self._buffer:
482+
chunk = await self._read_some()
483+
if not chunk:
484+
return True
485+
self._push_back(chunk)
486+
return False
487+
466488
async def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
467489
"""
468490
Read and return bytes from the stream, delimited by marker.

src/ahttpx/_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, stream, endpoint):
3131
# API entry points...
3232
async def handle_requests(self):
3333
try:
34-
while not self._parser.is_closed():
34+
while not await self._parser.recv_close():
3535
method, url, headers = await self._recv_head()
3636
stream = HTTPStream(self._recv_body, self._complete)
3737
# TODO: Handle endpoint exceptions
@@ -43,13 +43,13 @@ async def handle_requests(self):
4343
except Exception:
4444
logger.error("Internal Server Error", exc_info=True)
4545
content = Text("Internal Server Error")
46-
err = Response(code=500, content=content)
46+
err = Response(500, content=content)
4747
await self._send_head(err)
4848
await self._send_body(err)
4949
else:
5050
await self._send_head(response)
5151
await self._send_body(response)
52-
except Exception:
52+
except BaseException:
5353
logger.error("Internal Server Error", exc_info=True)
5454

5555
async def close(self):
@@ -89,7 +89,7 @@ async def _send_body(self, response: Response):
8989

9090
# Start it all over again...
9191
async def _complete(self):
92-
await self._parser.complete
92+
await self._parser.complete()
9393
self._idle_expiry = time.monotonic() + self._keepalive_duration
9494

9595

src/httpx/_network.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkS
160160
self._max_workers = 5
161161
self._executor = None
162162
self._thread = None
163-
self._streams = list[NetworkStream]
163+
self._streams: list[NetworkStream] = []
164164

165165
@property
166166
def host(self):
@@ -176,6 +176,8 @@ def __enter__(self):
176176
return self
177177

178178
def __exit__(self, exc_type, exc_val, exc_tb):
179+
for stream in self._streams:
180+
stream.close()
179181
self.listener.close()
180182
self._executor.shutdown(wait=True)
181183

@@ -185,9 +187,11 @@ def _serve(self):
185187

186188
def _handler(self, stream):
187189
try:
190+
self._streams.append(stream)
188191
self.handler(stream)
189192
finally:
190193
stream.close()
194+
self._streams.remove(stream)
191195

192196

193197
class NetworkBackend:

src/httpx/_parsers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@ def send_body(self, body: bytes) -> None:
224224
# Handle body close
225225
self.send_state = State.DONE
226226

227+
def recv_close(self) -> bool:
228+
# ...
229+
if self.is_closed():
230+
return True
231+
232+
if self.parser.read_eof():
233+
self.close()
234+
return True
235+
return False
236+
227237
def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
228238
"""
229239
Receive the initial request method line:
@@ -463,6 +473,18 @@ def read(self, size: int) -> bytes:
463473
self._push_back(bytes(push_back))
464474
return bytes(buffer)
465475

476+
def read_eof(self) -> bool:
477+
"""
478+
Attempt to read the closing EOF.
479+
Return True if the stream is EOF, or False otherwise.
480+
"""
481+
if not self._buffer:
482+
chunk = self._read_some()
483+
if not chunk:
484+
return True
485+
self._push_back(chunk)
486+
return False
487+
466488
def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
467489
"""
468490
Read and return bytes from the stream, delimited by marker.

0 commit comments

Comments
 (0)