Skip to content

Commit 64e43a8

Browse files
fix: run modal sandbox in thread pool to avoid blocking event loop (#25)
* fix: run modal sandbox in thread pool to avoid blocking event loop The _stream_modal_sandbox function was running synchronous Modal SDK calls inside an async generator, blocking the entire FastAPI app. Also renamed DEMO_USE_MODAL → AGENT_USE_MODAL in env files. * fix: move imports to top of file in analysis.py
1 parent 3201b67 commit 64e43a8

4 files changed

Lines changed: 84 additions & 26 deletions

File tree

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ MODAL_TOKEN_SECRET=as-...
2727

2828
# Demo agent
2929
ANTHROPIC_API_KEY=sk-ant-...
30-
DEMO_USE_MODAL=false
30+
AGENT_USE_MODAL=false
3131
POLICYENGINE_API_URL=http://localhost:8000

src/policyengine_api/agent_sandbox.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def run_claude_code_in_sandbox(
8282
return sb, process
8383

8484

85-
@app.function(image=sandbox_image, secrets=[anthropic_secret, logfire_secret], timeout=600)
85+
@app.function(
86+
image=sandbox_image, secrets=[anthropic_secret, logfire_secret], timeout=600
87+
)
8688
def run_policy_analysis(
8789
question: str, api_base_url: str = "https://v2.api.policyengine.org"
8890
) -> dict:
@@ -98,11 +100,15 @@ def run_policy_analysis(
98100

99101
logfire.configure(service_name="policyengine-agent-sandbox")
100102

101-
with logfire.span("run_policy_analysis", question=question[:100], api_base_url=api_base_url):
103+
with logfire.span(
104+
"run_policy_analysis", question=question[:100], api_base_url=api_base_url
105+
):
102106
# Write MCP config
103107
os.makedirs("/root/.claude", exist_ok=True)
104108
mcp_config = {
105-
"mcpServers": {"policyengine": {"type": "sse", "url": f"{api_base_url}/mcp"}}
109+
"mcpServers": {
110+
"policyengine": {"type": "sse", "url": f"{api_base_url}/mcp"}
111+
}
106112
}
107113
with open("/root/.claude/mcp_servers.json", "w") as f:
108114
json.dump(mcp_config, f)

src/policyengine_api/api/agent.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,38 +86,91 @@ async def _stream_claude_code(question: str, api_base_url: str):
8686

8787
async def _stream_modal_sandbox(question: str, api_base_url: str):
8888
"""Stream output from Claude Code running in Modal Sandbox."""
89+
from concurrent.futures import ThreadPoolExecutor
90+
8991
import logfire
9092

9193
sb = None
94+
executor = ThreadPoolExecutor(max_workers=1)
9295
try:
9396
from policyengine_api.agent_sandbox import run_claude_code_in_sandbox
9497

95-
logfire.info("Creating Modal sandbox", question=question[:100], api_base_url=api_base_url)
96-
sb, process = run_claude_code_in_sandbox(question, api_base_url)
98+
logfire.info(
99+
"Creating Modal sandbox", question=question[:100], api_base_url=api_base_url
100+
)
101+
102+
# Run blocking Modal SDK calls in thread pool to avoid blocking event loop
103+
loop = asyncio.get_event_loop()
104+
sb, process = await loop.run_in_executor(
105+
executor, run_claude_code_in_sandbox, question, api_base_url
106+
)
97107
logfire.info("Modal sandbox created, streaming output")
98108

99-
# Stream stdout line by line
100-
for line in process.stdout:
101-
yield f"data: {json.dumps({'type': 'output', 'content': line})}\n\n"
109+
# Poll for lines with timeout to allow other async tasks
110+
import queue
111+
import threading
102112

103-
process.wait()
113+
line_queue = queue.Queue()
104114

105-
if process.returncode != 0:
106-
stderr = process.stderr.read()
107-
logfire.error("Claude Code failed in sandbox", returncode=process.returncode, stderr=stderr[:500])
108-
yield f"data: {json.dumps({'type': 'error', 'content': stderr})}\n\n"
115+
def stream_reader():
116+
try:
117+
for line in process.stdout:
118+
line_queue.put(("line", line))
119+
process.wait()
120+
if process.returncode != 0:
121+
stderr = process.stderr.read()
122+
line_queue.put(("error", (process.returncode, stderr)))
123+
else:
124+
line_queue.put(("done", process.returncode))
125+
except Exception as e:
126+
line_queue.put(("exception", str(e)))
127+
128+
reader_thread = threading.Thread(target=stream_reader, daemon=True)
129+
reader_thread.start()
130+
131+
while True:
132+
try:
133+
# Non-blocking check with short timeout
134+
item = await loop.run_in_executor(
135+
executor, lambda: line_queue.get(timeout=0.1)
136+
)
137+
event_type, data = item
138+
139+
if event_type == "line":
140+
yield f"data: {json.dumps({'type': 'output', 'content': data})}\n\n"
141+
elif event_type == "error":
142+
returncode, stderr = data
143+
logfire.error(
144+
"Claude Code failed in sandbox",
145+
returncode=returncode,
146+
stderr=stderr[:500],
147+
)
148+
yield f"data: {json.dumps({'type': 'error', 'content': stderr})}\n\n"
149+
yield f"data: {json.dumps({'type': 'done', 'returncode': returncode})}\n\n"
150+
break
151+
elif event_type == "done":
152+
yield f"data: {json.dumps({'type': 'done', 'returncode': data})}\n\n"
153+
break
154+
elif event_type == "exception":
155+
raise Exception(data)
156+
except Exception as e:
157+
if "Empty" in type(e).__name__:
158+
# Queue timeout, continue polling
159+
await asyncio.sleep(0)
160+
continue
161+
raise
109162

110-
yield f"data: {json.dumps({'type': 'done', 'returncode': process.returncode})}\n\n"
111163
except Exception as e:
112164
logfire.exception("Modal sandbox failed", error=str(e))
113165
yield f"data: {json.dumps({'type': 'error', 'content': f'Sandbox error: {str(e)}'})}\n\n"
114166
yield f"data: {json.dumps({'type': 'done', 'returncode': 1})}\n\n"
115167
finally:
116168
if sb is not None:
117169
try:
118-
sb.terminate()
170+
await loop.run_in_executor(executor, sb.terminate)
119171
except Exception:
120172
pass
173+
executor.shutdown(wait=False)
121174

122175

123176
@router.post("/stream")

src/policyengine_api/api/analysis.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,6 @@
2424
from pydantic import BaseModel, Field
2525
from sqlmodel import Session, select
2626

27-
28-
def _safe_float(value: float | None) -> float | None:
29-
"""Convert NaN/inf to None for JSON serialization."""
30-
if value is None:
31-
return None
32-
if math.isnan(value) or math.isinf(value):
33-
return None
34-
return value
35-
36-
3727
from policyengine_api.models import (
3828
Dataset,
3929
DecileImpact,
@@ -49,6 +39,15 @@ def _safe_float(value: float | None) -> float | None:
4939
)
5040
from policyengine_api.services.database import get_session
5141

42+
43+
def _safe_float(value: float | None) -> float | None:
44+
"""Convert NaN/inf to None for JSON serialization."""
45+
if value is None:
46+
return None
47+
if math.isnan(value) or math.isinf(value):
48+
return None
49+
return value
50+
5251
# Namespace for deterministic UUIDs
5352
SIMULATION_NAMESPACE = UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
5453
REPORT_NAMESPACE = UUID("b2c3d4e5-f6a7-8901-bcde-f12345678901")

0 commit comments

Comments
 (0)