Skip to content

Commit 2907468

Browse files
fix bugs
1 parent e239b61 commit 2907468

3 files changed

Lines changed: 77 additions & 60 deletions

File tree

packages/bigframes/bigframes/session/bq_caching_executor.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@ def __init__(
8787
self.cache: execution_cache.ExecutionCache = execution_cache.ExecutionCache()
8888
self.metrics = metrics
8989
self.loader = loader
90-
self.bqstoragereadclient = bqstoragereadclient
9190
self._enable_polars_execution = enable_polars_execution
9291
self._publisher = publisher
93-
self._labels = labels
9492

9593
# TODO(tswast): Send events from semi-executors, too.
9694
self._semi_executors: Sequence[semi_executor.SemiExecutor] = (
@@ -109,7 +107,12 @@ def __init__(
109107
)
110108
self._upload_lock = threading.Lock()
111109
self._gbq_executor = direct_gbq_execution.DirectGbqExecutor(
112-
bqclient, compile.compiler, metrics=self.metrics, publisher=self._publisher
110+
bqclient,
111+
compiler=compile.compiler,
112+
bqstoragereadclient=bqstoragereadclient,
113+
metrics=self.metrics,
114+
publisher=self._publisher,
115+
labels=labels,
113116
)
114117

115118
def to_sql(
@@ -132,7 +135,6 @@ def to_sql(
132135
)
133136
return compiled.sql
134137

135-
136138
def execute(
137139
self,
138140
array_value: bigframes.core.ArrayValue,
@@ -167,7 +169,6 @@ def _try_execute_semi_executors(
167169
return maybe_result
168170
return None
169171

170-
171172
def _execute_bigquery(
172173
self,
173174
array_value: bigframes.core.ArrayValue,
@@ -177,7 +178,10 @@ def _execute_bigquery(
177178
# Recursive handlers for different cases, maybe extract to explicit interface.
178179
if isinstance(dest_spec, ex_spec.GcsOutputSpec):
179180
execution_spec = dataclasses.replace(
180-
execution_spec, destination_spec=ex_spec.TempTableSpec(cluster_cols=dest_spec.cluster_cols, lifetime="ephemeral")
181+
execution_spec,
182+
destination_spec=ex_spec.TempTableSpec(
183+
cluster_cols=dest_spec.cluster_cols, lifetime="ephemeral"
184+
),
181185
)
182186
results = self._execute_bigquery(array_value, execution_spec)
183187
self._export_result_gcs(results, dest_spec)
@@ -191,7 +195,11 @@ def _execute_bigquery(
191195
existing_table.schema, array_value.schema
192196
):
193197
execution_spec = dataclasses.replace(
194-
execution_spec, destination_spec=ex_spec.TempTableSpec(cluster_cols=execution_spec.destination_spec.cluster_cols, lifetime="ephemeral")
198+
execution_spec,
199+
destination_spec=ex_spec.TempTableSpec(
200+
cluster_cols=execution_spec.destination_spec.cluster_cols,
201+
lifetime="ephemeral",
202+
),
195203
)
196204
results = self._execute_bigquery(array_value, execution_spec)
197205
self._export_gbq_with_dml(results, dest_spec)
@@ -213,10 +221,14 @@ def _execute_bigquery(
213221
)
214222
arr_value = bigframes.core.ArrayValue(plan)
215223
execution_spec = dataclasses.replace(
216-
execution_spec, destination_spec=ex_spec.TableOutputSpec(table=destination_table, cluster_cols=dest_spec.cluster_cols, if_exists="replace")
224+
execution_spec,
225+
destination_spec=ex_spec.TableOutputSpec(
226+
table=destination_table,
227+
cluster_cols=dest_spec.cluster_cols,
228+
if_exists="replace",
229+
),
217230
)
218231
return self._execute_bigquery(arr_value, execution_spec)
219-
220232

221233
# At this point, dst should be unspecified, a specific bq table, or an ephemeral temp table
222234
# Also, ordering mode will either be none or row-sorted
@@ -393,7 +405,11 @@ def _cache_with_cluster_cols(
393405
]
394406
cluster_cols = cluster_cols[:_MAX_CLUSTER_COLUMNS]
395407
execution_spec = ex_spec.ExecutionSpec(
396-
destination_spec=ex_spec.TempTableSpec(cluster_cols=tuple(cluster_cols), lifetime="session", ordering="order_key")
408+
destination_spec=ex_spec.TempTableSpec(
409+
cluster_cols=tuple(cluster_cols),
410+
lifetime="session",
411+
ordering="order_key",
412+
)
397413
)
398414
result_bq_data = self.execute(
399415
array_value,
@@ -405,7 +421,9 @@ def _cache_with_cluster_cols(
405421
def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
406422
"""Executes the query and uses the resulting table to rewrite future executions."""
407423
execution_spec = ex_spec.ExecutionSpec(
408-
destination_spec=ex_spec.TempTableSpec(cluster_cols=(), lifetime="session", ordering="offsets_col")
424+
destination_spec=ex_spec.TempTableSpec(
425+
cluster_cols=(), lifetime="session", ordering="offsets_col"
426+
)
409427
)
410428
result_bq_data = self.execute(
411429
array_value,

packages/bigframes/bigframes/session/direct_gbq_execution.py

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import bigframes.session._io.bigquery as bq_io
2727
from bigframes.core import bq_data, compile, nodes
2828
from bigframes.session import executor, semi_executor, execution_spec
29+
from bigframes.core.compile.configs import CompileRequest, CompileResult
2930
from bigframes import exceptions as bfe
3031
import bigframes.core.schema as schemata
32+
import google.cloud.bigquery_storage_v1
3133

3234
import google.api_core.exceptions
3335

@@ -43,11 +45,13 @@ class DirectGbqExecutor(semi_executor.SemiExecutor):
4345
def __init__(
4446
self,
4547
bqclient: bigquery.Client,
46-
compiler: Literal["ibis", "sqlglot"]
47-
| Callable[[compile.CompileRequest], executor.CompileResult] = "ibis",
48+
bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient,
4849
*,
50+
compiler: Literal["ibis", "sqlglot"]
51+
| Callable[[CompileRequest], CompileResult] = "ibis",
4952
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
5053
publisher: Optional[bigframes.core.events.Publisher] = None,
54+
labels: Mapping[str, str] = {},
5155
):
5256
self.bqclient = bqclient
5357
if isinstance(compiler, str):
@@ -58,8 +62,10 @@ def __init__(
5862
)
5963
else:
6064
self._compile_fn = compiler
65+
self._bqstoragereadclient = bqstoragereadclient
6166
self._publisher = publisher
6267
self._metrics = metrics
68+
self._labels = labels
6369

6470
def execute(
6571
self,
@@ -69,36 +75,43 @@ def execute(
6975
"""Just execute whatever plan as is, without further caching or decomposition."""
7076

7177
og_schema = plan.schema
72-
compile_request = compile.CompileRequest(
73-
plan,
74-
sort_rows=spec.ordered,
75-
peek_count=spec.peek,
76-
)
78+
compile_request = CompileRequest(
79+
plan,
80+
sort_rows=spec.ordered,
81+
peek_count=spec.peek,
82+
)
7783

7884
compiled = self._compile_fn(compile_request)
7985
# might have more columns than og schema, for hidden ordering columns
8086
compiled_schema = compiled.sql_schema
8187

8288
job_config = bigquery.QueryJobConfig()
83-
if isinstance(spec.destination_spec, execution_spec.TableOutputSpec):
84-
job_config.destination = spec.destination_spec.table
85-
job_config.write_disposition = _WRITE_DISPOSITIONS[spec.destination_spec.if_exists]
86-
job_config.clustering_fields = spec.destination_spec.cluster_cols
87-
elif isinstance(spec.destination_spec, execution_spec.TempTableSpec) and spec.destination_spec.lifetime == "ephemeral":
88-
pass
89-
elif spec.destination_spec is not None:
90-
raise ValueError(f"Direct GBQ Executor does not support destination: {spec.destination_spec}")
89+
dest_spec = spec.destination_spec
90+
cluster_cols = ()
91+
if isinstance(dest_spec, execution_spec.TableOutputSpec):
92+
job_config.destination = dest_spec.table
93+
job_config.write_disposition = _WRITE_DISPOSITIONS[dest_spec.if_exists]
94+
cluster_cols = dest_spec.cluster_cols
95+
job_config.clustering_fields = dest_spec.cluster_cols
96+
elif (
97+
isinstance(dest_spec, execution_spec.TempTableSpec)
98+
and dest_spec.lifetime == "ephemeral"
99+
):
100+
cluster_cols = dest_spec.cluster_cols
101+
job_config.clustering_fields = dest_spec.cluster_cols
102+
elif dest_spec is not None:
103+
raise ValueError(
104+
f"Direct GBQ Executor does not support destination: {dest_spec}"
105+
)
91106

92107
job_config.labels["bigframes-dtypes"] = compiled.encoded_type_refs
93-
can_skip_job = spec.destination_spec is None and spec.promise_under_10gb
108+
can_skip_job = dest_spec is None and spec.promise_under_10gb
94109
iterator, query_job = self._run_execute_query(
95110
sql=compiled.sql,
96111
job_config=job_config,
97112
query_with_job=(not can_skip_job),
98113
session=plan.session,
99114
)
100-
101-
cluster_cols = spec.destination_spec.cluster_cols if spec.desination_spec else ()
102115
result_bq_data = None
103116
if query_job and query_job.destination:
104117
# we might add extra sql columns in compilation, esp if caching w ordering, infer a bigframes type for them
@@ -128,7 +141,7 @@ def execute(
128141
return executor.BQTableExecuteResult(
129142
data=result_bq_data,
130143
project_id=self.bqclient.project,
131-
storage_client=self.bqstoragereadclient,
144+
storage_client=self._bqstoragereadclient,
132145
execution_metadata=execution_metadata,
133146
selected_fields=tuple((col, col) for col in og_schema.names),
134147
)
@@ -159,34 +172,15 @@ def _run_execute_query(
159172
job_config.labels.update(self._labels)
160173

161174
try:
162-
# Trick the type checker into thinking we got a literal.
163-
if query_with_job:
164-
return bq_io.start_query_with_client(
165-
self.bqclient,
166-
sql,
167-
job_config=job_config,
168-
metrics=self._metrics,
169-
project=None,
170-
location=None,
171-
timeout=None,
172-
query_with_job=True,
173-
publisher=self._publisher,
174-
session=session,
175-
)
176-
else:
177-
return bq_io.start_query_with_client(
178-
self.bqclient,
179-
sql,
180-
job_config=job_config,
181-
metrics=self._metrics,
182-
project=None,
183-
location=None,
184-
timeout=None,
185-
query_with_job=False,
186-
publisher=self._publisher,
187-
session=session,
188-
)
189-
175+
return bq_io.start_query_with_client(
176+
self.bqclient,
177+
sql,
178+
job_config=job_config,
179+
metrics=self._metrics,
180+
query_with_job=query_with_job,
181+
publisher=self._publisher,
182+
session=session,
183+
)
190184
except google.api_core.exceptions.BadRequest as e:
191185
# Unfortunately, this error type does not have a separate error code or exception type
192186
if "Resources exceeded during query execution" in e.message:
@@ -195,6 +189,7 @@ def _run_execute_query(
195189
else:
196190
raise
197191

192+
198193
def _result_schema(
199194
logical_schema: schemata.ArraySchema, sql_schema: list[bigquery.SchemaField]
200195
) -> schemata.ArraySchema:

packages/bigframes/bigframes/session/execution_spec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ExecutionSpec:
2828
# implementation dependent and not stable.
2929
peek: Optional[int] = None
3030
# Controls whether output iterator is ordered. Cannot be true if destination is not
31-
# guaranteed to be ordered.
31+
# guaranteed to be ordered.
3232
ordered: bool = False
3333
# This is an optimization flag for gbq execution, it doesn't change semantics, but if promise is falsely made, errors may occur
3434
promise_under_10gb: bool = False
@@ -41,7 +41,10 @@ class TempTableSpec:
4141
Specifies that the result of an operation should be a session temp table.
4242
The table will be automatically deleted after the session ends.
4343
"""
44-
cluster_cols: tuple[str, ...] # if empty, will cluster using order key if ordering_key is set
44+
45+
cluster_cols: tuple[
46+
str, ...
47+
] # if empty, will cluster using order key if ordering_key is set
4548
lifetime: Literal["session", "ephemeral"] = "session"
4649
# Controls ordering and whether extra columns are materialized to preserve ordering
4750
# Any extra columns will be appended to the end of the schema.
@@ -59,6 +62,7 @@ class TableOutputSpec:
5962
6063
The executor is not responsible for managing lifecycle of the table.
6164
"""
65+
6266
table: bigquery.TableReference
6367
cluster_cols: tuple[str, ...]
6468
if_exists: Literal["fail", "replace", "append"] = "fail"

0 commit comments

Comments
 (0)