Skip to content

Commit 7dcdd50

Browse files
committed
Expose set_groups functions
This allows pyOpenSSL to restrict the groups allowed to be used. E.g. for restricting the groups to post-quantum hybrid groups (e.g. X25519MLKEM768) to always ensure that post-quantum cryptography is used. This commit uses the set_groups as public API since that is the preferred naming. Internally for the API we use the set_curves API since this name is available on all OpenSSL implementations and OpenSSL forks. Signed-off-by: Arne Schwabe <arne@rfc2549.org>
1 parent f72218e commit 7dcdd50

3 files changed

Lines changed: 120 additions & 0 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Deprecations:
1818

1919
Changes:
2020
^^^^^^^^
21+
- Added ``OpenSSL.SSL.Context.set_groups`` and ``OpenSSL.SSL.Connection.set_groups`` to set allowed groups/curves.
2122

2223
- Added support for using aws-lc instead of OpenSSL.
2324
- Properly raise an error if a DTLS cookie callback returned a cookie longer than ``DTLS1_COOKIE_LENGTH`` bytes. Previously this would result in a buffer-overflow. Credit to **dark_haxor** for reporting the issue. **CVE-2026-27459**

src/OpenSSL/SSL.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,19 @@ def set_tls13_ciphersuites(self, ciphersuites: bytes) -> None:
15311531
_lib.SSL_CTX_set_ciphersuites(self._context, ciphersuites) == 1
15321532
)
15331533

1534+
@_require_not_used
1535+
def set_groups(self, groups: bytes):
1536+
"""
1537+
Set the supported groups/curves in this SSL Session.
1538+
"""
1539+
if not isinstance(groups, bytes):
1540+
raise TypeError("groups must be a byte string.")
1541+
1542+
# We use the newer name (groups) in our public API and
1543+
# use the legacy/more compatible name in the internal API
1544+
rc = _lib.SSL_CTX_set1_curves_list(self._context, groups)
1545+
_openssl_assert(rc == 1)
1546+
15341547
@_require_not_used
15351548
def set_client_ca_list(
15361549
self, certificate_authorities: Sequence[X509Name]
@@ -3249,6 +3262,18 @@ def get_group_name(self) -> str | None:
32493262

32503263
return _ffi.string(group_name).decode("utf-8")
32513264

3265+
def set_groups(self, groups: bytes):
3266+
"""
3267+
Set the supported groups/curves in this SSL Session.
3268+
"""
3269+
if not isinstance(groups, bytes):
3270+
raise TypeError("groups must be a byte string.")
3271+
3272+
# We use the newer name (groups) in our public API and
3273+
# use the legacy/more compatible name in the internal API
3274+
rc = _lib.SSL_set1_curves_list(self._ssl, groups)
3275+
_openssl_assert(rc == 1)
3276+
32523277
def request_ocsp(self) -> None:
32533278
"""
32543279
Called to request that the server sends stapled OCSP data, if

tests/test_ssl.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,6 +3581,100 @@ def test_get_group_name(self) -> None:
35813581

35823582
assert server_group_name == client_group_name
35833583

3584+
@pytest.mark.skipif(
3585+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3586+
reason="SSL_get0_group_name unavailable",
3587+
)
3588+
def test_set_groups_context(self) -> None:
3589+
"""
3590+
`Context.set_groups` forces the use of a specific curve/groups list.
3591+
"""
3592+
3593+
def loopback_x448_client_factory(
3594+
socket: socket, version: int = SSLv23_METHOD
3595+
) -> Connection:
3596+
context = Context(version)
3597+
context.set_groups(b"X448")
3598+
client = Connection(context, socket)
3599+
client.set_connect_state()
3600+
return client
3601+
3602+
server, client = loopback(client_factory=loopback_x448_client_factory)
3603+
server_group_name = server.get_group_name()
3604+
client_group_name = client.get_group_name()
3605+
3606+
assert isinstance(server_group_name, str)
3607+
assert isinstance(client_group_name, str)
3608+
3609+
assert server_group_name.lower() == "x448"
3610+
assert client_group_name.lower() == "x448"
3611+
3612+
def loopback_x448_server_factory(
3613+
socket: socket, version: int = SSLv23_METHOD
3614+
) -> Connection:
3615+
connection = loopback_server_factory(socket, version)
3616+
connection.set_groups(b"X448")
3617+
return connection
3618+
3619+
@pytest.mark.skipif(
3620+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3621+
reason="SSL_get0_group_name unavailable",
3622+
)
3623+
def test_set_groups_session(self) -> None:
3624+
"""
3625+
`Connection.set_groups` forces the use of a specific curve/groups list.
3626+
"""
3627+
3628+
def loopback_x448_server_factory(
3629+
socket: socket, version: int = SSLv23_METHOD
3630+
) -> Connection:
3631+
connection = loopback_server_factory(socket, version)
3632+
connection.set_groups(b"X448")
3633+
return connection
3634+
3635+
server, client = loopback(server_factory=loopback_x448_server_factory)
3636+
server_group_name = server.get_group_name()
3637+
client_group_name = client.get_group_name()
3638+
3639+
assert isinstance(server_group_name, str)
3640+
assert isinstance(client_group_name, str)
3641+
3642+
assert server_group_name.lower() == "x448"
3643+
assert client_group_name.lower() == "x448"
3644+
3645+
def test_set_groups_mismatch(self):
3646+
"""
3647+
Forces different group lists on client and server so that a connection
3648+
should not be possible.
3649+
"""
3650+
3651+
def loopback_x25519_client_factory(
3652+
socket: socket, version: int = SSLv23_METHOD
3653+
) -> Connection:
3654+
connection = loopback_client_factory(socket, version)
3655+
connection.set_groups(b"X25519")
3656+
return connection
3657+
3658+
def loopback_x448_server_factory(
3659+
socket: socket, version: int = SSLv23_METHOD
3660+
) -> Connection:
3661+
ctx = Context(version)
3662+
ctx.set_groups(b"X448")
3663+
3664+
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
3665+
ctx.use_certificate(
3666+
load_certificate(FILETYPE_PEM, server_cert_pem)
3667+
)
3668+
server = Connection(ctx, socket)
3669+
server.set_accept_state()
3670+
return server
3671+
3672+
with pytest.raises(SSL.Error):
3673+
loopback(
3674+
client_factory=loopback_x25519_client_factory,
3675+
server_factory=loopback_x448_server_factory,
3676+
)
3677+
35843678
def test_wantReadError(self) -> None:
35853679
"""
35863680
`Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are

0 commit comments

Comments
 (0)