Skip to content
This repository was archived by the owner on Mar 16, 2026. It is now read-only.

Commit 26dffb4

Browse files
committed
Merge remote-tracking branch 'upstream/master' into feature/add-flake8-code-check
2 parents 6665817 + 72fbe2b commit 26dffb4

File tree

4 files changed

+185
-30
lines changed

4 files changed

+185
-30
lines changed

README.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@ To specify location of your datasets pass ``location`` to ``create_engine()``:
5454
Table names
5555
___________
5656

57-
To query tables from non-default projects, use the following format for the table name: ``project.dataset.table``, e.g.:
57+
To query tables from non-default projects or datasets, use the following format for the SQLAlchemy schema name: ``[project.]dataset``, e.g.:
5858

5959
.. code-block:: python
6060
61-
sample_table = Table('bigquery-public-data.samples.natality')
61+
# If neither dataset nor project are the default
62+
sample_table_1 = Table('natality', schema='bigquery-public-data.samples')
63+
# If just dataset is not the default
64+
sample_table_2 = Table('natality', schema='bigquery-public-data')
6265
6366
Batch size
6467
__________
@@ -85,7 +88,7 @@ When using a default dataset, don't include the dataset name in the table name,
8588
8689
table = Table('table_name')
8790
88-
Note that specyfing a default dataset doesn't restrict execution of queries to that particular dataset when using raw queries, e.g.:
91+
Note that specifying a default dataset doesn't restrict execution of queries to that particular dataset when using raw queries, e.g.:
8992

9093
.. code-block:: python
9194

pybigquery/sqlalchemy_bigquery.py

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from __future__ import absolute_import
44
from __future__ import unicode_literals
55

