Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/artifact_translation_package/databricks_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
import json
from typing import List, Dict, Any, Optional
from pathlib import Path

# Load .env file for local development only (skip on Databricks)
def _is_local_env() -> bool:
"""Check if running locally (not on Databricks)."""
return not (
"DATABRICKS_RUNTIME_VERSION" in os.environ or
os.path.exists("/databricks")
)

if _is_local_env():
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # python-dotenv not installed

from artifact_translation_package.utils.output_utils import make_timestamped_output_path, is_databricks_env
from artifact_translation_package.utils.sql_file_writer import save_sql_files
from artifact_translation_package.utils.result_saver import save_results
Expand Down
114 changes: 10 additions & 104 deletions src/artifact_translation_package/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from artifact_translation_package.utils.types import ArtifactBatch, TranslationResult
from artifact_translation_package.utils.observability import initialize, finalize, get_observability
from artifact_translation_package.utils.logger import LogLevel
from artifact_translation_package.utils.result_merger import (
create_empty_result,
merge_result_into,
)


class TranslationState(TypedDict):
Expand Down Expand Up @@ -326,6 +330,10 @@ def __init__(self, run_id: Optional[str] = None, log_level: LogLevel = LogLevel.

def run(self, batch: ArtifactBatch) -> Dict[str, Any]:
"""Process a single batch through the translation graph."""
# Reset metrics for this batch to prevent accumulation across runs
if self.obs:
self.obs.get_metrics().reset()

self.logger.info("Starting translation graph execution", context={
"artifact_type": batch.artifact_type,
"batch_size": len(batch.items)
Expand Down Expand Up @@ -361,109 +369,7 @@ def run(self, batch: ArtifactBatch) -> Dict[str, Any]:
summary = finalize()
raise

def _initialize_merged_result(self) -> Dict[str, Any]:
"""
Initialize empty merged result structure.

Returns:
Dictionary with empty result structure
"""
return {
"databases": [],
"schemas": [],
"tables": [],
"views": [],
"stages": [],
"external_locations": [],
"streams": [],
"pipes": [],
"roles": [],
"grants": [],
"tags": [],
"comments": [],
"masking_policies": [],
"udfs": [],
"procedures": [],
"metadata": {
"total_results": 0,
"errors": [],
"processing_stats": {}
}
}

def _merge_result_into(
self,
merged_result: Dict[str, Any],
result: Dict[str, Any]
) -> None:
"""
Merge a single result into the merged result structure.

Args:
merged_result: The merged result dictionary to update
result: A single result dictionary to merge
"""
for key, value in result.items():
if key == "metadata":
merged_result["metadata"]["total_results"] += result["metadata"].get("total_results", 0)
merged_result["metadata"]["errors"].extend(result["metadata"].get("errors", []))
merged_result["metadata"]["processing_stats"].update(result["metadata"].get("processing_stats", {}))
elif key == "observability":
# Merge observability data
if "observability" not in merged_result:
merged_result["observability"] = {
"run_id": value.get("run_id"),
"total_duration": 0,
"total_errors": 0,
"total_warnings": 0,
"total_retries": 0,
"artifact_counts": {},
"stages": {},
"ai_metrics": {}
}

obs = merged_result["observability"]

# Aggregate artifact_counts
for artifact_type, count in value.get("artifact_counts", {}).items():
obs["artifact_counts"][artifact_type] = obs["artifact_counts"].get(artifact_type, 0) + count

# Aggregate total_errors, total_warnings, total_retries
obs["total_errors"] += value.get("total_errors", 0)
obs["total_warnings"] += value.get("total_warnings", 0)
obs["total_retries"] += value.get("total_retries", 0)

# Merge stages (aggregate items_processed, error_count and duration)
for stage_name, stage_data in value.get("stages", {}).items():
if stage_name not in obs["stages"]:
obs["stages"][stage_name] = stage_data.copy()
else:
# Aggregate metrics
obs["stages"][stage_name]["items_processed"] += stage_data.get("items_processed", 0)
obs["stages"][stage_name]["error_count"] += stage_data.get("error_count", 0)

# Aggregate duration if present
if "duration" in stage_data and stage_data["duration"] is not None:
current_duration = obs["stages"][stage_name].get("duration", 0) or 0
obs["stages"][stage_name]["duration"] = current_duration + stage_data["duration"]

# Merge ai_metrics
for ai_key, ai_data in value.get("ai_metrics", {}).items():
if ai_key not in obs["ai_metrics"]:
obs["ai_metrics"][ai_key] = ai_data
else:
# Aggregate AI metrics
obs["ai_metrics"][ai_key]["call_count"] += ai_data.get("call_count", 0)
obs["ai_metrics"][ai_key]["total_latency"] += ai_data.get("total_latency", 0)
obs["ai_metrics"][ai_key]["errors"] += ai_data.get("errors", 0)
# Recalculate average latency
if obs["ai_metrics"][ai_key]["call_count"] > 0:
obs["ai_metrics"][ai_key]["average_latency"] = (
obs["ai_metrics"][ai_key]["total_latency"] /
obs["ai_metrics"][ai_key]["call_count"]
)
elif key in merged_result:
merged_result[key].extend(value)

def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]:
"""
Expand All @@ -486,10 +392,10 @@ def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]:
all_results.append(result)

if all_results:
merged_result = self._initialize_merged_result()
merged_result = create_empty_result()

for result in all_results:
self._merge_result_into(merged_result, result)
merge_result_into(merged_result, result)

# Calculate total duration
end_time = time.time()
Expand Down
19 changes: 13 additions & 6 deletions src/artifact_translation_package/nodes/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ def aggregate_translations(*results: TranslationResult, evaluation_results: Opti
if result.errors:
merged["metadata"]["errors"].extend(result.errors)

# Update processing stats
merged["metadata"]["processing_stats"][artifact_type] = {
"count": len(result.results),
"errors": len(result.errors),
**result.metadata
}
# Update processing stats - accumulate if artifact_type already exists
if artifact_type in merged["metadata"]["processing_stats"]:
# Accumulate counts for same artifact type across batches
existing = merged["metadata"]["processing_stats"][artifact_type]
existing["count"] = existing.get("count", 0) + len(result.results)
existing["errors"] = existing.get("errors", 0) + len(result.errors)
existing["processed"] = existing.get("processed", 0) + result.metadata.get("processed", len(result.results))
else:
merged["metadata"]["processing_stats"][artifact_type] = {
"count": len(result.results),
"errors": len(result.errors),
**result.metadata
}

merged["metadata"]["total_results"] += len(result.results)

Expand Down
9 changes: 8 additions & 1 deletion src/artifact_translation_package/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,15 @@ def __init__(self):
"""Initialize metrics collector."""
if self._initialized:
return
self.reset()
self._initialized = True

def reset(self):
"""Reset all metrics for a new run.

This should be called at the start of each translation run to ensure
metrics don't accumulate across runs.
"""
self.run_id: Optional[str] = None
self.start_time: float = time.time()
self.end_time: Optional[float] = None
Expand All @@ -82,7 +90,6 @@ def __init__(self):
self.total_retries: int = 0

self.logger = get_logger("metrics")
self._initialized = True

def set_run_id(self, run_id: str):
"""Set run ID for this execution."""
Expand Down
3 changes: 2 additions & 1 deletion src/artifact_translation_package/utils/observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def __init__(

self.logger = get_logger("observability", level=log_level, handlers=handlers)

# Setup metrics
# Setup metrics - reset to clear any accumulated state from previous runs
self.metrics = get_metrics_collector()
self.metrics.reset()
self.metrics.set_run_id(self.run_id)

self.logger.info("Observability initialized", context={"run_id": self.run_id})
Expand Down
Loading