Skip to content

Commit 6d55393

Browse files
committed
Expose set_groups functions
This allows pyOpenSSL to restrict the groups allowed to be used. E.g. for restricting the groups to post-quantum hybrid groups (e.g. X25519MLKEM768) to always ensure that post-quantum cryptography is used. This commit uses the set_groups as public API since that is the preferred naming. Internally for the API we use the set_curves API since this name is available on all OpenSSL implementations and OpenSSL forks. Signed-off-by: Arne Schwabe <arne@rfc2549.org>
1 parent 7b29beb commit 6d55393

File tree

4 files changed

+114
-1
lines changed

4 files changed

+114
-1
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Deprecations:
1818

1919
Changes:
2020
^^^^^^^^
21+
- Added ``OpenSSL.SSL.Context.set_groups`` and ``OpenSSL.SSL.Connection.set_groups`` to set allowed groups/curves.
2122

2223
- Added ``OpenSSL.SSL.Connection.get_group_name`` to determine which group name was negotiated.
2324

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def find_meta(meta):
9393
packages=find_packages(where="src"),
9494
package_dir={"": "src"},
9595
install_requires=[
96-
"cryptography>=46.0.0,<47",
96+
"cryptography>=47.0.0,<48",
9797
(
9898
"typing-extensions>=4.9; "
9999
"python_version < '3.13' and python_version >= '3.8'"

src/OpenSSL/SSL.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,19 @@ def set_tls13_ciphersuites(self, ciphersuites: bytes) -> None:
15151515
_lib.SSL_CTX_set_ciphersuites(self._context, ciphersuites) == 1
15161516
)
15171517

1518+
@_require_not_used
1519+
def set_groups(self, groups: bytes) -> None:
1520+
"""
1521+
Set the supported groups/curves in this SSL Session.
1522+
"""
1523+
if not isinstance(groups, bytes):
1524+
raise TypeError("groups must be a byte string.")
1525+
1526+
# We use the newer name (groups) in our public API and
1527+
# use the legacy/more compatible name in the internal API
1528+
rc = _lib.SSL_CTX_set1_curves_list(self._context, groups)
1529+
_openssl_assert(rc == 1)
1530+
15181531
@_require_not_used
15191532
def set_client_ca_list(
15201533
self, certificate_authorities: Sequence[X509Name]
@@ -3227,6 +3240,18 @@ def get_group_name(self) -> str | None:
32273240

32283241
return _ffi.string(group_name).decode("utf-8")
32293242

3243+
def set_groups(self, groups: bytes) -> None:
3244+
"""
3245+
Set the supported groups/curves in this SSL Session.
3246+
"""
3247+
if not isinstance(groups, bytes):
3248+
raise TypeError("groups must be a byte string.")
3249+
3250+
# We use the newer name (groups) in our public API and
3251+
# use the legacy/more compatible name in the internal API
3252+
rc = _lib.SSL_set1_curves_list(self._ssl, groups)
3253+
_openssl_assert(rc == 1)
3254+
32303255
def request_ocsp(self) -> None:
32313256
"""
32323257
Called to request that the server sends stapled OCSP data, if

tests/test_ssl.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,6 +3495,93 @@ def test_get_group_name(self) -> None:
34953495

34963496
assert server_group_name == client_group_name
34973497

3498+
@pytest.mark.skipif(
3499+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3500+
reason="SSL_get0_group_name unavailable",
3501+
)
3502+
def test_set_groups_context(self) -> None:
3503+
"""
3504+
`Context.set_groups` forces the use of a specific curve/groups list.
3505+
"""
3506+
3507+
def loopback_x448_client_factory(
3508+
socket: socket, version: int = SSLv23_METHOD
3509+
) -> Connection:
3510+
context = Context(version)
3511+
context.set_groups(b"X448")
3512+
client = Connection(context, socket)
3513+
client.set_connect_state()
3514+
return client
3515+
3516+
server, client = loopback(client_factory=loopback_x448_client_factory)
3517+
server_group_name = server.get_group_name()
3518+
client_group_name = client.get_group_name()
3519+
3520+
assert isinstance(server_group_name, str)
3521+
assert isinstance(client_group_name, str)
3522+
3523+
assert server_group_name.lower() == "x448"
3524+
assert client_group_name.lower() == "x448"
3525+
3526+
@pytest.mark.skipif(
3527+
not getattr(_lib, "Cryptography_HAS_SSL_GET0_GROUP_NAME", None),
3528+
reason="SSL_get0_group_name unavailable",
3529+
)
3530+
def test_set_groups_session(self) -> None:
3531+
"""
3532+
`Connection.set_groups` forces the use of a specific curve/groups list.
3533+
"""
3534+
3535+
def loopback_x448_server_factory(
3536+
socket: socket, version: int = SSLv23_METHOD
3537+
) -> Connection:
3538+
connection = loopback_server_factory(socket, version)
3539+
connection.set_groups(b"X448")
3540+
return connection
3541+
3542+
server, client = loopback(server_factory=loopback_x448_server_factory)
3543+
server_group_name = server.get_group_name()
3544+
client_group_name = client.get_group_name()
3545+
3546+
assert isinstance(server_group_name, str)
3547+
assert isinstance(client_group_name, str)
3548+
3549+
assert server_group_name.lower() == "x448"
3550+
assert client_group_name.lower() == "x448"
3551+
3552+
def test_set_groups_mismatch(self) -> None:
3553+
"""
3554+
Forces different group lists on client and server so that a connection
3555+
should not be possible.
3556+
"""
3557+
3558+
def loopback_x25519_client_factory(
3559+
socket: socket, version: int = SSLv23_METHOD
3560+
) -> Connection:
3561+
connection = loopback_client_factory(socket, version)
3562+
connection.set_groups(b"X25519")
3563+
return connection
3564+
3565+
def loopback_x448_server_factory(
3566+
socket: socket, version: int = SSLv23_METHOD
3567+
) -> Connection:
3568+
ctx = Context(version)
3569+
ctx.set_groups(b"X448")
3570+
3571+
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
3572+
ctx.use_certificate(
3573+
load_certificate(FILETYPE_PEM, server_cert_pem)
3574+
)
3575+
server = Connection(ctx, socket)
3576+
server.set_accept_state()
3577+
return server
3578+
3579+
with pytest.raises(SSL.Error):
3580+
loopback(
3581+
client_factory=loopback_x25519_client_factory,
3582+
server_factory=loopback_x448_server_factory,
3583+
)
3584+
34983585
def test_wantReadError(self) -> None:
34993586
"""
35003587
`Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are

0 commit comments

Comments
 (0)