|
22 | 22 | logging.getLogger().setLevel(logging.INFO) |
23 | 23 |
|
24 | 24 |
|
| 25 | +def _escape_sql_literal(val: str) -> str: |
| 26 | + r"""Escapes a string literal for use in nested BigQuery/Spanner queries. |
| 27 | +
|
| 28 | + This is required because the query string travels through two SQL parsers: |
| 29 | + 1. BigQuery parses the EXTERNAL_QUERY double-quoted string literal. |
| 30 | + 2. Spanner parses the resulting inner query's single-quoted string literal. |
| 31 | +
|
| 32 | + To ensure the value is correctly matched and prevent SQL injection: |
| 33 | + - Backslashes (\) are escaped to 4 backslashes (\\\\) so they survive |
| 34 | + both decodings (\\\\ -> \\ -> \). Otherwise, they may escape quotes |
| 35 | + or be interpreted as control characters (like \b becoming backspace). |
| 36 | + - Double quotes (") are escaped to \\" to prevent terminating BQ string. |
| 37 | + - Single quotes (') are escaped to '' to prevent terminating Spanner string. |
| 38 | + """ |
| 39 | + return val.replace('\\', '\\\\\\\\').replace('"', '\\"').replace("'", "''") |
| 40 | + |
| 41 | + |
25 | 42 | class BigQueryExecutor: |
26 | 43 | """Handles BigQuery client initialization and query execution.""" |
27 | 44 |
|
@@ -153,8 +170,11 @@ def run_linked_contained_in_place(self, |
153 | 170 | return |
154 | 171 |
|
155 | 172 | dest = self.executor.get_spanner_destination_uri() |
156 | | - provenances = [f"'dc/base/{name}'" for name in import_names] |
| 173 | + safe_names = [_escape_sql_literal(name) for name in import_names] |
| 174 | + prefix = "dc/base/" if self.is_base_dc else "" |
| 175 | + provenances = [f"'{prefix}{name}'" for name in safe_names] |
157 | 176 | provenance_filter = f" AND provenance IN ({', '.join(provenances)})" |
| 177 | + gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs' |
158 | 178 |
|
159 | 179 | query = f""" |
160 | 180 | -- Pull base edges needed for containedInPlace aggregation |
@@ -201,7 +221,7 @@ def run_linked_contained_in_place(self, |
201 | 221 | subject_id, |
202 | 222 | 'linkedContainedInPlace' as predicate, |
203 | 223 | ancestor_place as object_id, |
204 | | - 'dc/base/GeneratedGraphs' as provenance |
| 224 | + '{gen_graphs_prov}' as provenance |
205 | 225 | FROM |
206 | 226 | Ancestors |
207 | 227 | ), |
@@ -238,8 +258,11 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None: |
238 | 258 | return |
239 | 259 |
|
240 | 260 | dest = self.executor.get_spanner_destination_uri() |
241 | | - provenances = [f"'dc/base/{name}'" for name in import_names] |
| 261 | + safe_names = [_escape_sql_literal(name) for name in import_names] |
| 262 | + prefix = "dc/base/" if self.is_base_dc else "" |
| 263 | + provenances = [f"'{prefix}{name}'" for name in safe_names] |
242 | 264 | provenance_filter = f" AND provenance IN ({', '.join(provenances)})" |
| 265 | + gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs' |
243 | 266 |
|
244 | 267 | query = f""" |
245 | 268 | -- Pull base edges needed for memberOf aggregation |
@@ -289,7 +312,7 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None: |
289 | 312 | subject_id, |
290 | 313 | 'linkedMemberOf' as predicate, |
291 | 314 | ancestor as object_id, |
292 | | - 'dc/base/GeneratedGraphs' as provenance |
| 315 | + '{gen_graphs_prov}' as provenance |
293 | 316 | FROM |
294 | 317 | Ancestors |
295 | 318 | ), |
@@ -326,8 +349,11 @@ def run_linked_member(self, import_names: List[str] = None) -> None: |
326 | 349 | return |
327 | 350 |
|
328 | 351 | dest = self.executor.get_spanner_destination_uri() |
329 | | - provenances = [f"'dc/base/{name}'" for name in import_names] |
| 352 | + safe_names = [_escape_sql_literal(name) for name in import_names] |
| 353 | + prefix = "dc/base/" if self.is_base_dc else "" |
| 354 | + provenances = [f"'{prefix}{name}'" for name in safe_names] |
330 | 355 | provenance_filter = f" AND provenance IN ({', '.join(provenances)})" |
| 356 | + gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs' |
331 | 357 |
|
332 | 358 | query = f""" |
333 | 359 | -- Pull base edges needed for member aggregation |
@@ -375,7 +401,7 @@ def run_linked_member(self, import_names: List[str] = None) -> None: |
375 | 401 | descendant as subject_id, |
376 | 402 | 'linkedMember' as predicate, |
377 | 403 | subject_id as object_id, |
378 | | - 'dc/base/GeneratedGraphs' as provenance |
| 404 | + '{gen_graphs_prov}' as provenance |
379 | 405 | FROM |
380 | 406 | Descendants |
381 | 407 | WHERE subject_id LIKE 'dc/topic%' |
@@ -439,8 +465,10 @@ def run_provenance_summary_aggregation(self, |
439 | 465 | dest = self.executor.get_spanner_destination_uri() |
440 | 466 | connection_id = self.executor.connection_id |
441 | 467 |
|
| 468 | + safe_names = [_escape_sql_literal(name) for name in import_names] |
442 | 469 | # Format import names for the SQL IN clause |
443 | | - imports_str = ", ".join([f"'{name}'" for name in import_names]) |
| 470 | + imports_str = ", ".join([f"'{name}'" for name in safe_names]) |
| 471 | + provenance_dcid_expr = "CONCAT('dc/base/', raw.import_name)" if self.is_base_dc else "raw.import_name" |
444 | 472 |
|
445 | 473 | query = f""" |
446 | 474 | -- Step 1: Fetch Observation rows for the specific import |
@@ -504,7 +532,7 @@ def run_provenance_summary_aggregation(self, |
504 | 532 | raw.is_dc_aggregate, |
505 | 533 | JSON_VALUE(v, '$.key') as date_val, |
506 | 534 | SAFE_CAST(JSON_VALUE(v, '$.value') AS FLOAT64) as value_num, |
507 | | - CONCAT('dc/base/', raw.import_name) as provenance_dcid, |
| 535 | + {provenance_dcid_expr} as provenance_dcid, |
508 | 536 | nodes.name as place_name, |
509 | 537 | edges.place_type |
510 | 538 | FROM `temp_obs_raw` raw |
|
0 commit comments