Skip to content

Commit 34acea6

Browse files
authored
feat: add support for sparksql dialect (acryldata#4)
1 parent 4d352fd commit 34acea6

11 files changed

Lines changed: 379 additions & 19 deletions

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pytest-random==0.2
99
pytest-timeout==1.2.0
1010

1111
# actual dependencies: let things break if a package changes
12+
sqlalchemy==1.3.24
1213
requests>=1.0.0
1314
requests_kerberos>=0.12.0
1415
sasl>=0.2.1

pyhive/sqlalchemy_hive.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sqlalchemy import exc
1616
from sqlalchemy import processors
1717
from sqlalchemy import types
18-
from sqlalchemy import sql
1918
from sqlalchemy.sql import sqltypes
2019
from sqlalchemy import util
2120
# TODO shouldn't use mysql type
@@ -249,6 +248,8 @@ class HiveDialect(default.DefaultDialect):
249248
supports_multivalues_insert = True
250249
type_compiler = HiveTypeCompiler
251250
supports_sane_rowcount = False
251+
info_rows_delimiter = ('# Detailed Table Information', None, None)
252+
partition_columns_names = ['# Partition Information']
252253

253254
@classmethod
254255
def dbapi(cls):
@@ -316,7 +317,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
316317
rows = [row for row in rows if row[0] and row[0] != '# col_name']
317318
result = []
318319
for (col_name, full_col_type, comment) in rows:
319-
if col_name == '# Partition Information':
320+
if col_name in self.partition_columns_names:
320321
break
321322
# Take out the more detailed type information
322323
# e.g. 'map<int,int>' -> 'map'
@@ -353,7 +354,7 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
353354
# Filter out empty rows and comment
354355
rows = [row for row in rows if row[0] and row[0] != '# col_name']
355356
for i, (col_name, _col_type, _comment) in enumerate(rows):
356-
if col_name == '# Partition Information':
357+
if col_name in self.partition_columns_names:
357358
break
358359
# Handle partition columns
359360
col_names = []
@@ -369,12 +370,12 @@ def get_table_names(self, connection, schema=None, **kw):
369370
if schema:
370371
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
371372
return [row[0] for row in connection.execute(query)]
372-
373+
373374
def get_table_comment(self, connection, table_name, schema=None, **kw):
374375
rows = self._get_table_columns(connection, table_name, schema, extended=True)
375376

376377
# Remove the column type specs.
377-
start_detailed_info_index = rows.index(('# Detailed Table Information', None, None))
378+
start_detailed_info_index = rows.index(self.info_rows_delimiter)
378379
assert start_detailed_info_index >= 0
379380
rows = rows[start_detailed_info_index:]
380381

@@ -396,7 +397,7 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
396397
# col_name == "", data_type is not None
397398
prop_name = "{} {}".format(active_heading, data_type.rstrip())
398399
properties[prop_name] = value.rstrip()
399-
400+
400401
return {'text': properties.get('Table Parameters: comment', None), 'properties': properties}
401402

402403
def do_rollback(self, dbapi_connection):

pyhive/sqlalchemy_sparksql.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import re
2+
3+
from . import sqlalchemy_hive
4+
5+
6+
from sqlalchemy import exc
7+
from sqlalchemy.engine import default
8+
9+
10+
class SparkSqlDialect(sqlalchemy_hive.HiveDialect):
11+
name = b'sparksql'
12+
execution_ctx_cls = default.DefaultExecutionContext
13+
info_rows_delimiter = ('# Detailed Table Information', '', '')
14+
partition_columns_names = ['# Partition Information', '# Partitioning']
15+
16+
def _get_table_columns(self, connection, table_name, schema, extended=False):
17+
full_table = table_name
18+
if schema:
19+
full_table = schema + '.' + table_name
20+
# TODO using TGetColumnsReq hangs after sending TFetchResultsReq.
21+
# Using DESCRIBE works but is uglier.
22+
try:
23+
# we need to set this to avoid sparksql truncating long column types (i.e. structs), arbitrarily chose 1kk
24+
connection.execute("SET spark.sql.debug.maxToStringFields=1000000")
25+
extended = " FORMATTED" if extended else ""
26+
rows = connection.execute('DESCRIBE{} {}'.format(extended, full_table)).fetchall()
27+
except exc.OperationalError as e:
28+
regex_fmt = r'TExecuteStatementResp.*AnalysisException.*Table or view not found:.*{}'
29+
hive_regex_fmt = r'org.apache.spark.SparkException: Cannot recognize hive type ' \
30+
r'string'
31+
if re.search(regex_fmt.format(re.escape(table_name)), e.args[0]):
32+
raise exc.NoSuchTableError(full_table)
33+
elif re.search(hive_regex_fmt, e.args[0]):
34+
raise exc.UnreflectableTableError
35+
else:
36+
raise
37+
else:
38+
return rows
39+
40+
def get_table_names(self, connection, schema=None, **kw):
41+
# SHOW TABLES will show tables and views, SHOW VIEWS only views, if it is needed we could potentially
42+
# subtract set of views from set of tables to get a proper list of only tables
43+
# since hive dialect implementation does not support views extraction get_view_names need to be reimplemented
44+
# too
45+
query = 'SHOW TABLES'
46+
if schema:
47+
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
48+
# returns tuples ('database', 'tableName', 'isTemporary')
49+
result = connection.execute(query)
50+
a = [row[1] for row in result if not row[-1]]
51+
return a
52+
53+
def has_table(self, connection, table_name, schema=None):
54+
try:
55+
return super().has_table(connection, table_name, schema)
56+
except exc.UnreflectableTableError:
57+
return False

pyhive/tests/sqlalchemy_test_case.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def wrapped_fn(self, *args, **kwargs):
3434

3535

3636
class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
37+
complex_table = "one_row_complex"
38+
complex_null_table = "one_row_complex_null"
39+
3740
@with_engine_connection
3841
def test_basic_query(self, engine, connection):
3942
rows = connection.execute('SELECT * FROM one_row').fetchall()
@@ -43,7 +46,7 @@ def test_basic_query(self, engine, connection):
4346

4447
@with_engine_connection
4548
def test_one_row_complex_null(self, engine, connection):
46-
one_row_complex_null = Table('one_row_complex_null', MetaData(bind=engine), autoload=True)
49+
one_row_complex_null = Table(self.complex_null_table, MetaData(bind=engine), autoload=True)
4750
rows = one_row_complex_null.select().execute().fetchall()
4851
self.assertEqual(len(rows), 1)
4952
self.assertEqual(list(rows[0]), [None] * len(rows[0]))
@@ -62,7 +65,7 @@ def test_reflect_no_such_table(self, engine, connection):
6265
@with_engine_connection
6366
def test_reflect_include_columns(self, engine, connection):
6467
"""When passed include_columns, reflecttable should filter out other columns"""
65-
one_row_complex = Table('one_row_complex', MetaData(bind=engine))
68+
one_row_complex = Table(self.complex_table, MetaData(bind=engine))
6669
engine.dialect.reflecttable(
6770
connection, one_row_complex, include_columns=['int'],
6871
exclude_columns=[], resolve_fks=True)
@@ -123,7 +126,7 @@ def test_get_table_names(self, engine, connection):
123126
meta = MetaData()
124127
meta.reflect(bind=engine)
125128
self.assertIn('one_row', meta.tables)
126-
self.assertIn('one_row_complex', meta.tables)
129+
self.assertIn(self.complex_table, meta.tables)
127130

128131
insp = sqlalchemy.inspect(engine)
129132
self.assertIn(
@@ -138,7 +141,7 @@ def test_has_table(self, engine, connection):
138141

139142
@with_engine_connection
140143
def test_char_length(self, engine, connection):
141-
one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
144+
one_row_complex = Table(self.complex_table, MetaData(bind=engine), autoload=True)
142145
result = sqlalchemy.select([
143146
sqlalchemy.func.char_length(one_row_complex.c.string)
144147
]).execute().scalar()

0 commit comments

Comments
 (0)