Skip to content

Commit aa476f0

Browse files
committed
fix(tracing): migrate RemoteRolloutProcessor to /traces endpoint
The `/v1/traces/pointwise` endpoint was removed upstream in favor of a unified `/v1/traces` endpoint. This patches `FireworksTracingAdapter` and its default data loader to match the new contract. Wire changes: - URL: `/v1/traces/pointwise` → `/v1/traces` (both the default path and the `project_id`-scoped variant). - Query params: the old endpoint accepted `tags=rollout_id:<id>` as the only way to scope a request; the new one expects `rollout_id` as a top-level query parameter. `get_evaluation_rows` now extracts the rollout id from the `tags` kwarg so existing callers don't break, and raises `ValueError` if no `rollout_id:<id>` tag is supplied. - Response shape: the new endpoint returns flat row dicts with PascalCase keys (`Input`, `Output`, `Tags`, `InsertionId`) instead of the old nested snake_case shape with an `observations[]` array. The converter now reads the new keys and drops the "fall back to last GENERATION observation" branch, which has no equivalent server-side concept anymore. - `session_data["langfuse_trace_id"]` is now sourced from `InsertionId` so downstream consumers that key on that field keep working. Default data loader in `tracing_utils.py` now asks for `limit=1` since `update_row_with_remote_trace` only consumes a single row and raises on multi-row responses — `max_retries=5` was a no-op knob for the old Langfuse-polling path and the new endpoint doesn't expose it. Made-with: Cursor
1 parent 0655f89 commit aa476f0

2 files changed

Lines changed: 40 additions & 35 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,22 @@ def convert_trace_dict_to_evaluation_row(
6565
# Extract messages from trace input and output
6666
messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name)
6767

68-
# Extract tools if available
68+
# Extract tools if available. `Input` carries the request payload,
69+
# which optionally includes a `tools` array when tool-calling is used.
6970
tools = None
70-
if include_tool_calls and isinstance(trace.get("input"), dict) and "tools" in trace["input"]:
71-
tools = trace["input"]["tools"]
71+
trace_input = trace.get("Input")
72+
if include_tool_calls and isinstance(trace_input, dict) and "tools" in trace_input:
73+
tools = trace_input["tools"]
7274

7375
if not messages:
7476
return None
7577

7678
execution_metadata = ExecutionMetadata()
7779
row_id = None
7880

79-
# Extract metadata from tags
80-
tags = trace.get("tags", [])
81+
# Extract metadata from tags. `Tags` may be absent or null on a row
82+
# that was written without any, so coalesce to an empty list.
83+
tags = trace.get("Tags") or []
8184
if tags:
8285
for tag in tags:
8386
if tag.startswith("invocation_id:"):
@@ -106,14 +109,16 @@ def convert_trace_dict_to_evaluation_row(
106109
input_metadata=InputMetadata(
107110
row_id=row_id,
108111
session_data={
109-
"langfuse_trace_id": trace.get("id"), # Store the trace ID here
112+
# Historical key name kept for downstream compatibility;
113+
# sourced from the per-LLM-call identifier on the trace.
114+
"langfuse_trace_id": trace.get("InsertionId"),
110115
},
111116
),
112117
execution_metadata=execution_metadata,
113118
)
114119

115120
except (AttributeError, ValueError, KeyError) as e:
116-
logger.error("Error converting trace %s: %s", trace.get("id"), e)
121+
logger.error("Error converting trace %s: %s", trace.get("InsertionId"), e)
117122
return None
118123

119124

