Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions src/chat_sdk/adapters/discord/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
DiscordThreadId,
InteractionResponseType,
)
from chat_sdk.emoji import convert_emoji_placeholders
from chat_sdk.emoji import convert_emoji_placeholders, get_emoji, resolve_emoji_from_gchat
from chat_sdk.logger import ConsoleLogger, Logger
from chat_sdk.shared.adapter_utils import extract_card
from chat_sdk.shared.adapter_utils import extract_card, extract_files
from chat_sdk.shared.errors import NetworkError, ValidationError
from chat_sdk.types import (
ActionEvent,
Expand All @@ -50,6 +50,7 @@
EmojiValue,
FetchOptions,
FetchResult,
FileUpload,
FormattedContent,
Message,
MessageMetadata,
Expand Down Expand Up @@ -677,13 +678,21 @@ async def _handle_forwarded_reaction(
emoji_id = emoji_data.get("id")
raw_emoji = f"<:{emoji_name}:{emoji_id}>" if emoji_id else emoji_name

# Normalize emoji through the emoji resolver
if emoji_name and not emoji_id:
# Standard unicode emoji -- resolve through gchat (unicode) resolver
normalized = resolve_emoji_from_gchat(emoji_name)
else:
# Custom emoji -- use custom:{id} key or raw name
normalized = get_emoji(f"custom:{emoji_id}" if emoji_id else emoji_name)

self._chat.process_reaction(
ReactionEvent(
adapter=self,
thread=None,
thread_id=thread_id,
message_id=data.get("message_id", ""),
emoji=EmojiValue(name=emoji_name),
emoji=normalized,
raw_emoji=raw_emoji,
added=added,
user=Author(
Expand Down Expand Up @@ -730,20 +739,59 @@ async def post_message(
if components:
payload["components"] = components

# --- Handle file attachments via multipart/form-data ---
files = extract_files(message)

# --- Resolve deferred slash-command interaction if pending ---
req_ctx = self._request_context.get()
slash_ctx = req_ctx.slash_command if req_ctx else None
if slash_ctx and not slash_ctx.initial_response_sent:
slash_ctx.initial_response_sent = True
self._logger.debug(
"Discord API: PATCH deferred interaction response",
{
"channelId": channel_id,
"contentLength": len(payload.get("content", "")),
"embedCount": len(embeds),
"componentCount": len(components),
"fileCount": len(files),
},
)

result = await self._discord_fetch(
f"/webhooks/{self._application_id}/{slash_ctx.interaction_token}/messages/@original",
"PATCH",
payload,
files=files or None,
)

self._logger.debug(
"Discord API: PATCH deferred interaction response completed",
{"messageId": result.get("id") if result else None},
)

return RawMessage(
id=(result or {}).get("id", ""),
thread_id=thread_id,
raw=result or {},
)

self._logger.debug(
"Discord API: POST message",
{
"channelId": channel_id,
"contentLength": len(payload.get("content", "")),
"embedCount": len(embeds),
"componentCount": len(components),
"fileCount": len(files),
},
)

result = await self._discord_fetch(
f"/channels/{channel_id}/messages",
"POST",
payload,
files=files or None,
)

self._logger.debug(
Expand Down Expand Up @@ -1255,25 +1303,48 @@ async def _discord_fetch(
path: str,
method: str,
body: Any = None,
files: list[FileUpload] | None = None,
) -> Any:
"""Make a request to the Discord API using aiohttp (lazy import)."""
"""Make a request to the Discord API using aiohttp (lazy import).

When *files* is provided the request uses ``multipart/form-data``
with a ``payload_json`` field for the JSON body and one field per
file attachment, matching the Discord API multipart upload spec.
"""
import aiohttp # lazy import

url = f"{DISCORD_API_BASE}{path}"
headers: dict[str, str] = {
"Authorization": f"Bot {self._bot_token}",
}

if body is not None:
headers["Content-Type"] = "application/json"
# Build request kwargs depending on whether we have file uploads
request_kwargs: dict[str, Any] = {}
if files:
# Multipart form-data with payload_json + file parts
form = aiohttp.FormData()
form.add_field("payload_json", json.dumps(body or {}), content_type="application/json")
for idx, file in enumerate(files):
form.add_field(
f"files[{idx}]",
file.data,
filename=file.filename,
content_type=file.mime_type or "application/octet-stream",
)
request_kwargs["data"] = form
# Do NOT set Content-Type header -- aiohttp sets the multipart boundary
else:
if body is not None:
headers["Content-Type"] = "application/json"
request_kwargs["json"] = body

async with (
aiohttp.ClientSession() as session,
session.request(
method,
url,
headers=headers,
json=body if body is not None else None,
**request_kwargs,
) as response,
):
if not response.ok:
Expand Down
14 changes: 10 additions & 4 deletions src/chat_sdk/adapters/google_chat/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,16 +501,22 @@ async def _ensure_space_subscription(self, space_name: str) -> None:
"Subscription creation already in progress, waiting",
{"spaceName": space_name},
)
await self._pending_subscriptions[space_name].wait()
pending = self._pending_subscriptions[space_name]
await pending["event"].wait()
if pending.get("error"):
raise pending["error"]
return

# Create the subscription
event = asyncio.Event()
self._pending_subscriptions[space_name] = event
pending_entry: dict[str, Any] = {"event": asyncio.Event(), "error": None}
self._pending_subscriptions[space_name] = pending_entry
try:
await self._create_space_subscription_with_cache(space_name, cache_key)
except Exception as e:
pending_entry["error"] = e
raise
finally:
event.set()
pending_entry["event"].set()
self._pending_subscriptions.pop(space_name, None)

async def _create_space_subscription_with_cache(
Expand Down
3 changes: 2 additions & 1 deletion src/chat_sdk/adapters/linear/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ async def fetch_thread(self, thread_id: str) -> ThreadInfo:
channel_name=f"{issue.get('identifier', '')}: {issue.get('title', '')}",
is_dm=False,
metadata={
"issue_id": decoded.issue_id,
"issueId": decoded.issue_id,
"issue_id": decoded.issue_id, # snake_case alias for compatibility
"identifier": issue.get("identifier"),
"title": issue.get("title"),
"url": issue.get("url"),
Expand Down
18 changes: 11 additions & 7 deletions src/chat_sdk/adapters/slack/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import asyncio
import base64
import contextvars
import hashlib
import hmac
import json
Expand Down Expand Up @@ -681,18 +682,21 @@ async def handle_webhook(self, request: Any, options: WebhookOptions | None = No
"headers": {"Content-Type": "application/json"},
}

# Multi-workspace: resolve token before processing events
# Multi-workspace: resolve token before processing events.
# Use contextvars.copy_context() so the ContextVar value persists into
# any async tasks spawned by _process_event_payload (e.g. process_message
# creates a task via asyncio.create_task). The copied context is
# isolated -- the ContextVar change does not leak back to the caller
# and does not need an explicit reset.
if not self._default_bot_token and payload.get("type") == "event_callback":
team_id_event = payload.get("team_id")
if team_id_event:
ctx = await self._resolve_token_for_team(team_id_event)
if ctx:
tok = self._request_context.set(ctx)
try:
self._process_event_payload(payload, options)
return {"body": "ok", "status": 200}
finally:
self._request_context.reset(tok)
isolated = contextvars.copy_context()
isolated.run(self._request_context.set, ctx)
isolated.run(self._process_event_payload, payload, options)
return {"body": "ok", "status": 200}
self._logger.warn("Could not resolve token for team", {"teamId": team_id_event})
return {"body": "ok", "status": 200}

Expand Down
6 changes: 4 additions & 2 deletions src/chat_sdk/adapters/teams/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def _handle_reaction_activity(
thread=None,
adapter=self,
raw=activity,
)
),
options,
)

for reaction in activity.get("reactionsRemoved", []):
Expand All @@ -485,7 +486,8 @@ def _handle_reaction_activity(
thread=None,
adapter=self,
raw=activity,
)
),
options,
)

def _parse_teams_message(
Expand Down
6 changes: 5 additions & 1 deletion src/chat_sdk/adapters/whatsapp/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,13 @@ async def download_media(self, media_id: str) -> bytes:
f"Media download URL host is not an allowed Meta domain: {host}",
)

# Step 2: Download the actual file (no Bearer token -- CDN URLs are pre-signed)
# Step 2: Download the actual file.
# The WhatsApp Cloud API requires the Bearer token for media downloads
# (the URL is not pre-signed). The SSRF domain validation above ensures
# we only send the token to legitimate Meta/WhatsApp domains.
async with session.get(
download_url,
headers={"Authorization": f"Bearer {self._access_token}"},
) as data_response:
if data_response.status != 200:
self._logger.error(
Expand Down
56 changes: 51 additions & 5 deletions src/chat_sdk/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ModalCloseEvent,
ModalResponse,
ModalSubmitEvent,
OnLockConflict,
QueueEntry,
ReactionEvent,
SlashCommandEvent,
Expand Down Expand Up @@ -228,6 +229,7 @@ def __init__(self, config: ChatConfig | None = None, **kwargs: Any) -> None:
self._fallback_streaming_placeholder_text = config.fallback_streaming_placeholder_text
self._dedupe_ttl_ms = config.dedupe_ttl_ms or DEDUPE_TTL_MS
self._lock_scope_config = config.lock_scope
self._on_lock_conflict: OnLockConflict | None = config.on_lock_conflict

# -- Concurrency config -----------------------------------------------
concurrency = config.concurrency
Expand Down Expand Up @@ -1440,11 +1442,14 @@ async def _handle_drop(
) -> None:
lock = await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS)
if lock is None:
self._logger.warn("Could not acquire lock on thread", {"thread_id": thread_id, "lock_key": lock_key})
raise LockError(
thread_id,
f"Could not acquire lock on thread {thread_id}. Another instance may be processing.",
)
# Lock acquisition failed -- consult on_lock_conflict policy
lock = await self._resolve_lock_conflict(thread_id, lock_key, message)
if lock is None:
self._logger.warn("Could not acquire lock on thread", {"thread_id": thread_id, "lock_key": lock_key})
raise LockError(
thread_id,
f"Could not acquire lock on thread {thread_id}. Another instance may be processing.",
)

self._logger.debug("Lock acquired", {"thread_id": thread_id, "lock_key": lock_key, "token": lock.token})
try:
Expand All @@ -1453,6 +1458,47 @@ async def _handle_drop(
await self._state_adapter.release_lock(lock)
self._logger.debug("Lock released", {"thread_id": thread_id, "lock_key": lock_key})

async def _resolve_lock_conflict(
self,
thread_id: str,
lock_key: str,
message: Message,
) -> Lock | None:
"""Attempt to resolve a lock conflict based on the ``on_lock_conflict`` policy.

Returns a :class:`Lock` if the conflict was resolved and the lock
was successfully re-acquired, or ``None`` if the message should be
dropped.
"""
conflict = self._on_lock_conflict

if conflict is None or conflict == "drop":
return None

if conflict == "force":
self._logger.info(
"Force-releasing lock due to on_lock_conflict='force'",
{"thread_id": thread_id, "lock_key": lock_key},
)
await self._state_adapter.force_release_lock(lock_key)
return await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS)

# Callable handler -- invoke and inspect result
if callable(conflict):
result = conflict(thread_id, message)
# Support both sync and async callables
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
result = await result
if result:
self._logger.info(
"on_lock_conflict callback returned True, force-releasing lock",
{"thread_id": thread_id, "lock_key": lock_key},
)
await self._state_adapter.force_release_lock(lock_key)
return await self._state_adapter.acquire_lock(lock_key, DEFAULT_LOCK_TTL_MS)

return None

# -- Queue / Debounce strategy -------------------------------------------

async def _handle_queue_or_debounce(
Expand Down
Loading
Loading