Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/pytest-suite.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: pytest-suite

on:
push:
paths:
- 'src/**'
- 'tests/**'
- 'pyproject.toml'
- '.github/workflows/pytest-suite.yml'
pull_request:
paths:
- 'src/**'
- 'tests/**'
- 'pyproject.toml'
- '.github/workflows/pytest-suite.yml'
workflow_dispatch:

jobs:
pytest:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.11', '3.12']

steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
enable-cache: true

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install dependencies
run: uv sync --extra dev

- name: Run pytest suite
run: uv run python -m pytest tests/ -v --tb=short --junitxml=pytest-results-${{ matrix.python-version }}.xml

- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: pytest-results-${{ matrix.python-version }}.xml
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
.PHONY: test certification-env-freeze phase9-release-workflow
.PHONY: test test-pytest certification-env-freeze phase9-release-workflow

test:
PYTHONPATH=src python -m unittest discover -s tests -p 'test_*.py' -v

test-pytest:
uv run python -m pytest tests/ -v


certification-env-freeze:
PYTHONPATH=src python tools/freeze_certification_environment.py
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ full-featured = [
]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"aioquic>=1.3.0",
"h2>=4.1.0",
"websockets>=12.0",
Expand All @@ -75,3 +76,8 @@ tigrcorn = ["py.typed"]

[tool.setuptools.packages.find]
where = ["src"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
python_files = ["test_*_pytest.py"]
testpaths = ["tests"]
21 changes: 13 additions & 8 deletions tests/test_additional_remaining_work.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
from contextlib import suppress
import os
import socket
import unittest
Expand All @@ -17,8 +18,10 @@
from tigrcorn.protocols.websocket.frames import encode_frame, read_frame


async def _start_http_server(app):
async def _start_http_server(app, websocket_compression: str | None = None):
config = build_config(host='127.0.0.1', port=0, lifespan='off', http_versions=['1.1'])
if websocket_compression is not None:
config.websocket.compression = websocket_compression
server = TigrCornServer(app, config)
await server.start()
port = server._listeners[0].server.sockets[0].getsockname()[1]
Expand Down Expand Up @@ -114,9 +117,9 @@ async def app(scope, receive, send):
await send({'type': 'websocket.send', 'text': event['text']})
await send({'type': 'websocket.close', 'code': 1000})

server, port = await _start_http_server(app)
server, port = await _start_http_server(app, websocket_compression='permessage-deflate')
try:
reader, writer = await asyncio.open_connection('127.0.0.1', port)
reader, writer = await asyncio.wait_for(asyncio.open_connection('127.0.0.1', port), 1.0)
key = base64.b64encode(os.urandom(16))
writer.write(
b'GET /ws HTTP/1.1\r\n'
Expand All @@ -127,22 +130,24 @@ async def app(scope, receive, send):
b'Sec-WebSocket-Key: ' + key + b'\r\n'
b'Sec-WebSocket-Extensions: permessage-deflate\r\n\r\n'
)
await writer.drain()
response = await reader.readuntil(b'\r\n\r\n')
await asyncio.wait_for(writer.drain(), 1.0)
response = await asyncio.wait_for(reader.readuntil(b'\r\n\r\n'), 1.0)
self.assertIn(b'sec-websocket-extensions: permessage-deflate', response.lower())
compressed = _compress_ws_message(b'hello compressed')
writer.write(encode_frame(0x1, compressed, fin=True, masked=True, rsv1=True))
await writer.drain()
await asyncio.wait_for(writer.drain(), 1.0)
frame = await asyncio.wait_for(read_frame(reader, max_payload_size=4096, expect_masked=False, allow_rsv1=True), 1.0)
self.assertTrue(frame.rsv1)
decompressor = zlib.decompressobj(wbits=-15)
echoed = decompressor.decompress(frame.payload + b'\x00\x00\xff\xff')
self.assertEqual(echoed, b'hello compressed')
self.assertEqual(seen['event']['text'], 'hello compressed')
writer.close()
await writer.wait_closed()
await asyncio.wait_for(writer.wait_closed(), 1.0)
finally:
await server.close()
server.request_shutdown()
with suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(server.close(), 2.0)


class RemainingWorkQuicRoutingTests(unittest.IsolatedAsyncioTestCase):
Expand Down
242 changes: 242 additions & 0 deletions tests/test_additional_remaining_work_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import asyncio
import base64
from contextlib import suppress
import os
import zlib

import pytest

from tigrcorn.config.defaults import default_config
from tigrcorn.config.load import build_config
from tigrcorn.config.model import ListenerConfig
from tigrcorn.observability.logging import AccessLogger, configure_logging
from tigrcorn.protocols.http3.handler import HTTP3DatagramHandler
from tigrcorn.protocols.http3.streams import HTTP3ConnectionCore
from tigrcorn.protocols.websocket.frames import encode_frame, read_frame
from tigrcorn.server.runner import TigrCornServer
from tigrcorn.transports.quic import QuicConnection
from tigrcorn.transports.udp.packet import UDPPacket


async def _start_http_server(app, websocket_compression: str | None = None):
config = build_config(
host="127.0.0.1", port=0, lifespan="off", http_versions=["1.1"]
)
if websocket_compression is not None:
config.websocket.compression = websocket_compression
server = TigrCornServer(app, config)
await server.start()
port = server._listeners[0].server.sockets[0].getsockname()[1]
return server, port


def _compress_ws_message(payload: bytes) -> bytes:
compressor = zlib.compressobj(wbits=-15)
compressed = compressor.compress(payload) + compressor.flush(zlib.Z_SYNC_FLUSH)
return compressed[:-4]


@pytest.mark.asyncio
async def test_chunked_request_trailers_are_exposed() -> None:
seen = {}

async def app(scope, receive, send):
seen["extensions"] = scope["extensions"]
seen["events"] = [await receive(), await receive(), await receive()]
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
}
)
await send({"type": "http.response.body", "body": b"ok", "more_body": False})

