Skip to content

Commit 8db630e

Browse files
dcramercodex
andcommitted
Complete OAuth callbacks host-side by flow state
Co-Authored-By: GPT-5 Codex <codex@openai.com>
1 parent 34a087c commit 8db630e

8 files changed

Lines changed: 586 additions & 0 deletions

File tree

packages/ash-sandbox-cli/src/ash_sandbox_cli/commands/capability.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,52 @@ def auth_complete(
258258
)
259259

260260

261+
@auth_app.command("complete-callback")
262+
def auth_complete_callback(
263+
callback_url: Annotated[
264+
str | None,
265+
typer.Option("--callback-url", help="OAuth callback URL"),
266+
] = None,
267+
code: Annotated[
268+
str | None,
269+
typer.Option("--code", help="Authorization code"),
270+
] = None,
271+
capability: Annotated[
272+
str | None,
273+
typer.Option("--capability", "-c", help="Optional namespaced capability id"),
274+
] = None,
275+
account_hint: Annotated[
276+
str | None,
277+
typer.Option("--account", help="Optional account reference hint"),
278+
] = None,
279+
) -> None:
280+
"""Complete capability auth by callback/code with host-side flow resolution."""
281+
if not callback_url and not code:
282+
typer.echo("Error: Must specify either --callback-url or --code", err=True)
283+
raise typer.Exit(1)
284+
285+
params: dict[str, Any] = {}
286+
if callback_url:
287+
params["callback_url"] = callback_url
288+
if code:
289+
params["code"] = code
290+
if capability:
291+
params["capability"] = capability
292+
if account_hint:
293+
params["account_hint"] = account_hint
294+
295+
result = _call("capability.auth.complete_callback", params)
296+
if not result.get("ok"):
297+
typer.echo("Error: capability auth completion failed", err=True)
298+
raise typer.Exit(1)
299+
typer.echo(
300+
"Capability auth completed "
301+
f"(flow_id={result.get('flow_id', '')}, "
302+
f"capability={result.get('capability', '')}, "
303+
f"account_ref={result.get('account_ref', '')})"
304+
)
305+
306+
261307
@auth_app.command("poll")
262308
def auth_poll(
263309
flow_id: Annotated[

src/ash/capabilities/manager.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,125 @@ async def auth_complete(
454454

455455
return {"ok": True, "account_ref": account_ref}
456456

457+
async def auth_complete_callback(
458+
self,
459+
*,
460+
user_id: str,
461+
callback_url: str | None = None,
462+
code: str | None = None,
463+
capability_id: str | None = None,
464+
account_hint: str | None = None,
465+
chat_id: str | None = None,
466+
chat_type: str | None = None,
467+
provider: str | None = None,
468+
thread_id: str | None = None,
469+
session_key: str | None = None,
470+
source_username: str | None = None,
471+
source_display_name: str | None = None,
472+
) -> dict[str, Any]:
473+
"""Complete pending auth by callback URL/code with host-side flow resolution.
474+
475+
Prefers deterministic state matching when callback state is present.
476+
"""
477+
normalized_user_id = _required_text(
478+
value=user_id,
479+
code="capability_invalid_input",
480+
message="user_id is required",
481+
)
482+
normalized_capability_id = (
483+
_required_capability_id(capability_id)
484+
if capability_id is not None
485+
else None
486+
)
487+
normalized_account_hint = _optional_text(account_hint)
488+
normalized_chat_id = _optional_text(chat_id)
489+
normalized_chat_type = _optional_text(chat_type)
490+
normalized_provider = _optional_text(provider)
491+
normalized_thread_id = _optional_text(thread_id)
492+
normalized_session_key = _optional_text(session_key)
493+
normalized_source_username = _optional_text(source_username)
494+
normalized_source_display_name = _optional_text(source_display_name)
495+
496+
try:
497+
normalized_completion = normalize_auth_completion(
498+
callback_url=_optional_text(callback_url),
499+
code=_optional_text(code),
500+
expected_state=None,
501+
)
502+
except AuthNormalizationError as e:
503+
raise CapabilityError(e.code, str(e)) from e
504+
505+
callback_state = _optional_text(normalized_completion.state)
506+
507+
async with self._lock:
508+
self._prune_expired_flows_locked()
509+
eligible = [
510+
flow
511+
for flow in self._auth_flows.values()
512+
if flow.user_id == normalized_user_id
513+
and (
514+
normalized_capability_id is None
515+
or flow.capability_id == normalized_capability_id
516+
)
517+
and (
518+
normalized_account_hint is None
519+
or flow.account_hint == normalized_account_hint
520+
)
521+
]
522+
523+
if not eligible:
524+
raise CapabilityError(
525+
"capability_auth_flow_invalid",
526+
"no pending auth flows match caller scope",
527+
)
528+
529+
selected: CapabilityAuthFlow | None = None
530+
if callback_state is not None:
531+
state_matches = [
532+
flow
533+
for flow in eligible
534+
if _optional_text(flow.expected_callback_state) == callback_state
535+
]
536+
if not state_matches:
537+
raise CapabilityError(
538+
"capability_auth_state_mismatch",
539+
"callback_url state does not match auth flow",
540+
)
541+
if len(state_matches) > 1:
542+
raise CapabilityError(
543+
"capability_auth_flow_ambiguous",
544+
"multiple pending auth flows matched callback state",
545+
)
546+
selected = state_matches[0]
547+
elif len(eligible) == 1:
548+
selected = eligible[0]
549+
else:
550+
raise CapabilityError(
551+
"capability_auth_flow_ambiguous",
552+
"multiple pending auth flows; callback state is required",
553+
)
554+
555+
result = await self.auth_complete(
556+
flow_id=selected.flow_id,
557+
user_id=normalized_user_id,
558+
chat_id=normalized_chat_id,
559+
chat_type=normalized_chat_type,
560+
provider=normalized_provider,
561+
thread_id=normalized_thread_id,
562+
session_key=normalized_session_key,
563+
source_username=normalized_source_username,
564+
source_display_name=normalized_source_display_name,
565+
callback_url=normalized_completion.raw_callback_url,
566+
code=normalized_completion.authorization_code,
567+
)
568+
return {
569+
"ok": bool(result.get("ok")),
570+
"account_ref": result.get("account_ref"),
571+
"flow_id": selected.flow_id,
572+
"capability": selected.capability_id,
573+
"account_hint": selected.account_hint,
574+
}
575+
457576
async def auth_poll(
458577
self,
459578
*,

src/ash/providers/telegram/handlers/message_handler.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import inspect
77
import logging
8+
import re
89
from datetime import UTC, datetime, timedelta
910
from typing import TYPE_CHECKING, Any
1011

@@ -43,6 +44,7 @@
4344
from ash.tools.registry import ToolRegistry
4445

4546
logger = logging.getLogger("telegram")
47+
_LOCALHOST_CALLBACK_URL_PATTERN = re.compile(r"https?://localhost[^\s]*[?&]code=[^\s]*")
4648

4749

4850
def _extract_tool_calls_from_session(session: SessionState) -> list[dict[str, Any]]:
@@ -216,6 +218,126 @@ def _create_tool_tracker(self, message: IncomingMessage) -> ToolTracker:
216218
skill_registry=self._skill_registry,
217219
)
218220

221+
def _get_capability_manager(self) -> Any | None:
222+
"""Best-effort lookup of the capability manager via use_skill tool wiring."""
223+
if self._tool_registry is None:
224+
return None
225+
if not hasattr(self._tool_registry, "has") or not self._tool_registry.has(
226+
"use_skill"
227+
):
228+
return None
229+
try:
230+
skill_tool = self._tool_registry.get("use_skill")
231+
except Exception:
232+
return None
233+
return getattr(skill_tool, "_capability_manager", None)
234+
235+
@staticmethod
236+
def _extract_localhost_callback_url(message_text: str) -> str | None:
237+
"""Extract a localhost OAuth callback URL from message text."""
238+
match = _LOCALHOST_CALLBACK_URL_PATTERN.search(message_text)
239+
if not match:
240+
return None
241+
callback_url = match.group(0).strip().rstrip(".,;")
242+
return callback_url or None
243+
244+
async def _try_handle_capability_oauth_callback(
245+
self, message: IncomingMessage
246+
) -> bool:
247+
"""Complete capability auth from pasted callback URL without LLM orchestration."""
248+
callback_url = self._extract_localhost_callback_url(message.text or "")
249+
if not callback_url:
250+
return False
251+
252+
manager = self._get_capability_manager()
253+
if manager is None:
254+
return False
255+
256+
try:
257+
completion = await manager.auth_complete_callback(
258+
user_id=message.user_id,
259+
callback_url=callback_url,
260+
chat_id=message.chat_id,
261+
chat_type=message.metadata.get("chat_type"),
262+
provider=self._provider.name,
263+
thread_id=message.metadata.get("thread_id"),
264+
source_username=message.username,
265+
source_display_name=message.display_name,
266+
)
267+
pending_flows = await manager.list_auth_flows(user_id=message.user_id)
268+
except Exception as e:
269+
code = getattr(e, "code", "")
270+
if isinstance(code, str) and code.startswith("capability_auth_"):
271+
reply = (
272+
"I could not apply that OAuth callback yet "
273+
f"({code}). Please continue with the latest auth URL."
274+
)
275+
await self._provider.send(
276+
OutgoingMessage(
277+
chat_id=message.chat_id,
278+
text=reply,
279+
reply_to_message_id=message.id,
280+
)
281+
)
282+
return True
283+
return False
284+
285+
capability = str(completion.get("capability", "")).strip()
286+
account_hint = str(completion.get("account_hint", "")).strip() or "default"
287+
capability_label = {
288+
"gog.email": "Gmail",
289+
"gog.calendar": "Google Calendar",
290+
}.get(capability, capability or "Google capability")
291+
292+
pending_caps = sorted(
293+
{
294+
str(flow.get("capability", "")).strip()
295+
for flow in pending_flows
296+
if isinstance(flow, dict) and flow.get("capability")
297+
}
298+
)
299+
if pending_caps:
300+
pending_names = ", ".join(
301+
sorted(
302+
{
303+
{
304+
"gog.email": "Gmail",
305+
"gog.calendar": "Google Calendar",
306+
}.get(cap, cap)
307+
for cap in pending_caps
308+
}
309+
)
310+
)
311+
reply = (
312+
f"{capability_label} connected for account '{account_hint}'. "
313+
f"Still pending: {pending_names}. Paste the next callback URL when ready."
314+
)
315+
else:
316+
reply = (
317+
f"{capability_label} connected for account '{account_hint}'. "
318+
"Google setup is complete."
319+
)
320+
321+
response_external_id = await self._provider.send(
322+
OutgoingMessage(
323+
chat_id=message.chat_id,
324+
text=reply,
325+
reply_to_message_id=message.id,
326+
)
327+
)
328+
await self._session_handler.persist_messages(
329+
chat_id=message.chat_id,
330+
user_id=message.user_id,
331+
user_message=message.text,
332+
assistant_message=reply,
333+
external_id=message.id,
334+
response_external_id=response_external_id,
335+
thread_id=message.metadata.get("thread_id"),
336+
username=message.username,
337+
display_name=message.display_name,
338+
)
339+
return True
340+
219341
def _log_response(self, text: str | None) -> None:
220342
bot_name = self._provider.bot_username or "bot"
221343
logger.info(
@@ -370,6 +492,9 @@ async def _process_single_message_inner(
370492
if isinstance(candidate, IncomingMessage):
371493
message = candidate
372494

495+
if await self._try_handle_capability_oauth_callback(message):
496+
return
497+
373498
# Check if there's an active interactive subagent stack for this session
374499
thread_id = message.metadata.get("thread_id")
375500
session_key = make_session_key(

src/ash/rpc/methods/capability.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,39 @@ async def capability_auth_complete(params: dict[str, Any]) -> dict[str, Any]:
127127
raise ValueError(f"{e.code}: {e}") from e
128128
return {"ok": bool(result.get("ok")), "account_ref": result["account_ref"]}
129129

130+
async def capability_auth_complete_callback(
131+
params: dict[str, Any],
132+
) -> dict[str, Any]:
133+
user_id = _required_text(params, "user_id")
134+
callback_url = _optional_text(params, "callback_url")
135+
code = _optional_text(params, "code")
136+
capability_id = _optional_text(params, "capability")
137+
account_hint = _optional_text(params, "account_hint")
138+
try:
139+
result = await manager.auth_complete_callback(
140+
user_id=user_id,
141+
callback_url=callback_url,
142+
code=code,
143+
capability_id=capability_id,
144+
account_hint=account_hint,
145+
chat_id=_optional_text(params, "chat_id"),
146+
chat_type=_optional_text(params, "chat_type"),
147+
provider=_optional_text(params, "provider"),
148+
thread_id=_optional_text(params, "thread_id"),
149+
session_key=_optional_text(params, "session_key"),
150+
source_username=_optional_text(params, "source_username"),
151+
source_display_name=_optional_text(params, "source_display_name"),
152+
)
153+
except CapabilityError as e:
154+
raise ValueError(f"{e.code}: {e}") from e
155+
return {
156+
"ok": bool(result.get("ok")),
157+
"account_ref": result["account_ref"],
158+
"flow_id": result["flow_id"],
159+
"capability": result["capability"],
160+
"account_hint": result.get("account_hint"),
161+
}
162+
130163
async def capability_auth_poll(params: dict[str, Any]) -> dict[str, Any]:
131164
flow_id = _required_text(params, "flow_id")
132165
user_id = _required_text(params, "user_id")
@@ -150,6 +183,9 @@ async def capability_auth_poll(params: dict[str, Any]) -> dict[str, Any]:
150183
server.register("capability.auth.begin", capability_auth_begin)
151184
server.register("capability.auth.list", capability_auth_list)
152185
server.register("capability.auth.complete", capability_auth_complete)
186+
server.register(
187+
"capability.auth.complete_callback", capability_auth_complete_callback
188+
)
153189
server.register("capability.auth.poll", capability_auth_poll)
154190

155191
logger.debug("Registered capability RPC methods")

0 commit comments

Comments
 (0)