|
3 | 3 |
|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | | -from faststream import Depends, Logger |
| 6 | +import asyncio |
| 7 | +from typing import cast |
| 8 | + |
| 9 | +from aiokafka import ConsumerRecord |
| 10 | +from faststream import Depends, Logger, NoCast |
| 11 | +from faststream.kafka import KafkaMessage |
| 12 | +from faststream.kafka.publisher.asyncapi import AsyncAPIDefaultPublisher |
| 13 | +from pydantic import TypeAdapter |
7 | 14 | from sqlalchemy.ext.asyncio import AsyncSession |
8 | 15 |
|
9 | | -from data_rentgen.consumer.extractors import BatchExtractionResult, extract_batch |
| 16 | +from data_rentgen.consumer.extractors import BatchExtractionResult, BatchExtractor |
10 | 17 | from data_rentgen.consumer.openlineage.run_event import OpenLineageRunEvent |
11 | 18 | from data_rentgen.dependencies import Stub |
12 | 19 | from data_rentgen.services.uow import UnitOfWork |
|
15 | 22 | "runs_events_subscriber", |
16 | 23 | ] |
17 | 24 |
|
18 | | - |
19 | | -def get_unit_of_work(session: AsyncSession = Depends(Stub(AsyncSession))) -> UnitOfWork: |
20 | | - return UnitOfWork(session) |
| 25 | +OpenLineageRunEventAdapter = TypeAdapter(OpenLineageRunEvent) |
21 | 26 |
|
22 | 27 |
|
23 | 28 | async def runs_events_subscriber( |
24 | | - events: list[OpenLineageRunEvent], |
| 29 | + _events: NoCast[list[OpenLineageRunEvent]], |
| 30 | + batch: KafkaMessage, |
25 | 31 | logger: Logger, |
26 | | - unit_of_work: UnitOfWork = Depends(get_unit_of_work), |
| 32 | + publisher: AsyncAPIDefaultPublisher = Depends(Stub(AsyncAPIDefaultPublisher)), |
| 33 | + session: AsyncSession = Depends(Stub(AsyncSession)), |
27 | 34 | ): |
28 | | - logger.info("Got %d events", len(events)) |
29 | | - extracted = extract_batch(events) |
30 | | - logger.info("Extracted: %r", extracted) |
| 35 | + logger.info("Extracting events") |
| 36 | + parsed, malformed = await extract_events(batch, logger) |
31 | 37 |
|
32 | 38 | logger.info("Saving to database") |
33 | | - await save_to_db(extracted, unit_of_work, logger) |
| 39 | + await save_to_db(parsed, session, logger) |
34 | 40 | logger.info("Saved successfully") |
35 | 41 |
|
| 42 | + if malformed: |
| 43 | + logger.warning("Malformed messages: %d", len(malformed)) |
| 44 | + await report_malformed(batch, malformed, publisher) |
| 45 | + |
| 46 | + |
| 47 | +async def extract_events( |
| 48 | + raw_data: KafkaMessage, |
| 49 | + logger: Logger, |
| 50 | + await_every: int = 50, |
| 51 | +) -> tuple[BatchExtractionResult, list[ConsumerRecord]]: |
| 52 | + messages = cast(tuple[ConsumerRecord], raw_data.raw_message) # https://github.com/airtai/faststream/issues/2102 |
| 53 | + total_bytes = sum(len(message.value or "") for message in messages) |
| 54 | + logger.info("Got %d messages (%dKiB)", len(messages), total_bytes / 1024) |
| 55 | + |
| 56 | + extractor = BatchExtractor() |
| 57 | + malformed: list[ConsumerRecord] = [] |
| 58 | + |
| 59 | + for i, message in enumerate(messages): |
| 60 | + try: |
| 61 | + if message.value is None: |
| 62 | + msg = "Message value cannot be empty" |
| 63 | + raise ValueError(msg) # noqa: TRY301 |
| 64 | + |
| 65 | + event = OpenLineageRunEventAdapter.validate_json(message.value) |
| 66 | + extractor.add_events([event]) |
| 67 | + except (ValueError, TypeError): |
| 68 | + logger.error( # noqa: TRY400 |
| 69 | + "Failed to parse message: ConsumerRecord(topic=%r, partition=%d, offset=%d)", |
| 70 | + message.topic, |
| 71 | + message.partition, |
| 72 | + message.offset, |
| 73 | + ) |
| 74 | + malformed.append(message) |
| 75 | + |
| 76 | + if await_every and i >= await_every and i % await_every == 0: |
| 77 | + # OpenLineage models are heavy, parsing is CPU bound task which may take some time. |
| 78 | + # Blocking event loop is not a good idea, so we need to await sometimes, |
| 79 | + await asyncio.sleep(0) |
| 80 | + |
| 81 | + return extractor.result, malformed |
| 82 | + |
36 | 83 |
|
37 | 84 | async def save_to_db( |
38 | 85 | data: BatchExtractionResult, |
39 | | - unit_of_work: UnitOfWork, |
| 86 | + session: AsyncSession, |
40 | 87 | logger: Logger, |
41 | 88 | ) -> None: |
42 | 89 | # To avoid deadlocks when parallel consumer instances insert/update the same row, |
43 | 90 | # commit changes for each row instead of committing the whole batch. Yes, this cloud be slow. |
44 | 91 |
|
| 92 | + unit_of_work = UnitOfWork(session) |
| 93 | + |
45 | 94 | logger.debug("Creating locations") |
46 | 95 | for location_dto in data.locations(): |
47 | 96 | async with unit_of_work: |
@@ -108,3 +157,25 @@ async def save_to_db( |
108 | 157 |
|
109 | 158 | logger.debug("Creating column lineage") |
110 | 159 | await unit_of_work.column_lineage.create_bulk(column_lineage) |
| 160 | + |
| 161 | + |
| 162 | +async def report_malformed( |
| 163 | + batch: KafkaMessage, |
| 164 | + messages: list[ConsumerRecord], |
| 165 | + publisher: AsyncAPIDefaultPublisher, |
| 166 | +): |
| 167 | + # Return malformed messages back to the broker |
| 168 | + for message in messages: |
| 169 | + headers: dict[str, str] = {} |
| 170 | + if message.headers: |
| 171 | + headers = {key: value.decode("utf-8") for key, value in message.headers} |
| 172 | + |
| 173 | + await publisher.publish( |
| 174 | + message.value, |
| 175 | + key=message.key, |
| 176 | + partition=message.partition, |
| 177 | + timestamp_ms=message.timestamp, |
| 178 | + headers=headers or None, |
| 179 | + reply_to=batch.message_id, |
| 180 | + correlation_id=batch.correlation_id, |
| 181 | + ) |
0 commit comments