Skip to content

Commit 372cb9a

Browse files
dcramerclaude
andcommitted
Implement device code auth flow (RFC 8628) for Google capabilities
Add device code grant support across the capability stack so Google auth actually works instead of returning a placeholder URL. When GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET are configured, the bridge requests a device code from Google, returns a verification URL + user code, and polls for completion. Falls back to the legacy authorization_code placeholder when credentials are not configured. - Add flow_type, user_code, poll_interval_seconds to CapabilityAuthBeginResult - Add CapabilityAuthPollResult type and auth_poll to CapabilityProvider protocol - Add CapabilityManager.auth_poll with flow validation and account storage - Add SubprocessCapabilityProvider.auth_poll bridge dispatch - Register capability.auth.poll RPC method - Add ash-sb capability auth poll CLI command with blocking --timeout mode - Implement real device code flow in gogcli_bridge with scope mapping, token polling, and token refresh before operations - Update Google SKILL.md for device code UX workflow Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 22cb02e commit 372cb9a

12 files changed

Lines changed: 683 additions & 16 deletions

File tree

packages/ash-sandbox-cli/src/ash_sandbox_cli/commands/capability.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def auth_begin(
145145
typer.echo(f"Started capability auth flow (flow_id={result.get('flow_id', '?')})")
146146
typer.echo(f" Capability: {capability}")
147147
typer.echo(f" Auth URL: {result.get('auth_url', '')}")
148+
flow_type = result.get("flow_type", "authorization_code")
149+
typer.echo(f" Flow type: {flow_type}")
150+
if result.get("user_code"):
151+
typer.echo(f" User code: {result['user_code']}")
152+
if result.get("poll_interval_seconds") is not None:
153+
typer.echo(f" Poll interval: {result['poll_interval_seconds']}s")
148154
typer.echo(f" Expires: {result.get('expires_at', '')}")
149155

150156

@@ -184,4 +190,61 @@ def auth_complete(
184190
)
185191

186192

193+
@auth_app.command("poll")
194+
def auth_poll(
195+
flow_id: Annotated[
196+
str,
197+
typer.Option("--flow-id", help="Auth flow id from auth-begin"),
198+
],
199+
timeout: Annotated[
200+
int | None,
201+
typer.Option(
202+
"--timeout", help="Block and poll until complete or timeout (seconds)"
203+
),
204+
] = None,
205+
interval: Annotated[
206+
int | None,
207+
typer.Option("--interval", help="Override poll interval (seconds)"),
208+
] = None,
209+
) -> None:
210+
"""Poll a device code auth flow for completion."""
211+
import time
212+
213+
params: dict[str, Any] = {"flow_id": flow_id}
214+
result = _call("capability.auth.poll", params)
215+
216+
if timeout is None:
217+
# Single poll
218+
if result.get("ok"):
219+
typer.echo(
220+
f"Capability auth completed "
221+
f"(flow_id={flow_id}, account_ref={result.get('account_ref', '')})"
222+
)
223+
else:
224+
retry = result.get("retry_after_seconds", 5)
225+
typer.echo(f"Auth flow pending (flow_id={flow_id}, retry_after={retry}s)")
226+
return
227+
228+
# Blocking poll loop
229+
deadline = time.monotonic() + max(1, timeout)
230+
while True:
231+
if result.get("ok"):
232+
typer.echo(
233+
f"Capability auth completed "
234+
f"(flow_id={flow_id}, account_ref={result.get('account_ref', '')})"
235+
)
236+
return
237+
238+
retry_after = result.get("retry_after_seconds") or 5
239+
sleep_seconds = interval if interval is not None else retry_after
240+
sleep_seconds = max(1, sleep_seconds)
241+
242+
if time.monotonic() + sleep_seconds > deadline:
243+
typer.echo(f"Auth flow timed out (flow_id={flow_id})", err=True)
244+
raise typer.Exit(1)
245+
246+
time.sleep(sleep_seconds)
247+
result = _call("capability.auth.poll", params)
248+
249+
187250
app.add_typer(auth_app, name="auth")

src/ash/capabilities/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ash.capabilities.providers import (
2323
CapabilityAuthBeginResult,
2424
CapabilityAuthCompleteResult,
25+
CapabilityAuthPollResult,
2526
CapabilityCallContext,
2627
CapabilityProvider,
2728
SubprocessCapabilityProvider,
@@ -41,6 +42,7 @@
4142
"CapabilityCallContext",
4243
"CapabilityAuthBeginResult",
4344
"CapabilityAuthCompleteResult",
45+
"CapabilityAuthPollResult",
4446
"CapabilityAccount",
4547
"CapabilityAuthFlow",
4648
"CapabilityDefinition",

src/ash/capabilities/manager.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ash.capabilities.providers import (
1515
CapabilityAuthBeginResult,
1616
CapabilityAuthCompleteResult,
17+
CapabilityAuthPollResult,
1718
CapabilityCallContext,
1819
CapabilityProvider,
1920
)
@@ -271,6 +272,7 @@ async def auth_begin(
271272
expires_at = begin_result.expires_at or (
272273
datetime.now(UTC) + timedelta(seconds=self._auth_flow_ttl_seconds)
273274
)
275+
flow_type = begin_result.flow_type or "authorization_code"
274276
async with self._lock:
275277
self._auth_flows[flow_id] = CapabilityAuthFlow(
276278
flow_id=flow_id,
@@ -279,13 +281,20 @@ async def auth_begin(
279281
account_hint=normalized_account_hint,
280282
expires_at=expires_at,
281283
flow_state=dict(begin_result.flow_state),
284+
flow_type=flow_type,
282285
)
283286

284-
return {
287+
result: dict[str, Any] = {
285288
"flow_id": flow_id,
286289
"auth_url": begin_result.auth_url,
287290
"expires_at": expires_at.isoformat().replace("+00:00", "Z"),
291+
"flow_type": flow_type,
288292
}
293+
if begin_result.user_code is not None:
294+
result["user_code"] = begin_result.user_code
295+
if begin_result.poll_interval_seconds is not None:
296+
result["poll_interval_seconds"] = begin_result.poll_interval_seconds
297+
return result
289298

290299
async def auth_complete(
291300
self,
@@ -386,6 +395,96 @@ async def auth_complete(
386395

387396
return {"ok": True, "account_ref": account_ref}
388397

398+
async def auth_poll(
399+
self,
400+
*,
401+
flow_id: str,
402+
user_id: str,
403+
chat_id: str | None = None,
404+
chat_type: str | None = None,
405+
provider: str | None = None,
406+
thread_id: str | None = None,
407+
session_key: str | None = None,
408+
source_username: str | None = None,
409+
source_display_name: str | None = None,
410+
) -> dict[str, Any]:
411+
"""Poll a pending device code auth flow."""
412+
normalized_user_id = _required_text(
413+
value=user_id,
414+
code="capability_invalid_input",
415+
message="user_id is required",
416+
)
417+
normalized_flow_id = _required_text(
418+
value=flow_id,
419+
code="capability_invalid_input",
420+
message="flow_id is required",
421+
)
422+
423+
async with self._lock:
424+
self._prune_expired_flows_locked()
425+
flow = self._auth_flows.get(normalized_flow_id)
426+
if flow is None:
427+
raise CapabilityError(
428+
"capability_auth_flow_invalid",
429+
f"auth flow is invalid or expired: {normalized_flow_id}",
430+
)
431+
if flow.user_id != normalized_user_id:
432+
raise CapabilityError(
433+
"capability_auth_flow_invalid",
434+
"auth flow does not belong to caller",
435+
)
436+
if flow.flow_type != "device_code":
437+
raise CapabilityError(
438+
"capability_invalid_input",
439+
"auth_poll is only supported for device_code flows",
440+
)
441+
_, provider_impl = self._get_definition_and_provider_locked(
442+
flow.capability_id
443+
)
444+
445+
call_context = CapabilityCallContext(
446+
user_id=normalized_user_id,
447+
chat_id=_optional_text(chat_id),
448+
chat_type=_optional_text(chat_type),
449+
provider=_optional_text(provider),
450+
thread_id=_optional_text(thread_id),
451+
session_key=_optional_text(session_key),
452+
source_username=_optional_text(source_username),
453+
source_display_name=_optional_text(source_display_name),
454+
)
455+
poll_result = await self._provider_auth_poll(
456+
provider_impl,
457+
capability_id=flow.capability_id,
458+
flow_state=dict(flow.flow_state),
459+
context=call_context,
460+
)
461+
462+
if poll_result.status == "complete":
463+
account_ref = _required_text(
464+
value=poll_result.account_ref,
465+
code="capability_invalid_output",
466+
message="auth poll completion must return account_ref",
467+
)
468+
now = datetime.now(UTC)
469+
async with self._lock:
470+
self._accounts[(flow.user_id, flow.capability_id, account_ref)] = (
471+
CapabilityAccount(
472+
capability_id=flow.capability_id,
473+
user_id=flow.user_id,
474+
account_ref=account_ref,
475+
created_at=now,
476+
credential_material=dict(poll_result.credential_material),
477+
metadata=dict(poll_result.metadata),
478+
)
479+
)
480+
del self._auth_flows[normalized_flow_id]
481+
return {"ok": True, "account_ref": account_ref}
482+
483+
return {
484+
"status": "pending",
485+
"retry_after_seconds": poll_result.retry_after_seconds,
486+
}
487+
389488
async def invoke(
390489
self,
391490
*,
@@ -543,6 +642,28 @@ async def _provider_auth_begin(
543642
auth_url=auth_url,
544643
expires_at=result.expires_at,
545644
flow_state=dict(result.flow_state),
645+
flow_type=result.flow_type or "authorization_code",
646+
user_code=result.user_code,
647+
poll_interval_seconds=result.poll_interval_seconds,
648+
)
649+
650+
async def _provider_auth_poll(
651+
self,
652+
provider_impl: CapabilityProvider | None,
653+
*,
654+
capability_id: str,
655+
flow_state: dict[str, Any],
656+
context: CapabilityCallContext,
657+
) -> CapabilityAuthPollResult:
658+
if provider_impl is None:
659+
raise CapabilityError(
660+
"capability_invalid_input",
661+
"auth_poll requires a provider-backed capability",
662+
)
663+
return await provider_impl.auth_poll(
664+
capability_id=capability_id,
665+
flow_state=flow_state,
666+
context=context,
546667
)
547668

548669
async def _provider_auth_complete(

src/ash/capabilities/providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ash.capabilities.providers.base import (
44
CapabilityAuthBeginResult,
55
CapabilityAuthCompleteResult,
6+
CapabilityAuthPollResult,
67
CapabilityCallContext,
78
CapabilityProvider,
89
)
@@ -11,6 +12,7 @@
1112
__all__ = [
1213
"CapabilityAuthBeginResult",
1314
"CapabilityAuthCompleteResult",
15+
"CapabilityAuthPollResult",
1416
"CapabilityCallContext",
1517
"CapabilityProvider",
1618
"SubprocessCapabilityProvider",

src/ash/capabilities/providers/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ class CapabilityAuthBeginResult:
3333
auth_url: str
3434
expires_at: datetime | None = None
3535
flow_state: dict[str, Any] = field(default_factory=dict)
36+
flow_type: str = "authorization_code" # or "device_code"
37+
user_code: str | None = None
38+
poll_interval_seconds: int | None = None
3639

3740

3841
@dataclass(slots=True)
@@ -44,6 +47,17 @@ class CapabilityAuthCompleteResult:
4447
metadata: dict[str, Any] = field(default_factory=dict)
4548

4649

50+
@dataclass(slots=True)
51+
class CapabilityAuthPollResult:
52+
"""Provider response for device code auth polling."""
53+
54+
status: str # "pending" | "complete"
55+
retry_after_seconds: int | None = None
56+
account_ref: str | None = None
57+
credential_material: dict[str, Any] = field(default_factory=dict)
58+
metadata: dict[str, Any] = field(default_factory=dict)
59+
60+
4761
class CapabilityProvider(Protocol):
4862
"""Interface for capability provider backends."""
4963

@@ -85,3 +99,12 @@ async def auth_complete(
8599
context: CapabilityCallContext,
86100
) -> CapabilityAuthCompleteResult:
87101
"""Complete auth flow and return linked account result."""
102+
103+
async def auth_poll(
104+
self,
105+
*,
106+
capability_id: str,
107+
flow_state: dict[str, Any],
108+
context: CapabilityCallContext,
109+
) -> CapabilityAuthPollResult:
110+
"""Poll a device code auth flow for completion."""

src/ash/capabilities/providers/subprocess.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ash.capabilities.providers.base import (
2121
CapabilityAuthBeginResult,
2222
CapabilityAuthCompleteResult,
23+
CapabilityAuthPollResult,
2324
CapabilityCallContext,
2425
CapabilityProvider,
2526
)
@@ -90,10 +91,63 @@ async def auth_begin(
9091
code="capability_invalid_output",
9192
message="bridge auth_begin must return auth_url",
9293
)
94+
flow_type = str(result.get("flow_type") or "authorization_code").strip()
95+
raw_user_code = result.get("user_code")
96+
user_code = str(raw_user_code).strip() if raw_user_code is not None else None
97+
raw_poll_interval = result.get("poll_interval_seconds")
98+
poll_interval: int | None = None
99+
if raw_poll_interval is not None:
100+
try:
101+
poll_interval = int(raw_poll_interval)
102+
except (TypeError, ValueError):
103+
pass
93104
return CapabilityAuthBeginResult(
94105
auth_url=auth_url,
95106
expires_at=_parse_optional_datetime(result.get("expires_at")),
96107
flow_state=_as_object(result.get("flow_state"), default={}),
108+
flow_type=flow_type,
109+
user_code=user_code,
110+
poll_interval_seconds=poll_interval,
111+
)
112+
113+
async def auth_poll(
114+
self,
115+
*,
116+
capability_id: str,
117+
flow_state: dict[str, Any],
118+
context: CapabilityCallContext,
119+
) -> CapabilityAuthPollResult:
120+
result = await self._call_bridge(
121+
"auth_poll",
122+
{
123+
"capability_id": capability_id,
124+
"flow_state": dict(flow_state),
125+
"context_token": self._issue_context_token(context),
126+
},
127+
)
128+
status = _required_text(
129+
value=result.get("status"),
130+
code="capability_invalid_output",
131+
message="bridge auth_poll must return status",
132+
)
133+
raw_retry = result.get("retry_after_seconds")
134+
retry_after: int | None = None
135+
if raw_retry is not None:
136+
try:
137+
retry_after = int(raw_retry)
138+
except (TypeError, ValueError):
139+
pass
140+
account_ref = result.get("account_ref")
141+
if account_ref is not None:
142+
account_ref = str(account_ref).strip() or None
143+
return CapabilityAuthPollResult(
144+
status=status,
145+
retry_after_seconds=retry_after,
146+
account_ref=account_ref,
147+
credential_material=_as_object(
148+
result.get("credential_material"), default={}
149+
),
150+
metadata=_as_object(result.get("metadata"), default={}),
97151
)
98152

99153
async def auth_complete(

src/ash/capabilities/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CapabilityAuthFlow:
4343
account_hint: str | None
4444
expires_at: datetime
4545
flow_state: dict[str, Any] = field(default_factory=dict)
46+
flow_type: str = "authorization_code"
4647

4748

4849
@dataclass(slots=True)

0 commit comments

Comments
 (0)