-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_remote_fireworks.py
More file actions
185 lines (143 loc) · 6.68 KB
/
Copy pathtest_remote_fireworks.py
File metadata and controls
185 lines (143 loc) · 6.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
import logging
import subprocess
import socket
import time
from typing import List
import pytest
import requests
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
ROLLOUT_IDS = set()
class StatusLogCaptureHandler(logging.Handler):
"""Custom handler to capture status log messages."""
def __init__(self):
super().__init__()
self.status_100_messages: List[str] = []
def emit(self, record):
msg = record.getMessage() # Use getMessage(), not .message attribute
if "Found Fireworks log" in msg and "with status code 100" in msg:
self.status_100_messages.append(msg)
@pytest.fixture(autouse=True)
def check_rollout_coverage(monkeypatch):
"""
Ensure we attempted to fetch remote traces for each rollout and received status logs.
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
and tracks rollout_ids passed through its DataLoaderConfig.
"""
global ROLLOUT_IDS
ROLLOUT_IDS.clear()
original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
ROLLOUT_IDS.add(config.rollout_id)
return original_loader(config)
monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
# Add custom handler to capture status logs
status_handler = StatusLogCaptureHandler()
status_handler.setLevel(logging.INFO)
rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor")
rrp_logger.addHandler(status_handler)
# Ensure the logger level allows INFO messages through
original_level = rrp_logger.level
rrp_logger.setLevel(logging.INFO)
yield
# Cleanup handler and restore level
rrp_logger.removeHandler(status_handler)
rrp_logger.setLevel(original_level)
# After test completes, verify we saw status logs for all 3 rollouts
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
# Check that we received "Found Fireworks log ... with status code 100" for each rollout
assert len(status_handler.status_100_messages) == 3, (
f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. "
f"This means the status logs from the remote server were not received. "
f"Messages captured: {status_handler.status_100_messages}"
)
def find_available_port() -> int:
"""Find an available port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
port = s.getsockname()[1]
return port
SERVER_PORT = find_available_port()
def wait_for_server_to_startup(timeout: int = 120):
start_time = time.time()
while True:
try:
requests.get(f"http://127.0.0.1:{SERVER_PORT}")
break
except requests.exceptions.RequestException:
time.sleep(1)
if time.time() - start_time > timeout:
raise TimeoutError(f"Server did not start within {timeout} seconds")
@pytest.fixture(autouse=True)
def setup_remote_server():
"""Start the remote server"""
# kill all Python processes matching "python -m tests.remote_server.remote_server"
subprocess.run(["pkill", "-f", "python -m tests.remote_server.remote_server"], capture_output=True)
host = "127.0.0.1"
process = subprocess.Popen(
[
"python",
"-m",
"tests.remote_server.remote_server",
"--host",
host,
"--port",
str(SERVER_PORT),
]
)
# wait for the server to startup by polling
wait_for_server_to_startup()
yield
process.terminate()
process.wait()
def rows() -> List[EvaluationRow]:
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
base_dataset_info = {
"requirements": ["Answer with the capital city of France."],
"total_requirements": 1,
"original_prompt": "What is the capital of France?",
}
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
row.input_metadata.dataset_info = dict(base_dataset_info)
return [row, row, row]
@pytest.mark.parametrize(
"completion_params",
[{"model": "accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
)
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
),
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=180,
),
)
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
"""
End-to-end test:
- AUTO SERVER STARTUP: Server is automatically started and stopped by the test
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
- fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found
"""
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
assert row.input_metadata.row_id is not None
assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info["requirements"] == ["Answer with the capital city of France."]
assert row.input_metadata.dataset_info["total_requirements"] == 1
assert row.input_metadata.dataset_info["original_prompt"] == "What is the capital of France?"
assert "data_loader_type" in row.input_metadata.dataset_info
assert "data_loader_num_rows" in row.input_metadata.dataset_info
assert row.execution_metadata.finish_reason == "stop", (
f"Expected finish_reason='stop', got {row.execution_metadata.finish_reason}"
)
return row