Skip to content

Commit 38328b3

Browse files
committed
Add tests for 100-continue connection pool corruption scenario
1 parent bb4b904 commit 38328b3

3 files changed

Lines changed: 322 additions & 0 deletions

File tree

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python3
2+
"""Client that sends two requests on one TCP connection to reproduce
3+
100-continue connection pool corruption."""
4+
5+
# Licensed to the Apache Software Foundation (ASF) under one
6+
# or more contributor license agreements. See the NOTICE file
7+
# distributed with this work for additional information
8+
# regarding copyright ownership. The ASF licenses this file
9+
# to you under the Apache License, Version 2.0 (the
10+
# "License"); you may not use this file except in compliance
11+
# with the License. You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from http_utils import wait_for_headers_complete, determine_outstanding_bytes_to_read, drain_socket
22+
23+
import argparse
24+
import socket
25+
import sys
26+
import time
27+
28+
29+
def main() -> int:
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument('proxy_address')
32+
parser.add_argument('proxy_port', type=int)
33+
parser.add_argument('-s', '--server-hostname', dest='server_hostname', default='example.com')
34+
args = parser.parse_args()
35+
36+
host = args.server_hostname
37+
body_size = 103
38+
body_data = b'X' * body_size
39+
40+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
41+
sock.connect((args.proxy_address, args.proxy_port))
42+
43+
with sock:
44+
# Request 1: POST with Expect: 100-continue and a body.
45+
request1 = (
46+
f'POST /expect-100-corrupted HTTP/1.1\r\n'
47+
f'Host: {host}\r\n'
48+
f'Connection: keep-alive\r\n'
49+
f'Content-Length: {body_size}\r\n'
50+
f'Expect: 100-continue\r\n'
51+
f'\r\n').encode()
52+
sock.sendall(request1)
53+
54+
# Send the body after a short delay without waiting for 100-continue.
55+
time.sleep(0.5)
56+
sock.sendall(body_data)
57+
58+
# Drain the response (might be 100 + 301, or just 301).
59+
resp1_data = wait_for_headers_complete(sock)
60+
61+
# If we got a 100 Continue, read past it to the real response.
62+
if b'100' in resp1_data.split(b'\r\n')[0]:
63+
after_100 = resp1_data.split(b'\r\n\r\n', 1)[1] if b'\r\n\r\n' in resp1_data else b''
64+
if b'\r\n\r\n' not in after_100:
65+
after_100 += wait_for_headers_complete(sock)
66+
resp1_data = after_100
67+
68+
# Drain the response body.
69+
try:
70+
outstanding = determine_outstanding_bytes_to_read(resp1_data)
71+
if outstanding > 0:
72+
drain_socket(sock, resp1_data, outstanding)
73+
except ValueError:
74+
pass
75+
76+
# Let ATS pool the origin connection.
77+
time.sleep(0.5)
78+
79+
# Request 2: plain GET on the same client connection.
80+
request2 = (f'GET /second-request HTTP/1.1\r\n'
81+
f'Host: {host}\r\n'
82+
f'Connection: close\r\n'
83+
f'\r\n').encode()
84+
sock.sendall(request2)
85+
86+
resp2_data = wait_for_headers_complete(sock)
87+
status_line = resp2_data.split(b'\r\n')[0]
88+
89+
if b'400' in status_line or b'corrupted' in resp2_data.lower():
90+
print('Corruption detected: second request saw corrupted data', flush=True)
91+
elif b'502' in status_line:
92+
print('Corruption detected: ATS returned 502 (origin parse error)', flush=True)
93+
else:
94+
print('No corruption: second request completed normally', flush=True)
95+
96+
return 0
97+
98+
99+
if __name__ == '__main__':
100+
sys.exit(main())
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python3
2+
"""Origin that sends a 301 without consuming the request body, then checks
3+
whether a reused connection carries leftover (corrupted) data. Handles
4+
multiple connections so that a fixed ATS can open a fresh one for the
5+
second request."""
6+
7+
# Licensed to the Apache Software Foundation (ASF) under one
8+
# or more contributor license agreements. See the NOTICE file
9+
# distributed with this work for additional information
10+
# regarding copyright ownership. The ASF licenses this file
11+
# to you under the Apache License, Version 2.0 (the
12+
# "License"); you may not use this file except in compliance
13+
# with the License. You may obtain a copy of the License at
14+
#
15+
# http://www.apache.org/licenses/LICENSE-2.0
16+
#
17+
# Unless required by applicable law or agreed to in writing, software
18+
# distributed under the License is distributed on an "AS IS" BASIS,
19+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
# See the License for the specific language governing permissions and
21+
# limitations under the License.
22+
23+
import argparse
24+
import socket
25+
import sys
26+
import threading
27+
import time
28+
29+
VALID_METHODS = {'GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH'}
30+
31+
32+
def read_until_headers_complete(conn: socket.socket) -> bytes:
33+
data = b''
34+
while b'\r\n\r\n' not in data:
35+
chunk = conn.recv(4096)
36+
if not chunk:
37+
return data
38+
data += chunk
39+
return data
40+
41+
42+
def is_valid_http_request_line(line: str) -> bool:
43+
parts = line.strip().split(' ')
44+
if len(parts) < 3:
45+
return False
46+
return parts[0] in VALID_METHODS and parts[-1].startswith('HTTP/')
47+
48+
49+
def send_200(conn: socket.socket) -> None:
50+
ok_body = b'OK'
51+
conn.sendall(b'HTTP/1.1 200 OK\r\n'
52+
b'Content-Length: ' + str(len(ok_body)).encode() + b'\r\n'
53+
b'\r\n' + ok_body)
54+
55+
56+
def handle_connection(conn: socket.socket, args: argparse.Namespace, result: dict) -> None:
57+
try:
58+
data = read_until_headers_complete(conn)
59+
if not data:
60+
# Readiness probe.
61+
conn.close()
62+
return
63+
64+
first_line = data.split(b'\r\n')[0].decode('utf-8', errors='replace')
65+
66+
if first_line.startswith('POST'):
67+
# First request: send 301 without consuming the body.
68+
time.sleep(args.delay)
69+
70+
body = b'Redirecting'
71+
response = (
72+
b'HTTP/1.1 301 Moved Permanently\r\n'
73+
b'Location: http://example.com/\r\n'
74+
b'Connection: keep-alive\r\n'
75+
b'Content-Length: ' + str(len(body)).encode() + b'\r\n'
76+
b'\r\n' + body)
77+
conn.sendall(response)
78+
79+
# Wait for potential reuse on this connection.
80+
conn.settimeout(args.timeout)
81+
try:
82+
second_data = b''
83+
while b'\r\n' not in second_data:
84+
chunk = conn.recv(4096)
85+
if not chunk:
86+
break
87+
second_data += chunk
88+
89+
if second_data:
90+
second_line = second_data.split(b'\r\n')[0].decode('utf-8', errors='replace')
91+
if is_valid_http_request_line(second_line):
92+
send_200(conn)
93+
else:
94+
result['corrupted'] = True
95+
err_body = b'corrupted'
96+
conn.sendall(
97+
b'HTTP/1.1 400 Bad Request\r\n'
98+
b'Content-Length: ' + str(len(err_body)).encode() + b'\r\n'
99+
b'\r\n' + err_body)
100+
except socket.timeout:
101+
pass
102+
103+
elif first_line.startswith('GET'):
104+
# Second request on a new connection (fix is working).
105+
result['new_connection'] = True
106+
send_200(conn)
107+
108+
conn.close()
109+
except Exception:
110+
try:
111+
conn.close()
112+
except Exception:
113+
pass
114+
115+
116+
def main() -> int:
117+
parser = argparse.ArgumentParser()
118+
parser.add_argument('port', type=int)
119+
parser.add_argument('--delay', type=float, default=1.0)
120+
parser.add_argument('--timeout', type=float, default=5.0)
121+
args = parser.parse_args()
122+
123+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
125+
sock.bind(('', args.port))
126+
sock.listen(5)
127+
sock.settimeout(args.timeout + 5)
128+
129+
result = {'corrupted': False, 'new_connection': False}
130+
threads = []
131+
connections_handled = 0
132+
133+
try:
134+
while connections_handled < 10:
135+
try:
136+
conn, _ = sock.accept()
137+
t = threading.Thread(target=handle_connection, args=(conn, args, result))
138+
t.daemon = True
139+
t.start()
140+
threads.append(t)
141+
connections_handled += 1
142+
except socket.timeout:
143+
break
144+
except Exception:
145+
pass
146+
147+
for t in threads:
148+
t.join(timeout=args.timeout + 2)
149+
150+
sock.close()
151+
return 0
152+
153+
154+
if __name__ == '__main__':
155+
sys.exit(main())
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import sys
18+
19+
Test.Summary = '''
20+
Verify that when an origin responds before consuming the request body on a
21+
connection with Expect: 100-continue, ATS does not return the origin connection
22+
to the pool with unconsumed data.
23+
'''
24+
25+
tr = Test.AddTestRun('Verify 100-continue with early origin response does not corrupt pooled connections.')
26+
27+
# DNS.
28+
dns = tr.MakeDNServer('dns', default='127.0.0.1')
29+
30+
# Origin.
31+
Test.GetTcpPort('origin_port')
32+
tr.Setup.CopyAs('corruption_origin.py')
33+
origin = tr.Processes.Process(
34+
'origin', f'{sys.executable} corruption_origin.py '
35+
f'{Test.Variables.origin_port} --delay 1.0 --timeout 5.0')
36+
origin.Ready = When.PortOpen(Test.Variables.origin_port)
37+
38+
# ATS.
39+
ts = tr.MakeATSProcess('ts', enable_cache=False)
40+
ts.Disk.remap_config.AddLine(f'map / http://backend.example.com:{Test.Variables.origin_port}')
41+
ts.Disk.records_config.update(
42+
{
43+
'proxy.config.diags.debug.enabled': 1,
44+
'proxy.config.diags.debug.tags': 'http',
45+
'proxy.config.dns.nameservers': f'127.0.0.1:{dns.Variables.Port}',
46+
'proxy.config.dns.resolv_conf': 'NULL',
47+
'proxy.config.http.send_100_continue_response': 1,
48+
})
49+
50+
# Client.
51+
tr.Setup.CopyAs('corruption_client.py')
52+
tr.Setup.CopyAs('http_utils.py')
53+
tr.Processes.Default.Command = (
54+
f'{sys.executable} corruption_client.py '
55+
f'127.0.0.1 {ts.Variables.port} '
56+
f'-s backend.example.com')
57+
tr.Processes.Default.ReturnCode = 0
58+
tr.Processes.Default.StartBefore(dns)
59+
tr.Processes.Default.StartBefore(origin)
60+
tr.Processes.Default.StartBefore(ts)
61+
62+
# With the fix, ATS should not pool the origin connection when the
63+
# request body was not fully consumed, preventing corruption.
64+
tr.Processes.Default.Streams.stdout += Testers.ContainsExpression(
65+
'No corruption', 'The second request should complete normally because ATS '
66+
'does not pool origin connections with unconsumed body data.')
67+
tr.Processes.Default.Streams.stdout += Testers.ExcludesExpression('Corruption detected', 'No corruption should be detected.')

0 commit comments

Comments
 (0)