Skip to content

Commit bee0af4

Browse files
committed
SSL: Use SSLMode enum class from asyncpg
1 parent 3e1d84a commit bee0af4

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

src/sqlalchemy_cratedb/dialect.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from sqlalchemy import types as sqltypes
2727
from sqlalchemy.engine import default, reflection
28+
from sqlalchemy.exc import SQLAlchemyError
2829
from sqlalchemy.sql import functions
2930
from sqlalchemy.util import asbool, to_list
3031

@@ -35,6 +36,7 @@
3536
)
3637
from .sa_version import SA_1_4, SA_2_0, SA_VERSION
3738
from .type import FloatVector, ObjectArray, ObjectType
39+
from .util import SSLMode
3840

3941
TYPES_MAP = {
4042
"boolean": sqltypes.Boolean,
@@ -228,8 +230,7 @@ def connect(self, host=None, port=None, *args, **kwargs):
228230
server = kwargs.pop("servers")
229231
servers = to_list(server)
230232

231-
# Process SSL options, old and new.
232-
# TODO: Switch to the canonical default `sslmode=prefer` later.
233+
# Process legacy SSL option `ssl`.
233234
if "ssl" in kwargs:
234235
warnings.warn(
235236
"The `ssl=true` option will be deprecated, "
@@ -238,11 +239,25 @@ def connect(self, host=None, port=None, *args, **kwargs):
238239
stacklevel=2,
239240
)
240241
use_ssl = asbool(kwargs.pop("ssl", False))
241-
sslmode = kwargs.pop("sslmode", "disable")
242-
if sslmode in ["allow", "prefer", "require", "verify-ca", "verify-full"]:
243-
use_ssl = True
244-
if sslmode in ["allow", "prefer", "require"]:
245-
kwargs["verify_ssl_cert"] = False
242+
243+
# Process new SSL option `sslmode`.
244+
# Please consult https://www.postgresql.org/docs/18/libpq-connect.html.
245+
if "sslmode" in kwargs:
246+
try:
247+
sslmode = SSLMode.parse(kwargs.pop("sslmode"))
248+
except AttributeError as exc:
249+
modes = ", ".join(SSLMode.modes)
250+
raise SQLAlchemyError(
251+
"`sslmode` parameter must be one of: {}".format(modes)
252+
) from exc
253+
if sslmode < SSLMode.allow:
254+
use_ssl = False
255+
else:
256+
use_ssl = True
257+
if sslmode >= SSLMode.verify_ca:
258+
kwargs["verify_ssl_cert"] = True
259+
else:
260+
kwargs["verify_ssl_cert"] = False
246261

247262
if not servers:
248263
servers = [self.dbapi.http.Client.default_server.replace("http://", "")]

src/sqlalchemy_cratedb/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import enum
2+
3+
from sqlalchemy.util import classproperty
4+
5+
6+
class SSLMode(enum.IntEnum):
7+
"""
8+
SSLMode class from asyncpg, with a little improvement.
9+
https://github.com/MagicStack/asyncpg/blob/v0.31.0/asyncpg/connect_utils.py#L36-L48
10+
"""
11+
12+
disable = 0
13+
allow = 1
14+
prefer = 2
15+
require = 3
16+
verify_ca = 4
17+
verify_full = 5
18+
19+
@classmethod
20+
def parse(cls, sslmode):
21+
if isinstance(sslmode, cls):
22+
return sslmode
23+
return getattr(cls, sslmode.replace("-", "_"))
24+
25+
@classproperty
26+
def modes(cls):
27+
return [m.name.replace("_", "-") for m in cls]

tests/connection_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import pytest
2626
import sqlalchemy as sa
27-
from sqlalchemy.exc import NoSuchModuleError
27+
from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError
2828

2929
from sqlalchemy_cratedb import SA_1_4, SA_VERSION
3030
from tests.util import ExtraAssertions
@@ -104,6 +104,15 @@ def test_connection_server_uri_https_sslmode_disabled(self):
104104
self.assertEqual(["http://otherhost:19201"], servers)
105105
engine.dispose()
106106

107+
def test_connection_server_uri_https_sslmode_invalid(self):
108+
with pytest.raises(SQLAlchemyError) as exc_info:
109+
engine = sa.create_engine("crate://otherhost:19201/?sslmode=foo")
110+
engine.raw_connection()
111+
exc_info.match(
112+
"`sslmode` parameter must be one of: "
113+
"disable, allow, prefer, require, verify-ca, verify-full"
114+
)
115+
107116
def test_connection_server_uri_invalid_port(self):
108117
with self.assertRaises(ValueError) as context:
109118
sa.create_engine("crate://foo:bar")

0 commit comments

Comments
 (0)