diff --git a/src/chat_sdk/adapters/discord/adapter.py b/src/chat_sdk/adapters/discord/adapter.py index ce4da92..8d452e0 100644 --- a/src/chat_sdk/adapters/discord/adapter.py +++ b/src/chat_sdk/adapters/discord/adapter.py @@ -36,9 +36,9 @@ DiscordThreadId, InteractionResponseType, ) -from chat_sdk.emoji import convert_emoji_placeholders +from chat_sdk.emoji import convert_emoji_placeholders, get_emoji, resolve_emoji_from_gchat from chat_sdk.logger import ConsoleLogger, Logger -from chat_sdk.shared.adapter_utils import extract_card +from chat_sdk.shared.adapter_utils import extract_card, extract_files from chat_sdk.shared.errors import NetworkError, ValidationError from chat_sdk.types import ( ActionEvent, @@ -50,6 +50,7 @@ EmojiValue, FetchOptions, FetchResult, + FileUpload, FormattedContent, Message, MessageMetadata, @@ -677,13 +678,21 @@ async def _handle_forwarded_reaction( emoji_id = emoji_data.get("id") raw_emoji = f"<:{emoji_name}:{emoji_id}>" if emoji_id else emoji_name + # Normalize emoji through the emoji resolver + if emoji_name and not emoji_id: + # Standard unicode emoji -- resolve through gchat (unicode) resolver + normalized = resolve_emoji_from_gchat(emoji_name) + else: + # Custom emoji -- use custom:{id} key or raw name + normalized = get_emoji(f"custom:{emoji_id}" if emoji_id else emoji_name) + self._chat.process_reaction( ReactionEvent( adapter=self, thread=None, thread_id=thread_id, message_id=data.get("message_id", ""), - emoji=EmojiValue(name=emoji_name), + emoji=normalized, raw_emoji=raw_emoji, added=added, user=Author( @@ -730,6 +739,43 @@ async def post_message( if components: payload["components"] = components + # --- Handle file attachments via multipart/form-data --- + files = extract_files(message) + + # --- Resolve deferred slash-command interaction if pending --- + req_ctx = self._request_context.get() + slash_ctx = req_ctx.slash_command if req_ctx else None + if slash_ctx and not slash_ctx.initial_response_sent: + slash_ctx.initial_response_sent = True + self._logger.debug( + "Discord API: PATCH deferred interaction response", + { + "channelId": channel_id, + "contentLength": len(payload.get("content", "")), + "embedCount": len(embeds), + "componentCount": len(components), + "fileCount": len(files), + }, + ) + + result = await self._discord_fetch( + f"/webhooks/{self._application_id}/{slash_ctx.interaction_token}/messages/@original", + "PATCH", + payload, + files=files or None, + ) + + self._logger.debug( + "Discord API: PATCH deferred interaction response completed", + {"messageId": result.get("id") if result else None}, + ) + + return RawMessage( + id=(result or {}).get("id", ""), + thread_id=thread_id, + raw=result or {}, + ) + self._logger.debug( "Discord API: POST message", { @@ -737,6 +783,7 @@ async def post_message( "contentLength": len(payload.get("content", "")), "embedCount": len(embeds), "componentCount": len(components), + "fileCount": len(files), }, ) @@ -744,6 +791,7 @@ async def post_message( f"/channels/{channel_id}/messages", "POST", payload, + files=files or None, ) self._logger.debug( @@ -1255,8 +1303,14 @@ async def _discord_fetch( path: str, method: str, body: Any = None, + files: list[FileUpload] | None = None, ) -> Any: - """Make a request to the Discord API using aiohttp (lazy import).""" + """Make a request to the Discord API using aiohttp (lazy import). + + When *files* is provided the request uses ``multipart/form-data`` + with a ``payload_json`` field for the JSON body and one field per + file attachment, matching the Discord API multipart upload spec. + """ import aiohttp # lazy import url = f"{DISCORD_API_BASE}{path}" @@ -1264,8 +1318,25 @@ async def _discord_fetch( "Authorization": f"Bot {self._bot_token}", } - if body is not None: - headers["Content-Type"] = "application/json" + # Build request kwargs depending on whether we have file uploads + request_kwargs: dict[str, Any] = {} + if files: + # Multipart form-data with payload_json + file parts + form = aiohttp.FormData() + form.add_field("payload_json", json.dumps(body or {}), content_type="application/json") + for idx, file in enumerate(files): + form.add_field( + f"files[{idx}]", + file.data, + filename=file.filename, + content_type=file.mime_type or "application/octet-stream", + ) + request_kwargs["data"] = form + # Do NOT set Content-Type header -- aiohttp sets the multipart boundary + else: + if body is not None: + headers["Content-Type"] = "application/json" + request_kwargs["json"] = body async with ( aiohttp.ClientSession() as session, @@ -1273,7 +1344,7 @@ async def _discord_fetch( method, url, headers=headers, - json=body if body is not None else None, + **request_kwargs, ) as response, ): if not response.ok: diff --git a/src/chat_sdk/adapters/google_chat/adapter.py b/src/chat_sdk/adapters/google_chat/adapter.py index f0cfcd7..ec11f1f 100644 --- a/src/chat_sdk/adapters/google_chat/adapter.py +++ b/src/chat_sdk/adapters/google_chat/adapter.py @@ -501,16 +501,22 @@ async def _ensure_space_subscription(self, space_name: str) -> None: "Subscription creation already in progress, waiting", {"spaceName": space_name}, ) - await self._pending_subscriptions[space_name].wait() + pending = self._pending_subscriptions[space_name] + await pending["event"].wait() + if pending.get("error"): + raise pending["error"] return # Create the subscription - event = asyncio.Event() - self._pending_subscriptions[space_name] = event + pending_entry: dict[str, Any] = {"event": asyncio.Event(), "error": None} + self._pending_subscriptions[space_name] = pending_entry try: await self._create_space_subscription_with_cache(space_name, cache_key) + except Exception as e: + pending_entry["error"] = e + raise finally: - event.set() + pending_entry["event"].set() self._pending_subscriptions.pop(space_name, None) async def _create_space_subscription_with_cache( diff --git a/src/chat_sdk/adapters/linear/adapter.py b/src/chat_sdk/adapters/linear/adapter.py index 6217b6c..d0f840e 100644 --- a/src/chat_sdk/adapters/linear/adapter.py +++ b/src/chat_sdk/adapters/linear/adapter.py @@ -785,7 +785,8 @@ async def fetch_thread(self, thread_id: str) -> ThreadInfo: channel_name=f"{issue.get('identifier', '')}: {issue.get('title', '')}", is_dm=False, metadata={ - "issue_id": decoded.issue_id, + "issueId": decoded.issue_id, + "issue_id": decoded.issue_id, # snake_case alias for compatibility "identifier": issue.get("identifier"), "title": issue.get("title"), "url": issue.get("url"), diff --git a/src/chat_sdk/adapters/slack/adapter.py b/src/chat_sdk/adapters/slack/adapter.py index 3d99353..d75b00b 100644 --- a/src/chat_sdk/adapters/slack/adapter.py +++ b/src/chat_sdk/adapters/slack/adapter.py @@ -10,6 +10,7 @@ import asyncio import base64 +import contextvars import hashlib import hmac import json @@ -681,18 +682,21 @@ async def handle_webhook(self, request: Any, options: WebhookOptions | None = No "headers": {"Content-Type": "application/json"}, } - # Multi-workspace: resolve token before processing events + # Multi-workspace: resolve token before processing events. + # Use contextvars.copy_context() so the ContextVar value persists into + # any async tasks spawned by _process_event_payload (e.g. process_message + # creates a task via asyncio.create_task). The copied context is + # isolated -- the ContextVar change does not leak back to the caller + # and does not need an explicit reset. if not self._default_bot_token and payload.get("type") == "event_callback": team_id_event = payload.get("team_id") if team_id_event: ctx = await self._resolve_token_for_team(team_id_event) if ctx: - tok = self._request_context.set(ctx) - try: - self._process_event_payload(payload, options) - return {"body": "ok", "status": 200} - finally: - self._request_context.reset(tok) + isolated = contextvars.copy_context() + isolated.run(self._request_context.set, ctx) + isolated.run(self._process_event_payload, payload, options) + return {"body": "ok", "status": 200} self._logger.warn("Could not resolve token for team", {"teamId": team_id_event}) return {"body": "ok", "status": 200} diff --git a/src/chat_sdk/adapters/teams/adapter.py b/src/chat_sdk/adapters/teams/adapter.py index 53d13e2..2984e82 100644 --- a/src/chat_sdk/adapters/teams/adapter.py +++ b/src/chat_sdk/adapters/teams/adapter.py @@ -469,7 +469,8 @@ def _handle_reaction_activity( thread=None, adapter=self, raw=activity, - ) + ), + options, ) for reaction in activity.get("reactionsRemoved", []): @@ -485,7 +486,8 @@ def _handle_reaction_activity( thread=None, adapter=self, raw=activity, - ) + ), + options, ) def _parse_teams_message( diff --git a/src/chat_sdk/adapters/whatsapp/adapter.py b/src/chat_sdk/adapters/whatsapp/adapter.py index f12a46b..8aec191 100644 --- a/src/chat_sdk/adapters/whatsapp/adapter.py +++ b/src/chat_sdk/adapters/whatsapp/adapter.py @@ -660,9 +660,13 @@ async def download_media(self, media_id: str) -> bytes: f"Media download URL host is not an allowed Meta domain: {host}", ) - # Step 2: Download the actual file (no Bearer token -- CDN URLs are pre-signed) + # Step 2: Download the actual file. + # The WhatsApp Cloud API requires the Bearer token for media downloads + # (the URL is not pre-signed). The SSRF domain validation above ensures + # we only send the token to legitimate Meta/WhatsApp domains. async with session.get( download_url, + headers={"Authorization": f"Bearer {self._access_token}"}, ) as data_response: if data_response.status != 200: self._logger.error( diff --git a/src/chat_sdk/chat.py b/src/chat_sdk/chat.py index 661651f..d19c086 100644 --- a/src/chat_sdk/chat.py +++ b/src/chat_sdk/chat.py @@ -47,6 +47,7 @@ ModalCloseEvent, ModalResponse, ModalSubmitEvent, + OnLockConflict, QueueEntry, ReactionEvent, SlashCommandEvent, @@ -228,6 +229,7 @@ def __init__(self, config: ChatConfig | None = None, **kwargs: Any) -> None: self._fallback_streaming_placeholder_text = config.fallback_streaming_placeholder_text self._dedupe_ttl_ms = config.dedupe_ttl_ms or DEDUPE_TTL_MS self._lock_scope_config = config.lock_scope + self._on_lock_conflict: OnLockConflict | None = config.on_lock_conflict # -- Concurrency config ----------------------------------------------- concurrency = config.concurrency @@ -1440,11 +1442,14 @@ async def _handle_drop( ) -> None: lock = await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS) if lock is None: - self._logger.warn("Could not acquire lock on thread", {"thread_id": thread_id, "lock_key": lock_key}) - raise LockError( - thread_id, - f"Could not acquire lock on thread {thread_id}. Another instance may be processing.", - ) + # Lock acquisition failed -- consult on_lock_conflict policy + lock = await self._resolve_lock_conflict(thread_id, lock_key, message) + if lock is None: + self._logger.warn("Could not acquire lock on thread", {"thread_id": thread_id, "lock_key": lock_key}) + raise LockError( + thread_id, + f"Could not acquire lock on thread {thread_id}. Another instance may be processing.", + ) self._logger.debug("Lock acquired", {"thread_id": thread_id, "lock_key": lock_key, "token": lock.token}) try: @@ -1453,6 +1458,47 @@ async def _handle_drop( await self._state_adapter.release_lock(lock) self._logger.debug("Lock released", {"thread_id": thread_id, "lock_key": lock_key}) + async def _resolve_lock_conflict( + self, + thread_id: str, + lock_key: str, + message: Message, + ) -> Lock | None: + """Attempt to resolve a lock conflict based on the ``on_lock_conflict`` policy. + + Returns a :class:`Lock` if the conflict was resolved and the lock + was successfully re-acquired, or ``None`` if the message should be + dropped. + """ + conflict = self._on_lock_conflict + + if conflict is None or conflict == "drop": + return None + + if conflict == "force": + self._logger.info( + "Force-releasing lock due to on_lock_conflict='force'", + {"thread_id": thread_id, "lock_key": lock_key}, + ) + await self._state_adapter.force_release_lock(lock_key) + return await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS) + + # Callable handler -- invoke and inspect result + if callable(conflict): + result = conflict(thread_id, message) + # Support both sync and async callables + if asyncio.iscoroutine(result) or asyncio.isfuture(result): + result = await result + if result: + self._logger.info( + "on_lock_conflict callback returned True, force-releasing lock", + {"thread_id": thread_id, "lock_key": lock_key}, + ) + await self._state_adapter.force_release_lock(lock_key) + return await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS) + + return None + # -- Queue / Debounce strategy ------------------------------------------- async def _handle_queue_or_debounce( diff --git a/src/chat_sdk/types.py b/src/chat_sdk/types.py index d68da7a..1ce712f 100644 --- a/src/chat_sdk/types.py +++ b/src/chat_sdk/types.py @@ -39,6 +39,7 @@ LockScope = Literal["thread", "channel"] ConcurrencyStrategy = Literal["drop", "queue", "debounce", "concurrent"] +OnLockConflict = Literal["drop", "force"] | Callable[..., Awaitable[bool] | bool] FetchDirection = Literal["forward", "backward"] # Well-known emoji names @@ -1252,6 +1253,7 @@ class ChatConfig: lock_scope: LockScope | Callable[..., LockScope | Awaitable[LockScope]] | None = None logger: Logger | LogLevel | None = None message_history: dict[str, Any] | None = None + on_lock_conflict: OnLockConflict | None = None streaming_update_interval_ms: int = 500 @@ -1373,6 +1375,16 @@ async def fetch_metadata(self) -> ChannelInfo: """Fetch channel metadata from the platform.""" ... + def messages(self) -> AsyncIterable[Message]: + """Iterate messages newest first (backward from most recent). + + Auto-paginates lazily -- only fetches pages as consumed. + + Note: This is a method, not a property. Call with ``()``: + ``async for msg in channel.messages(): ...`` + """ + ... + def threads(self) -> AsyncIterable[ThreadSummary]: """Iterate threads in this channel, most recently active first. @@ -1405,6 +1417,9 @@ def messages(self) -> AsyncIterable[Message]: """Iterate messages newest first (backward from most recent). Auto-paginates lazily -- only fetches pages as consumed. + + Note: This is a method, not a property. Call with ``()``: + ``async for msg in thread.messages(): ...`` """ ... @@ -1412,6 +1427,9 @@ def all_messages(self) -> AsyncIterable[Message]: """Iterate messages oldest first (forward from beginning). Auto-paginates lazily. + + Note: This is a method, not a property. Call with ``()``: + ``async for msg in thread.all_messages(): ...`` """ ... diff --git a/tests/test_gchat_api.py b/tests/test_gchat_api.py index 58bed8d..4d1f98b 100644 --- a/tests/test_gchat_api.py +++ b/tests/test_gchat_api.py @@ -903,9 +903,9 @@ async def test_skips_duplicate_in_flight_subscription(self): """If a subscription is already being created, should wait rather than duplicate.""" adapter, api, state = await _init_adapter(pubsub_topic="projects/test/topics/test") - # Simulate an in-progress subscription + # Simulate an in-progress subscription (dict with event + error) event = asyncio.Event() - adapter._pending_subscriptions["spaces/TEST1"] = event + adapter._pending_subscriptions["spaces/TEST1"] = {"event": event, "error": None} # Should wait on the event and return without creating a new one async def wait_and_set(): diff --git a/tests/test_gchat_comprehensive.py b/tests/test_gchat_comprehensive.py index 64c69f8..9e2b9f4 100644 --- a/tests/test_gchat_comprehensive.py +++ b/tests/test_gchat_comprehensive.py @@ -1745,7 +1745,7 @@ async def test_skips_duplicate_in_flight_subscription(self): adapter, api, state = await _init_adapter(pubsub_topic="projects/test/topics/test") event = asyncio.Event() - adapter._pending_subscriptions["spaces/TEST1"] = event + adapter._pending_subscriptions["spaces/TEST1"] = {"event": event, "error": None} async def wait_and_set(): await asyncio.sleep(0.01)