99from __future__ import unicode_literals
1010
1111import re
12+ import sqlalchemy
1213from sqlalchemy import exc
1314from sqlalchemy import types
1415from sqlalchemy import util
1516# TODO shouldn't use mysql type
16- from sqlalchemy .databases import mysql
17+ from sqlalchemy .sql import text
18+ try :
19+ from sqlalchemy .databases import mysql
20+ mysql_tinyinteger = mysql .MSTinyInteger
21+ except ImportError :
22+ # Required for SQLAlchemy>=2.0
23+ from sqlalchemy .dialects import mysql
24+ mysql_tinyinteger = mysql .base .MSTinyInteger
1725from sqlalchemy .engine import default
1826from sqlalchemy .sql import compiler
1927from sqlalchemy .sql .compiler import SQLCompiler
2028
2129from pyhive import presto
2230from pyhive .common import UniversalSet
2331
32+ sqlalchemy_version = float (re .search (r"^([\d]+\.[\d]+)\..+" , sqlalchemy .__version__ ).group (1 ))
2433
2534class PrestoIdentifierPreparer (compiler .IdentifierPreparer ):
2635 # Just quote everything to make things simpler / easier to upgrade
@@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
2938
3039_type_map = {
3140 'boolean' : types .Boolean ,
32- 'tinyint' : mysql . MSTinyInteger ,
41+ 'tinyint' : mysql_tinyinteger ,
3342 'smallint' : types .SmallInteger ,
3443 'integer' : types .Integer ,
3544 'bigint' : types .BigInteger ,
@@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
8089 supports_multivalues_insert = True
8190 supports_unicode_statements = True
8291 supports_unicode_binds = True
92+ supports_statement_cache = False
8393 returns_unicode_strings = True
8494 description_encoding = None
8595 supports_native_boolean = True
@@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
8898 @classmethod
8999 def dbapi (cls ):
90100 return presto
101+
102+ @classmethod
103+ def import_dbapi (cls ):
104+ return presto
91105
92106 def create_connect_args (self , url ):
93107 db_parts = (url .database or 'hive' ).split ('/' )
@@ -108,14 +122,14 @@ def create_connect_args(self, url):
108122 return [], kwargs
109123
110124 def get_schema_names (self , connection , ** kw ):
111- return [row .Schema for row in connection .execute ('SHOW SCHEMAS' )]
125+ return [row .Schema for row in connection .execute (text ( 'SHOW SCHEMAS' ) )]
112126
113127 def _get_table_columns (self , connection , table_name , schema ):
114128 full_table = self .identifier_preparer .quote_identifier (table_name )
115129 if schema :
116130 full_table = self .identifier_preparer .quote_identifier (schema ) + '.' + full_table
117131 try :
118- return connection .execute ('SHOW COLUMNS FROM {}' .format (full_table ))
132+ return connection .execute (text ( 'SHOW COLUMNS FROM {}' .format (full_table ) ))
119133 except (presto .DatabaseError , exc .DatabaseError ) as e :
120134 # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
121135 # it successfully does in the Hive version. The difference with Presto is that this
@@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
134148 else :
135149 raise
136150
137- def has_table (self , connection , table_name , schema = None ):
151+ def has_table (self , connection , table_name , schema = None , ** kw ):
138152 try :
139153 self ._get_table_columns (connection , table_name , schema )
140154 return True
@@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
176190 # - a boolean column named "Partition Key"
177191 # - a string in the "Comment" column
178192 # - a string in the "Extra" column
193+ if sqlalchemy_version >= 1.4 :
194+ row = row ._mapping
179195 is_partition_key = (
180196 (part_key in row and row [part_key ])
181197 or row ['Comment' ].startswith (part_key )
@@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
192208 query = 'SHOW TABLES'
193209 if schema :
194210 query += ' FROM ' + self .identifier_preparer .quote_identifier (schema )
195- return [row .Table for row in connection .execute (query )]
211+ return [row .Table for row in connection .execute (text ( query ) )]
196212
197213 def do_rollback (self , dbapi_connection ):
198214 # No transactions for Presto
0 commit comments