@@ -153,28 +158,15 @@ def extract_messages_from_trace_dict(
153158

154159
else:
155160
try:
156-
# Extract messages from trace input and output
157-
if trace.get("input"):
158-
messages.extend(extract_messages_from_data(trace["input"], include_tool_calls))
159-
if trace.get("output"):
160-
messages.extend(extract_messages_from_data(trace["output"], include_tool_calls))
161+
# `Input` carries the request messages; `Output` carries the
162+
# assistant message returned for this call. `extract_messages_from_data`
163+
# accepts both `{"messages": [...]}` and single message dicts.
164+
if trace.get("Input"):
165+
messages.extend(extract_messages_from_data(trace["Input"], include_tool_calls))
166+
if trace.get("Output"):
167+
messages.extend(extract_messages_from_data(trace["Output"], include_tool_calls))
161168
except (AttributeError, ValueError, KeyError) as e:
162-
logger.warning("Error processing trace %s: %s", trace.get("id"), e)
163-
164-
# Fallback: use the last GENERATION observation which typically contains full chat history
165-
if not messages:
166-
try:
167-
all_observations = trace.get("observations", [])
168-
gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"]
169-
if gens:
170-
gens.sort(key=lambda x: x.get("start_time", ""))
171-
last_gen = gens[-1]
172-
if last_gen.get("input"):
173-
messages.extend(extract_messages_from_data(last_gen["input"], include_tool_calls))
174-
if last_gen.get("output"):
175-
messages.extend(extract_messages_from_data(last_gen["output"], include_tool_calls))
176-
except Exception as e:
177-
logger.warning("Failed to extract from last generation for trace %s: %s", trace.get("id"), e)
169+
logger.warning("Error processing trace %s: %s", trace.get("InsertionId"), e)
178170

179171
return messages
180172

@@ -429,13 +421,21 @@ def get_evaluation_rows(
429421
if not tags or len(tags) == 0:
430422
raise ValueError("At least one tag is required to fetch traces")
431423

424+
# Pull out rollout_id only, since that is the task-level id needed to fetch traces.
425+
rollout_id = next(
426+
(t.split(":", 1)[1] for t in tags if t.startswith("rollout_id:")),
427+
None,
428+
)
429+
if not rollout_id:
430+
raise ValueError("tags must contain a 'rollout_id:<id>' entry")
431+
432432
eval_rows = []
433433

434434
# Build query parameters for GET request
435435
params = {
436+
"rollout_id": rollout_id,
436437
"limit": limit,
437438
"sample_size": sample_size,
438-
"tags": tags,
439439
"user_id": user_id,
440440
"session_id": session_id,
441441
"name": name,
@@ -453,11 +453,11 @@ def get_evaluation_rows(
453453
# Remove None values
454454
params = {k: v for k, v in params.items() if v is not None}
455455

456-
# Make request to proxy (using pointwise for efficiency)
456+
# Make request to proxy
457457
if self.project_id:
458-
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces/pointwise"
458+
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
459459
else:
460-
url = f"{self.base_url}/v1/traces/pointwise"
460+
url = f"{self.base_url}/v1/traces"
461461

462462
headers = {
463463
"Authorization": f"Bearer {self._get_api_key()}",
@@ -500,7 +500,7 @@ def get_evaluation_rows(
500500
if eval_row:
501501
eval_rows.append(eval_row)
502502
except (AttributeError, ValueError, KeyError) as e:
503-
logger.warning("Failed to convert trace %s: %s", trace.get("id"), e)
503+
logger.warning("Failed to convert trace %s: %s", trace.get("InsertionId"), e)
504504
continue
505505

506506
logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows))

eval_protocol/pytest/tracing_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@
1515

1616

1717
def default_fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
18-
"""Default output data loader that fetches traces from Fireworks tracing proxy."""
18+
"""Default output data loader that fetches traces from Fireworks tracing proxy.
19+
20+
Requests a single trace per rollout — `update_row_with_remote_trace` in
21+
this module only consumes one row and raises if more come back, so
22+
pulling the full list would just waste bytes on the wire.
23+
"""
1924

2025
def fetch_traces() -> List[EvaluationRow]:
2126
base_url = config.model_base_url or "https://tracing.fireworks.ai"
2227
# Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY
2328
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
2429
adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key)
25-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
30+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], limit=1)
2631

2732
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
2833

0 commit comments

Comments
 (0)