Skip to content

Commit a692e20

Browse files
committed
fix(tracing): fail loud on span_name in Fireworks adapter
The /v1/traces endpoint returns flat PascalCase rows (one LLM call each) — there is no nested span/generation structure to walk. The leftover snake_case observations branch in extract_messages_from_trace_dict and the entire get_final_generation_in_span_dict helper would always return no messages, so any caller passing span_name silently got an empty EvaluationRow (returning None). Raise NotImplementedError instead: - get_evaluation_rows rejects span_name upfront unless a custom converter is supplied (skips the pointless HTTP round-trip) - convert_trace_dict_to_evaluation_row and extract_messages_from_trace_dict also raise, for defense in depth - delete the dead get_final_generation_in_span_dict helper Made-with: Cursor
1 parent 893c551 commit a692e20

1 file changed

Lines changed: 52 additions & 82 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 52 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,24 @@ def convert_trace_dict_to_evaluation_row(
5656
Args:
5757
trace: Trace dictionary from Fireworks proxy API
5858
include_tool_calls: Whether to include tool calling information
59-
span_name: If provided, extract messages from generations within this named span
59+
span_name: Not supported by this converter. Each row returned by the
60+
Fireworks tracing endpoint is a single LLM call, so there is no
61+
nested-span structure to walk. Pass a custom ``TraceDictConverter``
62+
via ``get_evaluation_rows(..., converter=...)`` if you need
63+
span-specific extraction logic.
6064
6165
Returns:
6266
EvaluationRow or None if conversion fails
67+
68+
Raises:
69+
NotImplementedError: If ``span_name`` is provided.
6370
"""
71+
if span_name:
72+
raise NotImplementedError(
73+
"span_name is not supported by the default Fireworks tracing converter. "
74+
"Each trace row is already a single LLM call; provide a custom "
75+
"TraceDictConverter to get_evaluation_rows() for span-aware logic."
76+
)
6477
try:
6578
# Extract messages from trace input and output
6679
messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name)
@@ -125,97 +138,44 @@ def convert_trace_dict_to_evaluation_row(
125138
def extract_messages_from_trace_dict(
126139
trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None
127140
) -> List[Message]:
128-
"""Extract messages from trace dictionary.
141+
"""Extract messages from a Fireworks trace row.
142+
143+
The Fireworks tracing endpoint returns one row per LLM call with the
144+
request on ``Input`` and the response on ``Output``. There is no nested
145+
span/generation structure at the row level.
129146
130147
Args:
131148
trace: Trace dictionary from proxy API
132149
include_tool_calls: Whether to include tool calling information
133-
span_name: If provided, extract messages from generations within this named span
150+
span_name: Not supported. Pass a custom ``TraceDictConverter`` to
151+
``FireworksTracingAdapter.get_evaluation_rows`` if you need
152+
span-specific extraction.
134153
135154
Returns:
136155
List of Message objects
137-
"""
138-
messages = []
139-
140-
if span_name: # Look for a generation tied to a span name
141-
try:
142-
# Find the final generation in the named span
143-
gen = get_final_generation_in_span_dict(trace, span_name)
144-
if not gen:
145-
return messages
146-
147-
# Extract messages from generation input and output
148-
if gen.get("input"):
149-
messages.extend(extract_messages_from_data(gen["input"], include_tool_calls))
150-
if gen.get("output"):
151-
messages.extend(extract_messages_from_data(gen["output"], include_tool_calls))
152-
153-
return messages
154156
155-
except Exception as e:
156-
logger.error("Failed to extract messages from span '%s' in trace %s: %s", span_name, trace.get("id"), e)
157-
return messages
158-
159-
else:
160-
try:
161-
# `Input` carries the request messages; `Output` carries the
162-
# assistant message returned for this call. `extract_messages_from_data`
163-
# accepts both `{"messages": [...]}` and single message dicts.
164-
if trace.get("Input"):
165-
messages.extend(extract_messages_from_data(trace["Input"], include_tool_calls))
166-
if trace.get("Output"):
167-
messages.extend(extract_messages_from_data(trace["Output"], include_tool_calls))
168-
except (AttributeError, ValueError, KeyError) as e:
169-
logger.warning("Error processing trace %s: %s", trace.get("InsertionId"), e)
170-
171-
return messages
172-
173-
174-
def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> Optional[Dict[str, Any]]:
175-
"""Get the final generation within a named span from trace dictionary.
176-
177-
Args:
178-
trace: Trace dictionary
179-
span_name: Name of the span to search for
180-
181-
Returns:
182-
The final generation dictionary, or None if not found
157+
Raises:
158+
NotImplementedError: If ``span_name`` is provided.
183159
"""
184-
# Get all observations from the trace
185-
all_observations = trace.get("observations", [])
186-
187-
# Find a span with the given name that has generation children
188-
parent_span = None
189-
for obs in all_observations:
190-
if obs.get("name") == span_name and obs.get("type") == "SPAN":
191-
# Check if this span has generation children
192-
has_generations = any(
193-
child.get("type") == "GENERATION" and child.get("parent_observation_id") == obs.get("id")
194-
for child in all_observations
195-
)
196-
if has_generations:
197-
parent_span = obs
198-
break
199-
200-
if not parent_span:
201-
logger.warning("No span named '%s' found in trace %s", span_name, trace.get("id"))
202-
return None
203-
204-
# Find all generations within this span
205-
generations = []
206-
for obs in all_observations:
207-
if obs.get("type") == "GENERATION" and obs.get("parent_observation_id") == parent_span.get("id"):
208-
generations.append(obs)
209-
210-
if not generations:
211-
logger.warning("No generations found in span '%s' in trace %s", span_name, trace.get("id"))
212-
return None
160+
if span_name:
161+
raise NotImplementedError(
162+
"span_name is not supported by extract_messages_from_trace_dict for "
163+
"Fireworks traces; each row is already a single LLM call."
164+
)
213165

214-
# Sort generations by start time for chronological order
215-
generations.sort(key=lambda x: x.get("start_time", ""))
166+
messages: List[Message] = []
167+
try:
168+
# `Input` carries the request messages; `Output` carries the
169+
# assistant message returned for this call. `extract_messages_from_data`
170+
# accepts both `{"messages": [...]}` and single message dicts.
171+
if trace.get("Input"):
172+
messages.extend(extract_messages_from_data(trace["Input"], include_tool_calls))
173+
if trace.get("Output"):
174+
messages.extend(extract_messages_from_data(trace["Output"], include_tool_calls))
175+
except (AttributeError, ValueError, KeyError) as e:
176+
logger.warning("Error processing trace %s: %s", trace.get("InsertionId"), e)
216177

217-
# Return the final generation (contains full message history)
218-
return generations[-1]
178+
return messages
219179

220180

221181
class FireworksTracingAdapter(BaseAdapter):
@@ -427,19 +387,29 @@ def get_evaluation_rows(
427387
include_tool_calls: Whether to include tool calling traces
428388
sleep_between_gets: Sleep time between polling attempts (default: 2.5s)
429389
max_retries: Max retry attempts used by proxy (default: 3)
390+
span_name: Only supported when a custom ``converter`` is supplied.
391+
The default Fireworks converter does not walk nested spans
392+
(each trace row is already a single LLM call).
430393
converter: Optional custom converter implementing TraceDictConverter protocol.
431394
If provided, this will be used instead of the default conversion logic.
432395
433396
Returns:
434397
List[EvaluationRow]: Converted evaluation rows
435398
436399
Raises:
437-
ValueError: If tags list is empty
400+
ValueError: If tags list is empty or no ``rollout_id`` tag is present.
401+
NotImplementedError: If ``span_name`` is provided without a custom ``converter``.
438402
"""
439403
# Validate that tags are provided
440404
if not tags or len(tags) == 0:
441405
raise ValueError("At least one tag is required to fetch traces")
442406

407+
if span_name and converter is None:
408+
raise NotImplementedError(
409+
"span_name is not supported by the default Fireworks tracing converter. "
410+
"Pass a custom converter=TraceDictConverter(...) if you need span-aware logic."
411+
)
412+
443413
# Pull out rollout_id only, since that is the task-level id needed to fetch traces.
444414
rollout_id = next(
445415
(t.split(":", 1)[1] for t in tags if t.startswith("rollout_id:")),

0 commit comments

Comments
 (0)