Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions data_rentgen/consumer/saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from faststream import Logger
from sqlalchemy.exc import DatabaseError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from data_rentgen.consumer.extractors import BatchExtractionResult
from data_rentgen.services.uow import UnitOfWork


class DatabaseSaver:
def __init__(
self,
session: AsyncSession,
logger: Logger,
) -> None:
self.unit_of_work = UnitOfWork(session)
self.logger = logger

async def save(self, data: BatchExtractionResult):
self.logger.info("Saving to database")

await self.create_locations(data)
await self.create_datasets(data)
await self.create_dataset_symlinks(data)
await self.create_job_types(data)
await self.create_jobs(data)
await self.create_users(data)
await self.create_sql_queries(data)
await self.create_schemas(data)

try:
await self.create_runs_bulk(data)
except DatabaseError:
await self.create_runs_one_by_one(data)

await self.create_operations(data)
await self.create_inputs(data)
await self.create_outputs(data)
await self.create_column_lineage(data)

self.logger.info("Saved successfully")

async def create_locations(self, data: BatchExtractionResult):
self.logger.debug("Creating locations")
# It's hard to fetch locations in bulk, and number of locations is usually small,
# so using a row-by-row approach
for location_dto in data.locations():
async with self.unit_of_work:
location = await self.unit_of_work.location.create_or_update(location_dto)
location_dto.id = location.id

# To avoid deadlocks when parallel consumer instances insert/update the same row,
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
# But most entities are unchanged after creation, so we could just fetch them, and do nothing.
async def create_datasets(self, data: BatchExtractionResult):
self.logger.debug("Creating datasets")
dataset_pairs = await self.unit_of_work.dataset.fetch_bulk(data.datasets())
for dataset_dto, dataset in dataset_pairs:
if not dataset:
async with self.unit_of_work:
dataset = await self.unit_of_work.dataset.create(dataset_dto) # noqa: PLW2901
dataset_dto.id = dataset.id

async def create_dataset_symlinks(self, data: BatchExtractionResult):
self.logger.debug("Creating dataset symlinks")
dataset_symlinks_pairs = await self.unit_of_work.dataset_symlink.fetch_bulk(data.dataset_symlinks())
for dataset_symlink_dto, dataset_symlink in dataset_symlinks_pairs:
if not dataset_symlink:
async with self.unit_of_work:
dataset_symlink = await self.unit_of_work.dataset_symlink.create(dataset_symlink_dto) # noqa: PLW2901
dataset_symlink_dto.id = dataset_symlink.id

async def create_job_types(self, data: BatchExtractionResult):
self.logger.debug("Creating job types")
job_type_pairs = await self.unit_of_work.job_type.fetch_bulk(data.job_types())
for job_type_dto, job_type in job_type_pairs:
if not job_type:
async with self.unit_of_work:
job_type = await self.unit_of_work.job_type.create(job_type_dto) # noqa: PLW2901
job_type_dto.id = job_type.id

async def create_jobs(self, data: BatchExtractionResult):
self.logger.debug("Creating jobs")
job_pairs = await self.unit_of_work.job.fetch_bulk(data.jobs())
for job_dto, job in job_pairs:
async with self.unit_of_work:
if not job:
job = await self.unit_of_work.job.create_or_update(job_dto) # noqa: PLW2901
else:
job = await self.unit_of_work.job.update(job, job_dto) # noqa: PLW2901
job_dto.id = job.id

async def create_users(self, data: BatchExtractionResult):
self.logger.debug("Creating users")
user_pairs = await self.unit_of_work.user.fetch_bulk(data.users())
for user_dto, user in user_pairs:
if not user:
async with self.unit_of_work:
user = await self.unit_of_work.user.create(user_dto) # noqa: PLW2901
user_dto.id = user.id

async def create_sql_queries(self, data: BatchExtractionResult):
self.logger.debug("Creating sql queries")
sql_query_ids = await self.unit_of_work.sql_query.fetch_known_ids(data.sql_queries())
for sql_query_dto, sql_query_id in sql_query_ids:
if not sql_query_id:
async with self.unit_of_work:
sql_query = await self.unit_of_work.sql_query.create(sql_query_dto)
sql_query_dto.id = sql_query.id
else:
sql_query_dto.id = sql_query_id

