Skip to content

Commit e8c3d33

Browse files
feat: add unified Logfire distributed tracing across FastAPI and Modal (#56)
Adds proper observability with trace context propagation so all spans from a single request appear together in Logfire, regardless of whether they run in FastAPI or Modal. Changes: - Add logfire to Modal base images and agent sandbox - Configure logfire with service names (policyengine-modal-uk/us, policyengine-agent) - Pass W3C traceparent from FastAPI to Modal spawn calls - Agent sandbox passes traceparent in HTTP requests back to API - Replace rich.console logging with logfire spans in all Modal functions
1 parent 298a6a5 commit e8c3d33

5 files changed

Lines changed: 939 additions & 810 deletions

File tree

src/policyengine_api/agent_sandbox.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,42 @@
1010
import requests
1111

1212
image = modal.Image.debian_slim(python_version="3.12").pip_install(
13-
"anthropic", "requests"
13+
"anthropic", "requests", "logfire[httpx]"
1414
)
1515

1616
app = modal.App("policyengine-sandbox")
1717
anthropic_secret = modal.Secret.from_name("anthropic-api-key")
18+
logfire_secrets = modal.Secret.from_name("policyengine-logfire")
19+
20+
21+
def configure_logfire(traceparent: str | None = None):
22+
"""Configure logfire with optional trace context propagation."""
23+
import os
24+
25+
import logfire
26+
27+
token = os.environ.get("LOGFIRE_TOKEN", "")
28+
if not token:
29+
return None
30+
31+
logfire.configure(
32+
service_name="policyengine-agent",
33+
token=token,
34+
environment=os.environ.get("LOGFIRE_ENVIRONMENT", "production"),
35+
console=False,
36+
)
37+
38+
# If traceparent provided, attach to the current context
39+
if traceparent:
40+
from opentelemetry.trace.propagation.tracecontext import (
41+
TraceContextTextMapPropagator,
42+
)
43+
44+
propagator = TraceContextTextMapPropagator()
45+
ctx = propagator.extract(carrier={"traceparent": traceparent})
46+
return ctx
47+
48+
return None
1849

1950
SYSTEM_PROMPT = """You are a PolicyEngine assistant that helps users understand tax and benefit policies.
2051
@@ -256,6 +287,7 @@ def execute_api_tool(
256287
tool_input: dict,
257288
api_base_url: str,
258289
log_fn: Callable,
290+
trace_headers: dict | None = None,
259291
) -> str:
260292
"""Execute an API tool by making the HTTP request."""
261293
meta = tool.get("_meta", {})
@@ -267,6 +299,8 @@ def execute_api_tool(
267299
url = f"{api_base_url}{path}"
268300
query_params = {}
269301
headers = {"Content-Type": "application/json"}
302+
if trace_headers:
303+
headers.update(trace_headers)
270304

271305
# Separate path, query, and body parameters
272306
body_data = {}
@@ -344,16 +378,26 @@ def _run_agent_impl(
344378
call_id: str = "",
345379
history: list[dict] | None = None,
346380
max_turns: int = 30,
381+
traceparent: str | None = None,
347382
) -> dict:
348383
"""Core agent implementation."""
384+
import logfire
385+
386+
# Get traceparent for HTTP requests
387+
def get_trace_headers() -> dict:
388+
if traceparent:
389+
return {"traceparent": traceparent}
390+
return {}
349391

350392
def log(msg: str) -> None:
393+
logfire.info(msg, call_id=call_id)
351394
print(msg)
352395
if call_id:
353396
try:
354397
requests.post(
355398
f"{api_base_url}/agent/log/{call_id}",
356399
json={"message": msg},
400+
headers=get_trace_headers(),
357401
timeout=5,
358402
)
359403
except Exception:
@@ -425,7 +469,9 @@ def log(msg: str) -> None:
425469
else:
426470
tool = tool_lookup.get(block.name)
427471
if tool:
428-
result = execute_api_tool(tool, block.input, api_base_url, log)
472+
result = execute_api_tool(
473+
tool, block.input, api_base_url, log, get_trace_headers()
474+
)
429475
else:
430476
result = f"Unknown tool: {block.name}"
431477

@@ -457,6 +503,7 @@ def log(msg: str) -> None:
457503
requests.post(
458504
f"{api_base_url}/agent/complete/{call_id}",
459505
json=result,
506+
headers=get_trace_headers(),
460507
timeout=10,
461508
)
462509
except Exception:
@@ -465,22 +512,29 @@ def log(msg: str) -> None:
465512
return result
466513

467514

468-
@app.function(image=image, secrets=[anthropic_secret], timeout=600)
515+
@app.function(image=image, secrets=[anthropic_secret, logfire_secrets], timeout=600)
469516
def run_agent(
470517
question: str,
471518
api_base_url: str = "https://v2.api.policyengine.org",
472519
call_id: str = "",
473520
history: list[dict] | None = None,
474521
max_turns: int = 30,
522+
traceparent: str | None = None,
475523
) -> dict:
476524
"""Run agentic loop to answer a policy question (Modal wrapper)."""
477-
return _run_agent_impl(
478-
question,
479-
api_base_url,
480-
call_id,
481-
history=history,
482-
max_turns=max_turns,
483-
)
525+
import logfire
526+
527+
ctx = configure_logfire(traceparent)
528+
529+
with logfire.span("run_agent", call_id=call_id, question=question[:200], _context=ctx):
530+
return _run_agent_impl(
531+
question,
532+
api_base_url,
533+
call_id,
534+
history=history,
535+
max_turns=max_turns,
536+
traceparent=traceparent,
537+
)
484538

485539

486540
if __name__ == "__main__":

src/policyengine_api/api/agent.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212

1313
import logfire
1414
from fastapi import APIRouter, HTTPException
15+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1516
from pydantic import BaseModel
1617

1718
from policyengine_api.config import settings
1819

20+
21+
def get_traceparent() -> str | None:
22+
"""Get the current W3C traceparent header for distributed tracing."""
23+
carrier: dict[str, str] = {}
24+
TraceContextTextMapPropagator().inject(carrier)
25+
return carrier.get("traceparent")
26+
1927
router = APIRouter(prefix="/agent", tags=["agent"])
2028

2129

@@ -126,9 +134,12 @@ async def run_agent(request: RunRequest) -> RunResponse:
126134
# Production: use Modal
127135
import modal
128136

137+
traceparent = get_traceparent()
129138
run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent")
130139
history_dicts = [{"role": m.role, "content": m.content} for m in request.history]
131-
call = run_fn.spawn(request.question, api_base_url, call_id, history_dicts)
140+
call = run_fn.spawn(
141+
request.question, api_base_url, call_id, history_dicts, traceparent=traceparent
142+
)
132143

133144
_calls[call_id] = {
134145
"call": call,
@@ -137,6 +148,7 @@ async def run_agent(request: RunRequest) -> RunResponse:
137148
"started_at": datetime.utcnow().isoformat(),
138149
"status": "running",
139150
"result": None,
151+
"trace_id": traceparent, # Store for linking
140152
}
141153
logfire.info("agent_spawned", call_id=call_id, modal_call_id=call.object_id)
142154
else:

src/policyengine_api/api/analysis.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import logfire
2323
from fastapi import APIRouter, Depends, HTTPException
24+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
2425
from pydantic import BaseModel, Field
2526
from sqlmodel import Session, select
2627

@@ -40,6 +41,13 @@
4041
from policyengine_api.services.database import get_session
4142

4243

44+
def get_traceparent() -> str | None:
45+
"""Get the current W3C traceparent header for distributed tracing."""
46+
carrier: dict[str, str] = {}
47+
TraceContextTextMapPropagator().inject(carrier)
48+
return carrier.get("traceparent")
49+
50+
4351
def _safe_float(value: float | None) -> float | None:
4452
"""Convert NaN/inf to None for JSON serialization."""
4553
if value is None:
@@ -522,6 +530,8 @@ def _trigger_economy_comparison(
522530
"""Trigger economy comparison analysis (local or Modal)."""
523531
from policyengine_api.config import settings
524532

533+
traceparent = get_traceparent()
534+
525535
if not settings.agent_use_modal and session is not None:
526536
# Run locally
527537
if tax_benefit_model_name == "policyengine_uk":
@@ -531,7 +541,7 @@ def _trigger_economy_comparison(
531541
import modal
532542

533543
fn = modal.Function.from_name("policyengine", "economy_comparison_us")
534-
fn.spawn(job_id=job_id)
544+
fn.spawn(job_id=job_id, traceparent=traceparent)
535545
else:
536546
# Use Modal
537547
import modal
@@ -541,7 +551,7 @@ def _trigger_economy_comparison(
541551
else:
542552
fn = modal.Function.from_name("policyengine", "economy_comparison_us")
543553

544-
fn.spawn(job_id=job_id)
554+
fn.spawn(job_id=job_id, traceparent=traceparent)
545555

546556

547557
@router.post("/economic-impact", response_model=EconomicImpactResponse)

src/policyengine_api/api/household.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import logfire
1111
from fastapi import APIRouter, Depends, HTTPException
12+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1213
from pydantic import BaseModel, Field
1314
from sqlmodel import Session
1415

@@ -20,6 +21,13 @@
2021
)
2122
from policyengine_api.services.database import get_session
2223

24+
25+
def get_traceparent() -> str | None:
26+
"""Get the current W3C traceparent header for distributed tracing."""
27+
carrier: dict[str, str] = {}
28+
TraceContextTextMapPropagator().inject(carrier)
29+
return carrier.get("traceparent")
30+
2331
router = APIRouter(prefix="/household", tags=["household"])
2432

2533

@@ -400,6 +408,8 @@ def _trigger_modal_household(
400408
# Use Modal
401409
import modal
402410

411+
traceparent = get_traceparent()
412+
403413
if request.tax_benefit_model_name == "policyengine_uk":
404414
fn = modal.Function.from_name("policyengine", "simulate_household_uk")
405415
fn.spawn(
@@ -410,6 +420,7 @@ def _trigger_modal_household(
410420
year=request.year or 2026,
411421
policy_data=policy_data,
412422
dynamic_data=dynamic_data,
423+
traceparent=traceparent,
413424
)
414425
else:
415426
fn = modal.Function.from_name("policyengine", "simulate_household_us")
@@ -424,6 +435,7 @@ def _trigger_modal_household(
424435
year=request.year or 2024,
425436
policy_data=policy_data,
426437
dynamic_data=dynamic_data,
438+
traceparent=traceparent,
427439
)
428440

429441

0 commit comments

Comments
 (0)