Skip to content

Commit 1e6ed9e

Browse files
committed
chore: refactor stac-loader handler into component functions
1 parent 85efcd0 commit 1e6ed9e

4 files changed

Lines changed: 292 additions & 160 deletions

File tree

lib/stac-loader/runtime/src/stac_loader/handler.py

Lines changed: 217 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
from collections import defaultdict
6+
from datetime import datetime
67
from typing import (
78
TYPE_CHECKING,
89
Annotated,
@@ -157,6 +158,211 @@ def process_s3_event(message_str: str) -> Dict[str, Any]:
157158
raise
158159

159160

161+
def parse_message_data(message_id: str, message_str: str) -> Dict[str, Any]:
162+
"""Parse message data, handling both S3 events and direct STAC JSON."""
163+
if is_s3_event(message_str):
164+
logger.debug(f"[{message_id}] Processing S3 event notification")
165+
return process_s3_event(message_str)
166+
else:
167+
return json.loads(message_str)
168+
169+
170+
def store_item_if_newer(
171+
items_by_collection: DefaultDict[
172+
str, Dict[str, Tuple[Dict[str, Any], str, datetime]]
173+
],
174+
item: Item,
175+
message_id: str,
176+
sns_timestamp: datetime,
177+
) -> None:
178+
"""Store item if it's newer than existing version."""
179+
if not item.collection:
180+
raise KeyError(f"item {item.id} is missing a collection id!")
181+
182+
existing = items_by_collection[item.collection].get(item.id)
183+
if existing is None or sns_timestamp > existing[2]:
184+
if existing:
185+
logger.debug(
186+
f"[{message_id}] Replacing older version of item {item.id} "
187+
f"(old timestamp: {existing[2]}, new timestamp: {sns_timestamp})"
188+
)
189+
items_by_collection[item.collection][item.id] = (
190+
item.model_dump(mode="json"),
191+
message_id,
192+
sns_timestamp,
193+
)
194+
else:
195+
logger.debug(
196+
f"[{message_id}] Skipping older version of item {item.id} "
197+
f"(existing timestamp: {existing[2]}, message timestamp: {sns_timestamp})"
198+
)
199+
200+
201+
def store_collection_if_newer(
202+
collections_dict: DefaultDict[str, Tuple[Dict[str, Any], str, datetime]],
203+
collection: Collection,
204+
message_id: str,
205+
sns_timestamp: datetime,
206+
) -> None:
207+
"""Store collection if it's newer than existing version."""
208+
existing = collections_dict.get(collection.id)
209+
if existing is None or sns_timestamp > existing[2]:
210+
if existing:
211+
logger.debug(
212+
f"[{message_id}] Replacing older version of collection {collection.id} "
213+
f"(old timestamp: {existing[2]}, new timestamp: {sns_timestamp})"
214+
)
215+
collections_dict[collection.id] = (
216+
collection.model_dump(mode="json"),
217+
message_id,
218+
sns_timestamp,
219+
)
220+
else:
221+
logger.debug(
222+
f"[{message_id}] Skipping older version of collection {collection.id} "
223+
f"(existing timestamp: {existing[2]}, message timestamp: {sns_timestamp})"
224+
)
225+
226+
227+
def process_record(
228+
record: Dict[str, Any],
229+
collections_dict: DefaultDict[str, Tuple[Dict[str, Any], str, datetime]],
230+
items_by_collection: DefaultDict[
231+
str, Dict[str, Tuple[Dict[str, Any], str, datetime]]
232+
],
233+
) -> Optional[BatchItemFailure]:
234+
"""Process a single SQS record and return failure if processing fails."""
235+
message_id = record.get("messageId")
236+
if not message_id:
237+
logger.warning("Record missing messageId, cannot report failure for it.")
238+
return None
239+
240+
try:
241+
sqs_body_str = record["body"]
242+
logger.debug(f"[{message_id}] SQS message body: {sqs_body_str}")
243+
sns_notification = json.loads(sqs_body_str)
244+
245+
message_str = sns_notification["Message"]
246+
logger.debug(f"[{message_id}] SNS Message content: {message_str}")
247+
248+
sns_timestamp_str = sns_notification["Timestamp"]
249+
sns_timestamp = datetime.fromisoformat(sns_timestamp_str.replace("Z", "+00:00"))
250+
logger.debug(f"[{message_id}] SNS Timestamp: {sns_timestamp}")
251+
252+
message_data = parse_message_data(message_id, message_str)
253+
254+
if message_data["type"] == "Feature":
255+
item = Item(**message_data)
256+
store_item_if_newer(items_by_collection, item, message_id, sns_timestamp)
257+
elif message_data["type"] == "Collection":
258+
collection = Collection(**message_data)
259+
store_collection_if_newer(
260+
collections_dict, collection, message_id, sns_timestamp
261+
)
262+
else:
263+
raise ValueError(
264+
f"expected either a 'Feature' or a 'Collection', received a {message_data['type']}"
265+
)
266+
267+
logger.debug(f"[{message_id}] Successfully processed.")
268+
return None
269+
270+
except (ValueError, KeyError, ValidationError, json.JSONDecodeError) as e:
271+
logger.error(f"[{message_id}] Failed with error: {e}", extra=record)
272+
return {"itemIdentifier": message_id}
273+
except Exception as e:
274+
logger.error(f"[{message_id}] Unexpected error: {e}", extra=record)
275+
return {"itemIdentifier": message_id}
276+
277+
278+
def load_collections_to_db(
279+
collections_dict: DefaultDict[str, Tuple[Dict[str, Any], str, datetime]],
280+
pgstac_dsn: str,
281+
) -> List[BatchItemFailure]:
282+
"""Load collections to database and return failures."""
283+
if not collections_dict:
284+
return []
285+
286+
collections = [collection for collection, _, _ in collections_dict.values()]
287+
message_ids = [msg_id for _, msg_id, _ in collections_dict.values()]
288+
289+
try:
290+
with PgstacDB(dsn=pgstac_dsn) as db:
291+
loader = Loader(db=db)
292+
logger.info("loading collections into database.")
293+
loader.load_collections(
294+
file=collections, # type: ignore
295+
insert_mode=Methods.upsert,
296+
)
297+
logger.info(f"successfully loaded {len(collections)} collections.")
298+
return []
299+
except Exception as e:
300+
logger.error(f"failed to load collections: {str(e)}")
301+
return [{"itemIdentifier": message_id} for message_id in message_ids]
302+
303+
304+
def ensure_collection_exists(
305+
db: PgstacDB, loader: Loader, collection_id: str, items: List[Dict[str, Any]]
306+
) -> None:
307+
"""Create a placeholder collection if it doesn't exist and environment allows."""
308+
if not os.getenv("CREATE_COLLECTIONS_IF_MISSING"):
309+
return
310+
311+
collection_exists = db.query_one(
312+
f"SELECT count(*) as count from collections where id = '{collection_id}'"
313+
)
314+
if not collection_exists:
315+
logger.info(
316+
f"[{collection_id}] loading collection into database because it is missing."
317+
)
318+
collection = Collection(
319+
id=collection_id,
320+
description=collection_id,
321+
links=Links([Link(href="placeholder", rel="self")]),
322+
type="Collection",
323+
license="proprietary",
324+
extent=Extent(
325+
spatial=SpatialExtent(bbox=[(-180, -90, 180, 90)]),
326+
temporal=TimeInterval(interval=[[None, None]]),
327+
),
328+
stac_version=items[0]["stac_version"],
329+
)
330+
loader.load_collections(
331+
[collection.model_dump()], # type: ignore
332+
insert_mode=Methods.upsert,
333+
)
334+
335+
336+
def load_items_for_collection(
337+
collection_id: str,
338+
items_dict: Dict[str, Tuple[Dict[str, Any], str, datetime]],
339+
pgstac_dsn: str,
340+
) -> List[BatchItemFailure]:
341+
"""Load items for a single collection to database and return failures."""
342+
items = [item_data for item_data, _, _ in items_dict.values()]
343+
message_ids = [msg_id for _, msg_id, _ in items_dict.values()]
344+
345+
logger.debug(
346+
f"[{collection_id}] Processing {len(items)} unique items from {len(items_dict)} dict entries. Item IDs: {list(items_dict.keys())}"
347+
)
348+
349+
try:
350+
with PgstacDB(dsn=pgstac_dsn) as db:
351+
loader = Loader(db=db)
352+
ensure_collection_exists(db, loader, collection_id, items)
353+
354+
logger.info(f"[{collection_id}] loading items into database.")
355+
loader.load_items(
356+
file=items, # type: ignore
357+
insert_mode=Methods.upsert,
358+
)
359+
logger.info(f"[{collection_id}] successfully loaded {len(items)} items.")
360+
return []
361+
except Exception as e:
362+
logger.error(f"[{collection_id}] failed to load items: {str(e)}")
363+
return [{"itemIdentifier": msg_id} for msg_id in message_ids]
364+
365+
160366
def handler(
161367
event: Dict[str, Any], context: Context
162368
) -> Optional[PartialBatchFailureResponse]:
@@ -171,131 +377,25 @@ def handler(
171377
pgstac_dsn = get_pgstac_dsn()
172378

173379
batch_failures: List[BatchItemFailure] = []
174-
175-
collections_dict: DefaultDict[str, Tuple[Dict[str, Any], str]] = defaultdict()
176-
# Track items by collection and item id to deduplicate within a batch
177-
# Maps: collection_id -> item_id -> (item_data, message_id)
178-
items_by_collection: DefaultDict[str, Dict[str, Tuple[Dict[str, Any], str]]] = (
179-
defaultdict(dict)
380+
collections_dict: DefaultDict[str, Tuple[Dict[str, Any], str, datetime]] = (
381+
defaultdict()
180382
)
383+
items_by_collection: DefaultDict[
384+
str, Dict[str, Tuple[Dict[str, Any], str, datetime]]
385+
] = defaultdict(dict)
181386

182387
for record in records:
183-
message_id = record.get("messageId")
184-
if not message_id:
185-
logger.warning("Record missing messageId, cannot report failure for it.")
186-
continue
388+
failure = process_record(record, collections_dict, items_by_collection)
389+
if failure:
390+
batch_failures.append(failure)
187391

188-
try:
189-
sqs_body_str = record["body"]
190-
logger.debug(f"[{message_id}] SQS message body: {sqs_body_str}")
191-
sns_notification = json.loads(sqs_body_str)
192-
193-
message_str = sns_notification["Message"]
194-
logger.debug(f"[{message_id}] SNS Message content: {message_str}")
195-
196-
if is_s3_event(message_str):
197-
logger.debug(f"[{message_id}] Processing S3 event notification")
198-
message_data = process_s3_event(message_str)
199-
else:
200-
message_data = json.loads(message_str)
201-
202-
if message_data["type"] == "Feature":
203-
item = Item(**message_data)
204-
205-
if not item.collection:
206-
raise KeyError(f"item {item.id} is missing a collection id!")
207-
208-
# Store item by id, replacing any previous version in this batch
209-
items_by_collection[item.collection][item.id] = (
210-
item.model_dump(mode="json"),
211-
message_id,
212-
)
213-
elif message_data["type"] == "Collection":
214-
collection = Collection(**message_data)
215-
collections_dict[collection.id] = (
216-
collection.model_dump(mode="json"),
217-
message_id,
218-
)
219-
else:
220-
raise ValueError(
221-
f"expected either a 'Feature' or a 'Collection', received a {message_data['type']}"
222-
)
223-
224-
logger.debug(f"[{message_id}] Successfully processed.")
225-
226-
except (ValueError, KeyError, ValidationError, json.JSONDecodeError) as e:
227-
logger.error(f"[{message_id}] Failed with error: {e}", extra=record)
228-
batch_failures.append({"itemIdentifier": message_id})
229-
except Exception as e:
230-
logger.error(f"[{message_id}] Unexpected error: {e}", extra=record)
231-
batch_failures.append({"itemIdentifier": message_id})
232-
233-
if collections_dict:
234-
collections = [collection for collection, _ in collections_dict.values()]
235-
message_ids = [msg_id for _, msg_id in collections_dict.values()]
236-
try:
237-
with PgstacDB(dsn=pgstac_dsn) as db:
238-
loader = Loader(db=db)
239-
logger.info("loading collections into database.")
240-
loader.load_collections(
241-
file=collections, # type: ignore
242-
insert_mode=Methods.upsert,
243-
)
244-
logger.info(f"successfully loaded {len(collections)} collections.")
245-
except Exception as e:
246-
logger.error(f"failed to load collections: {str(e)}")
247-
batch_failures.extend(
248-
[{"itemIdentifier": message_id} for message_id in message_ids]
249-
)
392+
batch_failures.extend(load_collections_to_db(collections_dict, pgstac_dsn))
250393

251394
for collection_id, items_dict in items_by_collection.items():
252-
# Extract items and message_ids from the dict structure
253-
items = [item_data for item_data, _ in items_dict.values()]
254-
message_ids = [msg_id for _, msg_id in items_dict.values()]
255-
256-
logger.debug(
257-
f"[{collection_id}] Processing {len(items)} unique items from {len(items_dict)} dict entries. Item IDs: {list(items_dict.keys())}"
395+
batch_failures.extend(
396+
load_items_for_collection(collection_id, items_dict, pgstac_dsn)
258397
)
259398

260-
try:
261-
with PgstacDB(dsn=pgstac_dsn) as db:
262-
loader = Loader(db=db)
263-
if os.getenv("CREATE_COLLECTIONS_IF_MISSING"):
264-
collection_exists = db.query_one(
265-
f"SELECT count(*) as count from collections where id = '{collection_id}'"
266-
)
267-
if not collection_exists:
268-
logger.info(
269-
f"[{collection_id}] loading collection into database because it is missing."
270-
)
271-
collection = Collection(
272-
id=collection_id,
273-
description=collection_id,
274-
links=Links([Link(href="placeholder", rel="self")]),
275-
type="Collection",
276-
license="proprietary",
277-
extent=Extent(
278-
spatial=SpatialExtent(bbox=[[-180, -90, 180, 90]]),
279-
temporal=TimeInterval(interval=[[None, None]]),
280-
),
281-
stac_version=items[0]["stac_version"],
282-
)
283-
loader.load_collections(
284-
[collection.model_dump()], # type: ignore
285-
insert_mode=Methods.upsert,
286-
)
287-
288-
logger.info(f"[{collection_id}] loading items into database.")
289-
loader.load_items(
290-
file=items, # type: ignore
291-
insert_mode=Methods.upsert,
292-
)
293-
logger.info(f"[{collection_id}] successfully loaded {len(items)} items.")
294-
except Exception as e:
295-
logger.error(f"[{collection_id}] failed to load items: {str(e)}")
296-
297-
batch_failures.extend([{"itemIdentifier": msg_id} for msg_id in message_ids])
298-
299399
if batch_failures:
300400
logger.warning(
301401
f"Finished processing batch. {len(batch_failures)} failure(s) reported."

0 commit comments

Comments
 (0)