6+
import operator
7+
68
from google import auth
79
from google.cloud import bigquery
810
from google.cloud.bigquery import dbapi
@@ -184,12 +186,19 @@ def visit_column(self, column, add_to_result_map=None,
184186
self.preparer.quote(tablename) + \
185187
"." + name
186188

187-
def visit_label(self, *args, **kwargs):
188-
# Use labels in GROUP BY clause
189-
if len(kwargs) == 0 or len(kwargs) == 1:
189+
def visit_label(self, *args, within_group_by=False, **kwargs):
190+
# Use labels in GROUP BY clause.
191+
#
192+
# Flag set in the group_by_clause method. Works around missing
193+
# equivalent to supports_simple_order_by_label for group by.
194+
if within_group_by:
190195
kwargs['render_label_as_label'] = args[0]
191-
result = super(BigQueryCompiler, self).visit_label(*args, **kwargs)
192-
return result
196+
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
197+
198+
def group_by_clause(self, select, **kw):
199+
return super(BigQueryCompiler, self).group_by_clause(
200+
select, **kw, within_group_by=True
201+
)
193202

194203

195204
class BigQueryTypeCompiler(GenericTypeCompiler):
@@ -206,6 +215,9 @@ def visit_text(self, type_, **kw):
206215
def visit_string(self, type_, **kw):
207216
return 'STRING'
208217

218+
def visit_ARRAY(self, type_, **kw):
219+
return "ARRAY<{}>".format(self.process(type_.item_type, **kw))
220+
209221
def visit_BINARY(self, type_, **kw):
210222
return 'BYTES'
211223

@@ -284,6 +296,11 @@ def __init__(
284296
def dbapi(cls):
285297
return dbapi
286298

299+
@staticmethod
300+
def _build_formatted_table_id(table):
301+
"""Build '<dataset_id>.<table_id>' string using given table."""
302+
return "{}.{}".format(table.reference.dataset_id, table.table_id)
303+
287304
@staticmethod
288305
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
289306
# If dataset_id is set, then we know the job_config isn't None
@@ -349,6 +366,26 @@ def _json_deserializer(self, row):
349366
"""
350367
return row
351368

369+
def _get_table_or_view_names(self, connection, table_type, schema=None):
370+
current_schema = schema or self.dataset_id
371+
get_table_name = self._build_formatted_table_id \
372+
if self.dataset_id is None else \
373+
operator.attrgetter("table_id")
374+
375+
client = connection.connection._client
376+
datasets = client.list_datasets()
377+
378+
result = []
379+
for dataset in datasets:
380+
if current_schema is not None and current_schema != dataset.dataset_id:
381+
continue
382+
383+
tables = client.list_tables(dataset.reference)
384+
for table in tables:
385+
if table_type == table.table_type:
386+
result.append(get_table_name(table))
387+
return result
388+
352389
@staticmethod
353390
def _split_table_name(full_table_name):
354391
# Split full_table_name to get project, dataset and table name
@@ -363,22 +400,51 @@ def _split_table_name(full_table_name):
363400
dataset, table_name = table_name_split
364401
elif len(table_name_split) == 3:
365402
project, dataset, table_name = table_name_split
403+
else:
404+
raise ValueError("Did not understand table_name: {}".format(full_table_name))
366405

367406
return (project, dataset, table_name)
368407

408+
def _table_reference(self, provided_schema_name, provided_table_name,
409+
client_project):
410+
project_id_from_table, dataset_id_from_table, table_id = self._split_table_name(provided_table_name)
411+
project_id_from_schema = None
412+
dataset_id_from_schema = None
413+
if provided_schema_name is not None:
414+
provided_schema_name_split = provided_schema_name.split('.')
415+
if len(provided_schema_name_split) == 0:
416+
pass
417+
elif len(provided_schema_name_split) == 1:
418+
if dataset_id_from_table:
419+
project_id_from_schema = provided_schema_name_split[0]
420+
else:
421+
dataset_id_from_schema = provided_schema_name_split[0]
422+
elif len(provided_schema_name_split) == 2:
423+
project_id_from_schema = provided_schema_name_split[0]
424+
dataset_id_from_schema = provided_schema_name_split[1]
425+
else:
426+
raise ValueError("Did not understand schema: {}".format(provided_schema_name))
427+
if (dataset_id_from_schema and dataset_id_from_table and
428+
dataset_id_from_schema != dataset_id_from_table):
429+
raise ValueError("dataset_id specified in schema and table_name disagree: got {} in schema, and {} in table_name".format(dataset_id_from_schema, dataset_id_from_table))
430+
if (project_id_from_schema and project_id_from_table and
431+
project_id_from_schema != project_id_from_table):
432+
raise ValueError("project_id specified in schema and table_name disagree: got {} in schema, and {} in table_name".format(project_id_from_schema, project_id_from_table))
433+
project_id = project_id_from_schema or project_id_from_table or client_project
434+
dataset_id = dataset_id_from_schema or dataset_id_from_table or self.dataset_id
435+
436+
table_ref = TableReference.from_string("{}.{}.{}".format(
437+
project_id, dataset_id, table_id
438+
))
439+
return table_ref
440+
369441
def _get_table(self, connection, table_name, schema=None):
370442
if isinstance(connection, Engine):
371443
connection = connection.connect()
372444

373445
client = connection.connection._client
374446

375-
project_id, dataset_id, table_id = self._split_table_name(table_name)
376-
project_id = project_id or client.project
377-
dataset_id = dataset_id or schema or self.dataset_id
378-
379-
table_ref = TableReference.from_string("{}.{}.{}".format(
380-
project_id, dataset_id, table_id
381-
))
447+
table_ref = self._table_reference(schema, table_name, client.project)
382448
try:
383449
table = client.get_table(table_ref)
384450
except NotFound:
@@ -464,23 +530,13 @@ def get_table_names(self, connection, schema=None, **kw):
464530
if isinstance(connection, Engine):
465531
connection = connection.connect()
466532

467-
datasets = connection.connection._client.list_datasets()
468-
result = []
469-
for d in datasets:
470-
if schema is not None and d.dataset_id != schema:
471-
continue
533+
return self._get_table_or_view_names(connection, "TABLE", schema)
472534

473-
if self.dataset_id is not None and d.dataset_id != self.dataset_id:
474-
continue
535+
def get_view_names(self, connection, schema=None, **kw):
536+
if isinstance(connection, Engine):
537+
connection = connection.connect()
475538

476-
tables = connection.connection._client.list_tables(d.reference)
477-
for t in tables:
478-
if self.dataset_id is None:
479-
table_name = d.dataset_id + '.' + t.table_id
480-
else:
481-
table_name = t.table_id
482-
result.append(table_name)
483-
return result
539+
return self._get_table_or_view_names(connection, "VIEW", schema)
484540

485541
def do_rollback(self, dbapi_connection):
486542
# BigQuery has no support for transactions.

scripts/load_test_data.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ bq rm -f -t test_pybigquery.sample
66
bq rm -f -t test_pybigquery_alt.sample_alt
77
bq rm -f -t test_pybigquery.sample_one_row
88
bq rm -f -t test_pybigquery.sample_dml
9+
bq rm -f -t test_pybigquery.sample_view
910
bq rm -f -t test_pybigquery_location.sample_one_row
1011

1112
bq mk --table --schema=$(dirname $0)/schema.json --time_partitioning_field timestamp --clustering_fields integer,string test_pybigquery.sample
@@ -17,3 +18,5 @@ bq load --source_format=NEWLINE_DELIMITED_JSON --schema=$(dirname $0)/schema.jso
1718

1819
bq --location=asia-northeast1 load --source_format=NEWLINE_DELIMITED_JSON --schema=$(dirname $0)/schema.json test_pybigquery_location.sample_one_row $(dirname $0)/sample_one_row.json
1920
bq mk --schema=$(dirname $0)/schema.json -t test_pybigquery.sample_dml
21+
22+
bq mk --use_legacy_sql=false --view 'SELECT string FROM test_pybigquery.sample' test_pybigquery.sample_view

test/test_sqlalchemy_bigquery.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from google.api_core.exceptions import BadRequest
55
from pybigquery.api import ApiClient
6+
from pybigquery.sqlalchemy_bigquery import BigQueryDialect
67
from sqlalchemy.engine import create_engine
78
from sqlalchemy.schema import Table, MetaData, Column
89
from sqlalchemy.ext.declarative import declarative_base
@@ -102,6 +103,11 @@ def engine():
102103
return engine
103104

104105

106+
@pytest.fixture(scope='session')
107+
def dialect():
108+
return BigQueryDialect()
109+
110+
105111
@pytest.fixture(scope='session')
106112
def engine_using_test_dataset():
107113
engine = create_engine('bigquery:///test_pybigquery', echo=True)
@@ -163,10 +169,14 @@ def query():
163169
def query(table):
164170
col1 = literal_column("TIMESTAMP_TRUNC(timestamp, DAY)").label("timestamp_label")
165171
col2 = func.sum(table.c.integer)
172+
# Test rendering of nested labels. Full expression should render in SELECT, but
173+
# ORDER/GROUP BY should use label only.
174+
col3 = func.sum(func.sum(table.c.integer.label("inner")).label("outer")).over().label('outer')
166175
query = (
167176
select([
168177
col1,
169178
col2,
179+
col3,
170180
])
171181
.where(col1 < '2017-01-01 00:00:00')
172182
.group_by(col1)
@@ -284,11 +294,13 @@ def test_tables_list(engine, engine_using_test_dataset):
284294
assert 'test_pybigquery.sample' in tables
285295
assert 'test_pybigquery.sample_one_row' in tables
286296
assert 'test_pybigquery.sample_dml' in tables
297+
assert 'test_pybigquery.sample_view' not in tables
287298

288299
tables = engine_using_test_dataset.table_names()
289300
assert 'sample' in tables
290301
assert 'sample_one_row' in tables
291302
assert 'sample_dml' in tables
303+
assert 'sample_view' not in tables
292304

293305

294306
def test_group_by(session, table, session_using_test_dataset, table_using_test_dataset):
@@ -298,6 +310,33 @@ def test_group_by(session, table, session_using_test_dataset, table_using_test_d
298310
assert len(result) > 0
299311

300312

313+
def test_nested_labels(engine, table):
314+
col = table.c.integer
315+
exprs = [
316+
sqlalchemy.func.sum(
317+
sqlalchemy.func.sum(col.label("inner")
318+
).label("outer")).over(),
319+
sqlalchemy.func.sum(
320+
sqlalchemy.case([[
321+
sqlalchemy.literal(True),
322+
col.label("inner"),
323+
]]).label("outer")
324+
),
325+
sqlalchemy.func.sum(
326+
sqlalchemy.func.sum(
327+
sqlalchemy.case([[
328+
sqlalchemy.literal(True), col.label("inner")
329+
]]).label("middle")
330+
).label("outer")
331+
).over(),
332+
]
333+
for expr in exprs:
334+
sql = str(expr.compile(engine))
335+
assert "inner" not in sql
336+
assert "middle" not in sql
337+
assert "outer" not in sql
338+
339+
301340
def test_session_query(session, table, session_using_test_dataset, table_using_test_dataset):
302341
for session, table in [(session, table), (session_using_test_dataset, table_using_test_dataset)]:
303342
col_concat = func.concat(table.c.string).label('concat')
@@ -359,6 +398,16 @@ def test_compiled_query_literal_binds(engine, engine_using_test_dataset, table,
359398
assert len(result) > 0
360399

361400

401+
@pytest.mark.parametrize(["column", "processed"], [
402+
(types.String(), "STRING"),
403+
(types.NUMERIC(), "NUMERIC"),
404+
(types.ARRAY(types.String), "ARRAY<STRING>"),
405+
])
406+
def test_compile_types(engine, column, processed):
407+
result = engine.dialect.type_compiler.process(column)
408+
assert result == processed
409+
410+
362411
def test_joins(session, table, table_one_row):
363412
result = (session.query(table.c.string, func.count(table_one_row.c.integer))
364413
.join(table_one_row, table_one_row.c.string == table.c.string)
@@ -438,15 +487,27 @@ def test_table_names_in_schema(inspector, inspector_using_test_dataset):
438487
assert 'test_pybigquery.sample' in tables
439488
assert 'test_pybigquery.sample_one_row' in tables
440489
assert 'test_pybigquery.sample_dml' in tables
490+
assert 'test_pybigquery.sample_view' not in tables
441491
assert len(tables) == 3
442492

443493
tables = inspector_using_test_dataset.get_table_names()
444494
assert 'sample' in tables
445495
assert 'sample_one_row' in tables
446496
assert 'sample_dml' in tables
497+
assert 'sample_view' not in tables
447498
assert len(tables) == 3
448499

449500

501+
def test_view_names(inspector, inspector_using_test_dataset):
502+
view_names = inspector.get_view_names()
503+
assert "test_pybigquery.sample_view" in view_names
504+
assert "test_pybigquery.sample" not in view_names
505+
506+
view_names = inspector_using_test_dataset.get_view_names()
507+
assert "sample_view" in view_names
508+
assert "sample" not in view_names
509+
510+
450511
def test_get_indexes(inspector, inspector_using_test_dataset):
451512
for _ in ['test_pybigquery.sample', 'test_pybigquery.sample_one_row']:
452513
indexes = inspector.get_indexes('test_pybigquery.sample')
@@ -479,6 +540,38 @@ def test_get_columns(inspector, inspector_using_test_dataset):
479540
assert col['type'].__class__.__name__ == sample_col['type'].__class__.__name__
480541

481542

543+
@pytest.mark.parametrize('provided_schema_name,provided_table_name,client_project',
544+
[
545+
('dataset', 'table', 'project'),
546+
(None, 'dataset.table', 'project'),
547+
(None, 'project.dataset.table', 'other_project'),
548+
('project', 'dataset.table', 'other_project'),
549+
('project.dataset', 'table', 'other_project'),
550+
])
551+
def test_table_reference(dialect, provided_schema_name,
552+
provided_table_name, client_project):
553+
ref = dialect._table_reference(provided_schema_name,
554+
provided_table_name,
555+
client_project)
556+
assert ref.table_id == 'table'
557+
assert ref.dataset_id == 'dataset'
558+
assert ref.project == 'project'
559+
560+
@pytest.mark.parametrize('provided_schema_name,provided_table_name,client_project',
561+
[
562+
('project.dataset', 'other_dataset.table', 'project'),
563+
('project.dataset', 'other_project.dataset.table', 'project'),
564+
('project.dataset.something_else', 'table', 'project'),
565+
(None, 'project.dataset.table.something_else', 'project'),
566+
])
567+
def test_invalid_table_reference(dialect, provided_schema_name,
568+
provided_table_name, client_project):
569+
with pytest.raises(ValueError):
570+
dialect._table_reference(provided_schema_name,
571+
provided_table_name,
572+
client_project)
573+
574+
482575
def test_has_table(engine, engine_using_test_dataset):
483576
assert engine.has_table('sample', 'test_pybigquery') is True
484577
assert engine.has_table('test_pybigquery.sample') is True

0 commit comments

Comments
 (0)