server, port = await _start_http_server(app)
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection("127.0.0.1", port), 1.0
)
writer.write(
b"POST /trailers HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n\r\n"
b"5\r\nhello\r\n"
b"0\r\nX-Trailer-One: yes\r\nX-Trailer-Two: done\r\n\r\n"
)
await asyncio.wait_for(writer.drain(), 1.0)
await reader.readuntil(b"\r\n\r\n")
writer.close()
await writer.wait_closed()
finally:
await server.close()
assert "tigrcorn.http.request_trailers" in seen["extensions"]
assert seen["events"][0]["type"] == "http.request"
assert seen["events"][1]["type"] == "http.request"
assert not seen["events"][1]["more_body"]
assert seen["events"][2]["type"] == "http.request.trailers"
assert seen["events"][2]["trailers"] == [
(b"x-trailer-one", b"yes"),
(b"x-trailer-two", b"done"),
]


@pytest.mark.asyncio
async def test_connect_tunnel_relays_bytes() -> None:
received = bytearray()

async def upstream_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
data = await reader.read(1024)
received.extend(data)
writer.write(data[::-1])
await asyncio.wait_for(writer.drain(), 1.0)
writer.close()
await writer.wait_closed()

upstream = await asyncio.start_server(upstream_handler, "127.0.0.1", 0)
upstream_port = upstream.sockets[0].getsockname()[1]

async def app(scope, receive, send):
raise AssertionError(
"CONNECT tunnel should be handled before ASGI app dispatch"
)

