Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*/__pycache__/*
presto_python_client.egg-info/*
prestodb/sqlalchemy/__pycache__/*
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,33 @@ The transaction is created when the first SQL statement is executed.
exits the *with* context and the queries succeed, otherwise
`prestodb.dbapi.Connection.rollback()' will be called.


# SQLAlchemy Support

The client also provides a SQLAlchemy dialect.

## Installation

```
$ pip install presto-python-client[sqlalchemy]
```

## Usage

To connect to Presto using SQLAlchemy:

```python
from sqlalchemy import create_engine

engine = create_engine('presto://user:password@host:port/catalog/schema')
connection = engine.connect()

rows = connection.execute("SELECT * FROM system.runtime.nodes").fetchall()
```

# Running Tests


There is a helper scripts, `run`, that provides commands to run tests.
Type `./run tests` to run both unit and integration tests.

Expand Down
1 change: 1 addition & 0 deletions prestodb/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

apilevel = "2.0"
threadsafety = 2
paramstyle = "pyformat"

logger = logging.getLogger(__name__)

Expand Down
11 changes: 11 additions & 0 deletions prestodb/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
198 changes: 198 additions & 0 deletions prestodb/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Attribution:
# This code is adapted from the trino-python-client project (Apache 2.0 License).
# https://github.com/trinodb/trino-python-client/blob/master/trino/sqlalchemy/dialect.py

from sqlalchemy import types, util
from sqlalchemy.engine import default
from sqlalchemy.sql import sqltypes

from prestodb import auth, dbapi
from prestodb.sqlalchemy import compiler, datatype

_type_map = {
# Standard types
"boolean": datatype.BOOLEAN,
"tinyint": datatype.TINYINT,
"smallint": datatype.SMALLINT,
"integer": datatype.INTEGER,
"bigint": datatype.BIGINT,
"real": datatype.REAL,
"double": datatype.DOUBLE,
"decimal": datatype.DECIMAL,
"varchar": datatype.VARCHAR,
"char": datatype.CHAR,
"varbinary": datatype.VARBINARY,
"json": datatype.JSON,
"date": datatype.DATE,
"time": datatype.TIME,
"time with time zone": datatype.TIME, # TODO: time with time zone
"timestamp": datatype.TIMESTAMP,
"timestamp with time zone": datatype.TIMESTAMP, # TODO: timestamp with time zone
"interval year to month": datatype.INTERVAL,
"interval day to second": datatype.INTERVAL,
# Specific types
"array": datatype.ARRAY,
"map": datatype.MAP,
"row": datatype.ROW,
"hyperloglog": datatype.HYPERLOGLOG,
"p4hyperloglog": datatype.P4HYPERLOGLOG,
"qdigest": datatype.QDIGEST,
}


class PrestoDialect(default.DefaultDialect):
name = "presto"
driver = "presto"
author = "Presto Team"
supports_alter = False
supports_pk_on_update = False
supports_full_outer_join = True
supports_simple_order_by_label = False
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_native_boolean = True

statement_compiler = compiler.PrestoSQLCompiler
type_compiler = compiler.PrestoTypeCompiler
preparer = compiler.PrestoIdentifierPreparer

def create_connect_args(self, url):
args = {"host": url.host}
if url.port:
args["port"] = url.port
if url.username:
args["user"] = url.username
if url.password:
args["http_scheme"] = "https"
args["auth"] = auth.BasicAuthentication(url.username, url.password)

db_parts = (url.database or "system").split("/")
if len(db_parts) == 1:
args["catalog"] = db_parts[0]
elif len(db_parts) == 2:
args["catalog"] = db_parts[0]
args["schema"] = db_parts[1]
else:
raise ValueError("Unexpected database format: {}".format(url.database))

return ([args], {})

@classmethod
def import_dbapi(cls):
return dbapi

def has_table(self, connection, table_name, schema=None):
return self._has_object(connection, "TABLE", table_name, schema)

def has_sequence(self, connection, sequence_name, schema=None):
return False

def _has_object(self, connection, object_type, object_name, schema=None):
if schema is None:
schema = connection.engine.dialect.default_schema_name

return (
connection.execute(
"SELECT count(*) FROM information_schema.tables "
"WHERE table_schema = '{}' AND table_name = '{}'".format(
schema, object_name
)
).scalar()
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
> 0
)

def get_schema_names(self, connection, **kw):
result = connection.execute("SHOW SCHEMAS")
return [row[0] for row in result]
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

def get_table_names(self, connection, schema=None, **kw):
schema = schema or self.default_schema_name
if schema is None:
raise ValueError("schema argument is required")
result = connection.execute("SHOW TABLES FROM {}".format(schema))
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
return [row[0] for row in result]

def get_columns(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
if schema is None:
raise ValueError("schema argument is required")
query = "SHOW COLUMNS FROM {}.{}".format(schema, table_name)
result = connection.execute(query)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
columns = []
for row in result:
# Column(Column, Type, Extra, Comment)
col_name = row[0]
col_type = row[1]
# extra = row[2]
# comment = row[3]
columns.append(
{
"name": col_name,
"type": self._parse_type(col_type),
"nullable": True, # TODO: check nullability
"default": None,
}
)
return columns

def _parse_type(self, type_str):
type_str = type_str.lower()
match = util.re.match(r"^([a-zA-Z0-9_]+)(\((.+)\))?$", type_str)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
if not match:
return sqltypes.NullType()

type_name = match.group(1)
type_args = match.group(3)

if type_name in _type_map:
type_class = _type_map[type_name]
if type_args:
return type_class(*self._parse_type_args(type_args))
return type_class()
return sqltypes.NullType()

def _parse_type_args(self, type_args):
# TODO: improve parsing for nested types
return [int(a) if a.isdigit() else a for a in type_args.split(",")]

def do_rollback(self, dbapi_connection):
# Presto transactions usually auto-commit or are read-only
pass

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Presto doesn't enforce foreign keys
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Presto doesn't enforce primary keys
return {"constrained_columns": [], "name": None}

def get_indexes(self, connection, table_name, schema=None, **kw):
# TODO: Implement index reflection
return []

def do_ping(self, dbapi_connection):
cursor = None
try:
cursor = dbapi_connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
except Exception:
if cursor:
cursor.close()
return False
else:
cursor.close()
return True
Loading