Skip to content

Commit 5d07395

Browse files
Restore conditional "dc/base/" prefixing for aggregations utils (#2061)
1 parent 327588e commit 5d07395

1 file changed

Lines changed: 36 additions & 8 deletions

File tree

import-automation/workflow/ingestion-helper/aggregation_utils.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@
2222
logging.getLogger().setLevel(logging.INFO)
2323

2424

25+
def _escape_sql_literal(val: str) -> str:
26+
r"""Escapes a string literal for use in nested BigQuery/Spanner queries.
27+
28+
This is required because the query string travels through two SQL parsers:
29+
1. BigQuery parses the EXTERNAL_QUERY double-quoted string literal.
30+
2. Spanner parses the resulting inner query's single-quoted string literal.
31+
32+
To ensure the value is correctly matched and prevent SQL injection:
33+
- Backslashes (\) are escaped to 4 backslashes (\\\\) so they survive
34+
both decodings (\\\\ -> \\ -> \). Otherwise, they may escape quotes
35+
or be interpreted as control characters (like \b becoming backspace).
36+
- Double quotes (") are escaped to \\" to prevent terminating BQ string.
37+
- Single quotes (') are escaped to '' to prevent terminating Spanner string.
38+
"""
39+
return val.replace('\\', '\\\\\\\\').replace('"', '\\"').replace("'", "''")
40+
41+
2542
class BigQueryExecutor:
2643
"""Handles BigQuery client initialization and query execution."""
2744

@@ -153,8 +170,11 @@ def run_linked_contained_in_place(self,
153170
return
154171

155172
dest = self.executor.get_spanner_destination_uri()
156-
provenances = [f"'dc/base/{name}'" for name in import_names]
173+
safe_names = [_escape_sql_literal(name) for name in import_names]
174+
prefix = "dc/base/" if self.is_base_dc else ""
175+
provenances = [f"'{prefix}{name}'" for name in safe_names]
157176
provenance_filter = f" AND provenance IN ({', '.join(provenances)})"
177+
gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs'
158178

159179
query = f"""
160180
-- Pull base edges needed for containedInPlace aggregation
@@ -201,7 +221,7 @@ def run_linked_contained_in_place(self,
201221
subject_id,
202222
'linkedContainedInPlace' as predicate,
203223
ancestor_place as object_id,
204-
'dc/base/GeneratedGraphs' as provenance
224+
'{gen_graphs_prov}' as provenance
205225
FROM
206226
Ancestors
207227
),
@@ -238,8 +258,11 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None:
238258
return
239259

240260
dest = self.executor.get_spanner_destination_uri()
241-
provenances = [f"'dc/base/{name}'" for name in import_names]
261+
safe_names = [_escape_sql_literal(name) for name in import_names]
262+
prefix = "dc/base/" if self.is_base_dc else ""
263+
provenances = [f"'{prefix}{name}'" for name in safe_names]
242264
provenance_filter = f" AND provenance IN ({', '.join(provenances)})"
265+
gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs'
243266

244267
query = f"""
245268
-- Pull base edges needed for memberOf aggregation
@@ -289,7 +312,7 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None:
289312
subject_id,
290313
'linkedMemberOf' as predicate,
291314
ancestor as object_id,
292-
'dc/base/GeneratedGraphs' as provenance
315+
'{gen_graphs_prov}' as provenance
293316
FROM
294317
Ancestors
295318
),
@@ -326,8 +349,11 @@ def run_linked_member(self, import_names: List[str] = None) -> None:
326349
return
327350

328351
dest = self.executor.get_spanner_destination_uri()
329-
provenances = [f"'dc/base/{name}'" for name in import_names]
352+
safe_names = [_escape_sql_literal(name) for name in import_names]
353+
prefix = "dc/base/" if self.is_base_dc else ""
354+
provenances = [f"'{prefix}{name}'" for name in safe_names]
330355
provenance_filter = f" AND provenance IN ({', '.join(provenances)})"
356+
gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs'
331357

332358
query = f"""
333359
-- Pull base edges needed for member aggregation
@@ -375,7 +401,7 @@ def run_linked_member(self, import_names: List[str] = None) -> None:
375401
descendant as subject_id,
376402
'linkedMember' as predicate,
377403
subject_id as object_id,
378-
'dc/base/GeneratedGraphs' as provenance
404+
'{gen_graphs_prov}' as provenance
379405
FROM
380406
Descendants
381407
WHERE subject_id LIKE 'dc/topic%'
@@ -439,8 +465,10 @@ def run_provenance_summary_aggregation(self,
439465
dest = self.executor.get_spanner_destination_uri()
440466
connection_id = self.executor.connection_id
441467

468+
safe_names = [_escape_sql_literal(name) for name in import_names]
442469
# Format import names for the SQL IN clause
443-
imports_str = ", ".join([f"'{name}'" for name in import_names])
470+
imports_str = ", ".join([f"'{name}'" for name in safe_names])
471+
provenance_dcid_expr = "CONCAT('dc/base/', raw.import_name)" if self.is_base_dc else "raw.import_name"
444472

445473
query = f"""
446474
-- Step 1: Fetch Observation rows for the specific import
@@ -504,7 +532,7 @@ def run_provenance_summary_aggregation(self,
504532
raw.is_dc_aggregate,
505533
JSON_VALUE(v, '$.key') as date_val,
506534
SAFE_CAST(JSON_VALUE(v, '$.value') AS FLOAT64) as value_num,
507-
CONCAT('dc/base/', raw.import_name) as provenance_dcid,
535+
{provenance_dcid_expr} as provenance_dcid,
508536
nodes.name as place_name,
509537
edges.place_type
510538
FROM `temp_obs_raw` raw

0 commit comments

Comments
 (0)