1717import math
1818import threading
1919from typing import Literal , Mapping , Optional , Sequence , Tuple
20- import weakref
2120
2221import google .api_core .exceptions
2322from google .cloud import bigquery
4746 semi_executor ,
4847)
4948import bigframes .session ._io .bigquery as bq_io
49+ import bigframes .session .execution_cache as execution_cache
5050import bigframes .session .execution_spec as ex_spec
5151import bigframes .session .metrics
5252import bigframes .session .planner
5959_MAX_CLUSTER_COLUMNS = 4
6060MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G
6161
62- SourceIdMapping = Mapping [str , str ]
63-
64-
65- class ExecutionCache :
66- def __init__ (self ):
67- # current assumption is only 1 cache of a given node
68- # in future, might have multiple caches, with different layout, localities
69- self ._cached_executions : weakref .WeakKeyDictionary [
70- nodes .BigFrameNode , nodes .CachedTableNode
71- ] = weakref .WeakKeyDictionary ()
72- self ._uploaded_local_data : weakref .WeakKeyDictionary [
73- local_data .ManagedArrowTable ,
74- tuple [bq_data .BigqueryDataSource , SourceIdMapping ],
75- ] = weakref .WeakKeyDictionary ()
76-
77- @property
78- def mapping (self ) -> Mapping [nodes .BigFrameNode , nodes .BigFrameNode ]:
79- return self ._cached_executions
80-
81- def cache_results_table (
82- self ,
83- original_root : nodes .BigFrameNode ,
84- data : bq_data .BigqueryDataSource ,
85- ):
86- # Assumption: GBQ cached table uses field name as bq column name
87- scan_list = nodes .ScanList (
88- tuple (
89- nodes .ScanItem (field .id , field .id .sql ) for field in original_root .fields
90- )
91- )
92- cached_replacement = nodes .CachedTableNode (
93- source = data ,
94- scan_list = scan_list ,
95- table_session = original_root .session ,
96- original_node = original_root ,
97- )
98- assert original_root .schema == cached_replacement .schema
99- self ._cached_executions [original_root ] = cached_replacement
100-
101- def cache_remote_replacement (
102- self ,
103- local_data : local_data .ManagedArrowTable ,
104- bq_data : bq_data .BigqueryDataSource ,
105- ):
106- # bq table has one extra column for offsets, those are implicit for local data
107- assert len (local_data .schema .items ) + 1 == len (bq_data .table .physical_schema )
108- mapping = {
109- local_data .schema .items [i ].column : bq_data .table .physical_schema [i ].name
110- for i in range (len (local_data .schema ))
111- }
112- self ._uploaded_local_data [local_data ] = (bq_data , mapping )
113-
11462
11563class BigQueryCachingExecutor (executor .Executor ):
11664 """Computes BigFrames values using BigQuery Engine.
@@ -128,20 +76,20 @@ def __init__(
12876 bqstoragereadclient : google .cloud .bigquery_storage_v1 .BigQueryReadClient ,
12977 loader : loader .GbqDataLoader ,
13078 * ,
131- strictly_ordered : bool = True ,
13279 metrics : Optional [bigframes .session .metrics .ExecutionMetrics ] = None ,
13380 enable_polars_execution : bool = False ,
13481 publisher : bigframes .core .events .Publisher ,
82+ labels : Mapping [str , str ] = {},
13583 ):
13684 self .bqclient = bqclient
13785 self .storage_manager = storage_manager
138- self .strictly_ordered : bool = strictly_ordered
139- self .cache : ExecutionCache = ExecutionCache ()
86+ self .cache : execution_cache .ExecutionCache = execution_cache .ExecutionCache ()
14087 self .metrics = metrics
14188 self .loader = loader
14289 self .bqstoragereadclient = bqstoragereadclient
14390 self ._enable_polars_execution = enable_polars_execution
14491 self ._publisher = publisher
92+ self ._labels = labels
14593
14694 # TODO(tswast): Send events from semi-executors, too.
14795 self ._semi_executors : Sequence [semi_executor .SemiExecutor ] = (
@@ -409,8 +357,8 @@ def _run_execute_query(
409357 bigframes .options .compute .maximum_bytes_billed
410358 )
411359
412- if not self .strictly_ordered :
413- job_config .labels [ "bigframes-mode" ] = "unordered"
360+ if self ._labels :
361+ job_config .labels . update ( self . _labels )
414362
415363 try :
416364 # Trick the type checker into thinking we got a literal.
@@ -449,9 +397,6 @@ def _run_execute_query(
449397 else :
450398 raise
451399
452- def replace_cached_subtrees (self , node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
453- return nodes .top_down (node , lambda x : self .cache .mapping .get (x , x ))
454-
455400 def _is_trivially_executable (self , array_value : bigframes .core .ArrayValue ):
456401 """
457402 Can the block be evaluated very cheaply?
@@ -481,7 +426,7 @@ def prepare_plan(
481426 ):
482427 self ._simplify_with_caching (plan )
483428
484- plan = self .replace_cached_subtrees (plan )
429+ plan = self .cache . subsitute_cached_subplans (plan )
485430 plan = rewrite .column_pruning (plan )
486431 plan = plan .top_down (rewrite .fold_row_counts )
487432
@@ -526,7 +471,7 @@ def _cache_with_session_awareness(
526471 self ._cache_with_cluster_cols (
527472 bigframes .core .ArrayValue (target ), cluster_cols_sql_names
528473 )
529- elif self . strictly_ordered :
474+ elif not target . order_ambiguous :
530475 self ._cache_with_offsets (bigframes .core .ArrayValue (target ))
531476 else :
532477 self ._cache_with_cluster_cols (bigframes .core .ArrayValue (target ), [])
@@ -551,7 +496,7 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
551496 node ,
552497 min_complexity = (QUERY_COMPLEXITY_LIMIT / 500 ),
553498 max_complexity = QUERY_COMPLEXITY_LIMIT ,
554- cache = dict ( self .cache . mapping ) ,
499+ cache = self .cache ,
555500 # Heuristic: subtree_compleixty * (copies of subtree)^2
556501 heuristic = lambda complexity , count : math .log (complexity )
557502 + 2 * math .log (count ),
@@ -580,32 +525,37 @@ def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode):
580525 def map_local_scans (node : nodes .BigFrameNode ):
581526 if not isinstance (node , nodes .ReadLocalNode ):
582527 return node
583- if node .local_data_source not in self .cache ._uploaded_local_data :
584- return node
585- bq_source , source_mapping = self .cache ._uploaded_local_data [
528+ uploaded_local_data = self .cache .get_uploaded_local_data (
586529 node .local_data_source
587- ]
588- scan_list = node .scan_list .remap_source_ids (source_mapping )
530+ )
531+ if uploaded_local_data is None :
532+ return node
533+
534+ scan_list = node .scan_list .remap_source_ids (
535+ uploaded_local_data .source_mapping
536+ )
589537 # offsets_col isn't part of ReadTableNode, so emulate by adding to end of scan_list
590538 if node .offsets_col is not None :
591539 # Offsets are always implicitly the final column of uploaded data
592540 # See: Loader.load_data
593541 scan_list = scan_list .append (
594- bq_source .table .physical_schema [- 1 ].name ,
542+ uploaded_local_data . bq_source .table .physical_schema [- 1 ].name ,
595543 bigframes .dtypes .INT_DTYPE ,
596544 node .offsets_col ,
597545 )
598- return nodes .ReadTableNode (bq_source , scan_list , node .session )
546+ return nodes .ReadTableNode (
547+ uploaded_local_data .bq_source , scan_list , node .session
548+ )
599549
600550 return original_root .bottom_up (map_local_scans )
601551
602552 def _upload_local_data (self , local_table : local_data .ManagedArrowTable ):
603- if local_table in self .cache ._uploaded_local_data :
553+ if self .cache .get_uploaded_local_data ( local_table ) is not None :
604554 return
605555 # Lock prevents concurrent repeated work, but slows things down.
606556 # Might be better as a queue and a worker thread
607557 with self ._upload_lock :
608- if local_table not in self .cache ._uploaded_local_data :
558+ if self .cache .get_uploaded_local_data ( local_table ) is None :
609559 uploaded = self .loader .load_data_or_write_data (
610560 local_table , bigframes .core .guid .generate_guid ()
611561 )
0 commit comments