|
| 1 | +import http.server |
| 2 | +import os |
| 3 | +import socket |
| 4 | +import ssl |
| 5 | +import threading |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +from utils.test_service import FluentBitTestService |
| 10 | + |
| 11 | + |
| 12 | +class IPv6ThreadingHTTPServer(http.server.ThreadingHTTPServer): |
| 13 | + address_family = socket.AF_INET6 |
| 14 | + allow_reuse_address = True |
| 15 | + |
| 16 | + |
| 17 | +class TLSReceiver: |
| 18 | + def __init__(self, port, cert_file, key_file): |
| 19 | + self.port = port |
| 20 | + self.cert_file = cert_file |
| 21 | + self.key_file = key_file |
| 22 | + self.server = None |
| 23 | + self.thread = None |
| 24 | + self.requests = [] |
| 25 | + self.sni_values = [] |
| 26 | + |
| 27 | + def start(self): |
| 28 | + receiver = self |
| 29 | + |
| 30 | + class Handler(http.server.BaseHTTPRequestHandler): |
| 31 | + def do_POST(self): |
| 32 | + content_length = int(self.headers.get("Content-Length", 0)) |
| 33 | + body = self.rfile.read(content_length) |
| 34 | + receiver.requests.append( |
| 35 | + { |
| 36 | + "path": self.path, |
| 37 | + "headers": dict(self.headers), |
| 38 | + "body": body, |
| 39 | + } |
| 40 | + ) |
| 41 | + self.send_response(200) |
| 42 | + self.end_headers() |
| 43 | + self.wfile.write(b"ok") |
| 44 | + |
| 45 | + def log_message(self, fmt, *args): |
| 46 | + pass |
| 47 | + |
| 48 | + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
| 49 | + context.load_cert_chain(self.cert_file, self.key_file) |
| 50 | + |
| 51 | + def record_sni(_socket, server_name, _context): |
| 52 | + self.sni_values.append(server_name) |
| 53 | + |
| 54 | + context.set_servername_callback( |
| 55 | + record_sni |
| 56 | + ) |
| 57 | + |
| 58 | + self.server = IPv6ThreadingHTTPServer(("::1", self.port), Handler) |
| 59 | + self.server.socket = context.wrap_socket(self.server.socket, server_side=True) |
| 60 | + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) |
| 61 | + self.thread.start() |
| 62 | + |
| 63 | + def stop(self): |
| 64 | + if self.server is not None: |
| 65 | + self.server.shutdown() |
| 66 | + self.server.server_close() |
| 67 | + self.server = None |
| 68 | + |
| 69 | + if self.thread is not None: |
| 70 | + self.thread.join(timeout=5) |
| 71 | + self.thread = None |
| 72 | + |
| 73 | + |
| 74 | +class Service: |
| 75 | + def __init__(self): |
| 76 | + config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), |
| 77 | + "../config")) |
| 78 | + cert_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), |
| 79 | + "../../in_splunk/certificate")) |
| 80 | + self.config_file = os.path.join(config_dir, "out_http_tls_ipv6_literal.yaml") |
| 81 | + self.tls_crt_file = os.path.join(cert_dir, "certificate.pem") |
| 82 | + self.tls_key_file = os.path.join(cert_dir, "private_key.pem") |
| 83 | + self.receiver = None |
| 84 | + self.service = FluentBitTestService( |
| 85 | + self.config_file, |
| 86 | + pre_start=self._start_receiver, |
| 87 | + post_stop=self._stop_receiver, |
| 88 | + ) |
| 89 | + |
| 90 | + def _start_receiver(self, service): |
| 91 | + self.receiver = TLSReceiver(service.test_suite_http_port, |
| 92 | + self.tls_crt_file, |
| 93 | + self.tls_key_file) |
| 94 | + self.receiver.start() |
| 95 | + |
| 96 | + def _stop_receiver(self, service): |
| 97 | + if self.receiver is not None: |
| 98 | + self.receiver.stop() |
| 99 | + |
| 100 | + def start(self): |
| 101 | + self.service.start() |
| 102 | + |
| 103 | + def stop(self): |
| 104 | + self.service.stop() |
| 105 | + |
| 106 | + def wait_for_requests(self, minimum_count, timeout=10): |
| 107 | + return self.service.wait_for_condition( |
| 108 | + lambda: self.receiver.requests |
| 109 | + if len(self.receiver.requests) >= minimum_count |
| 110 | + else None, |
| 111 | + timeout=timeout, |
| 112 | + interval=0.5, |
| 113 | + description=f"{minimum_count} outbound HTTPS requests", |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def ipv6_loopback_available(): |
| 118 | + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) |
| 119 | + try: |
| 120 | + sock.bind(("::1", 0)) |
| 121 | + return True |
| 122 | + except OSError: |
| 123 | + return False |
| 124 | + finally: |
| 125 | + sock.close() |
| 126 | + |
| 127 | + |
| 128 | +def test_tls_sni_omits_ipv6_literals(): |
| 129 | + if not ipv6_loopback_available(): |
| 130 | + pytest.skip("IPv6 loopback is not available") |
| 131 | + |
| 132 | + service = Service() |
| 133 | + try: |
| 134 | + service.start() |
| 135 | + service.wait_for_requests(1, timeout=30) |
| 136 | + finally: |
| 137 | + service.stop() |
| 138 | + |
| 139 | + assert service.receiver.sni_values |
| 140 | + assert service.receiver.sni_values[0] is None |
0 commit comments