Skip to content

Commit 7f64952

Browse files
authored
Merge pull request #36 from thisisqubika/feature/observability-agregation-fixes
Fix on observability metrics + code refactor
2 parents 02aac77 + 2f08ab2 commit 7f64952

6 files changed

Lines changed: 245 additions & 112 deletions

File tree

src/artifact_translation_package/databricks_job.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,22 @@
99
import json
1010
from typing import List, Dict, Any, Optional
1111
from pathlib import Path
12+
13+
# Load .env file for local development only (skip on Databricks)
14+
def _is_local_env() -> bool:
15+
"""Check if running locally (not on Databricks)."""
16+
return not (
17+
"DATABRICKS_RUNTIME_VERSION" in os.environ or
18+
os.path.exists("/databricks")
19+
)
20+
21+
if _is_local_env():
22+
try:
23+
from dotenv import load_dotenv
24+
load_dotenv()
25+
except ImportError:
26+
pass # python-dotenv not installed
27+
1228
from artifact_translation_package.utils.output_utils import make_timestamped_output_path, is_databricks_env
1329
from artifact_translation_package.utils.sql_file_writer import save_sql_files
1430
from artifact_translation_package.utils.result_saver import save_results

src/artifact_translation_package/graph_builder.py

Lines changed: 10 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from artifact_translation_package.utils.types import ArtifactBatch, TranslationResult
2626
from artifact_translation_package.utils.observability import initialize, finalize, get_observability
2727
from artifact_translation_package.utils.logger import LogLevel
28+
from artifact_translation_package.utils.result_merger import (
29+
create_empty_result,
30+
merge_result_into,
31+
)
2832

2933

