|
11 | 11 | Disposition, |
12 | 12 | StatementState, |
13 | 13 | ) |
| 14 | +from google.cloud import storage |
| 15 | +from .validation_extension import generate_extension_schema |
14 | 16 | from .config import databricks_vars, gcs_vars |
15 | 17 | from .utilities import databricksify_inst_name, SchemaType |
16 | | -from typing import List, Any, Dict |
| 18 | +from typing import List, Any, Dict, IO, cast, Optional |
17 | 19 | from databricks.sdk.errors import DatabricksError |
| 20 | +from fastapi import HTTPException |
| 21 | + |
| 22 | +try: |
| 23 | + import tomllib as _toml # Py 3.11+ |
| 24 | +except ModuleNotFoundError: |
| 25 | + import tomli as _toml # Py ≤ 3.10 |
| 26 | +import pandas as pd |
| 27 | +import re |
18 | 28 |
|
19 | 29 | # Setting up logger |
20 | 30 | LOGGER = logging.getLogger(__name__) |
@@ -366,3 +376,180 @@ def fetch_table_data( |
366 | 376 |
|
367 | 377 | # Combine column names with corresponding row values |
368 | 378 | return [dict(zip(column_names, row)) for row in data_rows] |
| 379 | + |
| 380 | + def get_key_for_file( |
| 381 | + self, mapping: Dict[str, Any], file_name: str |
| 382 | + ) -> Optional[str]: |
| 383 | + """ |
| 384 | + Case-insensitive match of file_name against mapping values. |
| 385 | + Values may be: |
| 386 | + - str literal (e.g., "student.csv") → allow optional base suffixes before the ext. |
| 387 | + - str regex (e.g., r"^course_.*\.csv$") → re.IGNORECASE fullmatch. |
| 388 | + - compiled regex (re.Pattern) → fullmatch, adding IGNORECASE if missing. |
| 389 | + - list of any of the above. |
| 390 | + """ |
| 391 | + # normalize filename (handles windows paths + stray whitespace) |
| 392 | + name = os.path.basename(file_name.replace("\\", "/")).strip() |
| 393 | + |
| 394 | + REGEX_META = re.compile(r"[()\[\]\{\}\|\?\+\*\\]") |
| 395 | + |
| 396 | + def looks_like_regex(s: str) -> bool: |
| 397 | + s = s.strip() |
| 398 | + return ( |
| 399 | + s.startswith("^") or s.endswith("$") or REGEX_META.search(s) is not None |
| 400 | + ) |
| 401 | + |
| 402 | + def matches_one(pat: Any) -> bool: |
| 403 | + # compiled regex |
| 404 | + if isinstance(pat, re.Pattern): |
| 405 | + # ensure case-insensitive |
| 406 | + flags = pat.flags | re.IGNORECASE |
| 407 | + return re.fullmatch(re.compile(pat.pattern, flags), name) is not None |
| 408 | + |
| 409 | + # string literal / regex |
| 410 | + if isinstance(pat, str): |
| 411 | + p = pat.strip() |
| 412 | + |
| 413 | + # exact literal (case-insensitive) |
| 414 | + if name.casefold() == p.casefold(): |
| 415 | + return True |
| 416 | + |
| 417 | + if looks_like_regex(p): |
| 418 | + try: |
| 419 | + return re.fullmatch(p, name, flags=re.IGNORECASE) is not None |
| 420 | + except re.error: |
| 421 | + return False |
| 422 | + |
| 423 | + # literal with suffix tolerance |
| 424 | + p_base, p_ext = os.path.splitext(p) |
| 425 | + if p_ext: |
| 426 | + # ^base(?:[._-].+)?ext$ |
| 427 | + rx = re.compile( |
| 428 | + rf"^{re.escape(p_base)}(?:[._-].+)?{re.escape(p_ext)}$", |
| 429 | + re.IGNORECASE, |
| 430 | + ) |
| 431 | + else: |
| 432 | + # ^literal(?:[._-].+)?(?:\..+)?$ |
| 433 | + rx = re.compile( |
| 434 | + rf"^{re.escape(p)}(?:[._-].+)?(?:\..+)?$", |
| 435 | + re.IGNORECASE, |
| 436 | + ) |
| 437 | + return rx.fullmatch(name) is not None |
| 438 | + |
| 439 | + # unsupported type |
| 440 | + return False |
| 441 | + |
| 442 | + for key, val in mapping.items(): |
| 443 | + items = val if isinstance(val, list) else [val] |
| 444 | + for pat in items: |
| 445 | + if matches_one(pat): |
| 446 | + return key |
| 447 | + |
| 448 | + return None |
| 449 | + |
| 450 | + def create_custom_schema_extension( |
| 451 | + self, |
| 452 | + bucket_name: str, |
| 453 | + inst_query: Any, |
| 454 | + file_name: str, |
| 455 | + base_schema: Dict[str, Any], # pass base schema dict in |
| 456 | + extension_schema: Optional[dict] = None, # existing extension or None |
| 457 | + ) -> Any: |
| 458 | + if ( |
| 459 | + os.getenv("SST_SKIP_EXT_GEN") == "1" |
| 460 | + ): # skip using workspace client for tests |
| 461 | + LOGGER.info("SST_SKIP_EXT_GEN=1; skipping Databricks extension generation.") |
| 462 | + return None |
| 463 | + |
| 464 | + # 1) Databricks client |
| 465 | + try: |
| 466 | + w = WorkspaceClient( |
| 467 | + host=databricks_vars["DATABRICKS_HOST_URL"], |
| 468 | + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
| 469 | + ) |
| 470 | + LOGGER.info("Successfully created Databricks WorkspaceClient.") |
| 471 | + except Exception as e: |
| 472 | + LOGGER.exception("WorkspaceClient init failed") |
| 473 | + raise ValueError(f"Workspace client initialization failed: {e}") |
| 474 | + |
| 475 | + # 2) Fetch & parse config.toml to get validation_mapping |
| 476 | + try: |
| 477 | + inst_name = inst_query[0][0].name |
| 478 | + inst_id_raw = inst_query[0][0].id |
| 479 | + inst_id = str(inst_id_raw) # be robust if id is not a string |
| 480 | + config_volume_path = ( |
| 481 | + f"/Volumes/staging_sst_01/" |
| 482 | + f"{databricksify_inst_name(inst_name)}_bronze/bronze_volume/config.toml" |
| 483 | + ) |
| 484 | + LOGGER.info("Attempting to download from %s", config_volume_path) |
| 485 | + response = w.files.download(config_volume_path) |
| 486 | + stream = cast(IO[bytes], response.contents) |
| 487 | + file_bytes = stream.read() |
| 488 | + LOGGER.info("Download successful, received %d bytes", len(file_bytes)) |
| 489 | + except Exception as e: |
| 490 | + LOGGER.exception("Failed to fetch config.toml") |
| 491 | + raise HTTPException(500, detail=f"Failed to fetch config: {e}") |
| 492 | + |
| 493 | + try: |
| 494 | + cfg = _toml.loads(file_bytes.decode("utf-8")) |
| 495 | + mapping = cfg["webapp"]["validation_mapping"] |
| 496 | + except KeyError: |
| 497 | + raise HTTPException( |
| 498 | + 404, detail="Missing [webapp].validation_mapping in config.toml" |
| 499 | + ) |
| 500 | + except Exception as e: |
| 501 | + LOGGER.exception("Invalid TOML") |
| 502 | + raise HTTPException(400, detail=f"Invalid TOML in {file_name}: {e}") |
| 503 | + |
| 504 | + if not isinstance(mapping, dict): |
| 505 | + raise HTTPException( |
| 506 | + 400, detail="validation_mapping must be a TOML table (dictionary)" |
| 507 | + ) |
| 508 | + |
| 509 | + key = self.get_key_for_file(mapping, file_name) # e.g., "student" |
| 510 | + if key is None: |
| 511 | + raise HTTPException( |
| 512 | + 404, detail=f"{file_name} not found in {inst_name} validation_mapping" |
| 513 | + ) |
| 514 | + |
| 515 | + key_lc = key.lower() |
| 516 | + |
| 517 | + # 4) If this model already exists in the provided extension for this institution, skip |
| 518 | + if extension_schema is not None: |
| 519 | + if not isinstance(extension_schema, dict): |
| 520 | + raise HTTPException( |
| 521 | + 400, detail="extension_schema must be a dict if provided" |
| 522 | + ) |
| 523 | + |
| 524 | + inst_block = extension_schema.get("institutions", {}).get(inst_id, {}) |
| 525 | + data_models = inst_block.get("data_models", {}) |
| 526 | + existing_keys_lc = {str(k).lower() for k in data_models.keys()} |
| 527 | + |
| 528 | + if key_lc in existing_keys_lc: |
| 529 | + LOGGER.info( |
| 530 | + "Model '%s' already present for institution '%s' — skipping (return None).", |
| 531 | + key, |
| 532 | + inst_id, |
| 533 | + ) |
| 534 | + return None # <-- sentinel: do not write |
| 535 | + |
| 536 | + # 5) Read the unvalidated CSV from GCS |
| 537 | + try: |
| 538 | + client = storage.Client() |
| 539 | + bucket = client.bucket(bucket_name) |
| 540 | + blob = bucket.blob(f"unvalidated/{file_name}") |
| 541 | + with blob.open("r") as fh: |
| 542 | + df = pd.read_csv(fh) |
| 543 | + except Exception as e: |
| 544 | + LOGGER.exception("Failed to read %s from GCS", file_name) |
| 545 | + raise HTTPException(500, detail=f"Failed to read {file_name} from GCS: {e}") |
| 546 | + |
| 547 | + updated_extension = generate_extension_schema( |
| 548 | + df=df, |
| 549 | + models=key, # exactly one model |
| 550 | + institution_id=inst_id, |
| 551 | + base_schema=base_schema, # reference only, not mutated |
| 552 | + existing_extension=extension_schema, # may be None |
| 553 | + ) |
| 554 | + |
| 555 | + return updated_extension |
0 commit comments