Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions src/tls/openssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <fluent-bit/flb_info.h>
#include <fluent-bit/flb_str.h>
Expand All @@ -40,11 +41,15 @@
#ifdef FLB_SYSTEM_WINDOWS
#define strtok_r(str, delimiter, context) \
strtok_s(str, delimiter, context)
#include <winsock2.h>
#include <ws2tcpip.h>
#include <wincrypt.h>
#ifndef CERT_FIND_SHA256_HASH
/* Older SDKs may not define this */
#define CERT_FIND_SHA256_HASH 0x0001000d
#endif
#else
#include <arpa/inet.h>
#endif

/*
Expand Down Expand Up @@ -78,6 +83,83 @@ struct tls_session {
struct tls_context *parent; /* parent struct tls_context ref */
};

static int host_is_ip_literal(const char *hostname, char *normalized, size_t normalized_size)
{
char buffer[256];
size_t hostname_len;
size_t lookup_len;
const char *lookup;
const char *bracket_end;
const char *zone_id;
struct in_addr addr4;
struct in6_addr addr6;
int ret;

if (hostname == NULL || hostname[0] == '\0') {
return FLB_FALSE;
}

ret = FLB_FALSE;
lookup = hostname;
hostname_len = strlen(hostname);

if (hostname[0] == '[') {
bracket_end = strchr(hostname + 1, ']');
if (bracket_end == NULL) {
return FLB_FALSE;
}

lookup = hostname + 1;
lookup_len = bracket_end - lookup;
}
else {
lookup_len = hostname_len;
}

zone_id = memchr(lookup, '%', lookup_len);
if (zone_id != NULL) {
lookup_len = zone_id - lookup;
}

if (lookup_len == 0 || lookup_len >= sizeof(buffer)) {
return FLB_FALSE;
}

memcpy(buffer, lookup, lookup_len);
buffer[lookup_len] = '\0';

if (inet_pton(AF_INET, buffer, &addr4) == 1) {
ret = FLB_TRUE;
}

if (inet_pton(AF_INET6, buffer, &addr6) == 1) {
ret = FLB_TRUE;
}

if (ret != FLB_TRUE) {
return FLB_FALSE;
}

if (normalized != NULL) {
if (normalized_size <= lookup_len) {
return FLB_FALSE;
}

memcpy(normalized, buffer, lookup_len + 1);
}

return FLB_TRUE;
}

static void setup_sni(struct tls_session *session, const char *hostname)
{
if (host_is_ip_literal(hostname, NULL, 0) == FLB_TRUE) {
return;
}

SSL_set_tlsext_host_name(session->ssl, hostname);
}

static int tls_init(void)
{
/*
Expand Down Expand Up @@ -1521,6 +1603,8 @@ static int tls_net_write(struct flb_tls_session *session,
int setup_hostname_validation(struct tls_session *session, const char *hostname)
{
X509_VERIFY_PARAM *param;
char normalized_ip[256];
int ret;

param = SSL_get0_param(session->ssl);

Expand All @@ -1530,7 +1614,14 @@ int setup_hostname_validation(struct tls_session *session, const char *hostname)
}

X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
if (!X509_VERIFY_PARAM_set1_host(param, hostname, 0)) {
if (host_is_ip_literal(hostname, normalized_ip, sizeof(normalized_ip)) == FLB_TRUE) {
ret = X509_VERIFY_PARAM_set1_ip_asc(param, normalized_ip);
}
else {
ret = X509_VERIFY_PARAM_set1_host(param, hostname, 0);
}

if (!ret) {
flb_error("[tls] error: hostname parameter vailidation is failed : %s",
hostname);
return -1;
Expand Down Expand Up @@ -1581,10 +1672,10 @@ static int tls_net_handshake(struct flb_tls *tls,
}

if (vhost != NULL) {
SSL_set_tlsext_host_name(session->ssl, vhost);
setup_sni(session, vhost);
}
else if (tls->vhost) {
SSL_set_tlsext_host_name(session->ssl, tls->vhost);
setup_sni(session, tls->vhost);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
service:
flush: 1
log_level: info
http_server: on
http_port: ${FLUENT_BIT_HTTP_MONITORING_PORT}

pipeline:
inputs:
- name: dummy
tag: tls_sni
dummy: '{"message":"hello over tls","source":"dummy"}'
samples: 1

outputs:
- name: http
match: tls_sni
host: "::1"
port: ${TEST_SUITE_HTTP_PORT}
uri: /data
format: json
json_date_key: false
retry_limit: 1
tls: on
tls.verify: off
140 changes: 140 additions & 0 deletions tests/integration/scenarios/tls/tests/test_tls_sni_001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import http.server
import os
import socket
import ssl
import threading

import pytest

from utils.test_service import FluentBitTestService


class IPv6ThreadingHTTPServer(http.server.ThreadingHTTPServer):
address_family = socket.AF_INET6
allow_reuse_address = True


class TLSReceiver:
def __init__(self, port, cert_file, key_file):
self.port = port
self.cert_file = cert_file
self.key_file = key_file
self.server = None
self.thread = None
self.requests = []
self.sni_values = []

def start(self):
receiver = self

class Handler(http.server.BaseHTTPRequestHandler):
def do_POST(self):
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length)
receiver.requests.append(
{
"path": self.path,
"headers": dict(self.headers),
"body": body,
}
)
self.send_response(200)
self.end_headers()
self.wfile.write(b"ok")

def log_message(self, fmt, *args):
pass

context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(self.cert_file, self.key_file)

def record_sni(_socket, server_name, _context):
self.sni_values.append(server_name)

context.set_servername_callback(
record_sni
)

self.server = IPv6ThreadingHTTPServer(("::1", self.port), Handler)
self.server.socket = context.wrap_socket(self.server.socket, server_side=True)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()

def stop(self):
if self.server is not None:
self.server.shutdown()
self.server.server_close()
self.server = None

if self.thread is not None:
self.thread.join(timeout=5)
self.thread = None


class Service:
def __init__(self):
config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__),
"../config"))
cert_dir = os.path.abspath(os.path.join(os.path.dirname(__file__),
"../../in_splunk/certificate"))
self.config_file = os.path.join(config_dir, "out_http_tls_ipv6_literal.yaml")
self.tls_crt_file = os.path.join(cert_dir, "certificate.pem")
self.tls_key_file = os.path.join(cert_dir, "private_key.pem")
self.receiver = None
self.service = FluentBitTestService(
self.config_file,
pre_start=self._start_receiver,
post_stop=self._stop_receiver,
)

def _start_receiver(self, service):
self.receiver = TLSReceiver(service.test_suite_http_port,
self.tls_crt_file,
self.tls_key_file)
self.receiver.start()

def _stop_receiver(self, service):
if self.receiver is not None:
self.receiver.stop()

def start(self):
self.service.start()

def stop(self):
self.service.stop()

def wait_for_requests(self, minimum_count, timeout=10):
return self.service.wait_for_condition(
lambda: self.receiver.requests
if len(self.receiver.requests) >= minimum_count
else None,
timeout=timeout,
interval=0.5,
description=f"{minimum_count} outbound HTTPS requests",
)


def ipv6_loopback_available():
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
try:
sock.bind(("::1", 0))
return True
except OSError:
return False
finally:
sock.close()


def test_tls_sni_omits_ipv6_literals():
if not ipv6_loopback_available():
pytest.skip("IPv6 loopback is not available")

service = Service()
try:
service.start()
service.wait_for_requests(1, timeout=30)
finally:
service.stop()

assert service.receiver.sni_values
assert service.receiver.sni_values[0] is None
Loading