diff --git a/data_rentgen/consumer/saver.py b/data_rentgen/consumer/saver.py new file mode 100644 index 00000000..23aa0d2f --- /dev/null +++ b/data_rentgen/consumer/saver.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: 2024-2025 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from faststream import Logger +from sqlalchemy.exc import DatabaseError, IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from data_rentgen.consumer.extractors import BatchExtractionResult +from data_rentgen.services.uow import UnitOfWork + + +class DatabaseSaver: + def __init__( + self, + session: AsyncSession, + logger: Logger, + ) -> None: + self.unit_of_work = UnitOfWork(session) + self.logger = logger + + async def save(self, data: BatchExtractionResult): + self.logger.info("Saving to database") + + await self.create_locations(data) + await self.create_datasets(data) + await self.create_dataset_symlinks(data) + await self.create_job_types(data) + await self.create_jobs(data) + await self.create_users(data) + await self.create_sql_queries(data) + await self.create_schemas(data) + + try: + await self.create_runs_bulk(data) + except DatabaseError: + await self.create_runs_one_by_one(data) + + await self.create_operations(data) + await self.create_inputs(data) + await self.create_outputs(data) + await self.create_column_lineage(data) + + self.logger.info("Saved successfully") + + async def create_locations(self, data: BatchExtractionResult): + self.logger.debug("Creating locations") + # It's hard to fetch locations in bulk, and number of locations is usually small, + # so using a row-by-row approach + for location_dto in data.locations(): + async with self.unit_of_work: + location = await self.unit_of_work.location.create_or_update(location_dto) + location_dto.id = location.id + + # To avoid deadlocks when parallel consumer instances insert/update the same row, + # commit changes for each row instead of committing the whole batch. Yes, this cloud be slow. + # But most entities are unchanged after creation, so we could just fetch them, and do nothing. + async def create_datasets(self, data: BatchExtractionResult): + self.logger.debug("Creating datasets") + dataset_pairs = await self.unit_of_work.dataset.fetch_bulk(data.datasets()) + for dataset_dto, dataset in dataset_pairs: + if not dataset: + async with self.unit_of_work: + dataset = await self.unit_of_work.dataset.create(dataset_dto) # noqa: PLW2901 + dataset_dto.id = dataset.id + + async def create_dataset_symlinks(self, data: BatchExtractionResult): + self.logger.debug("Creating dataset symlinks") + dataset_symlinks_pairs = await self.unit_of_work.dataset_symlink.fetch_bulk(data.dataset_symlinks()) + for dataset_symlink_dto, dataset_symlink in dataset_symlinks_pairs: + if not dataset_symlink: + async with self.unit_of_work: + dataset_symlink = await self.unit_of_work.dataset_symlink.create(dataset_symlink_dto) # noqa: PLW2901 + dataset_symlink_dto.id = dataset_symlink.id + + async def create_job_types(self, data: BatchExtractionResult): + self.logger.debug("Creating job types") + job_type_pairs = await self.unit_of_work.job_type.fetch_bulk(data.job_types()) + for job_type_dto, job_type in job_type_pairs: + if not job_type: + async with self.unit_of_work: + job_type = await self.unit_of_work.job_type.create(job_type_dto) # noqa: PLW2901 + job_type_dto.id = job_type.id + + async def create_jobs(self, data: BatchExtractionResult): + self.logger.debug("Creating jobs") + job_pairs = await self.unit_of_work.job.fetch_bulk(data.jobs()) + for job_dto, job in job_pairs: + async with self.unit_of_work: + if not job: + job = await self.unit_of_work.job.create_or_update(job_dto) # noqa: PLW2901 + else: + job = await self.unit_of_work.job.update(job, job_dto) # noqa: PLW2901 + job_dto.id = job.id + + async def create_users(self, data: BatchExtractionResult): + self.logger.debug("Creating users") + user_pairs = await self.unit_of_work.user.fetch_bulk(data.users()) + for user_dto, user in user_pairs: + if not user: + async with self.unit_of_work: + user = await self.unit_of_work.user.create(user_dto) # noqa: PLW2901 + user_dto.id = user.id + + async def create_sql_queries(self, data: BatchExtractionResult): + self.logger.debug("Creating sql queries") + sql_query_ids = await self.unit_of_work.sql_query.fetch_known_ids(data.sql_queries()) + for sql_query_dto, sql_query_id in sql_query_ids: + if not sql_query_id: + async with self.unit_of_work: + sql_query = await self.unit_of_work.sql_query.create(sql_query_dto) + sql_query_dto.id = sql_query.id + else: + sql_query_dto.id = sql_query_id + + async def create_schemas(self, data: BatchExtractionResult): + self.logger.debug("Creating schemas") + schema_ids = await self.unit_of_work.schema.fetch_known_ids(data.schemas()) + for schema_dto, schema_id in schema_ids: + if not schema_id: + async with self.unit_of_work: + schema = await self.unit_of_work.schema.create(schema_dto) + schema_dto.id = schema.id + else: + schema_dto.id = schema_id + + # In most cases, all the run tree created by some parent is send into one + # Kafka partition, and thus handled by just one worker. + # Cross fingers and create all runs in one transaction. + async def create_runs_bulk(self, data: BatchExtractionResult): + self.logger.debug("Creating runs in bulk") + async with self.unit_of_work: + await self.unit_of_work.run.create_or_update_bulk(data.runs()) + + # In case then child and parent runs are in different partitions, + # multiple workers may try to create/update the same run, leading to a deadlock. + # Fallback to creating runs one by one + async def create_runs_one_by_one(self, data: BatchExtractionResult): + self.logger.debug("Creating runs in one-by-one") + run_pairs = await self.unit_of_work.run.fetch_bulk(data.runs()) + for run_dto, run in run_pairs: + try: + async with self.unit_of_work: + if not run: + await self.unit_of_work.run.create(run_dto) + else: + await self.unit_of_work.run.update(run, run_dto) + except IntegrityError: # noqa: PERF203 + # deadlock occurred, states in DB and RAM are out of sync, + # so we have to fetch run from DB + async with self.unit_of_work: + await self.unit_of_work.run.create_or_update(run_dto) + + # All events related to same operation are always send to the same Kafka partition, + # so other workers never insert/update the same operation in parallel. + # These rows can be inserted/updated in bulk, in one transaction. + async def create_operations(self, data: BatchExtractionResult): + async with self.unit_of_work: + self.logger.debug("Creating operations") + await self.unit_of_work.operation.create_or_update_bulk(data.operations()) + + async def create_inputs(self, data: BatchExtractionResult): + async with self.unit_of_work: + self.logger.debug("Creating inputs") + await self.unit_of_work.input.create_or_update_bulk(data.inputs()) + + async def create_outputs(self, data: BatchExtractionResult): + async with self.unit_of_work: + self.logger.debug("Creating outputs") + await self.unit_of_work.output.create_or_update_bulk(data.outputs()) + + async def create_column_lineage(self, data: BatchExtractionResult): + async with self.unit_of_work: + self.logger.debug("Creating dataset column relations") + await self.unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(data.column_lineage()) + + self.logger.debug("Creating column lineage") + await self.unit_of_work.column_lineage.create_bulk(data.column_lineage()) diff --git a/data_rentgen/consumer/subscribers.py b/data_rentgen/consumer/subscribers.py index 6e672220..a9f39552 100644 --- a/data_rentgen/consumer/subscribers.py +++ b/data_rentgen/consumer/subscribers.py @@ -14,10 +14,10 @@ from pydantic import TypeAdapter from sqlalchemy.ext.asyncio import AsyncSession -from data_rentgen.consumer.extractors import BatchExtractionResult, BatchExtractor +from data_rentgen.consumer.extractors import BatchExtractor +from data_rentgen.consumer.saver import DatabaseSaver from data_rentgen.dependencies.stub import Stub from data_rentgen.openlineage.run_event import OpenLineageRunEvent -from data_rentgen.services.uow import UnitOfWork __all__ = [ "runs_events_subscriber", @@ -54,9 +54,8 @@ async def runs_events_subscriber( extracted = extractor.result logger.info("Got %r", extracted) - logger.info("Saving to database") - await save_to_db(extracted, session, logger) - logger.info("Saved successfully") + saver = DatabaseSaver(session, logger) + await saver.save(extracted) if malformed: logger.warning("Malformed messages: %d", len(malformed)) @@ -88,96 +87,6 @@ async def parse_messages( await asyncio.sleep(0) -async def save_to_db( - data: BatchExtractionResult, - session: AsyncSession, - logger: Logger, -) -> None: - # To avoid deadlocks when parallel consumer instances insert/update the same row, - # commit changes for each row instead of committing the whole batch. Yes, this cloud be slow. - - unit_of_work = UnitOfWork(session) - - logger.debug("Creating locations") - for location_dto in data.locations(): - async with unit_of_work: - location = await unit_of_work.location.create_or_update(location_dto) - location_dto.id = location.id - - logger.debug("Creating datasets") - for dataset_dto in data.datasets(): - async with unit_of_work: - dataset = await unit_of_work.dataset.get_or_create(dataset_dto) - dataset_dto.id = dataset.id - - logger.debug("Creating symlinks") - for dataset_symlink_dto in data.dataset_symlinks(): - async with unit_of_work: - dataset_symlink = await unit_of_work.dataset_symlink.get_or_create(dataset_symlink_dto) - dataset_symlink_dto.id = dataset_symlink.id - - logger.debug("Creating job types") - for job_type_dto in data.job_types(): - async with unit_of_work: - job_type = await unit_of_work.job_type.get_or_create(job_type_dto) - job_type_dto.id = job_type.id - - logger.debug("Creating jobs") - for job_dto in data.jobs(): - async with unit_of_work: - job = await unit_of_work.job.create_or_update(job_dto) - job_dto.id = job.id - - logger.debug("Creating sql queries") - for sql_query_dto in data.sql_queries(): - async with unit_of_work: - sql_query = await unit_of_work.sql_query.get_or_create(sql_query_dto) - sql_query_dto.id = sql_query.id - - logger.debug("Creating users") - for user_dto in data.users(): - async with unit_of_work: - user = await unit_of_work.user.get_or_create(user_dto) - user_dto.id = user.id - - logger.debug("Creating schemas") - for schema_dto in data.schemas(): - async with unit_of_work: - schema = await unit_of_work.schema.get_or_create(schema_dto) - schema_dto.id = schema.id - - # Some events related to specific run are send to the same Kafka partition, - # but at the same time we have parent_run which may be already inserted/updated by other worker - # (Kafka key maybe different for run and it's parent). - # In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks. - logger.debug("Creating runs") - for run_dto in data.runs(): - async with unit_of_work: - await unit_of_work.run.create_or_update(run_dto) - - # All events related to same operation are always send to the same Kafka partition, - # so other workers never insert/update the same operation in parallel. - # These rows can be inserted/updated in bulk, in one transaction. - async with unit_of_work: - logger.debug("Creating operations") - await unit_of_work.operation.create_or_update_bulk(data.operations()) - - logger.debug("Creating inputs") - await unit_of_work.input.create_or_update_bulk(data.inputs()) - - logger.debug("Creating outputs") - await unit_of_work.output.create_or_update_bulk(data.outputs()) - - # If something went wrong here, at least we will have inputs/outputs - async with unit_of_work: - column_lineage = data.column_lineage() - logger.debug("Creating dataset column relations") - await unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(column_lineage) - - logger.debug("Creating column lineage") - await unit_of_work.column_lineage.create_bulk(column_lineage) - - async def report_malformed( messages: list[ConsumerRecord], message_id: str, diff --git a/data_rentgen/db/repositories/column_lineage.py b/data_rentgen/db/repositories/column_lineage.py index 4fe0a7cb..27b98c30 100644 --- a/data_rentgen/db/repositories/column_lineage.py +++ b/data_rentgen/db/repositories/column_lineage.py @@ -6,7 +6,7 @@ from typing import NamedTuple from uuid import UUID -from sqlalchemy import ColumnElement, any_, func, select, tuple_ +from sqlalchemy import ARRAY, ColumnElement, Integer, any_, cast, func, select, tuple_ from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models import ColumnLineage, DatasetColumnRelation @@ -123,9 +123,20 @@ async def list_by_dataset_pairs( if not dataset_ids_pairs: return [] + source_dataset_ids = [pair[0] for pair in dataset_ids_pairs] + target_dataset_ids = [pair[1] for pair in dataset_ids_pairs] + pairs = ( + func.unnest( + cast(source_dataset_ids, ARRAY(Integer())), + cast(target_dataset_ids, ARRAY(Integer())), + ) + .table_valued("source_dataset_id", "target_dataset_id") + .render_derived() + ) + where = [ ColumnLineage.created_at >= since, - tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(dataset_ids_pairs), + tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(select(pairs)), ] if until: where.append( diff --git a/data_rentgen/db/repositories/dataset.py b/data_rentgen/db/repositories/dataset.py index 0fb08dce..8f38e81c 100644 --- a/data_rentgen/db/repositories/dataset.py +++ b/data_rentgen/db/repositories/dataset.py @@ -3,17 +3,22 @@ from collections.abc import Collection from sqlalchemy import ( + ARRAY, ColumnElement, CompoundSelect, + Integer, Row, Select, SQLColumnExpression, + String, any_, asc, + cast, desc, distinct, func, select, + tuple_, union, ) from sqlalchemy.orm import selectinload @@ -25,18 +30,36 @@ class DatasetRepository(Repository[Dataset]): - async def get_or_create(self, dataset: DatasetDTO) -> Dataset: - result = await self._get(dataset) + async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]: + if not datasets_dto: + return [] - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(dataset.location.id, dataset.name) - result = await self._get(dataset) + location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto] + names = [dataset_dto.name.lower() for dataset_dto in datasets_dto] + pairs = ( + func.unnest( + cast(location_ids, ARRAY(Integer())), + cast(names, ARRAY(String())), + ) + .table_valued("location_id", "name") + .render_derived() + ) - if not result: - return await self._create(dataset) - return result + statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs))) + scalars = await self._session.scalars(statement) + existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()} + return [ + ( + dto, + existing.get((dto.location.id, dto.name.lower())), # type: ignore[arg-type] + ) + for dto in datasets_dto + ] + + async def create(self, dataset: DatasetDTO) -> Dataset: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(dataset.location.id, dataset.name.lower()) + return await self._get(dataset) or await self._create(dataset) async def paginate( self, diff --git a/data_rentgen/db/repositories/dataset_symlink.py b/data_rentgen/db/repositories/dataset_symlink.py index 1c6810b8..4a8a6f45 100644 --- a/data_rentgen/db/repositories/dataset_symlink.py +++ b/data_rentgen/db/repositories/dataset_symlink.py @@ -3,7 +3,7 @@ from collections.abc import Collection -from sqlalchemy import BindParameter, any_, bindparam, or_, select +from sqlalchemy import ARRAY, BindParameter, Integer, any_, bindparam, cast, func, or_, select, tuple_ from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType from data_rentgen.db.repositories.base import Repository @@ -11,14 +11,42 @@ class DatasetSymlinkRepository(Repository[DatasetSymlink]): - async def get_or_create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: - result = await self._get(dataset_symlink) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(dataset_symlink.from_dataset.id, dataset_symlink.to_dataset.id) - result = await self._get(dataset_symlink) or await self._create(dataset_symlink) - return result + async def fetch_bulk( + self, + dataset_symlinks_dto: list[DatasetSymlinkDTO], + ) -> list[tuple[DatasetSymlinkDTO, DatasetSymlink | None]]: + if not dataset_symlinks_dto: + return [] + + from_dataset_ids = [dataset_symlink_dto.from_dataset.id for dataset_symlink_dto in dataset_symlinks_dto] + to_dataset_ids = [dataset_symlink_dto.to_dataset.id for dataset_symlink_dto in dataset_symlinks_dto] + + pairs = ( + func.unnest( + cast(from_dataset_ids, ARRAY(Integer())), + cast(to_dataset_ids, ARRAY(Integer())), + ) + .table_valued("from_dataset_ids", "to_dataset_ids") + .render_derived() + ) + + statement = select(DatasetSymlink).where( + tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(select(pairs)), + ) + scalars = await self._session.scalars(statement) + existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()} + return [ + ( + dto, + existing.get((dto.from_dataset.id, dto.to_dataset.id)), # type: ignore[arg-type] + ) + for dto in dataset_symlinks_dto + ] + + async def create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(dataset_symlink.from_dataset.id, dataset_symlink.to_dataset.id) + return await self._get(dataset_symlink) or await self._create(dataset_symlink) async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[DatasetSymlink]: if not dataset_ids: diff --git a/data_rentgen/db/repositories/job.py b/data_rentgen/db/repositories/job.py index 38ae7c5f..66163080 100644 --- a/data_rentgen/db/repositories/job.py +++ b/data_rentgen/db/repositories/job.py @@ -3,16 +3,21 @@ from collections.abc import Collection from sqlalchemy import ( + ARRAY, ColumnElement, CompoundSelect, + Integer, Row, Select, SQLColumnExpression, + String, any_, asc, + cast, desc, func, select, + tuple_, union, ) from sqlalchemy.orm import selectinload @@ -81,17 +86,63 @@ async def paginate( page_size=page_size, ) + async def fetch_bulk(self, jobs_dto: list[JobDTO]) -> list[tuple[JobDTO, Job | None]]: + if not jobs_dto: + return [] + + location_ids = [job_dto.location.id for job_dto in jobs_dto] + names = [job_dto.name.lower() for job_dto in jobs_dto] + pairs = ( + func.unnest( + cast(location_ids, ARRAY(Integer())), + cast(names, ARRAY(String())), + ) + .table_valued("location_id", "name") + .render_derived() + ) + + statement = select(Job).where(tuple_(Job.location_id, func.lower(Job.name)).in_(select(pairs))) + scalars = await self._session.scalars(statement) + existing = {(job.location_id, job.name.lower()): job for job in scalars.all()} + return [ + ( + job_dto, + existing.get((job_dto.location.id, job_dto.name.lower())), # type: ignore[arg-type] + ) + for job_dto in jobs_dto + ] + async def create_or_update(self, job: JobDTO) -> Job: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(job.location.id, job.name.lower()) result = await self._get(job) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(job.location.id, job.name) - result = await self._get(job) - if not result: return await self._create(job) - return await self._update(result, job) + return await self.update(result, job) + + async def _get(self, job: JobDTO) -> Job | None: + statement = select(Job).where( + Job.location_id == job.location.id, + func.lower(Job.name) == job.name.lower(), + ) + return await self._session.scalar(statement) + + async def _create(self, job: JobDTO) -> Job: + result = Job( + location_id=job.location.id, + name=job.name, + type_id=job.type.id if job.type else UNKNOWN_JOB_TYPE, + ) + self._session.add(result) + await self._session.flush([result]) + return result + + async def update(self, existing: Job, new: JobDTO) -> Job: + # almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged + if new.type and new.type.id and existing.type_id != new.type.id: + existing.type_id = new.type.id + await self._session.flush([existing]) + return existing async def list_by_ids(self, job_ids: Collection[int]) -> list[Job]: if not job_ids: @@ -121,27 +172,3 @@ async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict query_result = await self._session.execute(query) return {row.location_id: row for row in query_result.all()} - - async def _get(self, job: JobDTO) -> Job | None: - statement = select(Job).where( - Job.location_id == job.location.id, - func.lower(Job.name) == job.name.lower(), - ) - return await self._session.scalar(statement) - - async def _create(self, job: JobDTO) -> Job: - result = Job( - location_id=job.location.id, - name=job.name, - type_id=job.type.id if job.type else UNKNOWN_JOB_TYPE, - ) - self._session.add(result) - await self._session.flush([result]) - return result - - async def _update(self, existing: Job, new: JobDTO) -> Job: - # almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged - if new.type and new.type.id: - existing.type_id = new.type.id - await self._session.flush([existing]) - return existing diff --git a/data_rentgen/db/repositories/job_type.py b/data_rentgen/db/repositories/job_type.py index 451285d4..c74c4dcd 100644 --- a/data_rentgen/db/repositories/job_type.py +++ b/data_rentgen/db/repositories/job_type.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from sqlalchemy import ( + any_, select, ) @@ -11,14 +12,19 @@ class JobTypeRepository(Repository[JobType]): - async def get_or_create(self, job_type_dto: JobTypeDTO) -> JobType: - result = await self._get(job_type_dto) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock("job_type", job_type_dto.type) - result = await self._get(job_type_dto) or await self._create(job_type_dto) - return result + async def fetch_bulk(self, job_types_dto: list[JobTypeDTO]) -> list[tuple[JobTypeDTO, JobType | None]]: + unique_keys = [job_type_dto.type for job_type_dto in job_types_dto] + statement = select(JobType).where( + JobType.type == any_(unique_keys), # type: ignore[arg-type] + ) + scalars = await self._session.scalars(statement) + existing = {job.type: job for job in scalars.all()} + return [(job_type_dto, existing.get(job_type_dto.type)) for job_type_dto in job_types_dto] + + async def create(self, job_type_dto: JobTypeDTO) -> JobType: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock("job_type", job_type_dto.type) + return await self._get(job_type_dto) or await self._create(job_type_dto) async def _get(self, job_type_dto: JobTypeDTO) -> JobType | None: query = select(JobType).where(JobType.type == job_type_dto.type) diff --git a/data_rentgen/db/repositories/location.py b/data_rentgen/db/repositories/location.py index 5b522256..4cd907c5 100644 --- a/data_rentgen/db/repositories/location.py +++ b/data_rentgen/db/repositories/location.py @@ -24,20 +24,6 @@ class LocationRepository(Repository[Location]): - async def create_or_update(self, location: LocationDTO) -> Location: - result = await self._get(location) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(location.type, location.name) - result = await self._get(location) - - if not result: - result = await self._create(location) - - await self._update_addresses(result, location) - return result - async def paginate( self, page: int, @@ -101,6 +87,20 @@ async def update_external_id(self, location_id: int, external_id: str | None) -> await self._session.flush([location]) return location + async def create_or_update(self, location: LocationDTO) -> Location: + result = await self._get(location) + if not result: + # try one more time, but with lock acquired. + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(location.type, location.name) + result = await self._get(location) + + if not result: + result = await self._create(location) + + await self._update_addresses(result, location) + return result + async def _get(self, location: LocationDTO) -> Location | None: by_name = select(Location).where(Location.type == location.type, Location.name == location.name) by_addresses = ( @@ -114,7 +114,6 @@ async def _get(self, location: LocationDTO) -> Location | None: statement = ( select(Location).from_statement(by_name.union(by_addresses)).options(selectinload(Location.addresses)) ) - return await self._session.scalar(statement) async def _create(self, location: LocationDTO) -> Location: diff --git a/data_rentgen/db/repositories/run.py b/data_rentgen/db/repositories/run.py index 6b465bb6..d2cffac0 100644 --- a/data_rentgen/db/repositories/run.py +++ b/data_rentgen/db/repositories/run.py @@ -27,18 +27,6 @@ class RunRepository(Repository[Run]): - async def create_or_update(self, run: RunDTO) -> Run: - result = await self._get(run) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(run.id) - result = await self._get(run) - - if not result: - return await self._create(run) - return await self._update(result, run) - async def paginate( self, page: int, @@ -147,11 +135,30 @@ async def list_by_job_ids(self, job_ids: Collection[int], since: datetime, until result = await self._session.scalars(query) return list(result.all()) + async def fetch_bulk(self, runs_dto: list[RunDTO]) -> list[tuple[RunDTO, Run | None]]: + if not runs_dto: + return [] + ids = [run_dto.id for run_dto in runs_dto] + min_created_at = extract_timestamp_from_uuid(min(ids)) + statement = select(Run).where( + Run.created_at >= min_created_at, + Run.id == any_(ids), # type: ignore[arg-type] + ) + scalars = await self._session.scalars(statement) + existing = {run.id: run for run in scalars.all()} + return [(run_dto, existing.get(run_dto.id)) for run_dto in runs_dto] + + async def create_or_update(self, run: RunDTO) -> Run: + result = await self._get(run) + if not result: + return await self.create(run) + return await self.update(result, run) + async def _get(self, run: RunDTO) -> Run | None: query = select(Run).where(Run.id == run.id, Run.created_at == run.created_at) return await self._session.scalar(query) - async def _create(self, run: RunDTO) -> Run: + async def create(self, run: RunDTO) -> Run: result = Run( created_at=run.created_at, id=run.id, @@ -171,7 +178,7 @@ async def _create(self, run: RunDTO) -> Run: await self._session.flush([result]) return result - async def _update( + async def update( self, existing: Run, new: RunDTO, diff --git a/data_rentgen/db/repositories/schema.py b/data_rentgen/db/repositories/schema.py index 68b6974a..b45acfa8 100644 --- a/data_rentgen/db/repositories/schema.py +++ b/data_rentgen/db/repositories/schema.py @@ -11,15 +11,6 @@ class SchemaRepository(Repository[Schema]): - async def get_or_create(self, schema: SchemaDTO) -> Schema: - result = await self._get(schema) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(schema.digest) - result = await self._get(schema) or await self._create(schema) - return result - async def list_by_ids(self, schema_ids: Collection[int]) -> list[Schema]: if not schema_ids: return [] @@ -28,6 +19,30 @@ async def list_by_ids(self, schema_ids: Collection[int]) -> list[Schema]: result = await self._session.scalars(query) return list(result.all()) + async def fetch_known_ids(self, schemas_dto: list[SchemaDTO]) -> list[tuple[SchemaDTO, int | None]]: + if not schemas_dto: + return [] + + unique_digests = [schema_dto.digest for schema_dto in schemas_dto] + # schema JSON can be heavy, avoid loading it if not needed + statement = select(Schema.digest, Schema.id).where( + Schema.digest == any_(unique_digests), # type: ignore[arg-type] + ) + scalars = await self._session.execute(statement) + known_ids = {item.digest: item.id for item in scalars.all()} + return [ + ( + schema_dto, + known_ids.get(schema_dto.digest), + ) + for schema_dto in schemas_dto + ] + + async def create(self, schema: SchemaDTO) -> Schema: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(schema.digest) + return await self._get(schema) or await self._create(schema) + async def _get(self, schema: SchemaDTO) -> Schema | None: result = select(Schema).where(Schema.digest == schema.digest) return await self._session.scalar(result) diff --git a/data_rentgen/db/repositories/sql_query.py b/data_rentgen/db/repositories/sql_query.py index 1c2ac985..dddf5a67 100644 --- a/data_rentgen/db/repositories/sql_query.py +++ b/data_rentgen/db/repositories/sql_query.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: 2024-2025 MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from sqlalchemy import select + +from sqlalchemy import any_, select from data_rentgen.db.models.sql_query import SQLQuery from data_rentgen.db.repositories.base import Repository @@ -9,14 +10,29 @@ class SQLQueryRepository(Repository[SQLQuery]): - async def get_or_create(self, sql_query: SQLQueryDTO) -> SQLQuery: - result = await self._get(sql_query) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(sql_query.fingerprint) - result = await self._get(sql_query) or await self._create(sql_query) - return result + async def fetch_known_ids(self, sql_queries_dto: list[SQLQueryDTO]) -> list[tuple[SQLQueryDTO, int | None]]: + if not sql_queries_dto: + return [] + + unique_fingerprints = [sql_query_dto.fingerprint for sql_query_dto in sql_queries_dto] + # query text can be heavy, avoid loading it if not needed + statement = select(SQLQuery.fingerprint, SQLQuery.id).where( + SQLQuery.fingerprint == any_(unique_fingerprints), # type: ignore[arg-type] + ) + scalars = await self._session.execute(statement) + known_ids = {item.fingerprint: item.id for item in scalars.all()} + return [ + ( + sql_query_dto, + known_ids.get(sql_query_dto.fingerprint), + ) + for sql_query_dto in sql_queries_dto + ] + + async def create(self, sql_query: SQLQueryDTO) -> SQLQuery: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(sql_query.fingerprint) + return await self._get(sql_query) or await self._create(sql_query) async def _get(self, sql_query: SQLQueryDTO) -> SQLQuery | None: result = select(SQLQuery).where(SQLQuery.fingerprint == sql_query.fingerprint) diff --git a/data_rentgen/db/repositories/user.py b/data_rentgen/db/repositories/user.py index 16b172b8..5064d973 100644 --- a/data_rentgen/db/repositories/user.py +++ b/data_rentgen/db/repositories/user.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2024-2025 MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from sqlalchemy import func, select +from sqlalchemy import any_, func, select from data_rentgen.db.models import User from data_rentgen.db.repositories.base import Repository @@ -9,16 +9,24 @@ class UserRepository(Repository[User]): - async def get_or_create(self, user: UserDTO) -> User: - result = await self._get(user.name) - if not result: - await self._lock(user.name) - result = await self._get(user.name) or await self._create(user) - return result + async def fetch_bulk(self, users_dto: list[UserDTO]) -> list[tuple[UserDTO, User | None]]: + if not users_dto: + return [] + unique_keys = [user_dto.name.lower() for user_dto in users_dto] + statement = select(User).where( + func.lower(User.name) == any_(unique_keys), # type: ignore[arg-type] + ) + scalars = await self._session.scalars(statement) + existing = {user.name.lower(): user for user in scalars.all()} + return [(user_dto, existing.get(user_dto.name.lower())) for user_dto in users_dto] - async def read_by_id(self, id_: int) -> User | None: - statement = select(User).where(User.id == id_) - return await self._session.scalar(statement) + async def create(self, user_dto: UserDTO) -> User: + # if another worker already created the same row, just use it. if not - create with holding the lock. + await self._lock(user_dto.name) + return await self.get_or_create(user_dto) + + async def get_or_create(self, user_dto: UserDTO) -> User: + return await self._get(user_dto.name) or await self._create(user_dto) async def _get(self, name: str) -> User | None: statement = select(User).where(func.lower(User.name) == name.lower()) @@ -29,3 +37,7 @@ async def _create(self, user: UserDTO) -> User: self._session.add(result) await self._session.flush([result]) return result + + async def read_by_id(self, id_: int) -> User | None: + statement = select(User).where(User.id == id_) + return await self._session.scalar(statement) diff --git a/data_rentgen/db/scripts/seed/__main__.py b/data_rentgen/db/scripts/seed/__main__.py index 99e79b0d..cda8983d 100755 --- a/data_rentgen/db/scripts/seed/__main__.py +++ b/data_rentgen/db/scripts/seed/__main__.py @@ -14,11 +14,11 @@ from faker import Faker from data_rentgen.consumer.extractors import BatchExtractionResult +from data_rentgen.consumer.saver import DatabaseSaver from data_rentgen.db.factory import create_session_factory from data_rentgen.db.scripts.seed.dbt import generate_dbt_run from data_rentgen.db.scripts.seed.flink import generate_flink_run from data_rentgen.db.scripts.seed.hive import generate_hive_run -from data_rentgen.db.scripts.seed.save import save_to_db from data_rentgen.db.scripts.seed.spark_local import generate_spark_run_local from data_rentgen.db.scripts.seed.spark_yarn import generate_spark_run_yarn from data_rentgen.db.settings import DatabaseSettings @@ -110,8 +110,8 @@ async def main(args: list[str]) -> None: db_settings = DatabaseSettings() # type: ignore[call-arg] session_factory = create_session_factory(db_settings) async with session_factory() as session: - await save_to_db(result, session, logger=logger) - await session.commit() + saver = DatabaseSaver(session, logger) + await saver.save(result) logger.info(" Done!") diff --git a/data_rentgen/db/scripts/seed/save.py b/data_rentgen/db/scripts/seed/save.py deleted file mode 100644 index d64e8ed8..00000000 --- a/data_rentgen/db/scripts/seed/save.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 MTS PJSC -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -from faststream import Logger -from sqlalchemy.ext.asyncio import AsyncSession - -from data_rentgen.consumer.extractors import BatchExtractionResult -from data_rentgen.services.uow import UnitOfWork - - -async def save_to_db( - data: BatchExtractionResult, - session: AsyncSession, - logger: Logger, -) -> None: - """Save data to database. - - This is different from consumer's method because here we generate random data (unique), - and there are no writers we conflict with. So we can use one transaction + bulk insert for runs. - """ - async with UnitOfWork(session) as unit_of_work: - logger.debug("Creating locations") - for location_dto in data.locations(): - location = await unit_of_work.location.create_or_update(location_dto) - location_dto.id = location.id - - logger.debug("Creating datasets") - for dataset_dto in data.datasets(): - dataset = await unit_of_work.dataset.get_or_create(dataset_dto) - dataset_dto.id = dataset.id - - logger.debug("Creating symlinks") - for dataset_symlink_dto in data.dataset_symlinks(): - dataset_symlink = await unit_of_work.dataset_symlink.get_or_create(dataset_symlink_dto) - dataset_symlink_dto.id = dataset_symlink.id - - logger.debug("Creating job types") - for job_type_dto in data.job_types(): - job_type = await unit_of_work.job_type.get_or_create(job_type_dto) - job_type_dto.id = job_type.id - - logger.debug("Creating jobs") - for job_dto in data.jobs(): - job = await unit_of_work.job.create_or_update(job_dto) - job_dto.id = job.id - - logger.debug("Creating sql queries") - for sql_query_dto in data.sql_queries(): - sql_query = await unit_of_work.sql_query.get_or_create(sql_query_dto) - sql_query_dto.id = sql_query.id - - logger.debug("Creating users") - for user_dto in data.users(): - user = await unit_of_work.user.get_or_create(user_dto) - user_dto.id = user.id - - logger.debug("Creating schemas") - for schema_dto in data.schemas(): - schema = await unit_of_work.schema.get_or_create(schema_dto) - schema_dto.id = schema.id - - logger.debug("Creating runs") - await unit_of_work.run.create_or_update_bulk(data.runs()) - - logger.debug("Creating operations") - await unit_of_work.operation.create_or_update_bulk(data.operations()) - - logger.debug("Creating inputs") - await unit_of_work.input.create_or_update_bulk(data.inputs()) - - logger.debug("Creating outputs") - await unit_of_work.output.create_or_update_bulk(data.outputs()) - - column_lineage = data.column_lineage() - logger.debug("Creating dataset column relations") - await unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(column_lineage) - - logger.debug("Creating column lineage") - await unit_of_work.column_lineage.create_bulk(column_lineage) diff --git a/docs/changelog/next_release/314.improvement.rst b/docs/changelog/next_release/314.improvement.rst new file mode 100644 index 00000000..1944382e --- /dev/null +++ b/docs/changelog/next_release/314.improvement.rst @@ -0,0 +1 @@ +Improve consumer performance by reducing DB load on reading operations.