Skip to content

Commit 5a212ea

Browse files
author
Shrey Modi
committed
merge
1 parent 114d714 commit 5a212ea

6 files changed

Lines changed: 48 additions & 20 deletions

File tree

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from typing import Any, Callable, Dict, List, Optional
5-
5+
import json
66
import requests
77
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
@@ -87,10 +87,14 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8787

8888
def _dispatch_workflow():
8989
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
90+
91+
model = init_request.completion_params.get("model")
92+
if not model:
93+
raise ValueError("model is required in completion_params")
9094
payload = {
9195
"ref": self.ref,
9296
"inputs": {
93-
"model": init_request.model,
97+
"model": model,
9498
"metadata": init_request.metadata.model_dump_json(),
9599
"model_base_url": init_request.model_base_url,
96100
},

eval_protocol/pytest/tracing_utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,25 @@ def build_init_request(
8080
row_id=row.input_metadata.row_id,
8181
)
8282

83-
# Extract model
84-
model: Optional[str] = None
83+
# Build completion_params from row and config
84+
completion_params_dict: Dict[str, Any] = {}
85+
86+
# Start with config-level completion_params
87+
if config.completion_params and isinstance(config.completion_params, dict):
88+
completion_params_dict.update(config.completion_params)
89+
90+
# Override with row-specific completion_params
8591
if row.input_metadata and row.input_metadata.completion_params:
86-
model = row.input_metadata.completion_params.get("model")
87-
if model is None and config.completion_params:
88-
model = config.completion_params.get("model")
89-
if model is None:
90-
raise ValueError("Model must be provided in row.input_metadata.completion_params or config.completion_params")
91-
92+
row_cp = row.input_metadata.completion_params
93+
if isinstance(row_cp, dict):
94+
completion_params_dict.update(row_cp)
95+
96+
# Validate model is present
97+
if not completion_params_dict.get("model"):
98+
raise ValueError("Model must be provided in completion_params")
99+
92100
# Extract base_url from completion_params
93-
completion_params_base_url: Optional[str] = None
94-
if row.input_metadata and row.input_metadata.completion_params:
95-
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
96-
if completion_params_base_url is None and config.completion_params:
97-
completion_params_base_url = config.completion_params.get("base_url")
101+
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
98102

99103
# Strip non-OpenAI fields from messages
100104
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
@@ -124,7 +128,7 @@ def build_init_request(
124128
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)
125129

126130
return InitRequest(
127-
model=model,
131+
completion_params=completion_params_dict,
128132
messages=clean_messages,
129133
tools=row.tools,
130134
metadata=meta,

tests/github_actions/rollout_worker.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@ def main():
1818

1919
# Required arguments from workflow inputs
2020
parser.add_argument("--model", required=True, help="Model to use")
21+
parser.add_argument("--completion-params", required=False, help="JSON completion params (optional)")
2122
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
2223
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")
2324

2425
args = parser.parse_args()
2526

2627
# Parse the metadata
28+
completion_params = {}
29+
if args.completion_params:
30+
try:
31+
completion_params = json.loads(args.completion_params)
32+
except Exception as e:
33+
print(f"⚠️ Failed to parse completion_params: {e}")
34+
2735
try:
2836
metadata = json.loads(args.metadata)
2937
except Exception as e:
@@ -51,6 +59,9 @@ def main():
5159
try:
5260
completion_kwargs = {"model": args.model, "messages": messages}
5361

62+
if completion_params.get("model_kwargs"):
63+
completion_kwargs.update(completion_params["model_kwargs"])
64+
5465
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
5566

5667
print("📡 Calling OpenAI completion...")

tests/remote_server/remote_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _worker():
3535
try:
3636
if not req.messages:
3737
raise ValueError("messages is required")
38-
38+
3939
model = req.completion_params.get("model")
4040
if not model:
4141
raise ValueError("model is required in completion_params")
@@ -50,10 +50,12 @@ def _worker():
5050
model_kwargs = req.completion_params["model_kwargs"]
5151
if isinstance(model_kwargs, dict):
5252
completion_kwargs.update(model_kwargs)
53-
53+
5454
if req.tools:
5555
completion_kwargs["tools"] = req.tools
5656

57+
logger.info(f"Final completion_kwargs: {completion_kwargs}")
58+
5759
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
5860

5961
logger.info(f"Sending completion request to model {model}")

tests/remote_server/remote_server_multi_turn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def _worker():
3131
try:
3232
if not req.messages:
3333
raise ValueError("messages is required")
34+
35+
model = req.completion_params.get("model")
36+
if not model:
37+
raise ValueError("model is required in completion_params")
3438

3539
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
3640

tests/remote_server/test_remote_fireworks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def rows() -> List[EvaluationRow]:
5858

5959

6060
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
61-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
61+
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
62+
"model_kwargs": {"temperature": 0.5}
63+
}])
6264
@evaluation_test(
6365
data_loaders=DynamicDataLoader(
6466
generators=[rows],
@@ -82,5 +84,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
8284
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
8385
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
8486
)
85-
87+
assert row.input_metadata.completion_params["model_kwargs"] == {"temperature": 0.5}, "Row should have correct model_kwargs"
88+
8689
return row

0 commit comments

Comments
 (0)