Skip to content

Commit 62d779a

Browse files
committed
remote
1 parent e45ed3a commit 62d779a

2 files changed

Lines changed: 59 additions & 15 deletions

File tree

anton/chat.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -443,26 +443,38 @@ async def _handle_remote(
443443
status = result.get("status", "")
444444

445445
# Poll until ready — 5s intervals, 3 min max
446+
# Use /resolve to get the direct IP, then poll /health on it
446447
if status in ("provisioning", "starting"):
447448
max_wait = 180
448449
poll_interval = 5
449450
start_time = time.time()
451+
direct_endpoint = None
450452

451453
with Live(Spinner("dots", text=" Waiting for instance to be ready...", style="anton.cyan"), console=console, transient=True):
452454
while time.time() - start_time < max_wait:
453455
await asyncio.sleep(poll_interval)
454456
try:
455-
req = Request(
456-
f"{endpoint}/health",
457-
headers={
458-
"Authorization": f"Bearer {api_key}",
459-
"User-Agent": "anton/1.0",
460-
},
461-
)
462-
with urlopen(req, timeout=5) as resp:
463-
health = json.loads(resp.read().decode())
464-
if health.get("status") == "ok":
465-
break
457+
# First resolve the direct IP via Cloudflare Worker
458+
if not direct_endpoint:
459+
req = Request(
460+
f"{endpoint}/resolve",
461+
headers={"Authorization": f"Bearer {api_key}", "User-Agent": "anton/1.0"},
462+
)
463+
with urlopen(req, timeout=5) as resp:
464+
resolve_data = json.loads(resp.read().decode())
465+
if resolve_data.get("status") == "running":
466+
direct_endpoint = resolve_data.get("endpoint", "")
467+
468+
# Then check health directly
469+
if direct_endpoint:
470+
req = Request(
471+
f"{direct_endpoint}/health",
472+
headers={"Authorization": f"Bearer {api_key}", "User-Agent": "anton/1.0"},
473+
)
474+
with urlopen(req, timeout=5) as resp:
475+
health = json.loads(resp.read().decode())
476+
if health.get("status") == "ok":
477+
break
466478
except Exception:
467479
pass
468480
else:

anton/core/backends/remote.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,42 @@ def __init__(
4040
cells=cells,
4141
workspace_path=workspace_path,
4242
)
43-
self._endpoint = endpoint_url.rstrip("/")
43+
self._cloudflare_endpoint = endpoint_url.rstrip("/") # https://sp-xxx.4nton.ai
44+
self._direct_endpoint: str | None = None # resolved to http://IP:port
4445
self._api_key = api_key
4546

4647
# ------------------------------------------------------------------
4748
# HTTP helpers
4849
# ------------------------------------------------------------------
4950

51+
async def _resolve_endpoint(self) -> str:
52+
"""Resolve the Cloudflare endpoint to a direct IP endpoint.
53+
54+
Calls /resolve on the Cloudflare Worker which returns the instance's
55+
direct IP. Caches the result for subsequent calls.
56+
"""
57+
if self._direct_endpoint:
58+
return self._direct_endpoint
59+
60+
import aiohttp
61+
62+
url = f"{self._cloudflare_endpoint}/resolve"
63+
async with aiohttp.ClientSession() as session:
64+
async with session.get(
65+
url, headers=self._headers(), timeout=aiohttp.ClientTimeout(total=15)
66+
) as resp:
67+
if resp.status >= 400:
68+
text = await resp.text()
69+
raise RuntimeError(f"Failed to resolve remote scratchpad ({resp.status}): {text}")
70+
data = await resp.json()
71+
72+
endpoint = data.get("endpoint", "")
73+
if not endpoint:
74+
raise RuntimeError(f"No endpoint returned from /resolve: {data}")
75+
76+
self._direct_endpoint = endpoint.rstrip("/")
77+
return self._direct_endpoint
78+
5079
def _headers(self) -> dict[str, str]:
5180
return {
5281
"Authorization": f"Bearer {self._api_key}",
@@ -58,7 +87,8 @@ async def _post(self, path: str, body: dict | None = None) -> dict:
5887
"""POST to the remote service and return parsed JSON."""
5988
import aiohttp
6089

61-
url = f"{self._endpoint}{path}"
90+
endpoint = await self._resolve_endpoint()
91+
url = f"{endpoint}{path}"
6292
async with aiohttp.ClientSession() as session:
6393
async with session.post(
6494
url, json=body or {}, headers=self._headers(), timeout=aiohttp.ClientTimeout(total=300)
@@ -72,7 +102,8 @@ async def _get(self, path: str, params: dict | None = None) -> dict:
72102
"""GET from the remote service and return parsed JSON."""
73103
import aiohttp
74104

75-
url = f"{self._endpoint}{path}"
105+
endpoint = await self._resolve_endpoint()
106+
url = f"{endpoint}{path}"
76107
async with aiohttp.ClientSession() as session:
77108
async with session.get(
78109
url, params=params, headers=self._headers(), timeout=aiohttp.ClientTimeout(total=30)
@@ -86,7 +117,8 @@ async def _sse(self, path: str, body: dict) -> AsyncIterator[dict]:
86117
"""POST to an SSE endpoint and yield parsed events."""
87118
import aiohttp
88119

89-
url = f"{self._endpoint}{path}"
120+
endpoint = await self._resolve_endpoint()
121+
url = f"{endpoint}{path}"
90122
async with aiohttp.ClientSession() as session:
91123
async with session.post(
92124
url, json=body, headers=self._headers(), timeout=aiohttp.ClientTimeout(total=600)

0 commit comments

Comments
 (0)