77"""
88
99import argparse
10- import base64
1110import json
1211import os
1312
@@ -20,44 +19,24 @@ def main():
2019 # Required arguments from workflow inputs
2120 parser .add_argument ("--model" , required = True , help = "Model to use" )
2221 parser .add_argument ("--rollout-id" , required = True , help = "Rollout ID for tracking" )
23- parser .add_argument ("--messages-b64" , required = True , help = "Base64 encoded JSON messages" )
24- parser .add_argument ("--tools-b64" , required = False , help = "Base64 encoded JSON tools (optional)" )
22+ parser .add_argument ("--prompt" , required = True , help = "User prompt for the rollout" )
2523
2624 args = parser .parse_args ()
2725
2826 print (f"🚀 Starting rollout { args .rollout_id } " )
2927 print (f" Model: { args .model } " )
28+ print (f" Prompt: { args .prompt } " )
3029
31- # Decode messages and tools
32- try :
33- messages = json .loads (base64 .b64decode (args .messages_b64 ).decode ("utf-8" ))
34- tools = None
35- if args .tools_b64 :
36- tools = json .loads (base64 .b64decode (args .tools_b64 ).decode ("utf-8" ))
37- except Exception as e :
38- print (f"❌ Failed to decode inputs: { e } " )
39- # Save error trace
40- error_data = {
41- "status" : "error" ,
42- "rollout_id" : args .rollout_id ,
43- "model" : args .model ,
44- "messages" : [],
45- "error" : f"Failed to decode inputs: { e } " ,
46- }
47- with open (f"rollout_trace_{ args .rollout_id } .json" , "w" ) as f :
48- json .dump (error_data , f , indent = 2 )
49- exit (1 )
30+ # Build messages array
31+ messages = [{"role" : "user" , "content" : args .prompt }]
5032
5133 print (f" Messages: { len (messages )} messages" )
52- print (f" Tools: { len (tools ) if tools else 0 } tools" )
5334
5435 # Perform the rollout
5536 conversation = messages .copy ()
5637
5738 try :
5839 completion_kwargs = {"model" : args .model , "messages" : messages }
59- if tools :
60- completion_kwargs ["tools" ] = tools
6140
6241 client = OpenAI (api_key = os .environ .get ("FIREWORKS_API_KEY" ))
6342
@@ -76,7 +55,6 @@ def main():
7655 "rollout_id" : args .rollout_id ,
7756 "model" : args .model ,
7857 "messages" : conversation ,
79- "tools" : tools ,
8058 "usage" : completion .usage .model_dump () if completion .usage else None ,
8159 }
8260
@@ -91,7 +69,6 @@ def main():
9169 "rollout_id" : args .rollout_id ,
9270 "model" : args .model ,
9371 "messages" : conversation ,
94- "tools" : tools ,
9572 "error" : str (e ),
9673 }
9774
0 commit comments