Skip to content

Commit a15e403

Browse files
committed
Merge main into 0x0elliot/encrypt-apikey-session
2 parents be2365f + 329d46b commit a15e403

17 files changed

Lines changed: 1730 additions & 490 deletions

agent_mock.go

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
package shuffle
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"io/ioutil"
9+
"log"
10+
"net/url"
11+
"os"
12+
"path/filepath"
13+
)
14+
15+
func RunAgentDecisionMockHandler(execution WorkflowExecution, decision AgentDecision) ([]byte, string, string, error) {
16+
log.Printf("[DEBUG][%s] Mock handler called for tool=%s, action=%s", execution.ExecutionId, decision.Tool, decision.Action)
17+
18+
// Get mock response
19+
response, err := GetMockSingulResponse(execution.ExecutionId, decision.Fields)
20+
if err != nil {
21+
log.Printf("[ERROR][%s] Failed to get mock response: %s", execution.ExecutionId, err)
22+
return nil, "", decision.Tool, err
23+
}
24+
25+
// Parse the response to extract raw_response
26+
var outputMapped SchemalessOutput
27+
err = json.Unmarshal(response, &outputMapped)
28+
if err != nil {
29+
log.Printf("[ERROR][%s] Failed to unmarshal mock response: %s", execution.ExecutionId, err)
30+
return response, "", decision.Tool, err
31+
}
32+
33+
// Extract the raw_response field
34+
body := response
35+
if val, ok := outputMapped.RawResponse.(string); ok {
36+
body = []byte(val)
37+
} else if val, ok := outputMapped.RawResponse.([]byte); ok {
38+
body = val
39+
} else if val, ok := outputMapped.RawResponse.(map[string]interface{}); ok {
40+
marshalledRawResp, err := json.MarshalIndent(val, "", " ")
41+
if err != nil {
42+
log.Printf("[ERROR][%s] Failed to marshal raw response: %s", execution.ExecutionId, err)
43+
} else {
44+
body = marshalledRawResp
45+
}
46+
}
47+
48+
log.Printf("[DEBUG][%s] Returning mock response for %s (success=%v, response_size=%d bytes)",
49+
execution.ExecutionId, decision.Tool, outputMapped.Success, len(body))
50+
51+
return body, "", decision.Tool, nil
52+
}
53+
54+
func GetMockSingulResponse(executionId string, fields []Valuereplace) ([]byte, error) {
55+
ctx := context.Background()
56+
mockCacheKey := fmt.Sprintf("agent_mock_%s", executionId)
57+
cache, err := GetCache(ctx, mockCacheKey)
58+
59+
if err == nil {
60+
cacheData := cache.([]uint8)
61+
log.Printf("[DEBUG][%s] Using cached mock data (%d bytes)", executionId, len(cacheData))
62+
63+
var toolCalls []MockToolCall
64+
err = json.Unmarshal(cacheData, &toolCalls)
65+
if err != nil {
66+
log.Printf("[ERROR][%s] Failed to unmarshal cached mock data: %s", executionId, err)
67+
return nil, fmt.Errorf("failed to unmarshal cached mock data: %w", err)
68+
}
69+
70+
return GetMockResponseFromToolCalls(toolCalls, fields)
71+
}
72+
73+
testDataPath := os.Getenv("AGENT_TEST_DATA_PATH")
74+
if testDataPath == "" {
75+
return nil, fmt.Errorf("no mock data in cache for execution %s and AGENT_TEST_DATA_PATH not set", executionId)
76+
}
77+
78+
log.Printf("[DEBUG][%s] Cache miss, using file-based mocks from: %s", executionId, testDataPath)
79+
80+
useCase := os.Getenv("AGENT_TEST_USE_CASE")
81+
if useCase == "" {
82+
return nil, errors.New("AGENT_TEST_USE_CASE not set")
83+
}
84+
85+
useCaseData, err := loadUseCaseData(useCase)
86+
if err != nil {
87+
return nil, err
88+
}
89+
90+
return GetMockResponseFromToolCalls(useCaseData.ToolCalls, fields)
91+
}
92+
93+
// GetMockResponseFromToolCalls finds and returns the matching mock response from tool calls
94+
func GetMockResponseFromToolCalls(toolCalls []MockToolCall, fields []Valuereplace) ([]byte, error) {
95+
requestURL := extractFieldValue(fields, "url")
96+
if requestURL == "" {
97+
return nil, errors.New("no URL found in request fields")
98+
}
99+
100+
log.Printf("[DEBUG] Looking for mock data with URL: %s", requestURL)
101+
102+
var candidates []MockToolCall
103+
reqURLParsed, err := url.Parse(requestURL)
104+
if err != nil {
105+
log.Printf("[ERROR] Invalid request URL %s: %v", requestURL, err)
106+
return nil, fmt.Errorf("invalid request URL: %w", err)
107+
}
108+
for _, tc := range toolCalls {
109+
if urlsEqual(reqURLParsed, tc.URL) {
110+
candidates = append(candidates, tc)
111+
}
112+
}
113+
114+
// If no exact matches, try fuzzy matching
115+
if len(candidates) == 0 {
116+
log.Printf("[DEBUG] No exact match, trying fuzzy matching...")
117+
bestMatch, score := findBestFuzzyMatch(reqURLParsed, toolCalls)
118+
if score >= 0.80 {
119+
log.Printf("[INFO] Found fuzzy match with %.1f%% similarity: %s", score*100, bestMatch.URL)
120+
candidates = append(candidates, bestMatch)
121+
} else {
122+
return nil, fmt.Errorf("no mock data found for URL: %s (best match: %.1f%%)", requestURL, score*100)
123+
}
124+
}
125+
126+
// If only one match, return it
127+
if len(candidates) == 1 {
128+
log.Printf("[DEBUG] Found exact match for URL: %s", requestURL)
129+
return marshalResponse(candidates[0].Response)
130+
}
131+
132+
// Multiple matches - compare fields to find exact match
133+
log.Printf("[DEBUG] Found %d candidates for URL, comparing fields...", len(candidates))
134+
for _, candidate := range candidates {
135+
if fieldsMatch(fields, candidate.Fields) {
136+
log.Printf("[DEBUG] Found exact match based on fields")
137+
return marshalResponse(candidate.Response)
138+
}
139+
}
140+
141+
// No exact match - return first candidate with a warning
142+
log.Printf("[WARNING] No exact field match found, returning first candidate")
143+
return marshalResponse(candidates[0].Response)
144+
}
145+
146+
func urlsEqual(req *url.URL, stored string) bool {
147+
storedURL, err := url.Parse(stored)
148+
if err != nil {
149+
log.Printf("[WARN] Invalid stored URL %s: %v", stored, err)
150+
return false
151+
}
152+
if req.Scheme != storedURL.Scheme || req.Host != storedURL.Host || req.Path != storedURL.Path {
153+
return false
154+
}
155+
reqQuery := req.Query()
156+
storedQuery := storedURL.Query()
157+
// If the number of parameters differs, not a match
158+
if len(reqQuery) != len(storedQuery) {
159+
return false
160+
}
161+
162+
for key, reqVals := range reqQuery {
163+
storedVals, ok := storedQuery[key]
164+
if !ok {
165+
return false
166+
}
167+
if len(reqVals) != len(storedVals) {
168+
return false
169+
}
170+
for i, v := range reqVals {
171+
if v != storedVals[i] {
172+
return false
173+
}
174+
}
175+
}
176+
return true
177+
}
178+
179+
func loadUseCaseData(useCase string) (*MockUseCaseData, error) {
180+
possiblePaths := []string{}
181+
182+
if envPath := os.Getenv("AGENT_TEST_DATA_PATH"); envPath != "" {
183+
possiblePaths = append(possiblePaths, envPath)
184+
}
185+
186+
possiblePaths = append(possiblePaths, "agent_test_data")
187+
possiblePaths = append(possiblePaths, "../shuffle-shared/agent_test_data")
188+
possiblePaths = append(possiblePaths, "../../shuffle-shared/agent_test_data")
189+
190+
if homeDir, err := os.UserHomeDir(); err == nil {
191+
possiblePaths = append(possiblePaths, filepath.Join(homeDir, "Documents", "shuffle-shared", "agent_test_data"))
192+
}
193+
194+
var filePath string
195+
var foundPath string
196+
197+
for _, basePath := range possiblePaths {
198+
testPath := filepath.Join(basePath, fmt.Sprintf("%s.json", useCase))
199+
if _, err := os.Stat(testPath); err == nil {
200+
filePath = testPath
201+
foundPath = basePath
202+
break
203+
}
204+
}
205+
206+
if filePath == "" {
207+
return nil, fmt.Errorf("could not find test data file %s.json in any of these paths: %v", useCase, possiblePaths)
208+
}
209+
210+
log.Printf("[DEBUG] Loading use case data from: %s", filePath)
211+
212+
data, err := ioutil.ReadFile(filePath)
213+
if err != nil {
214+
return nil, fmt.Errorf("failed to read use case file %s: %s", filePath, err)
215+
}
216+
217+
var useCaseData MockUseCaseData
218+
err = json.Unmarshal(data, &useCaseData)
219+
if err != nil {
220+
return nil, fmt.Errorf("failed to parse use case data: %s", err)
221+
}
222+
223+
log.Printf("[DEBUG] Loaded use case '%s' with %d tool calls from %s", useCaseData.UseCase, len(useCaseData.ToolCalls), foundPath)
224+
225+
return &useCaseData, nil
226+
}
227+
228+
func extractFieldValue(fields []Valuereplace, key string) string {
229+
for _, field := range fields {
230+
if field.Key == key {
231+
return field.Value
232+
}
233+
}
234+
return ""
235+
}
236+
237+
func fieldsMatch(requestFields []Valuereplace, storedFields map[string]string) bool {
238+
// Convert request fields to map for easier comparison
239+
requestMap := make(map[string]string)
240+
for _, field := range requestFields {
241+
requestMap[field.Key] = field.Value
242+
}
243+
244+
for key, storedValue := range storedFields {
245+
requestValue, exists := requestMap[key]
246+
if !exists || requestValue != storedValue {
247+
return false
248+
}
249+
}
250+
251+
return true
252+
}
253+
254+
func marshalResponse(response map[string]interface{}) ([]byte, error) {
255+
data, err := json.Marshal(response)
256+
if err != nil {
257+
return nil, fmt.Errorf("failed to marshal response: %s", err)
258+
}
259+
return data, nil
260+
}
261+
262+
func findBestFuzzyMatch(reqURL *url.URL, toolCalls []MockToolCall) (MockToolCall, float64) {
263+
var bestMatch MockToolCall
264+
bestScore := 0.0
265+
266+
for _, tc := range toolCalls {
267+
storedURL, err := url.Parse(tc.URL)
268+
if err != nil {
269+
continue
270+
}
271+
272+
score := calculateURLSimilarity(reqURL, storedURL)
273+
if score > bestScore {
274+
bestScore = score
275+
bestMatch = tc
276+
}
277+
}
278+
279+
return bestMatch, bestScore
280+
}
281+
282+
func calculateURLSimilarity(url1, url2 *url.URL) float64 {
283+
score := 0.0
284+
totalWeight := 0.0
285+
286+
// Scheme (10% weight)
287+
if url1.Scheme == url2.Scheme {
288+
score += 0.10
289+
}
290+
totalWeight += 0.10
291+
292+
// Host (20% weight)
293+
if url1.Host == url2.Host {
294+
score += 0.20
295+
}
296+
totalWeight += 0.20
297+
298+
// Path (20% weight)
299+
if url1.Path == url2.Path {
300+
score += 0.20
301+
}
302+
totalWeight += 0.20
303+
304+
// Query parameters (50% weight)
305+
query1 := url1.Query()
306+
query2 := url2.Query()
307+
308+
if len(query1) == 0 && len(query2) == 0 {
309+
score += 0.50
310+
} else if len(query1) > 0 || len(query2) > 0 {
311+
matchingParams := 0
312+
totalParams := 0
313+
314+
allKeys := make(map[string]bool)
315+
for k := range query1 {
316+
allKeys[k] = true
317+
}
318+
for k := range query2 {
319+
allKeys[k] = true
320+
}
321+
totalParams = len(allKeys)
322+
323+
// Count how many match
324+
for key := range allKeys {
325+
val1, ok1 := query1[key]
326+
val2, ok2 := query2[key]
327+
328+
if ok1 && ok2 {
329+
// Both have this key - check if values match
330+
if len(val1) == len(val2) {
331+
allMatch := true
332+
for i := range val1 {
333+
if val1[i] != val2[i] {
334+
allMatch = false
335+
break
336+
}
337+
}
338+
if allMatch {
339+
matchingParams++
340+
}
341+
}
342+
}
343+
}
344+
345+
if totalParams > 0 {
346+
paramScore := float64(matchingParams) / float64(totalParams)
347+
score += paramScore * 0.50
348+
}
349+
}
350+
totalWeight += 0.50
351+
352+
return score / totalWeight
353+
}

0 commit comments

Comments
 (0)