Skip to content

Commit 5988d47

Browse files
committed
Expose set_groups functions
This allows pyOpenSSL to restrict the groups allowed to be used. E.g. for restricting the groups to post-quantum hybrid groups (e.g. X25519MLKEM768) to always ensure that post-quantum cryptography is used. This commit uses the set_groups as public API since that is the preferred naming. Internally for the API we use the set_curves API since this name is available on all OpenSSL implementations and OpenSSL forks. Signed-off-by: Arne Schwabe <arne@rfc2549.org>
1 parent 7b29beb commit 5988d47

3 files changed

Lines changed: 75 additions & 0 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Deprecations:
1818

1919
Changes:
2020
^^^^^^^^
21+
- Added ``OpenSSL.SSL.Context.set_groups`` and ``OpenSSL.SSL.Connection.set_groups`` to set allowed groups/curves.
2122

2223
- Added ``OpenSSL.SSL.Connection.get_group_name`` to determine which group name was negotiated.
2324

src/OpenSSL/SSL.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,21 @@ def set_tls13_ciphersuites(self, ciphersuites: bytes) -> None:
15151515
_lib.SSL_CTX_set_ciphersuites(self._context, ciphersuites) == 1
15161516
)
15171517

1518+
1519+
@_require_not_used
1520+
def set_groups(self, groups: bytes):
1521+
"""
1522+
Set the supported groups/curves in this SSL Session.
1523+
"""
1524+
if not isinstance(groups, bytes):
1525+
raise TypeError("groups must be a byte string.")
1526+
1527+
# We use the newer name (groups) in our public API and
1528+
# use the legacy/more compatible name in the internal API
1529+
rc = _lib.SSL_CTX_set1_curves_list(self._context, groups)
1530+
_openssl_assert(rc == 1)
1531+
1532+
15181533
@_require_not_used
15191534
def set_client_ca_list(
15201535
self, certificate_authorities: Sequence[X509Name]
@@ -3227,6 +3242,19 @@ def get_group_name(self) -> str | None:
32273242

32283243
return _ffi.string(group_name).decode("utf-8")
32293244

3245+
def set_groups(self, groups: bytes):
3246+
"""
3247+
Set the supported groups/curves in this SSL Session.
3248+
"""
3249+
if not isinstance(groups, bytes):
3250+
raise TypeError("groups must be a byte string.")
3251+
3252+
# We use the newer name (groups) in our public API and
3253+
# use the legacy/more compatible name in the internal API
3254+
rc = _lib.SSL_set1_curves_list(self._ssl, groups)
3255+
_openssl_assert(rc == 1)
3256+
3257+
32303258
def request_ocsp(self) -> None:
32313259
"""
32323260
Called to request that the server sends stapled OCSP data, if

tests/test_ssl.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,6 +3495,52 @@ def test_get_group_name(self) -> None:
34953495

34963496
assert server_group_name == client_group_name
34973497

3498+
def test_set_groups_context(self) -> None:
3499+
"""
3500+
`Context.set_groups` forces the use of a specific curve/groups list.
3501+
"""
3502+
def loopback_x448_client_factory(
3503+
socket: socket, version: int = SSLv23_METHOD
3504+
) -> Connection:
3505+
context = Context(version)
3506+
context.set_groups(b"X448")
3507+
client = Connection(context, socket)
3508+
client.set_connect_state()
3509+
return client
3510+
3511+
server, client = loopback(client_factory=loopback_x448_client_factory)
3512+
server_group_name = server.get_group_name()
3513+
client_group_name = client.get_group_name()
3514+
3515+
assert isinstance(server_group_name, str)
3516+
assert isinstance(client_group_name, str)
3517+
3518+
assert server_group_name.lower() == "x448"
3519+
assert client_group_name.lower() == "x448"
3520+
3521+
def test_set_groups_session(self) -> None:
3522+
"""
3523+
`Connection.set_groups` forces the use of a specific curve/groups list.
3524+
"""
3525+
3526+
def loopback_x448_server_factory(
3527+
socket: socket, version: int = SSLv23_METHOD
3528+
)-> Connection:
3529+
connection = loopback_server_factory(socket, version)
3530+
connection.set_groups(b"X448")
3531+
return connection
3532+
3533+
server, client = loopback(server_factory=loopback_x448_server_factory)
3534+
server_group_name = server.get_group_name()
3535+
client_group_name = client.get_group_name()
3536+
3537+
assert isinstance(server_group_name, str)
3538+
assert isinstance(client_group_name, str)
3539+
3540+
assert server_group_name.lower() == "x448"
3541+
assert client_group_name.lower() == "x448"
3542+
3543+
34983544
def test_wantReadError(self) -> None:
34993545
"""
35003546
`Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are

0 commit comments

Comments
 (0)