Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions authentik/sources/ldap/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,19 @@ def server(self, **kwargs) -> ServerPool:
tls_kwargs["local_certificate_file"] = certificate_file
if ciphers := CONFIG.get("ldap.tls.ciphers", None):
tls_kwargs["ciphers"] = ciphers.strip()
if self.sni:
tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip()
server_kwargs = {
"get_info": ALL,
"connect_timeout": LDAP_TIMEOUT,
"tls": Tls(**tls_kwargs),
}
server_kwargs.update(kwargs)
if "," in self.server_uri:
for server in self.server_uri.split(","):
servers.append(Server(server, **server_kwargs))
else:
servers = [Server(self.server_uri, **server_kwargs)]
for server_uri in self.server_uri.split(","):
server = Server(server_uri.strip(), **server_kwargs)
# The TLS SNI server name must be a bare hostname. ldap3 has already
# parsed the scheme and port out of the URI into `server.host`;
# passing the raw URI (e.g. `ldaps://host`) as the SNI name breaks
# the handshake against SNI-strict servers. See #7756.
server.tls = Tls(**tls_kwargs, sni=server.host if self.sni else None)
servers.append(server)
return ServerPool(servers, RANDOM, active=5, exhaust=True)

def connection(
Expand Down
50 changes: 50 additions & 0 deletions authentik/sources/ldap/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""LDAP Source model tests"""

from django.test import TestCase

from authentik.lib.generators import generate_id
from authentik.sources.ldap.models import LDAPSource


class LDAPModelTests(TestCase):
"""LDAP Source model tests"""

def test_server_sni(self):
"""Test that the SNI name is the bare hostname, not the full server URI"""
source = LDAPSource.objects.create(
name=generate_id(),
slug=generate_id(),
server_uri="ldaps://ldap.example.com:636",
base_dn="dc=example,dc=com",
sni=True,
)
pool = source.server()
self.assertEqual([server.tls.sni for server in pool.servers], ["ldap.example.com"])

def test_server_sni_multiple(self):
"""Test that each server in a pool gets its own hostname as the SNI name"""
source = LDAPSource.objects.create(
name=generate_id(),
slug=generate_id(),
server_uri="ldaps://ldap1.example.com,ldaps://ldap2.example.com:636",
base_dn="dc=example,dc=com",
sni=True,
)
pool = source.server()
self.assertEqual(
[server.tls.sni for server in pool.servers],
["ldap1.example.com", "ldap2.example.com"],
)

def test_server_sni_disabled(self):
"""Test that no SNI name is set when the SNI option is disabled"""
source = LDAPSource.objects.create(
name=generate_id(),
slug=generate_id(),
server_uri="ldaps://ldap.example.com",
base_dn="dc=example,dc=com",
sni=False,
)
pool = source.server()
for server in pool.servers:
self.assertIsNone(server.tls.sni)