Skip to content

Commit 5d97ef0

Browse files
committed
testing to see if explicit file handling will make this work
1 parent 1324110 commit 5d97ef0

1 file changed

Lines changed: 112 additions & 3 deletions

File tree

evaluation_function/main.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,127 @@
1+
"""
2+
Main entry point for the FSA evaluation function.
3+
4+
Supports two communication modes with shimmy:
5+
1. File-based (recommended for large payloads): shimmy passes input/output file paths as args
6+
2. RPC/IPC (default): Uses lf_toolkit's server for stdio/IPC communication
7+
"""
8+
9+
import sys
10+
import json
11+
from typing import Any, Dict
112

213
from lf_toolkit import create_server, run
14+
from lf_toolkit.evaluation import Params, Result as LFResult
315

416
from .evaluation import evaluation_function
517
from .preview import preview_function
618

19+
20+
def handle_file_based_communication(input_path: str, output_path: str) -> None:
21+
"""
22+
Handle file-based communication with shimmy.
23+
24+
Reads input JSON from input_path, processes it, and writes result to output_path.
25+
This is used when shimmy is configured with --interface file.
26+
27+
Args:
28+
input_path: Path to the input JSON file
29+
output_path: Path to write the output JSON file
30+
"""
31+
# Read input from file
32+
with open(input_path, 'r', encoding='utf-8') as f:
33+
input_data = json.load(f)
34+
35+
# Extract command and request data
36+
command = input_data.get('command', 'eval')
37+
request_id = input_data.get('$id')
38+
39+
# Build response structure
40+
response_data: Dict[str, Any] = {}
41+
if request_id is not None:
42+
response_data['$id'] = request_id
43+
response_data['command'] = command
44+
45+
try:
46+
if command == 'eval':
47+
# Extract evaluation inputs
48+
response = input_data.get('response')
49+
answer = input_data.get('answer')
50+
params_dict = input_data.get('params', {})
51+
52+
# Create params object
53+
params = Params(**params_dict) if params_dict else Params()
54+
55+
# Call evaluation function
56+
result = evaluation_function(response, answer, params)
57+
58+
# Convert result to dict
59+
if hasattr(result, 'to_dict'):
60+
response_data['result'] = result.to_dict()
61+
elif isinstance(result, dict):
62+
response_data['result'] = result
63+
else:
64+
response_data['result'] = {'is_correct': False, 'feedback': str(result)}
65+
66+
elif command == 'preview':
67+
# Extract preview inputs
68+
response = input_data.get('response')
69+
params_dict = input_data.get('params', {})
70+
71+
params = Params(**params_dict) if params_dict else Params()
72+
73+
# Call preview function
74+
result = preview_function(response, params)
75+
76+
if hasattr(result, 'to_dict'):
77+
response_data['result'] = result.to_dict()
78+
elif isinstance(result, dict):
79+
response_data['result'] = result
80+
else:
81+
response_data['result'] = {'preview': str(result)}
82+
83+
else:
84+
response_data['result'] = {
85+
'is_correct': False,
86+
'feedback': f'Unknown command: {command}'
87+
}
88+
89+
except Exception as e:
90+
response_data['result'] = {
91+
'is_correct': False,
92+
'feedback': f'Error processing request: {str(e)}'
93+
}
94+
95+
# Write output to file
96+
with open(output_path, 'w', encoding='utf-8') as f:
97+
json.dump(response_data, f, ensure_ascii=False)
98+
99+
7100
def main():
8-
"""Run the IPC server with the evaluation and preview functions.
9101
"""
102+
Run the evaluation function.
103+
104+
Detects communication mode based on command-line arguments:
105+
- If 2+ args provided: File-based communication (last 2 args are input/output paths)
106+
- Otherwise: RPC/IPC server mode using lf_toolkit
107+
"""
108+
# Check for file-based communication
109+
# shimmy passes input and output file paths as the last two arguments
110+
if len(sys.argv) >= 3:
111+
input_path = sys.argv[-2]
112+
output_path = sys.argv[-1]
113+
114+
# Verify they look like file paths (basic check)
115+
if not input_path.startswith('-') and not output_path.startswith('-'):
116+
handle_file_based_communication(input_path, output_path)
117+
return
118+
119+
# Fall back to RPC/IPC server mode
10120
server = create_server()
11-
12121
server.eval(evaluation_function)
13122
server.preview(preview_function)
14-
15123
run(server)
16124

125+
17126
if __name__ == "__main__":
18127
main()

0 commit comments

Comments
 (0)