Skip to content

Commit b2b01dd

Browse files
authored
Error safely if header values are invalid (#739)
1 parent 52a9b89 commit b2b01dd

3 files changed

Lines changed: 145 additions & 1 deletion

File tree

source/websocket.c

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,24 @@ static void s_websocket_on_connection_setup(
209209
PyObject *tuple_py = PyTuple_New(2);
210210
AWS_FATAL_ASSERT(tuple_py && "header tuple allocation failed");
211211

212+
/* Header names are tokens as per RFC 7230 Section 3.2 (strict ASCII),
213+
* which means aws-c-http rejects on the wire if they contain non-ASCII bytes.
214+
* So errors related to http header decoding will be caught at the protocol level.
215+
* We should never fail wrangling the header name. */
212216
PyObject *name_py = PyUnicode_FromAwsByteCursor(&header_i->name);
213217
AWS_FATAL_ASSERT(name_py && "header name wrangling failed");
214218
PyTuple_SetItem(tuple_py, 0, name_py); /* Steals a reference */
215219

220+
/* Header value can contain RFC 7230 obs-text (0x80-0xFF), which is
221+
* not guaranteed valid UTF-8. On decode failure, log it and drop
222+
* the whole header list rather than aborting the process. */
216223
PyObject *value_py = PyUnicode_FromAwsByteCursor(&header_i->value);
217-
AWS_FATAL_ASSERT(value_py && "header value wrangling failed");
224+
if (!value_py) {
225+
PyErr_WriteUnraisable(websocket_core_py);
226+
Py_DECREF(tuple_py);
227+
Py_CLEAR(headers_py);
228+
break;
229+
}
218230
PyTuple_SetItem(tuple_py, 1, value_py); /* Steals a reference */
219231

220232
PyList_SetItem(headers_py, i, tuple_py); /* Steals a reference */

test/test_websocket.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from queue import Empty, Queue
1414
import secrets
1515
import socket
16+
import subprocess
17+
import sys
1618
from test import NativeResourceTest
1719
import threading
1820
from time import sleep, time
@@ -182,6 +184,54 @@ def send_async(self, msg):
182184
asyncio.run_coroutine_threadsafe(self._current_connection.send(msg), self._server_loop)
183185

184186

187+
class MockHandshakeServer:
188+
# A raw-socket server that accepts one connection, drains the client's
189+
# HTTP handshake request, and sends back a caller-provided response.
190+
# Use this when tests need to send byte sequences that the 3rdparty
191+
# `websockets` library can't produce (e.g. malformed headers).
192+
#
193+
# Usage:
194+
# with MockHandshakeServer(host, response=b"HTTP/1.1 ...") as server:
195+
# # spawn a client that connects to (host, server.port)
196+
# ...
197+
198+
def __init__(self, host, response):
199+
self._host = host
200+
self._response = response
201+
self._listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
202+
self._listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
203+
self._listener.bind((host, 0))
204+
self._listener.listen(1)
205+
self.port = self._listener.getsockname()[1]
206+
self._thread = threading.Thread(target=self._serve, daemon=True)
207+
208+
def __enter__(self):
209+
self._thread.start()
210+
return self
211+
212+
def __exit__(self, exc_type, exc_value, exc_tb):
213+
self._listener.close()
214+
self._thread.join(TIMEOUT)
215+
216+
def _serve(self):
217+
try:
218+
conn, _ = self._listener.accept()
219+
except OSError:
220+
return
221+
with closing(conn):
222+
conn.settimeout(TIMEOUT)
223+
try:
224+
buf = b""
225+
while b"\r\n\r\n" not in buf:
226+
chunk = conn.recv(4096)
227+
if not chunk:
228+
return
229+
buf += chunk
230+
conn.sendall(self._response)
231+
except OSError:
232+
pass
233+
234+
185235
class TestClient(NativeResourceTest):
186236
def setUp(self):
187237
super().setUp()
@@ -324,6 +374,60 @@ def test_connect_failure_with_response(self):
324374
# check that body is a valid string
325375
self.assertGreater(len(setup_data.handshake_response_body.decode()), 0)
326376

377+
def test_connect_response_header_with_invalid_name_is_protocol_error(self):
378+
# A response header whose name contains a non-tchar byte (e.g. 0xE9) is
379+
# rejected by aws-c-http's HTTP/1.1 decoder before reaching the binding.
380+
# The connection should fail with AWS_ERROR_HTTP_PROTOCOL_ERROR.
381+
response = (
382+
b"HTTP/1.1 403 Forbidden\r\n"
383+
b"Content-Length: 0\r\n"
384+
b"X-Bad\xe9Name: whatever\r\n"
385+
b"\r\n"
386+
)
387+
with MockHandshakeServer(self.host, response=response) as server:
388+
setup_future = Future()
389+
connect(
390+
host=self.host,
391+
port=server.port,
392+
handshake_request=create_handshake_request(host=self.host),
393+
on_connection_setup=lambda x: setup_future.set_result(x))
394+
395+
setup_data: OnConnectionSetupData = setup_future.result(TIMEOUT)
396+
397+
self.assertIsNone(setup_data.websocket)
398+
self.assertIsNotNone(setup_data.exception)
399+
self.assertEqual("AWS_ERROR_HTTP_PROTOCOL_ERROR", setup_data.exception.name)
400+
# bad-name response is rejected at the parser, so no headers reach Python
401+
self.assertIsNone(setup_data.handshake_response_headers)
402+
403+
def test_connect_response_header_with_obs_text_does_not_abort(self):
404+
# A response header value containing a non-UTF-8 obs-text byte (e.g. lone 0xE9)
405+
# must not crash the process. Run the client in a subprocess so that an abort,
406+
# if it happens, is observable as a non-zero exit code.
407+
response = (
408+
b"HTTP/1.1 403 Forbidden\r\n"
409+
b"Content-Length: 0\r\n"
410+
b"X-Reason: caf\xe9\r\n"
411+
b"\r\n"
412+
)
413+
with MockHandshakeServer(self.host, response=response) as server:
414+
proc = subprocess.Popen(
415+
[sys.executable, '-m', 'test.ws_connect_helper', self.host, str(server.port)],
416+
stdout=subprocess.PIPE,
417+
stderr=subprocess.PIPE)
418+
419+
try:
420+
stdout, stderr = proc.communicate(timeout=TIMEOUT)
421+
except subprocess.TimeoutExpired:
422+
proc.kill()
423+
stdout, stderr = proc.communicate()
424+
self.fail("client subprocess hung")
425+
426+
self.assertEqual(
427+
0, proc.returncode,
428+
f"client subprocess crashed (returncode={proc.returncode}). "
429+
f"stdout={stdout!r} stderr={stderr!r}")
430+
327431
def test_exception_in_setup_callback_closes_websocket(self):
328432
with WebSocketServer(self.host, self.port) as server:
329433
setup_future = Future()

test/ws_connect_helper.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0.
3+
4+
# Helper for test_websocket subprocess scenarios.
5+
# Runs awscrt.websocket.connect() against a host:port given on the command
6+
# line and waits for on_connection_setup to fire. Used by tests that need
7+
# to observe whether a malformed server response crashes the client process.
8+
9+
import sys
10+
from concurrent.futures import Future
11+
12+
from awscrt.websocket import connect, create_handshake_request
13+
14+
TIMEOUT = 10.0
15+
16+
17+
def main(host, port):
18+
setup_future = Future()
19+
connect(
20+
host=host,
21+
port=port,
22+
handshake_request=create_handshake_request(host=host),
23+
on_connection_setup=lambda x: setup_future.set_result(x))
24+
setup_future.result(TIMEOUT)
25+
26+
27+
if __name__ == '__main__':
28+
main(sys.argv[1], int(sys.argv[2]))

0 commit comments

Comments
 (0)