async def create_schemas(self, data: BatchExtractionResult):
self.logger.debug("Creating schemas")
schema_ids = await self.unit_of_work.schema.fetch_known_ids(data.schemas())
for schema_dto, schema_id in schema_ids:
if not schema_id:
async with self.unit_of_work:
schema = await self.unit_of_work.schema.create(schema_dto)
schema_dto.id = schema.id
else:
schema_dto.id = schema_id

# In most cases, all the run tree created by some parent is send into one
# Kafka partition, and thus handled by just one worker.
# Cross fingers and create all runs in one transaction.
async def create_runs_bulk(self, data: BatchExtractionResult):
self.logger.debug("Creating runs in bulk")
async with self.unit_of_work:
await self.unit_of_work.run.create_or_update_bulk(data.runs())

# In case then child and parent runs are in different partitions,
# multiple workers may try to create/update the same run, leading to a deadlock.
# Fallback to creating runs one by one
async def create_runs_one_by_one(self, data: BatchExtractionResult):
self.logger.debug("Creating runs in one-by-one")
run_pairs = await self.unit_of_work.run.fetch_bulk(data.runs())
for run_dto, run in run_pairs:
try:
async with self.unit_of_work:
if not run:
await self.unit_of_work.run.create(run_dto)
else:
await self.unit_of_work.run.update(run, run_dto)
except IntegrityError: # noqa: PERF203
# deadlock occurred, states in DB and RAM are out of sync,
# so we have to fetch run from DB
async with self.unit_of_work:
await self.unit_of_work.run.create_or_update(run_dto)

# All events related to same operation are always send to the same Kafka partition,
# so other workers never insert/update the same operation in parallel.
# These rows can be inserted/updated in bulk, in one transaction.
async def create_operations(self, data: BatchExtractionResult):
async with self.unit_of_work:
self.logger.debug("Creating operations")
await self.unit_of_work.operation.create_or_update_bulk(data.operations())

async def create_inputs(self, data: BatchExtractionResult):
async with self.unit_of_work:
self.logger.debug("Creating inputs")
await self.unit_of_work.input.create_or_update_bulk(data.inputs())

async def create_outputs(self, data: BatchExtractionResult):
async with self.unit_of_work:
self.logger.debug("Creating outputs")
await self.unit_of_work.output.create_or_update_bulk(data.outputs())

async def create_column_lineage(self, data: BatchExtractionResult):
async with self.unit_of_work:
self.logger.debug("Creating dataset column relations")
await self.unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(data.column_lineage())

self.logger.debug("Creating column lineage")
await self.unit_of_work.column_lineage.create_bulk(data.column_lineage())
99 changes: 4 additions & 95 deletions data_rentgen/consumer/subscribers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from pydantic import TypeAdapter
from sqlalchemy.ext.asyncio import AsyncSession

from data_rentgen.consumer.extractors import BatchExtractionResult, BatchExtractor
from data_rentgen.consumer.extractors import BatchExtractor
from data_rentgen.consumer.saver import DatabaseSaver
from data_rentgen.dependencies.stub import Stub
from data_rentgen.openlineage.run_event import OpenLineageRunEvent
from data_rentgen.services.uow import UnitOfWork

