Skip to content

Commit 052d903

Browse files
committed
Expose set_groups functions
This allows pyOpenSSL to restrict the groups allowed to be used. E.g. for restricting the groups to post-quantum hybrid groups (e.g. X25519MLKEM768) to always ensure that post-quantum cryptography is used. This commit uses the set_groups as public API since that is the preferred naming. Internally for the API we use the set_curves API since this name is available on all OpenSSL implementations and OpenSSL forks. Signed-off-by: Arne Schwabe <arne@rfc2549.org>
1 parent 7b29beb commit 052d903

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Deprecations:
1818

1919
Changes:
2020
^^^^^^^^
21+
- Added ``OpenSSL.SSL.Context.set_groups`` and ``OpenSSL.SSL.Connection.set_groups`` to set allowed groups/curves.
2122

2223
- Added ``OpenSSL.SSL.Connection.get_group_name`` to determine which group name was negotiated.
2324

src/OpenSSL/SSL.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,19 @@ def set_tls13_ciphersuites(self, ciphersuites: bytes) -> None:
15151515
_lib.SSL_CTX_set_ciphersuites(self._context, ciphersuites) == 1
15161516
)
15171517

1518+
@_require_not_used
1519+
def set_groups(self, groups: bytes):
1520+
"""
1521+
Set the supported groups/curves in this SSL Session.
1522+
"""
1523+
if not isinstance(groups, bytes):
1524+
raise TypeError("groups must be a byte string.")
1525+
1526+
# We use the newer name (groups) in our public API and
1527+
# use the legacy/more compatible name in the internal API
1528+
rc = _lib.SSL_CTX_set1_curves_list(self._context, groups)
1529+
_openssl_assert(rc == 1)
1530+
15181531
@_require_not_used
15191532
def set_client_ca_list(
15201533
self, certificate_authorities: Sequence[X509Name]
@@ -3227,6 +3240,18 @@ def get_group_name(self) -> str | None:
32273240

32283241
return _ffi.string(group_name).decode("utf-8")
32293242

3243+
def set_groups(self, groups: bytes):
3244+
"""
3245+
Set the supported groups/curves in this SSL Session.
3246+
"""
3247+
if not isinstance(groups, bytes):
3248+
raise TypeError("groups must be a byte string.")
3249+
3250+
# We use the newer name (groups) in our public API and
3251+
# use the legacy/more compatible name in the internal API
3252+
rc = _lib.SSL_set1_curves_list(self._ssl, groups)
3253+
_openssl_assert(rc == 1)
3254+
32303255
def request_ocsp(self) -> None:
32313256
"""
32323257
Called to request that the server sends stapled OCSP data, if

tests/test_ssl.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,6 +3495,100 @@ def test_get_group_name(self) -> None:
34953495

34963496
assert server_group_name == client_group_name
34973497

3498+
@pytest.mark.skipif(
3499+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3500+
reason="SSL_get0_group_name unavailable",
3501+
)
3502+
def test_set_groups_context(self) -> None:
3503+
"""
3504+
`Context.set_groups` forces the use of a specific curve/groups list.
3505+
"""
3506+
3507+
def loopback_x448_client_factory(
3508+
socket: socket, version: int = SSLv23_METHOD
3509+
) -> Connection:
3510+
context = Context(version)
3511+
context.set_groups(b"X448")
3512+
client = Connection(context, socket)
3513+
client.set_connect_state()
3514+
return client
3515+
3516+
server, client = loopback(client_factory=loopback_x448_client_factory)
3517+
server_group_name = server.get_group_name()
3518+
client_group_name = client.get_group_name()
3519+
3520+
assert isinstance(server_group_name, str)
3521+
assert isinstance(client_group_name, str)
3522+
3523+
assert server_group_name.lower() == "x448"
3524+
assert client_group_name.lower() == "x448"
3525+
3526+
def loopback_x448_server_factory(
3527+
socket: socket, version: int = SSLv23_METHOD
3528+
) -> Connection:
3529+
connection = loopback_server_factory(socket, version)
3530+
connection.set_groups(b"X448")
3531+
return connection
3532+
3533+
@pytest.mark.skipif(
3534+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3535+
reason="SSL_get0_group_name unavailable",
3536+
)
3537+
def test_set_groups_session(self) -> None:
3538+
"""
3539+
`Connection.set_groups` forces the use of a specific curve/groups list.
3540+
"""
3541+
3542+
def loopback_x448_server_factory(
3543+
socket: socket, version: int = SSLv23_METHOD
3544+
) -> Connection:
3545+
connection = loopback_server_factory(socket, version)
3546+
connection.set_groups(b"X448")
3547+
return connection
3548+
3549+
server, client = loopback(server_factory=loopback_x448_server_factory)
3550+
server_group_name = server.get_group_name()
3551+
client_group_name = client.get_group_name()
3552+
3553+
assert isinstance(server_group_name, str)
3554+
assert isinstance(client_group_name, str)
3555+
3556+
assert server_group_name.lower() == "x448"
3557+
assert client_group_name.lower() == "x448"
3558+
3559+
def test_set_groups_mismatch(self):
3560+
"""
3561+
Forces different group lists on client and server so that a connection
3562+
should not be possible.
3563+
"""
3564+
3565+
def loopback_x25519_client_factory(
3566+
socket: socket, version: int = SSLv23_METHOD
3567+
) -> Connection:
3568+
connection = loopback_client_factory(socket, version)
3569+
connection.set_groups(b"X25519")
3570+
return connection
3571+
3572+
def loopback_x448_server_factory(
3573+
socket: socket, version: int = SSLv23_METHOD
3574+
) -> Connection:
3575+
ctx = Context(version)
3576+
ctx.set_groups(b"X448")
3577+
3578+
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
3579+
ctx.use_certificate(
3580+
load_certificate(FILETYPE_PEM, server_cert_pem)
3581+
)
3582+
server = Connection(ctx, socket)
3583+
server.set_accept_state()
3584+
return server
3585+
3586+
with pytest.raises(SSL.Error):
3587+
loopback(
3588+
client_factory=loopback_x25519_client_factory,
3589+
server_factory=loopback_x448_server_factory,
3590+
)
3591+
34983592
def test_wantReadError(self) -> None:
34993593
"""
35003594
`Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are

0 commit comments

Comments
 (0)