Skip to content

Commit bce982d

Browse files
committed
commit for trail proxy setup
1 parent f162be3 commit bce982d

4 files changed

Lines changed: 166 additions & 33 deletions

File tree

eval_protocol/proxy/proxy_core/app.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,39 @@ async def require_auth(request: Request) -> None:
173173
# =====================
174174
# Chat completion routes
175175
# =====================
176+
177+
# ============ Trail Management System Routes (New) ============
178+
@app.post("/trails/{trail_id}/chat/completions")
179+
@app.post("/v1/trails/{trail_id}/chat/completions")
180+
@app.post("/trails/{trail_id}/v1/chat/completions")
181+
@app.post("/project_id/{project_id}/trails/{trail_id}/chat/completions")
182+
@app.post("/v1/project_id/{project_id}/trails/{trail_id}/chat/completions")
183+
async def trail_chat_completion(
184+
trail_id: str,
185+
request: Request,
186+
project_id: Optional[str] = None,
187+
config: ProxyConfig = Depends(get_config),
188+
redis_client: redis.Redis = Depends(get_redis),
189+
_: None = Depends(require_auth),
190+
):
191+
"""
192+
Trail Management System endpoint for LLM inference tracking.
193+
194+
Automatically injects trail_id and insertion_id as Langfuse tags for tracing.
195+
All requests under the same trail_id can be queried together for analysis and training.
196+
"""
197+
params = ChatParams(
198+
project_id=project_id,
199+
trail_id=trail_id,
200+
)
201+
return await handle_chat_completion(
202+
config=config,
203+
redis_client=redis_client,
204+
request=request,
205+
params=params,
206+
)
207+
208+
# ============ Legacy Evaluation System Routes ============
176209
@app.post(
177210
"/project_id/{project_id}/rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}/chat/completions"
178211
)
@@ -246,6 +279,71 @@ async def chat_completion_with_project_only(
246279
# ===============
247280
# Traces routes
248281
# ===============
282+
283+
# ============ Trail Traces Routes (New) ============
284+
@app.get("/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
285+
@app.get("/v1/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
286+
@app.get("/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
287+
@app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces", response_model=LangfuseTracesResponse)
288+
async def get_trail_traces(
289+
trail_id: str,
290+
request: Request,
291+
params: TracesParams = Depends(get_traces_params),
292+
project_id: Optional[str] = None,
293+
config: ProxyConfig = Depends(get_config),
294+
redis_client: redis.Redis = Depends(get_redis),
295+
_: None = Depends(require_auth),
296+
) -> LangfuseTracesResponse:
297+
"""
298+
Fetch all Langfuse traces for a specific trail.
299+
300+
Waits for all expected insertion_ids to complete before returning traces.
301+
"""
302+
if project_id is not None:
303+
params.project_id = project_id
304+
# Inject trail_id tag into query parameters
305+
if params.tags is None:
306+
params.tags = []
307+
params.tags.append(f"trail_id:{trail_id}")
308+
return await fetch_langfuse_traces(
309+
config=config,
310+
redis_client=redis_client,
311+
request=request,
312+
params=params,
313+
)
314+
315+
@app.get("/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
316+
@app.get("/v1/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
317+
@app.get("/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
318+
@app.get("/v1/project_id/{project_id}/trails/{trail_id}/traces/pointwise", response_model=LangfuseTracesResponse)
319+
async def get_trail_pointwise_trace(
320+
trail_id: str,
321+
request: Request,
322+
params: TracesParams = Depends(get_traces_params),
323+
project_id: Optional[str] = None,
324+
config: ProxyConfig = Depends(get_config),
325+
redis_client: redis.Redis = Depends(get_redis),
326+
_: None = Depends(require_auth),
327+
) -> LangfuseTracesResponse:
328+
"""
329+
Fetch the latest trace for a trail (UUID7 time-ordered).
330+
331+
Returns only the most recent trace, useful for real-time monitoring.
332+
"""
333+
if project_id is not None:
334+
params.project_id = project_id
335+
# Inject trail_id tag into query parameters
336+
if params.tags is None:
337+
params.tags = []
338+
params.tags.append(f"trail_id:{trail_id}")
339+
return await pointwise_fetch_langfuse_trace(
340+
config=config,
341+
redis_client=redis_client,
342+
request=request,
343+
params=params,
344+
)
345+
346+
# ============ Legacy Traces Routes ============
249347
@app.get("/traces", response_model=LangfuseTracesResponse)
250348
@app.get("/v1/traces", response_model=LangfuseTracesResponse)
251349
@app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse)

eval_protocol/proxy/proxy_core/langfuse.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,16 @@ async def _fetch_trace_list_with_retry(
7676
) -> Any:
7777
"""Fetch trace list with rate limit retry logic."""
7878
list_retries = 0
79-
rollout_id: Optional[str] = None
79+
tracking_key: Optional[str] = None # Could be rollout_id or trail_id
8080
if tags:
8181
for t in tags:
82-
if isinstance(t, str) and t.startswith("rollout_id:"):
83-
rollout_id = t.split(":", 1)[1] if ":" in t else t
84-
break
82+
if isinstance(t, str):
83+
if t.startswith("rollout_id:"):
84+
tracking_key = t.split(":", 1)[1] if ":" in t else t
85+
break
86+
elif t.startswith("trail_id:"):
87+
tracking_key = t.split(":", 1)[1] if ":" in t else t
88+
break
8589
while list_retries < max_retries:
8690
try:
8791
traces = langfuse_client.api.trace.list(
@@ -124,17 +128,17 @@ async def _fetch_trace_list_with_retry(
124128
# Return 404 if we've retried max_retries
125129
# TODO: write some tests around proxy exception handling
126130
logger.error(
127-
"Failed to fetch trace list after %d retries (rollout_id=%s): %s",
131+
"Failed to fetch trace list after %d retries (tracking_key=%s): %s",
128132
max_retries,
129-
rollout_id,
133+
tracking_key,
130134
e,
131135
)
132136
raise HTTPException(
133137
status_code=404, detail=f"Failed to fetch traces after {max_retries} retries: {str(e)}"
134138
)
135139
else:
136140
# Catch all other exceptions
137-
logger.error("Failed to fetch trace list (rollout_id=%s): %s", rollout_id, e)
141+
logger.error("Failed to fetch trace list (tracking_key=%s): %s", tracking_key, e)
138142
raise HTTPException(status_code=500, detail=f"Failed to fetch traces: {str(e)}")
139143

140144

@@ -247,16 +251,16 @@ async def fetch_langfuse_traces(
247251

248252
# Get expected insertion_ids from Redis for completeness checking
249253
expected_ids: Set[str] = set()
250-
if rollout_id:
251-
expected_ids = get_insertion_ids(redis_client, rollout_id)
252-
logger.info(f"Fetching traces for rollout_id '{rollout_id}', expecting {len(expected_ids)} insertion_ids")
254+
if tracking_key:
255+
expected_ids = get_insertion_ids(redis_client, tracking_key)
256+
logger.info(f"Fetching traces for {tracking_label} '{tracking_key}', expecting {len(expected_ids)} insertion_ids")
253257
if not expected_ids:
254258
logger.warning(
255-
f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces."
259+
f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces."
256260
)
257261
raise HTTPException(
258262
status_code=500,
259-
detail=f"No expected insertion_ids found in Redis for rollout '{rollout_id}'. Returning empty traces.",
263+
detail=f"No expected insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Returning empty traces.",
260264
)
261265

262266
# Track all traces we've collected across retry attempts
@@ -265,15 +269,15 @@ async def fetch_langfuse_traces(
265269
insertion_ids: Set[str] = set() # Insertion IDs extracted from traces (for completeness check)
266270

267271
for retry in range(max_retries):
268-
# On first attempt, use rollout_id tag. On retries, target missing insertion_ids
272+
# On first attempt, use tracking tag. On retries, target missing insertion_ids
269273
if retry == 0:
270274
fetch_tags = tags
271275
else:
272276
# Build targeted tags for missing insertion_ids
273277
missing_ids = expected_ids - insertion_ids
274278
fetch_tags = [f"insertion_id:{id}" for id in missing_ids]
275279
logger.info(
276-
f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for rollout '{rollout_id}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}"
280+
f"Retry {retry}: Targeting {len(fetch_tags)} missing insertion_ids for {tracking_label} '{tracking_key}' (last5): {[id[-5:] for id in sorted(missing_ids)[:10]]}{'...' if len(missing_ids) > 10 else ''}"
277281
)
278282

279283
current_page = 1
@@ -329,7 +333,7 @@ async def fetch_langfuse_traces(
329333
insertion_id = _extract_tag_value(trace_dict.get("tags", []), "insertion_id:")
330334
if insertion_id:
331335
insertion_ids.add(insertion_id)
332-
logger.debug(f"Found insertion_id '{insertion_id}' for rollout '{rollout_id}'")
336+
logger.debug(f"Found insertion_id '{insertion_id}' for {tracking_label} '{tracking_key}'")
333337

334338
except Exception as e:
335339
logger.warning("Failed to serialize trace %s: %s", trace_info.id, e)
@@ -349,7 +353,7 @@ async def fetch_langfuse_traces(
349353
# If we have all expected completions or more, return traces. At least once is ok.
350354
if expected_ids <= insertion_ids:
351355
logger.info(
352-
f"Traces complete for rollout '{rollout_id}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces"
356+
f"Traces complete for {tracking_label} '{tracking_key}': {len(insertion_ids)}/{len(expected_ids)} insertion_ids found, returning {len(all_traces)} traces"
353357
)
354358
if sample_size is not None and len(all_traces) > sample_size:
355359
all_traces = random.sample(all_traces, sample_size)
@@ -366,16 +370,16 @@ async def fetch_langfuse_traces(
366370
wait_time = 2**retry
367371
still_missing = expected_ids - insertion_ids
368372
logger.info(
369-
f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for rollout '{rollout_id}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..."
373+
f"Attempt {retry + 1}/{max_retries}. Found {len(insertion_ids)}/{len(expected_ids)} for {tracking_label} '{tracking_key}'. Still missing (last5): {[id[-5:] for id in sorted(still_missing)[:10]]}{'...' if len(still_missing) > 10 else ''}. Waiting {wait_time}s..."
370374
)
371375
await asyncio.sleep(wait_time)
372376

373377
logger.error(
374-
f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions."
378+
f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions."
375379
)
376380
raise HTTPException(
377381
status_code=404,
378-
detail=f"Incomplete traces for rollout_id '{rollout_id}': Found {len(insertion_ids)}/{len(expected_ids)} completions.",
382+
detail=f"Incomplete traces for {tracking_label} '{tracking_key}': Found {len(insertion_ids)}/{len(expected_ids)} completions.",
379383
)
380384

381385
except ImportError:
@@ -431,8 +435,11 @@ async def pointwise_fetch_langfuse_trace(
431435
detail=f"Project ID '{project_id}' not found. Available projects: {list(config.langfuse_keys.keys())}",
432436
)
433437

434-
# Extract rollout_id from tags for Redis lookup
438+
# Extract tracking key (rollout_id or trail_id) from tags for Redis lookup
435439
rollout_id = _extract_tag_value(tags, "rollout_id:")
440+
trail_id = _extract_tag_value(tags, "trail_id:")
441+
tracking_key = trail_id if trail_id else rollout_id
442+
tracking_label = "trail_id" if trail_id else "rollout_id"
436443

437444
try:
438445
# Import the Langfuse adapter
@@ -461,23 +468,23 @@ async def pointwise_fetch_langfuse_trace(
461468

462469
# Get insertion_ids from Redis to find the latest one
463470
expected_ids: Set[str] = set()
464-
if rollout_id:
465-
expected_ids = get_insertion_ids(redis_client, rollout_id)
471+
if tracking_key:
472+
expected_ids = get_insertion_ids(redis_client, tracking_key)
466473
logger.info(
467-
f"Pointwise fetch for rollout_id '{rollout_id}', found {len(expected_ids)} insertion_ids in Redis"
474+
f"Pointwise fetch for {tracking_label} '{tracking_key}', found {len(expected_ids)} insertion_ids in Redis"
468475
)
469476
if not expected_ids:
470477
logger.warning(
471-
f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace."
478+
f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace."
472479
)
473480
raise HTTPException(
474481
status_code=500,
475-
detail=f"No insertion_ids found in Redis for rollout '{rollout_id}'. Cannot determine latest trace.",
482+
detail=f"No insertion_ids found in Redis for {tracking_label} '{tracking_key}'. Cannot determine latest trace.",
476483
)
477484

478485
# Get the latest (last) insertion_id since UUID v7 is time-ordered
479486
latest_insertion_id = max(expected_ids) # UUID v7 max = newest
480-
logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for rollout '{rollout_id}'")
487+
logger.info(f"Targeting latest insertion_id: {latest_insertion_id} for {tracking_label} '{tracking_key}'")
481488

482489
for retry in range(max_retries):
483490
# Fetch trace list targeting the latest insertion_id
@@ -513,7 +520,7 @@ async def pointwise_fetch_langfuse_trace(
513520
if trace_full:
514521
trace_dict = _serialize_trace_to_dict(trace_full)
515522
logger.info(
516-
f"Successfully fetched latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id}"
523+
f"Successfully fetched latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}"
517524
)
518525
return LangfuseTracesResponse(
519526
project_id=project_id,
@@ -525,17 +532,17 @@ async def pointwise_fetch_langfuse_trace(
525532
if retry < max_retries - 1:
526533
wait_time = 2**retry
527534
logger.info(
528-
f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for rollout '{rollout_id}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..."
535+
f"Pointwise fetch attempt {retry + 1}/{max_retries} failed for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id}. Retrying in {wait_time}s..."
529536
)
530537
await asyncio.sleep(wait_time)
531538

532539
# After all retries failed
533540
logger.error(
534-
f"Failed to fetch latest trace for rollout '{rollout_id}', insertion_id: {latest_insertion_id} after {max_retries} retries"
541+
f"Failed to fetch latest trace for {tracking_label} '{tracking_key}', insertion_id: {latest_insertion_id} after {max_retries} retries"
535542
)
536543
raise HTTPException(
537544
status_code=404,
538-
detail=f"Failed to fetch latest trace for rollout '{rollout_id}' after {max_retries} retries",
545+
detail=f"Failed to fetch latest trace for {tracking_label} '{tracking_key}' after {max_retries} retries",
539546
)
540547

541548
except ImportError:

eval_protocol/proxy/proxy_core/litellm.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ async def handle_chat_completion(
3636
data, params = config.preprocess_chat_request(data, request, params)
3737

3838
project_id = params.project_id
39+
trail_id = params.trail_id
3940
rollout_id = params.rollout_id
4041
invocation_id = params.invocation_id
4142
experiment_id = params.experiment_id
@@ -70,8 +71,31 @@ async def handle_chat_completion(
7071

7172
# If metadata IDs are provided, add them as tags
7273
insertion_id = None
73-
if rollout_id is not None:
74+
tracking_key = None # Key for Redis tracking (trail_id or rollout_id)
75+
76+
if trail_id is not None:
77+
# Trail Management System: Simple tagging with just trail_id
78+
insertion_id = str(uuid7())
79+
tracking_key = trail_id
80+
81+
if "metadata" not in data:
82+
data["metadata"] = {}
83+
if "tags" not in data["metadata"]:
84+
data["metadata"]["tags"] = []
85+
86+
# Add trail metadata as tags
87+
data["metadata"]["tags"].extend(
88+
[
89+
f"trail_id:{trail_id}",
90+
f"insertion_id:{insertion_id}",
91+
]
92+
)
93+
logger.debug(f"Trail request: trail_id={trail_id}, insertion_id={insertion_id}")
94+
95+
elif rollout_id is not None:
96+
# Legacy evaluation system: Complex tagging with multiple IDs
7497
insertion_id = str(uuid7())
98+
tracking_key = rollout_id
7599

76100
if "metadata" not in data:
77101
data["metadata"] = {}
@@ -89,6 +113,7 @@ async def handle_chat_completion(
89113
f"row_id:{row_id}",
90114
]
91115
)
116+
logger.debug(f"Rollout request: rollout_id={rollout_id}, insertion_id={insertion_id}")
92117

93118
# Add Langfuse configuration
94119
data["langfuse_public_key"] = config.langfuse_keys[project_id]["public_key"]
@@ -115,8 +140,8 @@ async def handle_chat_completion(
115140
)
116141

117142
# Register insertion_id in Redis only on successful response
118-
if response.status_code == 200 and insertion_id is not None and rollout_id is not None:
119-
register_insertion_id(redis_client, rollout_id, insertion_id)
143+
if response.status_code == 200 and insertion_id is not None and tracking_key is not None:
144+
register_insertion_id(redis_client, tracking_key, insertion_id)
120145

121146
# Return the response
122147
return Response(

eval_protocol/proxy/proxy_core/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class ChatParams(BaseModel):
2121
"""Typed container for chat completion URL path parameters."""
2222

2323
project_id: Optional[str] = None
24+
# Trail Management System (simpler path)
25+
trail_id: Optional[str] = None
26+
# Legacy evaluation system (complex path)
2427
rollout_id: Optional[str] = None
2528
invocation_id: Optional[str] = None
2629
experiment_id: Optional[str] = None

0 commit comments

Comments
 (0)