|
13 | 13 | from queue import Empty, Queue |
14 | 14 | import secrets |
15 | 15 | import socket |
| 16 | +import subprocess |
| 17 | +import sys |
16 | 18 | from test import NativeResourceTest |
17 | 19 | import threading |
18 | 20 | from time import sleep, time |
@@ -182,6 +184,54 @@ def send_async(self, msg): |
182 | 184 | asyncio.run_coroutine_threadsafe(self._current_connection.send(msg), self._server_loop) |
183 | 185 |
|
184 | 186 |
|
| 187 | +class MockHandshakeServer: |
| 188 | + # A raw-socket server that accepts one connection, drains the client's |
| 189 | + # HTTP handshake request, and sends back a caller-provided response. |
| 190 | + # Use this when tests need to send byte sequences that the 3rdparty |
| 191 | + # `websockets` library can't produce (e.g. malformed headers). |
| 192 | + # |
| 193 | + # Usage: |
| 194 | + # with MockHandshakeServer(host, response=b"HTTP/1.1 ...") as server: |
| 195 | + # # spawn a client that connects to (host, server.port) |
| 196 | + # ... |
| 197 | + |
| 198 | + def __init__(self, host, response): |
| 199 | + self._host = host |
| 200 | + self._response = response |
| 201 | + self._listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 202 | + self._listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 203 | + self._listener.bind((host, 0)) |
| 204 | + self._listener.listen(1) |
| 205 | + self.port = self._listener.getsockname()[1] |
| 206 | + self._thread = threading.Thread(target=self._serve, daemon=True) |
| 207 | + |
| 208 | + def __enter__(self): |
| 209 | + self._thread.start() |
| 210 | + return self |
| 211 | + |
| 212 | + def __exit__(self, exc_type, exc_value, exc_tb): |
| 213 | + self._listener.close() |
| 214 | + self._thread.join(TIMEOUT) |
| 215 | + |
| 216 | + def _serve(self): |
| 217 | + try: |
| 218 | + conn, _ = self._listener.accept() |
| 219 | + except OSError: |
| 220 | + return |
| 221 | + with closing(conn): |
| 222 | + conn.settimeout(TIMEOUT) |
| 223 | + try: |
| 224 | + buf = b"" |
| 225 | + while b"\r\n\r\n" not in buf: |
| 226 | + chunk = conn.recv(4096) |
| 227 | + if not chunk: |
| 228 | + return |
| 229 | + buf += chunk |
| 230 | + conn.sendall(self._response) |
| 231 | + except OSError: |
| 232 | + pass |
| 233 | + |
| 234 | + |
185 | 235 | class TestClient(NativeResourceTest): |
186 | 236 | def setUp(self): |
187 | 237 | super().setUp() |
@@ -324,6 +374,60 @@ def test_connect_failure_with_response(self): |
324 | 374 | # check that body is a valid string |
325 | 375 | self.assertGreater(len(setup_data.handshake_response_body.decode()), 0) |
326 | 376 |
|
| 377 | + def test_connect_response_header_with_invalid_name_is_protocol_error(self): |
| 378 | + # A response header whose name contains a non-tchar byte (e.g. 0xE9) is |
| 379 | + # rejected by aws-c-http's HTTP/1.1 decoder before reaching the binding. |
| 380 | + # The connection should fail with AWS_ERROR_HTTP_PROTOCOL_ERROR. |
| 381 | + response = ( |
| 382 | + b"HTTP/1.1 403 Forbidden\r\n" |
| 383 | + b"Content-Length: 0\r\n" |
| 384 | + b"X-Bad\xe9Name: whatever\r\n" |
| 385 | + b"\r\n" |
| 386 | + ) |
| 387 | + with MockHandshakeServer(self.host, response=response) as server: |
| 388 | + setup_future = Future() |
| 389 | + connect( |
| 390 | + host=self.host, |
| 391 | + port=server.port, |
| 392 | + handshake_request=create_handshake_request(host=self.host), |
| 393 | + on_connection_setup=lambda x: setup_future.set_result(x)) |
| 394 | + |
| 395 | + setup_data: OnConnectionSetupData = setup_future.result(TIMEOUT) |
| 396 | + |
| 397 | + self.assertIsNone(setup_data.websocket) |
| 398 | + self.assertIsNotNone(setup_data.exception) |
| 399 | + self.assertEqual("AWS_ERROR_HTTP_PROTOCOL_ERROR", setup_data.exception.name) |
| 400 | + # bad-name response is rejected at the parser, so no headers reach Python |
| 401 | + self.assertIsNone(setup_data.handshake_response_headers) |
| 402 | + |
| 403 | + def test_connect_response_header_with_obs_text_does_not_abort(self): |
| 404 | + # A response header value containing a non-UTF-8 obs-text byte (e.g. lone 0xE9) |
| 405 | + # must not crash the process. Run the client in a subprocess so that an abort, |
| 406 | + # if it happens, is observable as a non-zero exit code. |
| 407 | + response = ( |
| 408 | + b"HTTP/1.1 403 Forbidden\r\n" |
| 409 | + b"Content-Length: 0\r\n" |
| 410 | + b"X-Reason: caf\xe9\r\n" |
| 411 | + b"\r\n" |
| 412 | + ) |
| 413 | + with MockHandshakeServer(self.host, response=response) as server: |
| 414 | + proc = subprocess.Popen( |
| 415 | + [sys.executable, '-m', 'test.ws_connect_helper', self.host, str(server.port)], |
| 416 | + stdout=subprocess.PIPE, |
| 417 | + stderr=subprocess.PIPE) |
| 418 | + |
| 419 | + try: |
| 420 | + stdout, stderr = proc.communicate(timeout=TIMEOUT) |
| 421 | + except subprocess.TimeoutExpired: |
| 422 | + proc.kill() |
| 423 | + stdout, stderr = proc.communicate() |
| 424 | + self.fail("client subprocess hung") |
| 425 | + |
| 426 | + self.assertEqual( |
| 427 | + 0, proc.returncode, |
| 428 | + f"client subprocess crashed (returncode={proc.returncode}). " |
| 429 | + f"stdout={stdout!r} stderr={stderr!r}") |
| 430 | + |
327 | 431 | def test_exception_in_setup_callback_closes_websocket(self): |
328 | 432 | with WebSocketServer(self.host, self.port) as server: |
329 | 433 | setup_future = Future() |
|
0 commit comments