|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import base64 |
4 | 5 | import inspect |
| 6 | +import json |
5 | 7 | import time |
6 | 8 | import traceback |
7 | 9 | from collections.abc import Awaitable, Callable, Mapping |
| 10 | +from dataclasses import dataclass |
8 | 11 | from datetime import datetime, timezone |
9 | | -from typing import Any |
| 12 | +from typing import Any, cast |
10 | 13 |
|
11 | 14 | from . import serializer |
12 | 15 | from .errors import ActivityCancelled, NonRetryableError |
13 | 16 | from .external_storage import ExternalPayloadCache, ExternalStorageDriver |
14 | 17 | from .external_task_input import ExternalTaskInput, parse_external_task_input |
15 | | -from .external_task_result import EXTERNAL_TASK_RESULT_SCHEMA, EXTERNAL_TASK_RESULT_VERSION |
| 18 | +from .external_task_result import ( |
| 19 | + EXTERNAL_TASK_RESULT_MEDIA_TYPE, |
| 20 | + EXTERNAL_TASK_RESULT_SCHEMA, |
| 21 | + EXTERNAL_TASK_RESULT_VERSION, |
| 22 | +) |
16 | 23 |
|
17 | 24 | InvocableActivityCallable = Callable[..., Any | Awaitable[Any]] |
18 | 25 |
|
@@ -315,3 +322,123 @@ async def handle_invocable_activity( |
315 | 322 | """Handle one invocable activity task with a temporary adapter instance.""" |
316 | 323 |
|
317 | 324 | return await InvocableActivityHandler(handlers, **options).handle(envelope) |
| 325 | + |
| 326 | + |
| 327 | +@dataclass(frozen=True) |
| 328 | +class InvocableHttpResponse: |
| 329 | + """Structured HTTP response from an invocable activity carrier endpoint.""" |
| 330 | + |
| 331 | + status_code: int |
| 332 | + headers: Mapping[str, str] |
| 333 | + body: str |
| 334 | + |
| 335 | + def json(self) -> dict[str, Any]: |
| 336 | + decoded = json.loads(self.body) |
| 337 | + if not isinstance(decoded, dict): |
| 338 | + raise ValueError("invocable HTTP response body is not a JSON object") |
| 339 | + return cast(dict[str, Any], decoded) |
| 340 | + |
| 341 | + |
| 342 | +async def handle_invocable_http_request( |
| 343 | + body: bytes | str | Mapping[str, Any], |
| 344 | + handlers: Mapping[str, InvocableActivityCallable], |
| 345 | + **options: Any, |
| 346 | +) -> InvocableHttpResponse: |
| 347 | + """Handle one HTTP-addressed invocable activity request. |
| 348 | +
|
| 349 | + The server expects a structured external-task result envelope on HTTP 200. |
| 350 | + Bad request bodies return HTTP 400 because no durable task identity can be |
| 351 | + recovered for a valid failure envelope. |
| 352 | + """ |
| 353 | + |
| 354 | + try: |
| 355 | + envelope = _coerce_json_object(body) |
| 356 | + except (TypeError, ValueError) as exc: |
| 357 | + return InvocableHttpResponse( |
| 358 | + status_code=400, |
| 359 | + headers={"Content-Type": "application/json"}, |
| 360 | + body=_json_dump({"error": "invalid_invocable_request", "message": str(exc)}), |
| 361 | + ) |
| 362 | + |
| 363 | + try: |
| 364 | + result = await handle_invocable_activity(envelope, handlers, **options) |
| 365 | + except Exception as exc: |
| 366 | + return InvocableHttpResponse( |
| 367 | + status_code=400, |
| 368 | + headers={"Content-Type": "application/json"}, |
| 369 | + body=_json_dump({"error": "invalid_invocable_request", "message": str(exc)}), |
| 370 | + ) |
| 371 | + |
| 372 | + return InvocableHttpResponse( |
| 373 | + status_code=200, |
| 374 | + headers={"Content-Type": EXTERNAL_TASK_RESULT_MEDIA_TYPE}, |
| 375 | + body=_json_dump(result), |
| 376 | + ) |
| 377 | + |
| 378 | + |
| 379 | +async def handle_invocable_lambda_event( |
| 380 | + event: Mapping[str, Any], |
| 381 | + handlers: Mapping[str, InvocableActivityCallable], |
| 382 | + **options: Any, |
| 383 | +) -> dict[str, Any]: |
| 384 | + """Handle an AWS Lambda / API Gateway style invocable activity event.""" |
| 385 | + |
| 386 | + try: |
| 387 | + body = _lambda_event_body(event) |
| 388 | + except (TypeError, ValueError) as exc: |
| 389 | + response = InvocableHttpResponse( |
| 390 | + status_code=400, |
| 391 | + headers={"Content-Type": "application/json"}, |
| 392 | + body=_json_dump({"error": "invalid_invocable_request", "message": str(exc)}), |
| 393 | + ) |
| 394 | + else: |
| 395 | + response = await handle_invocable_http_request(body, handlers, **options) |
| 396 | + |
| 397 | + return { |
| 398 | + "statusCode": response.status_code, |
| 399 | + "headers": dict(response.headers), |
| 400 | + "body": response.body, |
| 401 | + "isBase64Encoded": False, |
| 402 | + } |
| 403 | + |
| 404 | + |
| 405 | +def lambda_invocable_activity_handler( |
| 406 | + handlers: Mapping[str, InvocableActivityCallable], |
| 407 | + **options: Any, |
| 408 | +) -> Callable[[Mapping[str, Any], Any], dict[str, Any]]: |
| 409 | + """Build a synchronous AWS Lambda handler for invocable activities.""" |
| 410 | + |
| 411 | + def _handler(event: Mapping[str, Any], context: Any) -> dict[str, Any]: |
| 412 | + return asyncio.run(handle_invocable_lambda_event(event, handlers, **options)) |
| 413 | + |
| 414 | + return _handler |
| 415 | + |
| 416 | + |
| 417 | +def _coerce_json_object(body: bytes | str | Mapping[str, Any]) -> dict[str, Any]: |
| 418 | + if isinstance(body, Mapping): |
| 419 | + return dict(body) |
| 420 | + |
| 421 | + if isinstance(body, bytes): |
| 422 | + body = body.decode("utf-8") |
| 423 | + |
| 424 | + if not isinstance(body, str): |
| 425 | + raise TypeError("request body must be bytes, str, or a JSON object") |
| 426 | + |
| 427 | + decoded = json.loads(body) |
| 428 | + if not isinstance(decoded, dict): |
| 429 | + raise ValueError("request body must decode to a JSON object") |
| 430 | + |
| 431 | + return cast(dict[str, Any], decoded) |
| 432 | + |
| 433 | + |
| 434 | +def _lambda_event_body(event: Mapping[str, Any]) -> bytes | str | Mapping[str, Any]: |
| 435 | + body = event.get("body") |
| 436 | + if isinstance(body, str) and event.get("isBase64Encoded") is True: |
| 437 | + return base64.b64decode(body) |
| 438 | + if isinstance(body, (bytes, str, Mapping)): |
| 439 | + return cast(bytes | str | Mapping[str, Any], body) |
| 440 | + raise ValueError("Lambda event body must contain the invocable request JSON") |
| 441 | + |
| 442 | + |
| 443 | +def _json_dump(value: Mapping[str, Any]) -> str: |
| 444 | + return json.dumps(value, separators=(",", ":"), sort_keys=True) |
0 commit comments