diff --git a/src/artifact_translation_package/databricks_job.py b/src/artifact_translation_package/databricks_job.py index 237aafa..57e2520 100644 --- a/src/artifact_translation_package/databricks_job.py +++ b/src/artifact_translation_package/databricks_job.py @@ -9,6 +9,22 @@ import json from typing import List, Dict, Any, Optional from pathlib import Path + +# Load .env file for local development only (skip on Databricks) +def _is_local_env() -> bool: + """Check if running locally (not on Databricks).""" + return not ( + "DATABRICKS_RUNTIME_VERSION" in os.environ or + os.path.exists("/databricks") + ) + +if _is_local_env(): + try: + from dotenv import load_dotenv + load_dotenv() + except ImportError: + pass # python-dotenv not installed + from artifact_translation_package.utils.output_utils import make_timestamped_output_path, is_databricks_env from artifact_translation_package.utils.sql_file_writer import save_sql_files from artifact_translation_package.utils.result_saver import save_results diff --git a/src/artifact_translation_package/graph_builder.py b/src/artifact_translation_package/graph_builder.py index b5f7062..58390ea 100644 --- a/src/artifact_translation_package/graph_builder.py +++ b/src/artifact_translation_package/graph_builder.py @@ -25,6 +25,10 @@ from artifact_translation_package.utils.types import ArtifactBatch, TranslationResult from artifact_translation_package.utils.observability import initialize, finalize, get_observability from artifact_translation_package.utils.logger import LogLevel +from artifact_translation_package.utils.result_merger import ( + create_empty_result, + merge_result_into, +) class TranslationState(TypedDict): @@ -326,6 +330,10 @@ def __init__(self, run_id: Optional[str] = None, log_level: LogLevel = LogLevel. def run(self, batch: ArtifactBatch) -> Dict[str, Any]: """Process a single batch through the translation graph.""" + # Reset metrics for this batch to prevent accumulation across runs + if self.obs: + self.obs.get_metrics().reset() + self.logger.info("Starting translation graph execution", context={ "artifact_type": batch.artifact_type, "batch_size": len(batch.items) @@ -361,109 +369,7 @@ def run(self, batch: ArtifactBatch) -> Dict[str, Any]: summary = finalize() raise - def _initialize_merged_result(self) -> Dict[str, Any]: - """ - Initialize empty merged result structure. - - Returns: - Dictionary with empty result structure - """ - return { - "databases": [], - "schemas": [], - "tables": [], - "views": [], - "stages": [], - "external_locations": [], - "streams": [], - "pipes": [], - "roles": [], - "grants": [], - "tags": [], - "comments": [], - "masking_policies": [], - "udfs": [], - "procedures": [], - "metadata": { - "total_results": 0, - "errors": [], - "processing_stats": {} - } - } - def _merge_result_into( - self, - merged_result: Dict[str, Any], - result: Dict[str, Any] - ) -> None: - """ - Merge a single result into the merged result structure. - - Args: - merged_result: The merged result dictionary to update - result: A single result dictionary to merge - """ - for key, value in result.items(): - if key == "metadata": - merged_result["metadata"]["total_results"] += result["metadata"].get("total_results", 0) - merged_result["metadata"]["errors"].extend(result["metadata"].get("errors", [])) - merged_result["metadata"]["processing_stats"].update(result["metadata"].get("processing_stats", {})) - elif key == "observability": - # Merge observability data - if "observability" not in merged_result: - merged_result["observability"] = { - "run_id": value.get("run_id"), - "total_duration": 0, - "total_errors": 0, - "total_warnings": 0, - "total_retries": 0, - "artifact_counts": {}, - "stages": {}, - "ai_metrics": {} - } - - obs = merged_result["observability"] - - # Aggregate artifact_counts - for artifact_type, count in value.get("artifact_counts", {}).items(): - obs["artifact_counts"][artifact_type] = obs["artifact_counts"].get(artifact_type, 0) + count - - # Aggregate total_errors, total_warnings, total_retries - obs["total_errors"] += value.get("total_errors", 0) - obs["total_warnings"] += value.get("total_warnings", 0) - obs["total_retries"] += value.get("total_retries", 0) - - # Merge stages (aggregate items_processed, error_count and duration) - for stage_name, stage_data in value.get("stages", {}).items(): - if stage_name not in obs["stages"]: - obs["stages"][stage_name] = stage_data.copy() - else: - # Aggregate metrics - obs["stages"][stage_name]["items_processed"] += stage_data.get("items_processed", 0) - obs["stages"][stage_name]["error_count"] += stage_data.get("error_count", 0) - - # Aggregate duration if present - if "duration" in stage_data and stage_data["duration"] is not None: - current_duration = obs["stages"][stage_name].get("duration", 0) or 0 - obs["stages"][stage_name]["duration"] = current_duration + stage_data["duration"] - - # Merge ai_metrics - for ai_key, ai_data in value.get("ai_metrics", {}).items(): - if ai_key not in obs["ai_metrics"]: - obs["ai_metrics"][ai_key] = ai_data - else: - # Aggregate AI metrics - obs["ai_metrics"][ai_key]["call_count"] += ai_data.get("call_count", 0) - obs["ai_metrics"][ai_key]["total_latency"] += ai_data.get("total_latency", 0) - obs["ai_metrics"][ai_key]["errors"] += ai_data.get("errors", 0) - # Recalculate average latency - if obs["ai_metrics"][ai_key]["call_count"] > 0: - obs["ai_metrics"][ai_key]["average_latency"] = ( - obs["ai_metrics"][ai_key]["total_latency"] / - obs["ai_metrics"][ai_key]["call_count"] - ) - elif key in merged_result: - merged_result[key].extend(value) def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]: """ @@ -486,10 +392,10 @@ def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]: all_results.append(result) if all_results: - merged_result = self._initialize_merged_result() + merged_result = create_empty_result() for result in all_results: - self._merge_result_into(merged_result, result) + merge_result_into(merged_result, result) # Calculate total duration end_time = time.time() diff --git a/src/artifact_translation_package/nodes/aggregator.py b/src/artifact_translation_package/nodes/aggregator.py index 5875e37..bb05f39 100644 --- a/src/artifact_translation_package/nodes/aggregator.py +++ b/src/artifact_translation_package/nodes/aggregator.py @@ -48,12 +48,19 @@ def aggregate_translations(*results: TranslationResult, evaluation_results: Opti if result.errors: merged["metadata"]["errors"].extend(result.errors) - # Update processing stats - merged["metadata"]["processing_stats"][artifact_type] = { - "count": len(result.results), - "errors": len(result.errors), - **result.metadata - } + # Update processing stats - accumulate if artifact_type already exists + if artifact_type in merged["metadata"]["processing_stats"]: + # Accumulate counts for same artifact type across batches + existing = merged["metadata"]["processing_stats"][artifact_type] + existing["count"] = existing.get("count", 0) + len(result.results) + existing["errors"] = existing.get("errors", 0) + len(result.errors) + existing["processed"] = existing.get("processed", 0) + result.metadata.get("processed", len(result.results)) + else: + merged["metadata"]["processing_stats"][artifact_type] = { + "count": len(result.results), + "errors": len(result.errors), + **result.metadata + } merged["metadata"]["total_results"] += len(result.results) diff --git a/src/artifact_translation_package/utils/metrics.py b/src/artifact_translation_package/utils/metrics.py index 39e6b12..4e34ab5 100644 --- a/src/artifact_translation_package/utils/metrics.py +++ b/src/artifact_translation_package/utils/metrics.py @@ -68,7 +68,15 @@ def __init__(self): """Initialize metrics collector.""" if self._initialized: return + self.reset() + self._initialized = True + + def reset(self): + """Reset all metrics for a new run. + This should be called at the start of each translation run to ensure + metrics don't accumulate across runs. + """ self.run_id: Optional[str] = None self.start_time: float = time.time() self.end_time: Optional[float] = None @@ -82,7 +90,6 @@ def __init__(self): self.total_retries: int = 0 self.logger = get_logger("metrics") - self._initialized = True def set_run_id(self, run_id: str): """Set run ID for this execution.""" diff --git a/src/artifact_translation_package/utils/observability.py b/src/artifact_translation_package/utils/observability.py index b30c8cb..56a9603 100644 --- a/src/artifact_translation_package/utils/observability.py +++ b/src/artifact_translation_package/utils/observability.py @@ -44,8 +44,9 @@ def __init__( self.logger = get_logger("observability", level=log_level, handlers=handlers) - # Setup metrics + # Setup metrics - reset to clear any accumulated state from previous runs self.metrics = get_metrics_collector() + self.metrics.reset() self.metrics.set_run_id(self.run_id) self.logger.info("Observability initialized", context={"run_id": self.run_id}) diff --git a/src/artifact_translation_package/utils/result_merger.py b/src/artifact_translation_package/utils/result_merger.py new file mode 100644 index 0000000..3caedf6 --- /dev/null +++ b/src/artifact_translation_package/utils/result_merger.py @@ -0,0 +1,196 @@ +"""Result merging utilities for translation outputs. + +Provides functions to merge and aggregate translation results from multiple +batch runs into a single consolidated result structure. +""" + +from typing import Dict, Any, List + + +# Default artifact types for result structure +ARTIFACT_TYPES = [ + "databases", "schemas", "tables", "views", "stages", "external_locations", + "streams", "pipes", "roles", "grants", "tags", "comments", + "masking_policies", "udfs", "procedures" +] + + +def create_empty_result() -> Dict[str, Any]: + """ + Create an empty result structure for merging. + + Returns: + Dictionary with empty result structure ready for merging + """ + result = {artifact_type: [] for artifact_type in ARTIFACT_TYPES} + result["metadata"] = { + "total_results": 0, + "errors": [], + "processing_stats": {} + } + return result + + +def accumulate_processing_stats( + target_stats: Dict[str, Any], + source_stats: Dict[str, Any] +) -> None: + """ + Accumulate processing stats from source into target. + + Args: + target_stats: Target stats dictionary to update + source_stats: Source stats dictionary to merge from + """ + for artifact_type, stats in source_stats.items(): + if artifact_type in target_stats: + existing = target_stats[artifact_type] + existing["count"] = existing.get("count", 0) + stats.get("count", 0) + existing["errors"] = existing.get("errors", 0) + stats.get("errors", 0) + existing["processed"] = existing.get("processed", 0) + stats.get("processed", 0) + else: + target_stats[artifact_type] = stats.copy() + + +def merge_metadata( + merged_metadata: Dict[str, Any], + source_metadata: Dict[str, Any] +) -> None: + """ + Merge metadata from a single result into merged metadata. + + Args: + merged_metadata: Target metadata dictionary to update + source_metadata: Source metadata dictionary to merge from + """ + merged_metadata["total_results"] += source_metadata.get("total_results", 0) + merged_metadata["errors"].extend(source_metadata.get("errors", [])) + accumulate_processing_stats( + merged_metadata["processing_stats"], + source_metadata.get("processing_stats", {}) + ) + + +def merge_observability_stages( + target_stages: Dict[str, Any], + source_stages: Dict[str, Any] +) -> None: + """ + Merge stage metrics from source into target. + + Args: + target_stages: Target stages dictionary to update + source_stages: Source stages dictionary to merge from + """ + for stage_name, stage_data in source_stages.items(): + if stage_name not in target_stages: + target_stages[stage_name] = stage_data.copy() + else: + target_stages[stage_name]["items_processed"] += stage_data.get("items_processed", 0) + target_stages[stage_name]["error_count"] += stage_data.get("error_count", 0) + if "duration" in stage_data and stage_data["duration"] is not None: + current = target_stages[stage_name].get("duration", 0) or 0 + target_stages[stage_name]["duration"] = current + stage_data["duration"] + + +def merge_ai_metrics( + target_metrics: Dict[str, Any], + source_metrics: Dict[str, Any] +) -> None: + """ + Merge AI metrics from source into target. + + Args: + target_metrics: Target AI metrics dictionary to update + source_metrics: Source AI metrics dictionary to merge from + """ + for ai_key, ai_data in source_metrics.items(): + if ai_key not in target_metrics: + target_metrics[ai_key] = ai_data.copy() + else: + target_metrics[ai_key]["call_count"] += ai_data.get("call_count", 0) + target_metrics[ai_key]["total_latency"] += ai_data.get("total_latency", 0) + target_metrics[ai_key]["errors"] += ai_data.get("errors", 0) + if target_metrics[ai_key]["call_count"] > 0: + target_metrics[ai_key]["average_latency"] = ( + target_metrics[ai_key]["total_latency"] / + target_metrics[ai_key]["call_count"] + ) + + +def merge_observability( + merged_result: Dict[str, Any], + source_obs: Dict[str, Any] +) -> None: + """ + Merge observability data from a single result into merged result. + + Args: + merged_result: Target result dictionary containing observability + source_obs: Source observability dictionary to merge from + """ + if "observability" not in merged_result: + merged_result["observability"] = { + "run_id": source_obs.get("run_id"), + "total_duration": 0, + "total_errors": 0, + "total_warnings": 0, + "total_retries": 0, + "artifact_counts": {}, + "stages": {}, + "ai_metrics": {} + } + + obs = merged_result["observability"] + + # Aggregate artifact_counts + for artifact_type, count in source_obs.get("artifact_counts", {}).items(): + obs["artifact_counts"][artifact_type] = obs["artifact_counts"].get(artifact_type, 0) + count + + # Aggregate totals + obs["total_errors"] += source_obs.get("total_errors", 0) + obs["total_warnings"] += source_obs.get("total_warnings", 0) + obs["total_retries"] += source_obs.get("total_retries", 0) + + # Merge nested data + merge_observability_stages(obs["stages"], source_obs.get("stages", {})) + merge_ai_metrics(obs["ai_metrics"], source_obs.get("ai_metrics", {})) + + +def merge_result_into( + merged_result: Dict[str, Any], + result: Dict[str, Any] +) -> None: + """ + Merge a single result into the merged result structure. + + Args: + merged_result: The merged result dictionary to update + result: A single result dictionary to merge + """ + for key, value in result.items(): + if key == "metadata": + merge_metadata(merged_result["metadata"], result["metadata"]) + elif key == "observability": + merge_observability(merged_result, value) + elif key in merged_result: + merged_result[key].extend(value) + + +def merge_results(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Merge multiple translation results into a single consolidated result. + + This is a convenience function that creates an empty result and merges + all provided results into it. + + Args: + results: List of result dictionaries to merge + + Returns: + Merged result dictionary + """ + merged = create_empty_result() + for result in results: + merge_result_into(merged, result) + return merged