server, port = await _start_http_server(app)
try:
reader, writer = await asyncio.open_connection("127.0.0.1", port)
writer.write(
f"CONNECT 127.0.0.1:{upstream_port} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode(
"ascii"
)
)
await writer.drain()
head = await reader.readuntil(b"\r\n\r\n")
assert b"200 Connection Established" in head
writer.write(b"abcdef")
await writer.drain()
echoed = await asyncio.wait_for(reader.readexactly(6), 1.0)
assert echoed == b"fedcba"
assert bytes(received) == b"abcdef"
writer.close()
await writer.wait_closed()
finally:
server.request_shutdown()
await server.close()
upstream.close()
await upstream.wait_closed()


@pytest.mark.asyncio
async def test_permessage_deflate_negotiates_and_roundtrips() -> None:
seen = {}

async def app(scope, receive, send):
await receive()
await send(
{
"type": "websocket.accept",
"headers": [(b"sec-websocket-extensions", b"permessage-deflate")],
}
)
event = await receive()
seen["event"] = event
await send({"type": "websocket.send", "text": event["text"]})
await send({"type": "websocket.close", "code": 1000})

server, port = await _start_http_server(
app, websocket_compression="permessage-deflate"
)
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection("127.0.0.1", port), 1.0
)
key = base64.b64encode(os.urandom(16))
writer.write(
b"GET /ws HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"Sec-WebSocket-Key: " + key + b"\r\n"
b"Sec-WebSocket-Extensions: permessage-deflate\r\n\r\n"
)
await asyncio.wait_for(writer.drain(), 1.0)
response = await asyncio.wait_for(reader.readuntil(b"\r\n\r\n"), 1.0)
assert b"sec-websocket-extensions: permessage-deflate" in response.lower()
compressed = _compress_ws_message(b"hello compressed")
writer.write(encode_frame(0x1, compressed, fin=True, masked=True, rsv1=True))
await asyncio.wait_for(writer.drain(), 1.0)
frame = await asyncio.wait_for(
read_frame(
reader, max_payload_size=4096, expect_masked=False, allow_rsv1=True
),
1.0,
)
assert frame.rsv1
decompressor = zlib.decompressobj(wbits=-15)
echoed = decompressor.decompress(frame.payload + b"\x00\x00\xff\xff")
assert echoed == b"hello compressed"
assert seen["event"]["text"] == "hello compressed"
writer.close()
await asyncio.wait_for(writer.wait_closed(), 1.0)
finally:
server.request_shutdown()
with suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(server.close(), 2.0)


@pytest.mark.asyncio
async def test_http3_session_survives_address_rebinding_via_connection_id() -> None:
async def app(scope, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

handler = HTTP3DatagramHandler(
app=app,
config=default_config(),
listener=ListenerConfig(
kind="udp",
host="127.0.0.1",
port=1,
protocols=["http3"],
quic_secret=b"shared",
),
access_logger=AccessLogger(configure_logging("warning"), enabled=False),
)

class Endpoint:
def __init__(self):
self.sent = []
self.local_addr = ("127.0.0.1", 4433)

def send(self, data, addr):
self.sent.append((data, addr))

endpoint = Endpoint()
client = QuicConnection(is_client=True, secret=b"shared", local_cid=b"cli1")
await handler.handle_packet(
UDPPacket(data=client.build_initial(), addr=("127.0.0.1", 50000)), endpoint
)
assert len(handler.sessions_by_local_cid) == 1
core = HTTP3ConnectionCore()
for raw, _addr in endpoint.sent:
for event in client.receive_datagram(raw):
if event.kind == "stream":
core.receive_stream_data(event.stream_id, event.data, fin=event.fin)
endpoint.sent.clear()
request_payload = core.get_request(0).encode_request(
[(b":method", b"POST"), (b":path", b"/rebind"), (b":scheme", b"https")], b"hi"
)
await handler.handle_packet(
UDPPacket(
data=client.send_stream_data(0, request_payload, fin=True),
addr=("127.0.0.1", 50001),
),
endpoint,
)
assert len(handler.sessions_by_local_cid) == 1
assert len(handler.sessions) == 1
session = next(iter(handler.sessions.values()))
assert session.addr == ("127.0.0.1", 50001)
Loading