11from __future__ import annotations
22from fastapi import FastAPI , HTTPException
33from fastapi .middleware .cors import CORSMiddleware
4- from fastapi .encoders import jsonable_encoder
54import os , json , re
65from threading import Lock
76from pydantic import BaseModel
87from .costs import estimate_prediction_cost
98from importlib .metadata import version as _pkg_version , PackageNotFoundError
109from .config import configure_lm
1110from .agent import MicroAgent
12- from .runtime import dump_trace , new_trace_id
11+ from .runtime import dump_trace , new_trace_id , to_jsonable
1312from .logging_setup import setup_logging
1413
1514app = FastAPI (title = "DSPy Micro Agent" )
15+ origins_env = os .getenv ("MICRO_AGENT_CORS_ORIGINS" , "*" ).strip ()
16+ if origins_env == "*" :
17+ allow_origins = ["*" ]
18+ allow_credentials = False
19+ else :
20+ allow_origins = [o .strip () for o in origins_env .split ("," ) if o .strip ()]
21+ allow_credentials = os .getenv ("MICRO_AGENT_CORS_CREDENTIALS" , "0" ).strip ().lower () in {"1" , "true" , "yes" , "on" }
1622app .add_middleware (
1723 CORSMiddleware ,
18- allow_origins = [ "*" ] ,
19- allow_credentials = True ,
24+ allow_origins = allow_origins ,
25+ allow_credentials = allow_credentials ,
2026 allow_methods = ["*" ],
2127 allow_headers = ["*" ],
2228)
@@ -63,7 +69,9 @@ def ask(req: AskRequest):
6369 raise HTTPException (status_code = 400 , detail = "max_steps must be between 1 and 20" )
6470
6571 def _call_agent ():
66- agent = _agent if req .use_tool_calls is None and req .max_steps == _agent .max_steps else MicroAgent (max_steps = req .max_steps , use_tool_calls = req .use_tool_calls )
72+ if _serialize and req .use_tool_calls is None and req .max_steps == _agent .max_steps :
73+ return _agent (question )
74+ agent = MicroAgent (max_steps = req .max_steps , use_tool_calls = req .use_tool_calls , use_global_trace = _serialize )
6775 return agent (question )
6876
6977 if _serialize :
@@ -76,7 +84,7 @@ def _call_agent():
7684 usage = getattr (pred , "usage" , {}) or {}
7785 est = estimate_prediction_cost (question , pred .trace , pred .answer , usage )
7886 path = dump_trace (trace_id , question , pred .trace , pred .answer , usage = usage , cost_usd = est .get ("cost_usd" ))
79- steps = jsonable_encoder (pred .trace )
87+ steps = to_jsonable (pred .trace )
8088 return AskResponse (answer = pred .answer , trace_id = trace_id , trace_path = path , steps = steps , usage = usage , cost_usd = est .get ("cost_usd" ))
8189
8290@app .get ("/trace/{trace_id}" )
0 commit comments