-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdialect.py
More file actions
446 lines (376 loc) · 15.8 KB
/
dialect.py
File metadata and controls
446 lines (376 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
import logging
import warnings
from datetime import date, datetime
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import functions
from sqlalchemy.util import asbool, to_list
from .compiler import (
CrateDDLCompiler,
CrateIdentifierPreparer,
CrateTypeCompiler,
)
from .sa_version import SA_1_4, SA_2_0, SA_VERSION
from .type import FloatVector, ObjectArray, ObjectType
from .util import SSLMode
TYPES_MAP = {
"boolean": sqltypes.Boolean,
"short": sqltypes.SmallInteger,
"smallint": sqltypes.SmallInteger,
"timestamp": sqltypes.TIMESTAMP(timezone=False),
"timestamp with time zone": sqltypes.TIMESTAMP(timezone=True),
"object": ObjectType,
"integer": sqltypes.Integer,
"long": sqltypes.NUMERIC,
"bigint": sqltypes.NUMERIC,
"double": sqltypes.DECIMAL,
"double precision": sqltypes.DECIMAL,
"object_array": ObjectArray,
"float": sqltypes.Float,
"real": sqltypes.Float,
"string": sqltypes.String,
"text": sqltypes.String,
"float_vector": FloatVector,
}
# Needed for SQLAlchemy >= 1.1.
# TODO: Dissolve.
try:
from sqlalchemy.types import ARRAY
TYPES_MAP["integer_array"] = ARRAY(sqltypes.Integer)
TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean)
TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False))
TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True))
TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL)
TYPES_MAP["double precision_array"] = ARRAY(sqltypes.DECIMAL)
TYPES_MAP["float_array"] = ARRAY(sqltypes.Float)
TYPES_MAP["real_array"] = ARRAY(sqltypes.Float)
TYPES_MAP["string_array"] = ARRAY(sqltypes.String)
TYPES_MAP["text_array"] = ARRAY(sqltypes.String)
except Exception: # noqa: S110
pass
log = logging.getLogger(__name__)
class Date(sqltypes.Date):
def bind_processor(self, dialect):
def process(value):
if value is not None:
assert isinstance(value, date) # noqa: S101
return value.strftime("%Y-%m-%d")
return None
return process
def result_processor(self, dialect, coltype):
def process(value):
if not value:
return None
try:
return datetime.utcfromtimestamp(value / 1e3).date()
except TypeError:
pass
# Crate doesn't really have datetime or date types but a
# timestamp type. The "date" mapping (conversion to long)
# is only applied if the schema definition for the column exists
# and if the sql insert statement was used.
# In case of dynamic mapping or using the rest indexing endpoint
# the date will be returned in the format it was inserted.
log.warning(
"Received timestamp isn't a long value."
"Trying to parse as date string and then as datetime string"
)
try:
return datetime.strptime(value, "%Y-%m-%d").date()
except ValueError:
return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ").date()
return process
class DateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if isinstance(value, (datetime, date)):
return value.strftime("%Y-%m-%dT%H:%M:%S.%f%z")
return value
return process
def result_processor(self, dialect, coltype):
def process(value):
if not value:
return None
try:
return datetime.utcfromtimestamp(value / 1e3)
except TypeError:
pass
# Crate doesn't really have datetime or date types but a
# timestamp type. The "date" mapping (conversion to long)
# is only applied if the schema definition for the column exists
# and if the sql insert statement was used.
# In case of dynamic mapping or using the rest indexing endpoint
# the date will be returned in the format it was inserted.
log.warning(
"Received timestamp isn't a long value."
"Trying to parse as datetime string and then as date string"
)
try:
return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError:
return datetime.strptime(value, "%Y-%m-%d")
return process
colspecs = {
sqltypes.Date: Date,
sqltypes.DateTime: DateTime,
sqltypes.TIMESTAMP: DateTime,
}
if SA_VERSION >= SA_2_0:
from .compat.core20 import CrateCompilerSA20
statement_compiler = CrateCompilerSA20
elif SA_VERSION >= SA_1_4:
from .compat.core14 import CrateCompilerSA14
statement_compiler = CrateCompilerSA14
else:
from .compat.core10 import CrateCompilerSA10
statement_compiler = CrateCompilerSA10
class CrateDialect(default.DefaultDialect):
name = "crate"
driver = "crate-python"
default_paramstyle = "qmark"
statement_compiler = statement_compiler
ddl_compiler = CrateDDLCompiler
type_compiler = CrateTypeCompiler
preparer = CrateIdentifierPreparer
use_insertmanyvalues = True
use_insertmanyvalues_wo_returning = True
supports_multivalues_insert = True
supports_native_boolean = True
supports_statement_cache = True
colspecs = colspecs
implicit_returning = True
insert_returning = True
update_returning = True
def __init__(self, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
# CrateDB does not need `OBJECT` types to be serialized as JSON.
# Corresponding data is forwarded 1:1, and will get marshalled
# by the low-level driver.
self._json_deserializer = lambda x: x
self._json_serializer = lambda x: x
# Currently, our SQL parser doesn't support unquoted column names that
# start with _. Adding it here causes sqlalchemy to quote such columns.
self.identifier_preparer.illegal_initial_characters.add("_")
def initialize(self, connection):
# get lowest server version
self.server_version_info = self._get_server_version_info(connection)
# get default schema name
self.default_schema_name = self._get_default_schema_name(connection)
def do_rollback(self, connection):
# if any exception is raised by the dbapi, sqlalchemy by default
# attempts to do a rollback crate doesn't support rollbacks.
# implementing this as noop seems to cause sqlalchemy to propagate the
# original exception to the user
pass
def connect(self, host=None, port=None, *args, **kwargs):
server = None
if host:
server = "{0}:{1}".format(host, port or "4200")
if "servers" in kwargs:
server = kwargs.pop("servers")
servers = to_list(server)
# Process legacy SSL option `ssl`.
if "ssl" in kwargs:
warnings.warn(
"The `ssl=true` option will be deprecated, "
"please use `sslmode=require` going forward.",
DeprecationWarning,
stacklevel=2,
)
use_ssl = asbool(kwargs.pop("ssl", False))
# Process new SSL option `sslmode`.
# Please consult https://www.postgresql.org/docs/18/libpq-connect.html.
if "sslmode" in kwargs:
try:
sslmode = SSLMode.parse(kwargs.pop("sslmode"))
except AttributeError as exc:
modes = ", ".join(SSLMode.modes)
raise SQLAlchemyError(
"`sslmode` parameter must be one of: {}".format(modes)
) from exc
if sslmode < SSLMode.allow:
use_ssl = False
else:
use_ssl = True
if sslmode >= SSLMode.verify_ca:
kwargs["verify_ssl_cert"] = True
else:
kwargs["verify_ssl_cert"] = False
if not servers:
servers = [self.dbapi.http.Client.default_server.replace("http://", "")]
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)
def do_execute(self, cursor, statement, parameters, context=None):
"""
Slightly amended to store its response into the request context instance.
"""
result = cursor.execute(statement, parameters)
if context is not None:
context.last_result = result
def do_execute_no_params(self, cursor, statement, context=None):
"""
Slightly amended to store its response into the request context instance.
"""
result = cursor.execute(statement)
if context is not None:
context.last_result = result
def do_executemany(self, cursor, statement, parameters, context=None):
"""
Slightly amended to store its response into the request context instance.
"""
result = cursor.executemany(statement, parameters)
if context is not None:
context.last_result = result
def _get_default_schema_name(self, connection):
return "doc"
def _get_effective_schema_name(self, connection):
schema_name_raw = connection.engine.url.query.get("schema")
schema_name = None
if isinstance(schema_name_raw, str):
schema_name = schema_name_raw
elif isinstance(schema_name_raw, tuple):
schema_name = schema_name_raw[0]
return schema_name
def _get_server_version_info(self, connection):
return tuple(connection.connection.lowest_server_version.version)
@classmethod
def import_dbapi(cls):
from crate import client
return client
@classmethod
def dbapi(cls):
return cls.import_dbapi()
def has_schema(self, connection, schema, **kw):
return schema in self.get_schema_names(connection, **kw)
def has_table(self, connection, table_name, schema=None, **kw):
return table_name in self.get_table_names(connection, schema=schema, **kw)
@reflection.cache
def get_schema_names(self, connection, **kw):
cursor = connection.exec_driver_sql(
"select schema_name from information_schema.schemata order by schema_name asc"
)
return [row[0] for row in cursor.fetchall()]
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
(schema or self.default_schema_name,),
)
return [row[0] for row in cursor.fetchall()]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.views "
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
(schema or self.default_schema_name,),
)
return [row[0] for row in cursor.fetchall()]
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
query = (
"SELECT column_name, data_type "
"FROM information_schema.columns "
"WHERE table_name = ? AND {0} = ? "
"AND column_name !~ ?".format(self.schema_column)
)
cursor = connection.exec_driver_sql(
query,
(
table_name,
schema or self.default_schema_name,
r"(.*)\[\'(.*)\'\]",
), # regex to filter subscript
)
return [self._create_column_info(row) for row in cursor.fetchall()]
@reflection.cache
def get_pk_constraint(self, engine, table_name, schema=None, **kw):
if self.server_version_info >= (3, 0, 0):
query = """SELECT column_name
FROM information_schema.key_column_usage
WHERE table_name = ? AND table_schema = ?"""
def result_fun(result):
rows = result.fetchall()
return set(map(lambda el: el[0], rows))
elif self.server_version_info >= (2, 3, 0):
query = """SELECT column_name
FROM information_schema.key_column_usage
WHERE table_name = ? AND table_catalog = ?"""
def result_fun(result):
rows = result.fetchall()
return set(map(lambda el: el[0], rows))
else:
query = """SELECT constraint_name
FROM information_schema.table_constraints
WHERE table_name = ? AND {schema_col} = ?
AND constraint_type='PRIMARY_KEY'
""".format(schema_col=self.schema_column)
def result_fun(result):
rows = result.fetchone()
return set(rows[0] if rows else [])
pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name))
pks = result_fun(pk_result)
return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"}
@reflection.cache
def get_foreign_keys(
self, connection, table_name, schema=None, postgresql_ignore_search_path=False, **kw
):
# Crate doesn't support Foreign Keys, so this stays empty
return []
@reflection.cache
def get_indexes(self, connection, table_name, schema, **kw):
return []
@property
def schema_column(self):
return "table_schema"
def _create_column_info(self, row):
return {
"name": row[0],
"type": self._resolve_type(row[1]),
# In Crate every column is nullable except PK
# Primary Key Constraints are not nullable anyway, no matter what
# we return here, so it's fine to return always `True`
"nullable": True,
}
def _resolve_type(self, type_):
return TYPES_MAP.get(type_, sqltypes.UserDefinedType)
def has_ilike_operator(self):
"""
Only CrateDB 4.1.0 and higher implements the `ILIKE` operator.
"""
server_version_info = self.server_version_info
return server_version_info is not None and server_version_info >= (4, 1, 0)
class DateTrunc(functions.GenericFunction):
name = "date_trunc"
type = sqltypes.TIMESTAMP
dialect = CrateDialect