Skip to content

Commit e629990

Browse files
Merge pull request #6 from thisisqubika/DC-308-phase-1-metadata-extraction-validation
Dc 308 phase 1 metadata extraction validation
2 parents 2af06fd + 5b82f59 commit e629990

5 files changed

Lines changed: 319 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ build-backend = "poetry.core.masonry.api"
2222
[tool.poetry.scripts]
2323
migration-accelerator = "migration_accelerator_package.main:main"
2424
snowpark-reader = "migration_accelerator_package.snowpark:main"
25+
snowflake-validator = "migration_accelerator_package.ingestion_validation:main"
26+
2527

resources/jobs.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ resources:
1414
entry_point: snowpark-reader
1515
package_name: migration_accelerator_package
1616
environment_key: "serverless_wheel"
17+
18+
- task_key: snowflake_ingestion_validation_task
19+
depends_on:
20+
- task_key: snowflake_ingestion_task
21+
python_wheel_task:
22+
entry_point: snowflake-validator
23+
package_name: migration_accelerator_package
24+
environment_key: "serverless_wheel"
25+
1726

1827
environments:
1928
- environment_key: "serverless_wheel"
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Artifact Validator Class
3+
Provides a clean interface for validating completeness and correctness of the ingested types of Snowflake artifacts.
4+
"""
5+
import json
6+
from abc import ABC, abstractmethod
7+
from typing import List, Dict, Any
8+
from snowflake.snowpark import Session
9+
from migration_accelerator_package.constants import ArtifactType, ArtifactFileName
10+
from databricks.sdk.runtime import *
11+
12+
def normalize_column(col: Dict[str, Any]) -> Dict[str, Any]:
13+
"""
14+
Normalize column metadata for comparison.
15+
Only keeps the essential fields needed for correctness validation.
16+
"""
17+
col = {k.lower(): v for k, v in col.items()}
18+
19+
return {
20+
"column_name": col.get("column_name"),
21+
"data_type": col.get("data_type"),
22+
"is_nullable": col.get("is_nullable"),
23+
}
24+
25+
26+
class MetadataValidator:
27+
"""
28+
Validates completeness and correctness of extracted Snowflake metadata.
29+
"""
30+
def __init__(self, session: Session, volume_path: str):
31+
self.session = session
32+
self.volume_path = volume_path
33+
34+
def _load_extracted(self, filename: str) -> Dict[str, Any]:
35+
path = f"{self.volume_path}/{filename}"
36+
raw = dbutils.fs.head(path, 50_000_000)
37+
return json.loads(raw)
38+
39+
40+
41+
def load_all_artifacts(self) -> Dict[str, Dict[str, Any]]:
42+
extracted = {}
43+
for artifact_type in ArtifactType:
44+
file_enum = ArtifactFileName[artifact_type.name]
45+
filename = file_enum.value
46+
extracted[artifact_type.value] = self._load_extracted(filename)
47+
return extracted
48+
49+
def count_snowflake_artifacts(self, artifact_type: ArtifactType, db: str, schema: str) -> int:
50+
if artifact_type == ArtifactType.TABLES:
51+
query = f"""
52+
SELECT COUNT(*) FROM {db}.information_schema.tables
53+
WHERE table_schema = '{schema}'
54+
AND table_type = 'BASE TABLE'
55+
"""
56+
elif artifact_type == ArtifactType.VIEWS:
57+
query = f"""
58+
SELECT COUNT(*)
59+
FROM {db}.information_schema.views
60+
WHERE table_schema = '{schema}'
61+
"""
62+
elif artifact_type == ArtifactType.PROCEDURES:
63+
query = f"""
64+
SELECT COUNT(*)
65+
FROM {db}.information_schema.procedures
66+
WHERE procedure_schema = '{schema}'
67+
"""
68+
elif artifact_type == ArtifactType.FUNCTIONS:
69+
query = f"""
70+
SELECT COUNT(*)
71+
FROM {db}.information_schema.functions
72+
WHERE function_schema = '{schema}'
73+
"""
74+
elif artifact_type == ArtifactType.SEQUENCES:
75+
query = f"""
76+
SELECT COUNT(*)
77+
FROM {db}.information_schema.sequences
78+
WHERE sequence_schema = '{schema}'
79+
"""
80+
elif artifact_type == ArtifactType.STAGES:
81+
query = f"SHOW STAGES IN SCHEMA {db}.{schema}"
82+
return len(self.session.sql(query).collect())
83+
elif artifact_type == ArtifactType.FILE_FORMATS:
84+
query = f"SHOW FILE FORMATS IN SCHEMA {db}.{schema}"
85+
return len(self.session.sql(query).collect())
86+
elif artifact_type == ArtifactType.TASKS:
87+
query = f"SHOW TASKS IN SCHEMA {db}.{schema}"
88+
return len(self.session.sql(query).collect())
89+
elif artifact_type == ArtifactType.STREAMS:
90+
query = f"SHOW STREAMS IN SCHEMA {db}.{schema}"
91+
return len(self.session.sql(query).collect())
92+
elif artifact_type == ArtifactType.PIPES:
93+
query = f"SHOW PIPES IN SCHEMA {db}.{schema}"
94+
return len(self.session.sql(query).collect())
95+
else:
96+
return 0
97+
98+
return self.session.sql(query).collect()[0][0]
99+
100+
def validate_completeness(self, extracted: Dict[str, Any], db: str, schema: str):
101+
completeness = {}
102+
103+
for artifact_type in ArtifactType:
104+
snowflake_count = self.count_snowflake_artifacts(artifact_type, db, schema)
105+
extracted_count = len(extracted[artifact_type.value][artifact_type.value])
106+
107+
coverage = (extracted_count / snowflake_count) if snowflake_count > 0 else 1.0
108+
109+
completeness[artifact_type.value] = {
110+
"snowflake": snowflake_count,
111+
"extracted": extracted_count,
112+
"coverage_pct": round(coverage * 100, 2),
113+
"perfect_match": extracted_count == snowflake_count
114+
}
115+
116+
return completeness
117+
118+
def validate_table_definition(self, db, schema, extracted_table: Dict[str, Any]) -> Dict[str, Any]:
119+
table_name = extracted_table["table_name"]
120+
121+
query = f"""
122+
SELECT column_name, data_type, is_nullable,
123+
character_maximum_length, numeric_precision, numeric_scale,
124+
column_default, comment
125+
FROM {db}.information_schema.columns
126+
WHERE table_schema = '{schema}'
127+
AND table_name = '{table_name}'
128+
ORDER BY ordinal_position
129+
"""
130+
131+
# Normalize Snowflake columns
132+
sf_columns_raw = [dict(row.as_dict()) for row in self.session.sql(query).collect()]
133+
sf_columns = [normalize_column(col) for col in sf_columns_raw]
134+
135+
# Normalize extracted columns
136+
extracted_columns_raw = extracted_table.get("columns", [])
137+
extracted_columns = [normalize_column(col) for col in extracted_columns_raw]
138+
139+
# Compute correctness statistics
140+
total = max(len(sf_columns), len(extracted_columns), 1)
141+
matches = sum(1 for sf, ex in zip(sf_columns, extracted_columns) if sf == ex)
142+
correctness_pct = round((matches / total) * 100, 2)
143+
144+
return {
145+
"table": table_name,
146+
"snowflake_column_count": len(sf_columns),
147+
"extracted_column_count": len(extracted_columns),
148+
"matches": matches,
149+
"total_columns": total,
150+
"correctness_pct": correctness_pct,
151+
"columns_match_exactly": sf_columns == extracted_columns,
152+
"snowflake": sf_columns,
153+
"extracted": extracted_columns
154+
}
155+
156+
157+
158+
def validate_view_definition(self, db, schema, extracted_view: Dict[str, Any]) -> Dict[str, Any]:
159+
view_name = extracted_view["view_name"]
160+
161+
query = f"""
162+
SELECT view_definition
163+
FROM {db}.information_schema.views
164+
WHERE table_schema = '{schema}'
165+
AND table_name = '{view_name}'
166+
"""
167+
result = self.session.sql(query).collect()
168+
sf_def = result[0]["VIEW_DEFINITION"] if result else ""
169+
170+
extracted_def = extracted_view.get("view_definition", "")
171+
172+
# Normalize whitespace for fair comparison
173+
def normalize(s):
174+
return " ".join(s.lower().strip().split())
175+
176+
sf_norm = normalize(sf_def)
177+
ex_norm = normalize(extracted_def)
178+
179+
match = sf_norm == ex_norm
180+
181+
return {
182+
"view": view_name,
183+
"match": match,
184+
"snowflake_definition": sf_def,
185+
"extracted_definition": extracted_def
186+
}
187+
188+
189+
190+
191+
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Entry point for Snowflake metadata validation.
3+
Runs as a Databricks wheel task using 'snowflake-validator'.
4+
"""
5+
6+
import json
7+
from snowflake.snowpark import Session
8+
from databricks.sdk.runtime import dbutils
9+
10+
from migration_accelerator_package.snowpark_utils import (
11+
build_snowflake_connection_params,
12+
get_uc_volume_path,
13+
)
14+
15+
from migration_accelerator_package.artifact_validators import MetadataValidator
16+
from migration_accelerator_package.constants import SnowflakeConfig
17+
18+
19+
def main():
20+
print("=" * 80)
21+
print(" SNOWFLAKE METADATA VALIDATION ")
22+
print("=" * 80)
23+
24+
connection_parameters = build_snowflake_connection_params()
25+
session = Session.builder.configs(connection_parameters).create()
26+
27+
db = SnowflakeConfig.SNOWFLAKE_DATABASE.value
28+
schema = SnowflakeConfig.SNOWFLAKE_SCHEMA.value
29+
30+
volume_path = get_uc_volume_path()
31+
print(f"UC Volume Path: {volume_path}")
32+
33+
34+
validator = MetadataValidator(session, volume_path)
35+
36+
print("Loading extracted metadata...")
37+
extracted = validator.load_all_artifacts()
38+
print("✓ Loaded all JSON files")
39+
40+
print("Running completeness validation...")
41+
completeness_report = validator.validate_completeness(extracted, db, schema)
42+
print("✓ Completeness check done")
43+
44+
print("Running correctness checks...")
45+
46+
sample_tables = extracted["tables"]["tables"][:5]
47+
table_results = [
48+
validator.validate_table_definition(db, schema, t)
49+
for t in sample_tables
50+
]
51+
52+
sample_views = extracted["views"]["views"][:5]
53+
view_results = [
54+
validator.validate_view_definition(db, schema, v)
55+
for v in sample_views
56+
]
57+
58+
report = {
59+
"database": db,
60+
"schema": schema,
61+
"completeness": completeness_report,
62+
"correctness": {
63+
"tables": table_results,
64+
"views": view_results,
65+
}
66+
}
67+
68+
print("✓ Validation complete")
69+
print(json.dumps(report, indent=2))
70+
71+
output_path = f"{volume_path}/validation_report.json"
72+
dbutils.fs.put(output_path, json.dumps(report, indent=2), overwrite=True)
73+
print(f"✓ Validation report saved to {output_path}")
74+
75+
session.close()
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Utility functions shared across ingestion + validation entrypoints.
3+
"""
4+
5+
import os
6+
from databricks.sdk.runtime import dbutils
7+
from migration_accelerator_package.constants import SnowflakeConfig, UnityCatalogConfig
8+
9+
10+
def get_secret(secret_name: str):
11+
"""Retrieve secrets from Databricks secret scope or fallback to env variables."""
12+
try:
13+
return dbutils.secrets.get("migration-accelerator", secret_name)
14+
except Exception:
15+
return os.getenv(secret_name, "")
16+
17+
18+
def build_snowflake_connection_params():
19+
"""Return Snowflake connection parameters used by all wheel entrypoints."""
20+
return {
21+
"account": get_secret("SNOWFLAKE_ACCOUNT"),
22+
"user": get_secret("SNOWFLAKE_USER"),
23+
"password": get_secret("SNOWFLAKE_PASSWORD"),
24+
"role": SnowflakeConfig.SNOWFLAKE_ROLE.value,
25+
"warehouse": SnowflakeConfig.SNOWFLAKE_WAREHOUSE.value,
26+
"database": SnowflakeConfig.SNOWFLAKE_DATABASE.value,
27+
"schema": SnowflakeConfig.SNOWFLAKE_SCHEMA.value,
28+
}
29+
30+
31+
def get_uc_volume_path() -> str:
32+
"""Return the base UC volume path where JSON artifacts live."""
33+
return (
34+
f"/Volumes/"
35+
f"{UnityCatalogConfig.CATALOG.value}/"
36+
f"{UnityCatalogConfig.SCHEMA.value}/"
37+
f"{UnityCatalogConfig.RAW_VOLUME.value}"
38+
)

0 commit comments

Comments
 (0)