3034
class TranslationState(TypedDict):
@@ -326,6 +330,10 @@ def __init__(self, run_id: Optional[str] = None, log_level: LogLevel = LogLevel.
326330

327331
def run(self, batch: ArtifactBatch) -> Dict[str, Any]:
328332
"""Process a single batch through the translation graph."""
333+
# Reset metrics for this batch to prevent accumulation across runs
334+
if self.obs:
335+
self.obs.get_metrics().reset()
336+
329337
self.logger.info("Starting translation graph execution", context={
330338
"artifact_type": batch.artifact_type,
331339
"batch_size": len(batch.items)
@@ -361,109 +369,7 @@ def run(self, batch: ArtifactBatch) -> Dict[str, Any]:
361369
summary = finalize()
362370
raise
363371

364-
def _initialize_merged_result(self) -> Dict[str, Any]:
365-
"""
366-
Initialize empty merged result structure.
367-
368-
Returns:
369-
Dictionary with empty result structure
370-
"""
371-
return {
372-
"databases": [],
373-
"schemas": [],
374-
"tables": [],
375-
"views": [],
376-
"stages": [],
377-
"external_locations": [],
378-
"streams": [],
379-
"pipes": [],
380-
"roles": [],
381-
"grants": [],
382-
"tags": [],
383-
"comments": [],
384-
"masking_policies": [],
385-
"udfs": [],
386-
"procedures": [],
387-
"metadata": {
388-
"total_results": 0,
389-
"errors": [],
390-
"processing_stats": {}
391-
}
392-
}
393372

394-
def _merge_result_into(
395-
self,
396-
merged_result: Dict[str, Any],
397-
result: Dict[str, Any]
398-
) -> None:
399-
"""
400-
Merge a single result into the merged result structure.
401-
402-
Args:
403-
merged_result: The merged result dictionary to update
404-
result: A single result dictionary to merge
405-
"""
406-
for key, value in result.items():
407-
if key == "metadata":
408-
merged_result["metadata"]["total_results"] += result["metadata"].get("total_results", 0)
409-
merged_result["metadata"]["errors"].extend(result["metadata"].get("errors", []))
410-
merged_result["metadata"]["processing_stats"].update(result["metadata"].get("processing_stats", {}))
411-
elif key == "observability":
412-
# Merge observability data
413-
if "observability" not in merged_result:
414-
merged_result["observability"] = {
415-
"run_id": value.get("run_id"),
416-
"total_duration": 0,
417-
"total_errors": 0,
418-
"total_warnings": 0,
419-
"total_retries": 0,
420-
"artifact_counts": {},
421-
"stages": {},
422-
"ai_metrics": {}
423-
}
424-
425-
obs = merged_result["observability"]
426-
427-
# Aggregate artifact_counts
428-
for artifact_type, count in value.get("artifact_counts", {}).items():
429-
obs["artifact_counts"][artifact_type] = obs["artifact_counts"].get(artifact_type, 0) + count
430-
431-
# Aggregate total_errors, total_warnings, total_retries
432-
obs["total_errors"] += value.get("total_errors", 0)
433-
obs["total_warnings"] += value.get("total_warnings", 0)
434-
obs["total_retries"] += value.get("total_retries", 0)
435-
436-
# Merge stages (aggregate items_processed, error_count and duration)
437-
for stage_name, stage_data in value.get("stages", {}).items():
438-
if stage_name not in obs["stages"]:
439-
obs["stages"][stage_name] = stage_data.copy()
440-
else:
441-
# Aggregate metrics
442-
obs["stages"][stage_name]["items_processed"] += stage_data.get("items_processed", 0)
443-
obs["stages"][stage_name]["error_count"] += stage_data.get("error_count", 0)
444-
445-
# Aggregate duration if present
446-
if "duration" in stage_data and stage_data["duration"] is not None:
447-
current_duration = obs["stages"][stage_name].get("duration", 0) or 0
448-
obs["stages"][stage_name]["duration"] = current_duration + stage_data["duration"]
449-
450-
# Merge ai_metrics
451-
for ai_key, ai_data in value.get("ai_metrics", {}).items():
452-
if ai_key not in obs["ai_metrics"]:
453-
obs["ai_metrics"][ai_key] = ai_data
454-
else:
455-
# Aggregate AI metrics
456-
obs["ai_metrics"][ai_key]["call_count"] += ai_data.get("call_count", 0)
457-
obs["ai_metrics"][ai_key]["total_latency"] += ai_data.get("total_latency", 0)
458-
obs["ai_metrics"][ai_key]["errors"] += ai_data.get("errors", 0)
459-
# Recalculate average latency
460-
if obs["ai_metrics"][ai_key]["call_count"] > 0:
461-
obs["ai_metrics"][ai_key]["average_latency"] = (
462-
obs["ai_metrics"][ai_key]["total_latency"] /
463-
obs["ai_metrics"][ai_key]["call_count"]
464-
)
465-
elif key in merged_result:
466-
merged_result[key].extend(value)
467373

468374
def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]:
469375
"""
@@ -486,10 +392,10 @@ def run_batches(self, batches: List[ArtifactBatch]) -> Dict[str, Any]:
486392
all_results.append(result)
487393

488394
if all_results:
489-
merged_result = self._initialize_merged_result()
395+
merged_result = create_empty_result()
490396

491397
for result in all_results:
492-
self._merge_result_into(merged_result, result)
398+
merge_result_into(merged_result, result)
493399

494400
# Calculate total duration
495401
end_time = time.time()

src/artifact_translation_package/nodes/aggregator.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,19 @@ def aggregate_translations(*results: TranslationResult, evaluation_results: Opti
4848
if result.errors:
4949
merged["metadata"]["errors"].extend(result.errors)
5050

51-
# Update processing stats
52-
merged["metadata"]["processing_stats"][artifact_type] = {
53-
"count": len(result.results),
54-
"errors": len(result.errors),
55-
**result.metadata
56-
}
51+
# Update processing stats - accumulate if artifact_type already exists
52+
if artifact_type in merged["metadata"]["processing_stats"]:
53+
# Accumulate counts for same artifact type across batches
54+
existing = merged["metadata"]["processing_stats"][artifact_type]
55+
existing["count"] = existing.get("count", 0) + len(result.results)
56+
existing["errors"] = existing.get("errors", 0) + len(result.errors)
57+
existing["processed"] = existing.get("processed", 0) + result.metadata.get("processed", len(result.results))
58+
else:
59+
merged["metadata"]["processing_stats"][artifact_type] = {
60+
"count": len(result.results),
61+
"errors": len(result.errors),
62+
**result.metadata
63+
}
5764

5865
merged["metadata"]["total_results"] += len(result.results)
5966

src/artifact_translation_package/utils/metrics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@ def __init__(self):
6868
"""Initialize metrics collector."""
6969
if self._initialized:
7070
return
71+
self.reset()
72+
self._initialized = True
73+
74+
def reset(self):
75+
"""Reset all metrics for a new run.
7176
77+
This should be called at the start of each translation run to ensure
78+
metrics don't accumulate across runs.
79+
"""
7280
self.run_id: Optional[str] = None
7381
self.start_time: float = time.time()
7482
self.end_time: Optional[float] = None
@@ -82,7 +90,6 @@ def __init__(self):
8290
self.total_retries: int = 0
8391

8492
self.logger = get_logger("metrics")
85-
self._initialized = True
8693

8794
def set_run_id(self, run_id: str):
8895
"""Set run ID for this execution."""

src/artifact_translation_package/utils/observability.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def __init__(
4444

4545
self.logger = get_logger("observability", level=log_level, handlers=handlers)
4646

47-
# Setup metrics
47+
# Setup metrics - reset to clear any accumulated state from previous runs
4848
self.metrics = get_metrics_collector()
49+
self.metrics.reset()
4950
self.metrics.set_run_id(self.run_id)
5051

5152
self.logger.info("Observability initialized", context={"run_id": self.run_id})

0 commit comments

Comments
 (0)