|
33 | 33 | import fastapi |
34 | 34 | import pydantic |
35 | 35 | from aiobotocore import session |
36 | | -from pydantic import field_validator |
37 | 36 | from fastapi import FastAPI |
| 37 | +from pydantic import field_validator |
38 | 38 | from pydantic_settings import BaseSettings |
39 | 39 | from tortoise import functions, transactions |
40 | 40 | from tortoise.contrib.fastapi import RegisterTortoise |
@@ -205,7 +205,40 @@ def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: |
205 | 205 | return inverted_str + "-" + timestamp.isoformat() |
206 | 206 |
|
207 | 207 |
|
208 | | -class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin, EventDrivenBackendMixin): |
| 208 | +def _parse_sqs_message_events( |
| 209 | + body: dict, |
| 210 | +) -> Optional[List[Tuple[str, datetime.datetime]]]: |
| 211 | + """Parse EventBridge-wrapped or native S3 notification bodies from SQS. |
| 212 | +
|
| 213 | + Returns None if the format is not recognized. Multiple S3 records in one |
| 214 | + message yield one tuple per record. |
| 215 | + """ |
| 216 | + if "detail" in body: |
| 217 | + return [ |
| 218 | + ( |
| 219 | + body["detail"]["object"]["key"], |
| 220 | + datetime.datetime.fromisoformat(body["time"].replace("Z", "+00:00")), |
| 221 | + ) |
| 222 | + ] |
| 223 | + if "Records" in body: |
| 224 | + out: List[Tuple[str, datetime.datetime]] = [] |
| 225 | + for record in body["Records"]: |
| 226 | + out.append( |
| 227 | + ( |
| 228 | + record["s3"]["object"]["key"], |
| 229 | + datetime.datetime.fromisoformat(record["eventTime"].replace("Z", "+00:00")), |
| 230 | + ) |
| 231 | + ) |
| 232 | + return out |
| 233 | + return None |
| 234 | + |
| 235 | + |
| 236 | +class SQLiteS3Backend( |
| 237 | + BackendBase, |
| 238 | + IndexingBackendMixin, |
| 239 | + SnapshottingBackendMixin, |
| 240 | + EventDrivenBackendMixin, |
| 241 | +): |
209 | 242 | def __init__( |
210 | 243 | self, |
211 | 244 | s3_bucket: str, |
@@ -790,27 +823,14 @@ async def start_event_consumer(self) -> None: |
790 | 823 | for message in messages: |
791 | 824 | try: |
792 | 825 | body = json.loads(message["Body"]) |
793 | | - s3_key = None |
794 | | - event_time = None |
795 | | - |
796 | | - # Handle EventBridge wrapped S3 events |
797 | | - if "detail" in body: |
798 | | - s3_key = body["detail"]["object"]["key"] |
799 | | - event_time = datetime.datetime.fromisoformat( |
800 | | - body["time"].replace("Z", "+00:00") |
801 | | - ) |
802 | | - elif "Records" in body: |
803 | | - record = body["Records"][0] |
804 | | - s3_key = record["s3"]["object"]["key"] |
805 | | - event_time = datetime.datetime.fromisoformat( |
806 | | - record["eventTime"].replace("Z", "+00:00") |
807 | | - ) |
808 | | - else: |
809 | | - logger.warning(f"Unknown message format: {body}") |
| 826 | + events = _parse_sqs_message_events(body) |
| 827 | + if events is None: |
| 828 | + logger.warning("Unknown message format: %s", body) |
810 | 829 | continue |
811 | 830 |
|
812 | | - if s3_key and s3_key.endswith(".jsonl"): |
813 | | - await self._handle_s3_event(s3_key, event_time) |
| 831 | + for s3_key, event_time in events: |
| 832 | + if s3_key and s3_key.endswith(".jsonl"): |
| 833 | + await self._handle_s3_event(s3_key, event_time) |
814 | 834 |
|
815 | 835 | await sqs_client.delete_message( |
816 | 836 | QueueUrl=self._sqs_queue_url, |
@@ -865,7 +885,6 @@ async def indexing_jobs( |
865 | 885 |
|
866 | 886 | if __name__ == "__main__": |
867 | 887 | os.environ["BURR_LOAD_SNAPSHOT_ON_START"] = "True" |
868 | | - import asyncio |
869 | 888 |
|
870 | 889 | be = SQLiteS3Backend.from_settings(S3Settings()) |
871 | 890 | # coro = be.snapshot() # save to s3 |
|
0 commit comments