__all__ = [
"runs_events_subscriber",
Expand Down Expand Up @@ -54,9 +54,8 @@ async def runs_events_subscriber(
extracted = extractor.result
logger.info("Got %r", extracted)

logger.info("Saving to database")
await save_to_db(extracted, session, logger)
logger.info("Saved successfully")
saver = DatabaseSaver(session, logger)
await saver.save(extracted)

if malformed:
logger.warning("Malformed messages: %d", len(malformed))
Expand Down Expand Up @@ -88,96 +87,6 @@ async def parse_messages(
await asyncio.sleep(0)


async def save_to_db(
data: BatchExtractionResult,
session: AsyncSession,
logger: Logger,
) -> None:
# To avoid deadlocks when parallel consumer instances insert/update the same row,
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.

unit_of_work = UnitOfWork(session)

logger.debug("Creating locations")
for location_dto in data.locations():
async with unit_of_work:
location = await unit_of_work.location.create_or_update(location_dto)
location_dto.id = location.id

logger.debug("Creating datasets")
for dataset_dto in data.datasets():
async with unit_of_work:
dataset = await unit_of_work.dataset.get_or_create(dataset_dto)
dataset_dto.id = dataset.id

logger.debug("Creating symlinks")
for dataset_symlink_dto in data.dataset_symlinks():
async with unit_of_work:
dataset_symlink = await unit_of_work.dataset_symlink.get_or_create(dataset_symlink_dto)
dataset_symlink_dto.id = dataset_symlink.id

logger.debug("Creating job types")
for job_type_dto in data.job_types():
async with unit_of_work:
job_type = await unit_of_work.job_type.get_or_create(job_type_dto)
job_type_dto.id = job_type.id

logger.debug("Creating jobs")
for job_dto in data.jobs():
async with unit_of_work:
job = await unit_of_work.job.create_or_update(job_dto)
job_dto.id = job.id

logger.debug("Creating sql queries")
for sql_query_dto in data.sql_queries():
async with unit_of_work:
sql_query = await unit_of_work.sql_query.get_or_create(sql_query_dto)
sql_query_dto.id = sql_query.id

logger.debug("Creating users")
for user_dto in data.users():
async with unit_of_work:
user = await unit_of_work.user.get_or_create(user_dto)
user_dto.id = user.id

logger.debug("Creating schemas")
for schema_dto in data.schemas():
async with unit_of_work:
schema = await unit_of_work.schema.get_or_create(schema_dto)
schema_dto.id = schema.id

# Some events related to specific run are send to the same Kafka partition,
# but at the same time we have parent_run which may be already inserted/updated by other worker
# (Kafka key maybe different for run and it's parent).
# In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks.
logger.debug("Creating runs")
for run_dto in data.runs():
async with unit_of_work:
await unit_of_work.run.create_or_update(run_dto)

# All events related to same operation are always send to the same Kafka partition,
# so other workers never insert/update the same operation in parallel.
# These rows can be inserted/updated in bulk, in one transaction.
async with unit_of_work:
logger.debug("Creating operations")
await unit_of_work.operation.create_or_update_bulk(data.operations())

logger.debug("Creating inputs")
await unit_of_work.input.create_or_update_bulk(data.inputs())

logger.debug("Creating outputs")
await unit_of_work.output.create_or_update_bulk(data.outputs())

# If something went wrong here, at least we will have inputs/outputs
async with unit_of_work:
column_lineage = data.column_lineage()
logger.debug("Creating dataset column relations")
await unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(column_lineage)

logger.debug("Creating column lineage")
await unit_of_work.column_lineage.create_bulk(column_lineage)


async def report_malformed(
messages: list[ConsumerRecord],
message_id: str,
Expand Down
15 changes: 13 additions & 2 deletions data_rentgen/db/repositories/column_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import NamedTuple
from uuid import UUID

from sqlalchemy import ColumnElement, any_, func, select, tuple_
from sqlalchemy import ARRAY, ColumnElement, Integer, any_, cast, func, select, tuple_
from sqlalchemy.dialects.postgresql import insert

from data_rentgen.db.models import ColumnLineage, DatasetColumnRelation
Expand Down Expand Up @@ -123,9 +123,20 @@ async def list_by_dataset_pairs(
if not dataset_ids_pairs:
return []

source_dataset_ids = [pair[0] for pair in dataset_ids_pairs]
target_dataset_ids = [pair[1] for pair in dataset_ids_pairs]
pairs = (
func.unnest(
cast(source_dataset_ids, ARRAY(Integer())),
cast(target_dataset_ids, ARRAY(Integer())),
)
.table_valued("source_dataset_id", "target_dataset_id")
.render_derived()
)

where = [
ColumnLineage.created_at >= since,
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(dataset_ids_pairs),
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(select(pairs)),
]
if until:
where.append(
Expand Down
Loading
Loading