-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtracing_utils.py
More file actions
259 lines (218 loc) · 11 KB
/
Copy pathtracing_utils.py
File metadata and controls
259 lines (218 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Shared utilities for rollout processors.
"""
import base64
import os
from typing import Any, Callable, Dict, List, Optional
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig, RolloutMetadata, InitRequest
from eval_protocol.pytest.types import RolloutProcessorConfig
def default_fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
"""Default output data loader that fetches traces from Fireworks tracing proxy."""
def fetch_traces() -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
# Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key)
return adapter.get_evaluation_rows(
tags=[f"rollout_id:{config.rollout_id}"],
max_retries=5,
include_payloads=config.include_payloads,
)
def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]:
filtered_rows = filter_longest_conversation(rows)
if config.include_payloads and filtered_rows:
_merge_payloads_into_longest_row(filtered_rows[0], rows)
return filtered_rows
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces)
def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None:
"""
Preserve per-turn payload-derived metadata after selecting the longest trace row.
Each trace row carries payloads for its final assistant turn. The longest row
keeps the full conversation, while its top-level execution metadata remains
the payload metadata for the final completion for backward compatibility.
"""
target_assistants = longest_row.get_assistant_messages()
assistant_turn_payloads = []
for row in sorted(rows, key=lambda item: len(item.messages)):
source = row.last_assistant_message()
source_turn_index = len(row.get_assistant_messages()) - 1
if source_turn_index < 0 or source_turn_index >= len(target_assistants):
continue
if source and source.logprobs and not target_assistants[source_turn_index].logprobs:
target_assistants[source_turn_index].logprobs = source.logprobs
extra = row.execution_metadata.extra or {}
turn_payload = {
key: extra[key]
for key in (
"completion_logprobs",
"completion_token_ids",
"logprobs_metadata",
"routing_matrices",
"routing_metadata",
)
if key in extra
}
if turn_payload:
turn_payload["assistant_turn_index"] = source_turn_index
assistant_turn_payloads.append(turn_payload)
if assistant_turn_payloads:
if longest_row.execution_metadata.extra is None:
longest_row.execution_metadata.extra = {}
longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads
def build_fireworks_tracing_url(
base_url: str, metadata: RolloutMetadata, completion_params_base_url: Optional[str] = None
) -> str:
"""Build a Fireworks tracing URL by appending rollout metadata to the base URL path,
allowing the Fireworks tracing proxy to automatically tag traces.
Format: {base_url}/rollout_id/{id}/invocation_id/{id}/experiment_id/{id}/run_id/{id}/row_id/{id}
Args:
base_url: Fireworks tracing proxy URL (e.g., https://tracing.fireworks.ai)
metadata: Rollout metadata containing IDs to embed in the URL
completion_params_base_url: Optional LLM base URL to encode and append to the final URL
"""
url = (
f"{base_url}/rollout_id/{metadata.rollout_id}"
f"/invocation_id/{metadata.invocation_id}"
f"/experiment_id/{metadata.experiment_id}"
f"/run_id/{metadata.run_id}"
f"/row_id/{metadata.row_id}"
)
if completion_params_base_url:
encoded_base_url = base64.urlsafe_b64encode(completion_params_base_url.encode()).decode()
url = f"{url}/encoded_base_url/{encoded_base_url}"
return url
def build_init_request(
row: EvaluationRow,
config: RolloutProcessorConfig,
model_base_url: str,
) -> InitRequest:
"""Build an InitRequest from an EvaluationRow and config (shared logic)."""
# Validation
if row.execution_metadata.invocation_id is None:
raise ValueError("Invocation ID is required")
if row.execution_metadata.experiment_id is None:
raise ValueError("Experiment ID is required")
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required")
if row.execution_metadata.run_id is None:
raise ValueError("Run ID is required")
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required")
# Build metadata
meta = RolloutMetadata(
invocation_id=row.execution_metadata.invocation_id,
experiment_id=row.execution_metadata.experiment_id,
rollout_id=row.execution_metadata.rollout_id,
run_id=row.execution_metadata.run_id,
row_id=row.input_metadata.row_id,
)
# Build completion_params from row and config
completion_params_dict: Dict[str, Any] = {}
# Start with config-level completion_params
if config.completion_params and isinstance(config.completion_params, dict):
completion_params_dict.update(config.completion_params)
# Override with row-specific completion_params
if row.input_metadata and row.input_metadata.completion_params:
row_cp = row.input_metadata.completion_params
if isinstance(row_cp, dict):
completion_params_dict.update(row_cp)
# Validate model is present
if not completion_params_dict.get("model"):
raise ValueError("Model must be provided in completion_params")
# Extract base_url from completion_params for tracing-gateway URL encoding
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
# Strip non-OpenAI fields from messages
# Use dump_mdoel_for_chat_completion_request() to automatically exclude unsupported fields (weight, control_plane_step, reasoning_content)
clean_messages = []
for m in row.messages:
md: Dict[str, Any]
if hasattr(m, "dump_mdoel_for_chat_completion_request"):
# Use the Message method that automatically filters unsupported fields
md = m.dump_mdoel_for_chat_completion_request()
elif hasattr(m, "model_dump"):
md = m.model_dump()
elif isinstance(m, dict):
md = m
else:
# Fallback to constructing a dict from Message-like object
md = {
"role": getattr(m, "role", None),
"content": getattr(m, "content", None),
"tool_calls": getattr(m, "tool_calls", None),
"tool_call_id": getattr(m, "tool_call_id", None),
"name": getattr(m, "name", None),
}
# Additional filtering to ensure only allowed fields are kept (already handled by dump_mdoel_for_chat_completion_request for Message objects)
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
# Build final model base URL with tracing metadata
final_model_base_url = model_base_url
if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost") or "litellm-gateway" in model_base_url):
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)
# Extract API key from environment or completion_params
# EP_REMOTE_API_KEY takes precedence for remote rollout processors,
# falling back to FIREWORKS_API_KEY for backwards compatibility
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
return InitRequest(
completion_params=completion_params_dict,
messages=clean_messages,
tools=row.tools,
metadata=meta,
model_base_url=final_model_base_url,
api_key=api_key,
)
def update_row_with_remote_trace(
row: EvaluationRow,
output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader],
model_base_url: str,
include_payloads: bool = False,
) -> None:
"""Update row with remote trace data using output_data_loader (shared logic)."""
if not row.execution_metadata.rollout_id:
return None
loader_config = DataLoaderConfig(
rollout_id=row.execution_metadata.rollout_id,
model_base_url=model_base_url,
include_payloads=include_payloads,
)
data_loader = output_data_loader(loader_config)
results = data_loader.load()
output_rows: List[EvaluationRow] = [r for result in results for r in result.rows]
if len(output_rows) == 0: # Fallback to original row if no remote data found
row.rollout_status = Status.rollout_not_found_error("No remote data found for rollout")
return None
elif len(output_rows) == 1: # Return the remote row
remote_row = output_rows[0]
# if the remote_row has the same number of messages as the original row, something went wrong
if len(remote_row.messages) == len(row.messages):
row.rollout_status = Status.rollout_internal_error(
"Rollout finished with the same number of messages as the original row"
)
return None
row.messages = remote_row.messages
row.tools = remote_row.tools
row.input_metadata.session_data = remote_row.input_metadata.session_data
remote_info = remote_row.input_metadata.dataset_info or {}
if row.input_metadata.dataset_info is None:
row.input_metadata.dataset_info = dict(remote_info)
else:
for k, v in remote_info.items():
if k not in row.input_metadata.dataset_info:
row.input_metadata.dataset_info[k] = v
preserved_extra = row.execution_metadata.extra
row.execution_metadata = remote_row.execution_metadata.model_copy(deep=True)
if preserved_extra:
if row.execution_metadata.extra:
# Merge remote and local extras; local takes precedence on conflicts
merged = row.execution_metadata.extra or {}
merged.update(preserved_extra)
row.execution_metadata.extra = merged
else:
row.execution_metadata.extra = preserved_extra
return None
else:
raise ValueError("Output data loader should return exactly one row.")