diff --git a/.gitignore b/.gitignore index 1e37f2c..fc490f0 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ celerybeat.pid # Environments .env .envrc +env.example_2 .venv env/ venv/ diff --git a/pyproject.toml b/pyproject.toml index b91332f..40f5b9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,4 +22,6 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] migration-accelerator = "migration_accelerator_package.main:main" snowpark-reader = "migration_accelerator_package.snowpark:main" +snowflake-validator = "migration_accelerator_package.ingestion_validation:main" + diff --git a/resources/jobs.yml b/resources/jobs.yml index f344585..5ecf484 100644 --- a/resources/jobs.yml +++ b/resources/jobs.yml @@ -14,6 +14,15 @@ resources: entry_point: snowpark-reader package_name: migration_accelerator_package environment_key: "serverless_wheel" + + - task_key: snowflake_ingestion_validation_task + depends_on: + - task_key: snowflake_ingestion_task + python_wheel_task: + entry_point: snowflake-validator + package_name: migration_accelerator_package + environment_key: "serverless_wheel" + environments: - environment_key: "serverless_wheel" diff --git a/src/migration_accelerator_package/artifact_readers.py b/src/migration_accelerator_package/artifact_readers.py new file mode 100644 index 0000000..8d4b346 --- /dev/null +++ b/src/migration_accelerator_package/artifact_readers.py @@ -0,0 +1,251 @@ +""" +Artifact Reader Facade Classes +Provides a clean interface for reading different types of Snowflake artifacts. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from snowflake.snowpark import Session +from migration_accelerator_package.constants import ArtifactType + + +class ArtifactReader(ABC): + """Abstract base class for artifact readers.""" + + def __init__(self, session: Session, database: str, schema: str): + """Initialize the artifact reader.""" + self.session = session + self.database = database + self.schema = schema + + @abstractmethod + def read(self) -> List[Dict[str, Any]]: + """Read artifacts of this type.""" + pass + + def _normalize_keys(self, row_dict: Dict[str, Any]) -> Dict[str, Any]: + """Normalize dictionary keys to lowercase.""" + return {k.lower(): v for k, v in row_dict.items()} + + def _normalize_rows(self, rows: List) -> List[Dict[str, Any]]: + """Normalize a list of rows to dictionaries with lowercase keys.""" + return [self._normalize_keys(dict(row.as_dict())) for row in rows] + + +class TablesReader(ArtifactReader): + """Reader for Snowflake tables.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all tables in the schema.""" + query = f""" + SELECT + table_catalog as database_name, + table_schema as schema_name, + table_name, + table_type, + row_count, + bytes, + created, + last_altered, + comment + FROM information_schema.tables + WHERE table_schema = '{self.schema}' + AND table_type = 'BASE TABLE' + ORDER BY table_name + """ + result = self.session.sql(query).collect() + return [self._normalize_keys(dict(row.as_dict())) for row in result] + + def read_columns(self, table_name: str) -> List[Dict[str, Any]]: + """Get columns for a specific table.""" + query = f""" + SELECT + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale, + is_nullable, + column_default, + comment + FROM information_schema.columns + WHERE table_schema = '{self.schema}' + AND table_name = '{table_name}' + ORDER BY ordinal_position + """ + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + def read_table_data(self, table_name: str, limit: int = None) -> List[Dict[str, Any]]: + """Get data from a specific table.""" + query = f"SELECT * FROM {self.database}.{self.schema}.{table_name}" + if limit: + query += f" LIMIT {limit}" + result = self.session.sql(query).collect() + return [dict(row.as_dict()) for row in result] + + +class ViewsReader(ArtifactReader): + """Reader for Snowflake views.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all views in the schema.""" + query = f""" + SELECT + table_catalog as database_name, + table_schema as schema_name, + table_name as view_name, + view_definition, + created, + comment + FROM information_schema.views + WHERE table_schema = '{self.schema}' + ORDER BY view_name + """ + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class ProceduresReader(ArtifactReader): + """Reader for Snowflake stored procedures.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all stored procedures in the schema.""" + query = f""" + SELECT + procedure_catalog as database_name, + procedure_schema as schema_name, + procedure_name, + procedure_definition, + created, + last_altered, + comment + FROM information_schema.procedures + WHERE procedure_schema = '{self.schema}' + ORDER BY procedure_name + """ + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class FunctionsReader(ArtifactReader): + """Reader for Snowflake user-defined functions.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all user-defined functions in the schema.""" + query = f""" + SELECT + function_catalog as database_name, + function_schema as schema_name, + function_name, + function_definition, + created, + last_altered, + comment + FROM information_schema.functions + WHERE function_schema = '{self.schema}' + ORDER BY function_name + """ + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class SequencesReader(ArtifactReader): + """Reader for Snowflake sequences.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all sequences in the schema.""" + query = f"SHOW SEQUENCES IN SCHEMA {self.database}.{self.schema}" + try: + result = self.session.sql(query).collect() + return self._normalize_rows(result) + except Exception as e: + # Fallback: try information_schema with basic columns only + print(f" ⚠ Warning: SHOW SEQUENCES failed, trying information_schema: {e}") + query = f""" + SELECT + sequence_catalog as database_name, + sequence_schema as schema_name, + sequence_name + FROM information_schema.sequences + WHERE sequence_schema = '{self.schema}' + ORDER BY sequence_name + """ + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class StagesReader(ArtifactReader): + """Reader for Snowflake stages.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all stages in the schema.""" + query = f"SHOW STAGES IN SCHEMA {self.database}.{self.schema}" + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class FileFormatsReader(ArtifactReader): + """Reader for Snowflake file formats.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all file formats in the schema.""" + query = f"SHOW FILE FORMATS IN SCHEMA {self.database}.{self.schema}" + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class TasksReader(ArtifactReader): + """Reader for Snowflake tasks.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all tasks in the schema.""" + query = f"SHOW TASKS IN SCHEMA {self.database}.{self.schema}" + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class StreamsReader(ArtifactReader): + """Reader for Snowflake streams.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all streams in the schema.""" + query = f"SHOW STREAMS IN SCHEMA {self.database}.{self.schema}" + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class PipesReader(ArtifactReader): + """Reader for Snowflake pipes.""" + + def read(self) -> List[Dict[str, Any]]: + """Get all pipes in the schema.""" + query = f"SHOW PIPES IN SCHEMA {self.database}.{self.schema}" + result = self.session.sql(query).collect() + return self._normalize_rows(result) + + +class ArtifactReaderFactory: + """Factory for creating artifact readers.""" + + _readers = { + ArtifactType.TABLES: TablesReader, + ArtifactType.VIEWS: ViewsReader, + ArtifactType.PROCEDURES: ProceduresReader, + ArtifactType.FUNCTIONS: FunctionsReader, + ArtifactType.SEQUENCES: SequencesReader, + ArtifactType.STAGES: StagesReader, + ArtifactType.FILE_FORMATS: FileFormatsReader, + ArtifactType.TASKS: TasksReader, + ArtifactType.STREAMS: StreamsReader, + ArtifactType.PIPES: PipesReader, + } + + @classmethod + def create_reader(cls, artifact_type: ArtifactType, session: Session, database: str, schema: str) -> ArtifactReader: + """Create an artifact reader for the given type.""" + reader_class = cls._readers.get(artifact_type) + if not reader_class: + raise ValueError(f"No reader available for artifact type: {artifact_type}") + return reader_class(session, database, schema) + diff --git a/src/migration_accelerator_package/artifact_validators.py b/src/migration_accelerator_package/artifact_validators.py new file mode 100644 index 0000000..37a8219 --- /dev/null +++ b/src/migration_accelerator_package/artifact_validators.py @@ -0,0 +1,191 @@ +""" +Artifact Validator Class +Provides a clean interface for validating completeness and correctness of the ingested types of Snowflake artifacts. +""" +import json +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from snowflake.snowpark import Session +from migration_accelerator_package.constants import ArtifactType, ArtifactFileName +from databricks.sdk.runtime import * + +def normalize_column(col: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize column metadata for comparison. + Only keeps the essential fields needed for correctness validation. + """ + col = {k.lower(): v for k, v in col.items()} + + return { + "column_name": col.get("column_name"), + "data_type": col.get("data_type"), + "is_nullable": col.get("is_nullable"), + } + + +class MetadataValidator: + """ + Validates completeness and correctness of extracted Snowflake metadata. + """ + def __init__(self, session: Session, volume_path: str): + self.session = session + self.volume_path = volume_path + + def _load_extracted(self, filename: str) -> Dict[str, Any]: + path = f"{self.volume_path}/{filename}" + raw = dbutils.fs.head(path, 50_000_000) + return json.loads(raw) + + + + def load_all_artifacts(self) -> Dict[str, Dict[str, Any]]: + extracted = {} + for artifact_type in ArtifactType: + file_enum = ArtifactFileName[artifact_type.name] + filename = file_enum.value + extracted[artifact_type.value] = self._load_extracted(filename) + return extracted + + def count_snowflake_artifacts(self, artifact_type: ArtifactType, db: str, schema: str) -> int: + if artifact_type == ArtifactType.TABLES: + query = f""" + SELECT COUNT(*) FROM {db}.information_schema.tables + WHERE table_schema = '{schema}' + AND table_type = 'BASE TABLE' + """ + elif artifact_type == ArtifactType.VIEWS: + query = f""" + SELECT COUNT(*) + FROM {db}.information_schema.views + WHERE table_schema = '{schema}' + """ + elif artifact_type == ArtifactType.PROCEDURES: + query = f""" + SELECT COUNT(*) + FROM {db}.information_schema.procedures + WHERE procedure_schema = '{schema}' + """ + elif artifact_type == ArtifactType.FUNCTIONS: + query = f""" + SELECT COUNT(*) + FROM {db}.information_schema.functions + WHERE function_schema = '{schema}' + """ + elif artifact_type == ArtifactType.SEQUENCES: + query = f""" + SELECT COUNT(*) + FROM {db}.information_schema.sequences + WHERE sequence_schema = '{schema}' + """ + elif artifact_type == ArtifactType.STAGES: + query = f"SHOW STAGES IN SCHEMA {db}.{schema}" + return len(self.session.sql(query).collect()) + elif artifact_type == ArtifactType.FILE_FORMATS: + query = f"SHOW FILE FORMATS IN SCHEMA {db}.{schema}" + return len(self.session.sql(query).collect()) + elif artifact_type == ArtifactType.TASKS: + query = f"SHOW TASKS IN SCHEMA {db}.{schema}" + return len(self.session.sql(query).collect()) + elif artifact_type == ArtifactType.STREAMS: + query = f"SHOW STREAMS IN SCHEMA {db}.{schema}" + return len(self.session.sql(query).collect()) + elif artifact_type == ArtifactType.PIPES: + query = f"SHOW PIPES IN SCHEMA {db}.{schema}" + return len(self.session.sql(query).collect()) + else: + return 0 + + return self.session.sql(query).collect()[0][0] + + def validate_completeness(self, extracted: Dict[str, Any], db: str, schema: str): + completeness = {} + + for artifact_type in ArtifactType: + snowflake_count = self.count_snowflake_artifacts(artifact_type, db, schema) + extracted_count = len(extracted[artifact_type.value][artifact_type.value]) + + coverage = (extracted_count / snowflake_count) if snowflake_count > 0 else 1.0 + + completeness[artifact_type.value] = { + "snowflake": snowflake_count, + "extracted": extracted_count, + "coverage_pct": round(coverage * 100, 2), + "perfect_match": extracted_count == snowflake_count + } + + return completeness + + def validate_table_definition(self, db, schema, extracted_table: Dict[str, Any]) -> Dict[str, Any]: + table_name = extracted_table["table_name"] + + query = f""" + SELECT column_name, data_type, is_nullable, + character_maximum_length, numeric_precision, numeric_scale, + column_default, comment + FROM {db}.information_schema.columns + WHERE table_schema = '{schema}' + AND table_name = '{table_name}' + ORDER BY ordinal_position + """ + + # Normalize Snowflake columns + sf_columns_raw = [dict(row.as_dict()) for row in self.session.sql(query).collect()] + sf_columns = [normalize_column(col) for col in sf_columns_raw] + + # Normalize extracted columns + extracted_columns_raw = extracted_table.get("columns", []) + extracted_columns = [normalize_column(col) for col in extracted_columns_raw] + + # Compute correctness statistics + total = max(len(sf_columns), len(extracted_columns), 1) + matches = sum(1 for sf, ex in zip(sf_columns, extracted_columns) if sf == ex) + correctness_pct = round((matches / total) * 100, 2) + + return { + "table": table_name, + "snowflake_column_count": len(sf_columns), + "extracted_column_count": len(extracted_columns), + "matches": matches, + "total_columns": total, + "correctness_pct": correctness_pct, + "columns_match_exactly": sf_columns == extracted_columns, + "snowflake": sf_columns, + "extracted": extracted_columns + } + + + + def validate_view_definition(self, db, schema, extracted_view: Dict[str, Any]) -> Dict[str, Any]: + view_name = extracted_view["view_name"] + + query = f""" + SELECT view_definition + FROM {db}.information_schema.views + WHERE table_schema = '{schema}' + AND table_name = '{view_name}' + """ + result = self.session.sql(query).collect() + sf_def = result[0]["VIEW_DEFINITION"] if result else "" + + extracted_def = extracted_view.get("view_definition", "") + + # Normalize whitespace for fair comparison + def normalize(s): + return " ".join(s.lower().strip().split()) + + sf_norm = normalize(sf_def) + ex_norm = normalize(extracted_def) + + match = sf_norm == ex_norm + + return { + "view": view_name, + "match": match, + "snowflake_definition": sf_def, + "extracted_definition": extracted_def + } + + + + + diff --git a/src/migration_accelerator_package/constants.py b/src/migration_accelerator_package/constants.py index 85a0bb9..e762eb7 100644 --- a/src/migration_accelerator_package/constants.py +++ b/src/migration_accelerator_package/constants.py @@ -9,4 +9,30 @@ class SnowflakeConfig(Enum): class UnityCatalogConfig(Enum): CATALOG = "qubika_partner_solutions" SCHEMA = "migration_accelerator" - RAW_VOLUME = "snowflake_artifacts_raw" \ No newline at end of file + RAW_VOLUME = "snowflake_artifacts_raw" + +class ArtifactType(Enum): + """Enumeration of Snowflake artifact types.""" + TABLES = "tables" + VIEWS = "views" + PROCEDURES = "procedures" + FUNCTIONS = "functions" + SEQUENCES = "sequences" + STAGES = "stages" + FILE_FORMATS = "file_formats" + TASKS = "tasks" + STREAMS = "streams" + PIPES = "pipes" + +class ArtifactFileName(Enum): + """Enumeration of output file names for each artifact type.""" + TABLES = "tables.json" + VIEWS = "views.json" + PROCEDURES = "procedures.json" + FUNCTIONS = "functions.json" + SEQUENCES = "sequences.json" + STAGES = "stages.json" + FILE_FORMATS = "file_formats.json" + TASKS = "tasks.json" + STREAMS = "streams.json" + PIPES = "pipelines.json" # pipes saved as pipelines.json \ No newline at end of file diff --git a/src/migration_accelerator_package/ingestion_validation.py b/src/migration_accelerator_package/ingestion_validation.py new file mode 100644 index 0000000..b5dd6f9 --- /dev/null +++ b/src/migration_accelerator_package/ingestion_validation.py @@ -0,0 +1,79 @@ +""" +Entry point for Snowflake metadata validation. +Runs as a Databricks wheel task using 'snowflake-validator'. +""" + +import json +from snowflake.snowpark import Session +from databricks.sdk.runtime import dbutils + +from migration_accelerator_package.snowpark_utils import ( + build_snowflake_connection_params, + get_uc_volume_path, +) + +from migration_accelerator_package.artifact_validators import MetadataValidator +from migration_accelerator_package.constants import SnowflakeConfig + + +def main(): + print("=" * 80) + print(" SNOWFLAKE METADATA VALIDATION ") + print("=" * 80) + + connection_parameters = build_snowflake_connection_params() + session = Session.builder.configs(connection_parameters).create() + + db = SnowflakeConfig.SNOWFLAKE_DATABASE.value + schema = SnowflakeConfig.SNOWFLAKE_SCHEMA.value + + volume_path = get_uc_volume_path() + print(f"UC Volume Path: {volume_path}") + + + validator = MetadataValidator(session, volume_path) + + print("Loading extracted metadata...") + extracted = validator.load_all_artifacts() + print("✓ Loaded all JSON files") + + print("Running completeness validation...") + completeness_report = validator.validate_completeness(extracted, db, schema) + print("✓ Completeness check done") + + print("Running correctness checks...") + + sample_tables = extracted["tables"]["tables"][:5] + table_results = [ + validator.validate_table_definition(db, schema, t) + for t in sample_tables + ] + + sample_views = extracted["views"]["views"][:5] + view_results = [ + validator.validate_view_definition(db, schema, v) + for v in sample_views + ] + + report = { + "database": db, + "schema": schema, + "completeness": completeness_report, + "correctness": { + "tables": table_results, + "views": view_results, + } + } + + print("✓ Validation complete") + print(json.dumps(report, indent=2)) + + output_path = f"{volume_path}/validation_report.json" + dbutils.fs.put(output_path, json.dumps(report, indent=2), overwrite=True) + print(f"✓ Validation report saved to {output_path}") + + session.close() + + +if __name__ == "__main__": + main() diff --git a/src/migration_accelerator_package/snowpark.py b/src/migration_accelerator_package/snowpark.py index 56ed4cc..5c11c0b 100644 --- a/src/migration_accelerator_package/snowpark.py +++ b/src/migration_accelerator_package/snowpark.py @@ -8,9 +8,13 @@ import json from typing import Dict, List, Any, Optional from snowflake.snowpark import Session -from snowflake.snowpark.functions import col from databricks.sdk.runtime import dbutils from migration_accelerator_package import constants +from migration_accelerator_package.artifact_readers import ( + ArtifactReaderFactory, + TablesReader +) +from migration_accelerator_package.constants import ArtifactType, ArtifactFileName def get_secret(secret_name): """Retrieve secrets from Databricks secret scope""" @@ -64,248 +68,163 @@ def __init__(self, session: Session): self.session = session self.database = SFLKdatabase self.schema = SFLKschema + self._readers = {} + self._initialize_readers() + + def _initialize_readers(self): + """Initialize artifact readers using the factory pattern.""" + for artifact_type in ArtifactType: + self._readers[artifact_type] = ArtifactReaderFactory.create_reader( + artifact_type, self.session, self.database, self.schema + ) def get_tables(self) -> List[Dict[str, Any]]: """Get all tables in the schema.""" - query = f""" - SELECT - table_catalog as database_name, - table_schema as schema_name, - table_name, - table_type, - row_count, - bytes, - created, - last_altered, - comment - FROM information_schema.tables - WHERE table_schema = '{self.schema}' - AND table_type = 'BASE TABLE' - ORDER BY table_name - """ - result = self.session.sql(query).collect() - # Convert to dict and normalize keys (handle case sensitivity) - tables = [] - for row in result: - row_dict = dict(row.as_dict()) - # Normalize keys to lowercase for consistency - normalized = {k.lower(): v for k, v in row_dict.items()} - tables.append(normalized) - return tables + return self._readers[ArtifactType.TABLES].read() def get_table_columns(self, table_name: str) -> List[Dict[str, Any]]: """Get columns for a specific table.""" - query = f""" - SELECT - column_name, - data_type, - character_maximum_length, - numeric_precision, - numeric_scale, - is_nullable, - column_default, - comment - FROM information_schema.columns - WHERE table_schema = '{self.schema}' - AND table_name = '{table_name}' - ORDER BY ordinal_position - """ - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + tables_reader = self._readers[ArtifactType.TABLES] + if isinstance(tables_reader, TablesReader): + return tables_reader.read_columns(table_name) + return [] def get_views(self) -> List[Dict[str, Any]]: """Get all views in the schema.""" - query = f""" - SELECT - table_catalog as database_name, - table_schema as schema_name, - table_name as view_name, - view_definition, - created, - comment - FROM information_schema.views - WHERE table_schema = '{self.schema}' - ORDER BY view_name - """ - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.VIEWS].read() def get_procedures(self) -> List[Dict[str, Any]]: """Get all stored procedures in the schema.""" - query = f""" - SELECT - procedure_catalog as database_name, - procedure_schema as schema_name, - procedure_name, - procedure_definition, - created, - last_altered, - comment - FROM information_schema.procedures - WHERE procedure_schema = '{self.schema}' - ORDER BY procedure_name - """ - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.PROCEDURES].read() def get_functions(self) -> List[Dict[str, Any]]: """Get all user-defined functions in the schema.""" - query = f""" - SELECT - function_catalog as database_name, - function_schema as schema_name, - function_name, - function_definition, - created, - last_altered, - comment - FROM information_schema.functions - WHERE function_schema = '{self.schema}' - ORDER BY function_name - """ - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.FUNCTIONS].read() def get_sequences(self) -> List[Dict[str, Any]]: """Get all sequences in the schema.""" - # Use SHOW command for sequences as information_schema may not have all columns - query = f"SHOW SEQUENCES IN SCHEMA {self.database}.{self.schema}" - try: - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] - except Exception as e: - # Fallback: try information_schema with basic columns only - print(f" ⚠ Warning: SHOW SEQUENCES failed, trying information_schema: {e}") - query = f""" - SELECT - sequence_catalog as database_name, - sequence_schema as schema_name, - sequence_name - FROM information_schema.sequences - WHERE sequence_schema = '{self.schema}' - ORDER BY sequence_name - """ - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.SEQUENCES].read() def get_stages(self) -> List[Dict[str, Any]]: """Get all stages in the schema.""" - # Use SHOW command for stages - query = f"SHOW STAGES IN SCHEMA {self.database}.{self.schema}" - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.STAGES].read() def get_file_formats(self) -> List[Dict[str, Any]]: """Get all file formats in the schema.""" - query = f"SHOW FILE FORMATS IN SCHEMA {self.database}.{self.schema}" - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.FILE_FORMATS].read() def get_tasks(self) -> List[Dict[str, Any]]: """Get all tasks in the schema.""" - query = f"SHOW TASKS IN SCHEMA {self.database}.{self.schema}" - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.TASKS].read() def get_streams(self) -> List[Dict[str, Any]]: """Get all streams in the schema.""" - query = f"SHOW STREAMS IN SCHEMA {self.database}.{self.schema}" - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.STREAMS].read() def get_pipes(self) -> List[Dict[str, Any]]: """Get all pipes in the schema.""" - query = f"SHOW PIPES IN SCHEMA {self.database}.{self.schema}" - result = self.session.sql(query).collect() - # Normalize keys to lowercase - return [{k.lower(): v for k, v in dict(row.as_dict()).items()} for row in result] + return self._readers[ArtifactType.PIPES].read() def get_table_data(self, table_name: str, limit: Optional[int] = None) -> List[Dict[str, Any]]: """Get data from a specific table.""" - query = f"SELECT * FROM {self.database}.{self.schema}.{table_name}" - if limit: - query += f" LIMIT {limit}" - result = self.session.sql(query).collect() - return [dict(row.as_dict()) for row in result] + tables_reader = self._readers[ArtifactType.TABLES] + if isinstance(tables_reader, TablesReader): + return tables_reader.read_table_data(table_name, limit) + return [] def get_all_objects(self) -> Dict[str, Any]: - """Get all database objects in one call.""" + """Get all database objects in one call using artifact readers.""" print("\n📊 Reading all Snowflake objects using Snowpark...") objects = { 'database': self.database, 'schema': self.schema, - 'tables': self.get_tables(), - 'views': self.get_views(), - 'procedures': self.get_procedures(), - 'functions': self.get_functions(), - 'sequences': self.get_sequences(), - 'stages': self.get_stages(), - 'file_formats': self.get_file_formats(), - 'tasks': self.get_tasks(), - 'streams': self.get_streams(), - 'pipes': self.get_pipes(), } - # Add column details for each table - for table in objects['tables']: + # Read all artifacts using facade pattern + for artifact_type in ArtifactType: + try: + artifacts = self._readers[artifact_type].read() + objects[artifact_type.value] = artifacts + print(f"✓ Found {len(artifacts)} {artifact_type.value}") + except Exception as e: + print(f" ⚠ Warning: Error reading {artifact_type.value}: {str(e)[:100]}") + objects[artifact_type.value] = [] + + # Add column details and sample data for each table + for table in objects[ArtifactType.TABLES.value]: # Handle both lowercase and uppercase keys table_name = table.get('table_name') or table.get('TABLE_NAME') if table_name: table['columns'] = self.get_table_columns(table_name) + + # Add sample data (limit to 10 rows each) + try: + table['sample_data'] = self.get_table_data(table_name, limit=10) + except Exception as e: + table['sample_data'] = f"Error retrieving data: {str(e)}" else: print(f" ⚠ Warning: Could not find table_name in table object: {list(table.keys())}") table['columns'] = [] - - # Add sample data for tables (limit to 10 rows each) - for table in objects['tables']: - try: - # Handle both lowercase and uppercase keys - table_name = table.get('table_name') or table.get('TABLE_NAME') - if table_name: - table['sample_data'] = self.get_table_data(table_name, limit=10) - else: - table['sample_data'] = "Error: table_name not found" - except Exception as e: - table['sample_data'] = f"Error retrieving data: {str(e)}" - - print(f"✓ Found {len(objects['tables'])} tables") - print(f"✓ Found {len(objects['views'])} views") - print(f"✓ Found {len(objects['procedures'])} procedures") - print(f"✓ Found {len(objects['functions'])} functions") - print(f"✓ Found {len(objects['sequences'])} sequences") - print(f"✓ Found {len(objects['stages'])} stages") - print(f"✓ Found {len(objects['file_formats'])} file formats") - print(f"✓ Found {len(objects['tasks'])} tasks") - print(f"✓ Found {len(objects['streams'])} streams") - print(f"✓ Found {len(objects['pipes'])} pipes") + table['sample_data'] = "Error: table_name not found" return objects - def save_to_json(self, output_file: str = 'snowflake_objects_snowpark.json'): - """Save all objects to a JSON file in Unity Catalog Volume using dbutils.""" + def save_to_json(self, output_dir: str = None): + """Save all objects to separate JSON files in Unity Catalog Volume, one per artifact type.""" objects = self.get_all_objects() - # Convert objects to JSON string - json_data = json.dumps(objects, indent=2, default=str) + # Use default volume path if output_dir not specified + if output_dir is None: + base_volume_path = f"/Volumes/{constants.UnityCatalogConfig.CATALOG.value}/{constants.UnityCatalogConfig.SCHEMA.value}/{constants.UnityCatalogConfig.RAW_VOLUME.value}" + else: + base_volume_path = output_dir + + # Metadata to include in each file + metadata = { + 'database': objects['database'], + 'schema': objects['schema'] + } + + # Mapping of artifact types to file names + artifact_file_mapping = { + ArtifactType.TABLES: ArtifactFileName.TABLES, + ArtifactType.VIEWS: ArtifactFileName.VIEWS, + ArtifactType.PROCEDURES: ArtifactFileName.PROCEDURES, + ArtifactType.FUNCTIONS: ArtifactFileName.FUNCTIONS, + ArtifactType.SEQUENCES: ArtifactFileName.SEQUENCES, + ArtifactType.STAGES: ArtifactFileName.STAGES, + ArtifactType.FILE_FORMATS: ArtifactFileName.FILE_FORMATS, + ArtifactType.TASKS: ArtifactFileName.TASKS, + ArtifactType.STREAMS: ArtifactFileName.STREAMS, + ArtifactType.PIPES: ArtifactFileName.PIPES, + } - # Define the volume path - volume_path = f"/Volumes/{constants.UnityCatalogConfig.CATALOG.value}/{constants.UnityCatalogConfig.SCHEMA.value}/{constants.UnityCatalogConfig.RAW_VOLUME.value}/{output_file}" + saved_files = [] - # Write using dbutils - dbutils.fs.put(volume_path, json_data, overwrite=True) + # Save each artifact type to its own file + for artifact_type, file_name_enum in artifact_file_mapping.items(): + artifact_key = artifact_type.value + if artifact_key in objects: + filename = file_name_enum.value + volume_path = f"{base_volume_path}/{filename}" + + # Prepare data with metadata + artifact_data = { + **metadata, + artifact_key: objects[artifact_key] + } + + # Convert to JSON string + json_data = json.dumps(artifact_data, indent=2, default=str) + + # Write using dbutils + dbutils.fs.put(volume_path, json_data, overwrite=True) + + saved_files.append(filename) + print(f" ✓ Saved {artifact_key} to {filename}") - print(f"\n✓ Saved all objects to Unity Catalog Volume: {volume_path}") + print(f"\n✓ Saved {len(saved_files)} artifact files to Unity Catalog Volume: {base_volume_path}") def object_exists(self, object_name: str, object_type: str = 'TABLE') -> bool: """Check if an object exists in the schema.""" diff --git a/src/migration_accelerator_package/snowpark_utils.py b/src/migration_accelerator_package/snowpark_utils.py new file mode 100644 index 0000000..a5a32a0 --- /dev/null +++ b/src/migration_accelerator_package/snowpark_utils.py @@ -0,0 +1,38 @@ +""" +Utility functions shared across ingestion + validation entrypoints. +""" + +import os +from databricks.sdk.runtime import dbutils +from migration_accelerator_package.constants import SnowflakeConfig, UnityCatalogConfig + + +def get_secret(secret_name: str): + """Retrieve secrets from Databricks secret scope or fallback to env variables.""" + try: + return dbutils.secrets.get("migration-accelerator", secret_name) + except Exception: + return os.getenv(secret_name, "") + + +def build_snowflake_connection_params(): + """Return Snowflake connection parameters used by all wheel entrypoints.""" + return { + "account": get_secret("SNOWFLAKE_ACCOUNT"), + "user": get_secret("SNOWFLAKE_USER"), + "password": get_secret("SNOWFLAKE_PASSWORD"), + "role": SnowflakeConfig.SNOWFLAKE_ROLE.value, + "warehouse": SnowflakeConfig.SNOWFLAKE_WAREHOUSE.value, + "database": SnowflakeConfig.SNOWFLAKE_DATABASE.value, + "schema": SnowflakeConfig.SNOWFLAKE_SCHEMA.value, + } + + +def get_uc_volume_path() -> str: + """Return the base UC volume path where JSON artifacts live.""" + return ( + f"/Volumes/" + f"{UnityCatalogConfig.CATALOG.value}/" + f"{UnityCatalogConfig.SCHEMA.value}/" + f"{UnityCatalogConfig.RAW_VOLUME.value}" + ) \ No newline at end of file