33from __future__ import absolute_import
44from __future__ import unicode_literals
55
6+ import operator
7+
68from google import auth
79from google .cloud import bigquery
810from google .cloud .bigquery import dbapi
@@ -184,12 +186,19 @@ def visit_column(self, column, add_to_result_map=None,
184186 self .preparer .quote (tablename ) + \
185187 "." + name
186188
187- def visit_label (self , * args , ** kwargs ):
188- # Use labels in GROUP BY clause
189- if len (kwargs ) == 0 or len (kwargs ) == 1 :
189+ def visit_label (self , * args , within_group_by = False , ** kwargs ):
190+ # Use labels in GROUP BY clause.
191+ #
192+ # Flag set in the group_by_clause method. Works around missing
193+ # equivalent to supports_simple_order_by_label for group by.
194+ if within_group_by :
190195 kwargs ['render_label_as_label' ] = args [0 ]
191- result = super (BigQueryCompiler , self ).visit_label (* args , ** kwargs )
192- return result
196+ return super (BigQueryCompiler , self ).visit_label (* args , ** kwargs )
197+
198+ def group_by_clause (self , select , ** kw ):
199+ return super (BigQueryCompiler , self ).group_by_clause (
200+ select , ** kw , within_group_by = True
201+ )
193202
194203
195204class BigQueryTypeCompiler (GenericTypeCompiler ):
@@ -206,6 +215,9 @@ def visit_text(self, type_, **kw):
206215 def visit_string (self , type_ , ** kw ):
207216 return 'STRING'
208217
218+ def visit_ARRAY (self , type_ , ** kw ):
219+ return "ARRAY<{}>" .format (self .process (type_ .item_type , ** kw ))
220+
209221 def visit_BINARY (self , type_ , ** kw ):
210222 return 'BYTES'
211223
@@ -284,6 +296,11 @@ def __init__(
284296 def dbapi (cls ):
285297 return dbapi
286298
299+ @staticmethod
300+ def _build_formatted_table_id (table ):
301+ """Build '<dataset_id>.<table_id>' string using given table."""
302+ return "{}.{}" .format (table .reference .dataset_id , table .table_id )
303+
287304 @staticmethod
288305 def _add_default_dataset_to_job_config (job_config , project_id , dataset_id ):
289306 # If dataset_id is set, then we know the job_config isn't None
@@ -349,6 +366,26 @@ def _json_deserializer(self, row):
349366 """
350367 return row
351368
369+ def _get_table_or_view_names (self , connection , table_type , schema = None ):
370+ current_schema = schema or self .dataset_id
371+ get_table_name = self ._build_formatted_table_id \
372+ if self .dataset_id is None else \
373+ operator .attrgetter ("table_id" )
374+
375+ client = connection .connection ._client
376+ datasets = client .list_datasets ()
377+
378+ result = []
379+ for dataset in datasets :
380+ if current_schema is not None and current_schema != dataset .dataset_id :
381+ continue
382+
383+ tables = client .list_tables (dataset .reference )
384+ for table in tables :
385+ if table_type == table .table_type :
386+ result .append (get_table_name (table ))
387+ return result
388+
352389 @staticmethod
353390 def _split_table_name (full_table_name ):
354391 # Split full_table_name to get project, dataset and table name
@@ -363,22 +400,51 @@ def _split_table_name(full_table_name):
363400 dataset , table_name = table_name_split
364401 elif len (table_name_split ) == 3 :
365402 project , dataset , table_name = table_name_split
403+ else :
404+ raise ValueError ("Did not understand table_name: {}" .format (full_table_name ))
366405
367406 return (project , dataset , table_name )
368407
408+ def _table_reference (self , provided_schema_name , provided_table_name ,
409+ client_project ):
410+ project_id_from_table , dataset_id_from_table , table_id = self ._split_table_name (provided_table_name )
411+ project_id_from_schema = None
412+ dataset_id_from_schema = None
413+ if provided_schema_name is not None :
414+ provided_schema_name_split = provided_schema_name .split ('.' )
415+ if len (provided_schema_name_split ) == 0 :
416+ pass
417+ elif len (provided_schema_name_split ) == 1 :
418+ if dataset_id_from_table :
419+ project_id_from_schema = provided_schema_name_split [0 ]
420+ else :
421+ dataset_id_from_schema = provided_schema_name_split [0 ]
422+ elif len (provided_schema_name_split ) == 2 :
423+ project_id_from_schema = provided_schema_name_split [0 ]
424+ dataset_id_from_schema = provided_schema_name_split [1 ]
425+ else :
426+ raise ValueError ("Did not understand schema: {}" .format (provided_schema_name ))
427+ if (dataset_id_from_schema and dataset_id_from_table and
428+ dataset_id_from_schema != dataset_id_from_table ):
429+ raise ValueError ("dataset_id specified in schema and table_name disagree: got {} in schema, and {} in table_name" .format (dataset_id_from_schema , dataset_id_from_table ))
430+ if (project_id_from_schema and project_id_from_table and
431+ project_id_from_schema != project_id_from_table ):
432+ raise ValueError ("project_id specified in schema and table_name disagree: got {} in schema, and {} in table_name" .format (project_id_from_schema , project_id_from_table ))
433+ project_id = project_id_from_schema or project_id_from_table or client_project
434+ dataset_id = dataset_id_from_schema or dataset_id_from_table or self .dataset_id
435+
436+ table_ref = TableReference .from_string ("{}.{}.{}" .format (
437+ project_id , dataset_id , table_id
438+ ))
439+ return table_ref
440+
369441 def _get_table (self , connection , table_name , schema = None ):
370442 if isinstance (connection , Engine ):
371443 connection = connection .connect ()
372444
373445 client = connection .connection ._client
374446
375- project_id , dataset_id , table_id = self ._split_table_name (table_name )
376- project_id = project_id or client .project
377- dataset_id = dataset_id or schema or self .dataset_id
378-
379- table_ref = TableReference .from_string ("{}.{}.{}" .format (
380- project_id , dataset_id , table_id
381- ))
447+ table_ref = self ._table_reference (schema , table_name , client .project )
382448 try :
383449 table = client .get_table (table_ref )
384450 except NotFound :
@@ -464,23 +530,13 @@ def get_table_names(self, connection, schema=None, **kw):
464530 if isinstance (connection , Engine ):
465531 connection = connection .connect ()
466532
467- datasets = connection .connection ._client .list_datasets ()
468- result = []
469- for d in datasets :
470- if schema is not None and d .dataset_id != schema :
471- continue
533+ return self ._get_table_or_view_names (connection , "TABLE" , schema )
472534
473- if self .dataset_id is not None and d .dataset_id != self .dataset_id :
474- continue
535+ def get_view_names (self , connection , schema = None , ** kw ):
536+ if isinstance (connection , Engine ):
537+ connection = connection .connect ()
475538
476- tables = connection .connection ._client .list_tables (d .reference )
477- for t in tables :
478- if self .dataset_id is None :
479- table_name = d .dataset_id + '.' + t .table_id
480- else :
481- table_name = t .table_id
482- result .append (table_name )
483- return result
539+ return self ._get_table_or_view_names (connection , "VIEW" , schema )
484540
485541 def do_rollback (self , dbapi_connection ):
486542 # BigQuery has no support for transactions.
0 commit comments