-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathrollout_worker.py
More file actions
84 lines (63 loc) · 2.8 KB
/
Copy pathrollout_worker.py
File metadata and controls
84 lines (63 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
"""
GitHub Actions rollout worker script.
This script is called by the GitHub Actions workflow to perform the actual rollout.
It makes an OpenAI completion call that gets automatically traced via the tracing proxy.
"""
import argparse
import json
import os
from openai import OpenAI
def main():
parser = argparse.ArgumentParser(description="GitHub Actions rollout worker")
# Required arguments from workflow inputs
parser.add_argument("--model", required=True, help="Model to use")
parser.add_argument("--completion-params", required=False, help="JSON completion params (optional)")
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")
args = parser.parse_args()
# Parse the metadata
completion_params = {}
if args.completion_params:
try:
completion_params = json.loads(args.completion_params)
except Exception as e:
print(f"⚠️ Failed to parse completion_params: {e}")
try:
metadata = json.loads(args.metadata)
except Exception as e:
print(f"❌ Failed to parse metadata: {e}")
exit(1)
rollout_id = metadata["rollout_id"]
row_id = metadata["row_id"]
print(f"🚀 Starting rollout {rollout_id}")
print(f" Model: {args.model}")
print(f" Row ID: {row_id}")
dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of Italy?",
]
user_content = dataset[int(row_id)]
messages = [{"role": "user", "content": user_content}]
print(f" Messages: {len(messages)} messages")
try:
completion_kwargs = {"model": args.model, "messages": messages}
# Parse and apply completion_params if provided
if args.completion_params:
try:
cp = json.loads(args.completion_params)
if cp.get("model_kwargs"):
completion_kwargs.update(cp["model_kwargs"])
print(f" Applied model_kwargs: {cp.get('model_kwargs')}")
except Exception as e:
print(f"⚠️ Failed to parse completion_params: {e}")
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
print("📡 Calling OpenAI completion...")
print(f" Completion kwargs: {completion_kwargs}")
completion = client.chat.completions.create(**completion_kwargs)
print(f"✅ Rollout {rollout_id} completed successfully")
except Exception as e:
print(f"❌ Error in rollout {rollout_id}: {e}")
if __name__ == "__main__":
main()