|
| 1 | +""" |
| 2 | +Artifact Validator Class |
| 3 | +Provides a clean interface for validating completeness and correctness of the ingested types of Snowflake artifacts. |
| 4 | +""" |
| 5 | +import json |
| 6 | +from abc import ABC, abstractmethod |
| 7 | +from typing import List, Dict, Any |
| 8 | +from snowflake.snowpark import Session |
| 9 | +from migration_accelerator_package.constants import ArtifactType, ArtifactFileName |
| 10 | +from databricks.sdk.runtime import * |
| 11 | + |
| 12 | +def normalize_column(col: Dict[str, Any]) -> Dict[str, Any]: |
| 13 | + """ |
| 14 | + Normalize column metadata for comparison. |
| 15 | + Only keeps the essential fields needed for correctness validation. |
| 16 | + """ |
| 17 | + col = {k.lower(): v for k, v in col.items()} |
| 18 | + |
| 19 | + return { |
| 20 | + "column_name": col.get("column_name"), |
| 21 | + "data_type": col.get("data_type"), |
| 22 | + "is_nullable": col.get("is_nullable"), |
| 23 | + } |
| 24 | + |
| 25 | + |
| 26 | +class MetadataValidator: |
| 27 | + """ |
| 28 | + Validates completeness and correctness of extracted Snowflake metadata. |
| 29 | + """ |
| 30 | + def __init__(self, session: Session, volume_path: str): |
| 31 | + self.session = session |
| 32 | + self.volume_path = volume_path |
| 33 | + |
| 34 | + def _load_extracted(self, filename: str) -> Dict[str, Any]: |
| 35 | + path = f"{self.volume_path}/{filename}" |
| 36 | + raw = dbutils.fs.head(path, 50_000_000) |
| 37 | + return json.loads(raw) |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | + def load_all_artifacts(self) -> Dict[str, Dict[str, Any]]: |
| 42 | + extracted = {} |
| 43 | + for artifact_type in ArtifactType: |
| 44 | + file_enum = ArtifactFileName[artifact_type.name] |
| 45 | + filename = file_enum.value |
| 46 | + extracted[artifact_type.value] = self._load_extracted(filename) |
| 47 | + return extracted |
| 48 | + |
| 49 | + def count_snowflake_artifacts(self, artifact_type: ArtifactType, db: str, schema: str) -> int: |
| 50 | + if artifact_type == ArtifactType.TABLES: |
| 51 | + query = f""" |
| 52 | + SELECT COUNT(*) FROM {db}.information_schema.tables |
| 53 | + WHERE table_schema = '{schema}' |
| 54 | + AND table_type = 'BASE TABLE' |
| 55 | + """ |
| 56 | + elif artifact_type == ArtifactType.VIEWS: |
| 57 | + query = f""" |
| 58 | + SELECT COUNT(*) |
| 59 | + FROM {db}.information_schema.views |
| 60 | + WHERE table_schema = '{schema}' |
| 61 | + """ |
| 62 | + elif artifact_type == ArtifactType.PROCEDURES: |
| 63 | + query = f""" |
| 64 | + SELECT COUNT(*) |
| 65 | + FROM {db}.information_schema.procedures |
| 66 | + WHERE procedure_schema = '{schema}' |
| 67 | + """ |
| 68 | + elif artifact_type == ArtifactType.FUNCTIONS: |
| 69 | + query = f""" |
| 70 | + SELECT COUNT(*) |
| 71 | + FROM {db}.information_schema.functions |
| 72 | + WHERE function_schema = '{schema}' |
| 73 | + """ |
| 74 | + elif artifact_type == ArtifactType.SEQUENCES: |
| 75 | + query = f""" |
| 76 | + SELECT COUNT(*) |
| 77 | + FROM {db}.information_schema.sequences |
| 78 | + WHERE sequence_schema = '{schema}' |
| 79 | + """ |
| 80 | + elif artifact_type == ArtifactType.STAGES: |
| 81 | + query = f"SHOW STAGES IN SCHEMA {db}.{schema}" |
| 82 | + return len(self.session.sql(query).collect()) |
| 83 | + elif artifact_type == ArtifactType.FILE_FORMATS: |
| 84 | + query = f"SHOW FILE FORMATS IN SCHEMA {db}.{schema}" |
| 85 | + return len(self.session.sql(query).collect()) |
| 86 | + elif artifact_type == ArtifactType.TASKS: |
| 87 | + query = f"SHOW TASKS IN SCHEMA {db}.{schema}" |
| 88 | + return len(self.session.sql(query).collect()) |
| 89 | + elif artifact_type == ArtifactType.STREAMS: |
| 90 | + query = f"SHOW STREAMS IN SCHEMA {db}.{schema}" |
| 91 | + return len(self.session.sql(query).collect()) |
| 92 | + elif artifact_type == ArtifactType.PIPES: |
| 93 | + query = f"SHOW PIPES IN SCHEMA {db}.{schema}" |
| 94 | + return len(self.session.sql(query).collect()) |
| 95 | + else: |
| 96 | + return 0 |
| 97 | + |
| 98 | + return self.session.sql(query).collect()[0][0] |
| 99 | + |
| 100 | + def validate_completeness(self, extracted: Dict[str, Any], db: str, schema: str): |
| 101 | + completeness = {} |
| 102 | + |
| 103 | + for artifact_type in ArtifactType: |
| 104 | + snowflake_count = self.count_snowflake_artifacts(artifact_type, db, schema) |
| 105 | + extracted_count = len(extracted[artifact_type.value][artifact_type.value]) |
| 106 | + |
| 107 | + coverage = (extracted_count / snowflake_count) if snowflake_count > 0 else 1.0 |
| 108 | + |
| 109 | + completeness[artifact_type.value] = { |
| 110 | + "snowflake": snowflake_count, |
| 111 | + "extracted": extracted_count, |
| 112 | + "coverage_pct": round(coverage * 100, 2), |
| 113 | + "perfect_match": extracted_count == snowflake_count |
| 114 | + } |
| 115 | + |
| 116 | + return completeness |
| 117 | + |
| 118 | + def validate_table_definition(self, db, schema, extracted_table: Dict[str, Any]) -> Dict[str, Any]: |
| 119 | + table_name = extracted_table["table_name"] |
| 120 | + |
| 121 | + query = f""" |
| 122 | + SELECT column_name, data_type, is_nullable, |
| 123 | + character_maximum_length, numeric_precision, numeric_scale, |
| 124 | + column_default, comment |
| 125 | + FROM {db}.information_schema.columns |
| 126 | + WHERE table_schema = '{schema}' |
| 127 | + AND table_name = '{table_name}' |
| 128 | + ORDER BY ordinal_position |
| 129 | + """ |
| 130 | + |
| 131 | + # Normalize Snowflake columns |
| 132 | + sf_columns_raw = [dict(row.as_dict()) for row in self.session.sql(query).collect()] |
| 133 | + sf_columns = [normalize_column(col) for col in sf_columns_raw] |
| 134 | + |
| 135 | + # Normalize extracted columns |
| 136 | + extracted_columns_raw = extracted_table.get("columns", []) |
| 137 | + extracted_columns = [normalize_column(col) for col in extracted_columns_raw] |
| 138 | + |
| 139 | + # Compute correctness statistics |
| 140 | + total = max(len(sf_columns), len(extracted_columns), 1) |
| 141 | + matches = sum(1 for sf, ex in zip(sf_columns, extracted_columns) if sf == ex) |
| 142 | + correctness_pct = round((matches / total) * 100, 2) |
| 143 | + |
| 144 | + return { |
| 145 | + "table": table_name, |
| 146 | + "snowflake_column_count": len(sf_columns), |
| 147 | + "extracted_column_count": len(extracted_columns), |
| 148 | + "matches": matches, |
| 149 | + "total_columns": total, |
| 150 | + "correctness_pct": correctness_pct, |
| 151 | + "columns_match_exactly": sf_columns == extracted_columns, |
| 152 | + "snowflake": sf_columns, |
| 153 | + "extracted": extracted_columns |
| 154 | + } |
| 155 | + |
| 156 | + |
| 157 | + |
| 158 | + def validate_view_definition(self, db, schema, extracted_view: Dict[str, Any]) -> Dict[str, Any]: |
| 159 | + view_name = extracted_view["view_name"] |
| 160 | + |
| 161 | + query = f""" |
| 162 | + SELECT view_definition |
| 163 | + FROM {db}.information_schema.views |
| 164 | + WHERE table_schema = '{schema}' |
| 165 | + AND table_name = '{view_name}' |
| 166 | + """ |
| 167 | + result = self.session.sql(query).collect() |
| 168 | + sf_def = result[0]["VIEW_DEFINITION"] if result else "" |
| 169 | + |
| 170 | + extracted_def = extracted_view.get("view_definition", "") |
| 171 | + |
| 172 | + # Normalize whitespace for fair comparison |
| 173 | + def normalize(s): |
| 174 | + return " ".join(s.lower().strip().split()) |
| 175 | + |
| 176 | + sf_norm = normalize(sf_def) |
| 177 | + ex_norm = normalize(extracted_def) |
| 178 | + |
| 179 | + match = sf_norm == ex_norm |
| 180 | + |
| 181 | + return { |
| 182 | + "view": view_name, |
| 183 | + "match": match, |
| 184 | + "snowflake_definition": sf_def, |
| 185 | + "extracted_definition": extracted_def |
| 186 | + } |
| 187 | + |
| 188 | + |
| 189 | + |
| 190 | + |
| 191 | + |
0 commit comments