Skip to content

Commit 5260f83

Browse files
committed
Fix search_papers tool
1 parent bb7223d commit 5260f83

2 files changed

Lines changed: 39 additions & 45 deletions

File tree

internal/services/toolkit/client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func NewAIClient(
5656
// toolRegistry.Register("export_papers")
5757
// toolRegistry.Register("get_conference_papers")
5858
// toolRegistry.Register("get_user_papers")
59-
toolRegistry.Register("search_papers", toolSearchPapers.Description, toolSearchPapers.Call)
59+
toolRegistry.Register("search_relevant_papers", toolSearchPapers.Description, toolSearchPapers.Call)
6060
// toolRegistry.Register("search_user")
6161
// toolRegistry.Register("identify_improvements")
6262
// toolRegistry.Register("suggest_improvement")

internal/services/toolkit/tools/xtragpt/search_papers.go

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ type MCPParams struct {
3333
Arguments map[string]interface{} `json:"arguments"`
3434
}
3535

36-
// Venue represents a conference venue with year
37-
type Venue struct {
38-
Venue string `json:"venue"`
39-
Year string `json:"year"`
40-
}
4136
type SearchPapersTool struct {
4237
Description responses.ToolUnionParam
4338
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
@@ -47,37 +42,39 @@ type SearchPapersTool struct {
4742
client *http.Client
4843
}
4944

50-
var schema map[string]any
51-
52-
var SearchPapersToolDescription = responses.ToolUnionParam{
53-
OfFunction: &responses.FunctionToolParam{
54-
Name: "search_papers",
55-
Description: param.NewOpt("Search for papers by keywords within specific conference venues, with various matching modes."),
56-
Parameters: openai.FunctionParameters(schema),
57-
},
58-
}
59-
6045
func NewSearchPapersTool(db *db.DB, projectService *services.ProjectService) *SearchPapersTool {
61-
json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords / topics or content to search for (e.g., 'time series token merging', 'neural networks').","title":"Query","type":"string"},"venues":{"description":"List of conference venues and years to search in. Each entry must be a dict with 'venue' (e.g., 'ICLR.cc', 'NeurIPS.cc', 'ICML.cc'; users may omit '.cc') and 'year' (e.g., '2024', '2025').","items":{"additionalProperties":{"type":"string"},"type":"object"},"minItems":1,"title":"Venues","type":"array"},"search_fields":{"default":["title","abstract"],"description":"Fields to search within each paper. Options: 'title', 'abstract', 'authors'.","items":{"enum":["title","abstract","authors"],"type":"string"},"title":"Search Fields","type":"array"},"match_mode":{"default":"majority","description":"Match mode:\n- any: At least one keyword must match\n- all: All keywords must match\n- exact: Match the entire phrase exactly\n- majority: Match majority of keywords (>50%)\n- threshold: Match percentage of terms based on 'match_threshold'.","enum":["any","all","exact","majority","threshold"],"title":"Match Mode","type":"string"},"match_threshold":{"default":0.5,"description":"Minimum fraction (0.0-1.0) of search terms that must match when using 'threshold' mode. Example: 0.5 = 50% of terms must match.","maximum":1,"minimum":0,"title":"Match Threshold","type":"number"},"limit":{"default":10,"description":"Maximum number of results to return (1-16).","maximum":16,"minimum":1,"title":"Limit","type":"integer"},"min_score":{"default":0.6,"description":"Minimum match score (0.0-1.0). Lower values allow looser matches; higher values enforce stricter matches.","maximum":1,"minimum":0,"title":"Min Score","type":"number"}},"required":["query","venues"],"title":"search_papers_toolArguments","type":"object"}`), &schema)
46+
// Create and populate schema
47+
var schema map[string]any
48+
json.Unmarshal([]byte(`{"properties":{"query":{"description":"Keywords, topics, content, or a chunk of text to search for.","examples":["time series token merging","neural networks","...when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution"],"title":"Query","type":"string"},"top_k":{"description":"Number of top relevant or similar papers to return.","title":"Top K","type":"integer"},"date_min":{"description":"Minimum publication date (YYYY-MM-DD) to filter papers.","examples":["2023-01-01","2022-06-25"],"title":"Date Min","type":"string"},"date_max":{"description":"Maximum publication date (YYYY-MM-DD) to filter papers.","examples":["2024-12-31","2023-06-25"],"title":"Date Max","type":"string"},"countries":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"description":"List of country codes in ISO ALPHA-3 format to filter papers by author affiliations.","examples":[["USA","CHN","SGP","GBR","DEU","KOR","JPN"]],"title":"Countries"},"min_similarity":{"description":"Minimum similarity score (0.0-1.0) for returned papers. Higher values yield more relevant results but fewer papers.","examples":[0.3,0.5,0.7,0.9],"title":"Min Similarity","type":"number"}},"required":["query","top_k","countries","min_similarity"],"title":"search_papers_toolArguments","type":"object"}`), &schema)
49+
50+
// Create tool description with populated schema
51+
description := responses.ToolUnionParam{
52+
OfFunction: &responses.FunctionToolParam{
53+
Name: "search_relevant_papers",
54+
Description: param.NewOpt("Search for similar or relevant papers by keywords against the local database of academic papers. This tool uses semantic search with vector embeddings to find the most relevant results. It is the default and recommended tool for paper searches."),
55+
Parameters: openai.FunctionParameters(schema),
56+
},
57+
}
58+
6259
toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db)
6360
return &SearchPapersTool{
64-
Description: SearchPapersToolDescription,
61+
Description: description,
6562
toolCallRecordDB: toolCallRecordDB,
6663
projectService: projectService,
6764
coolDownTime: 5 * time.Minute,
68-
baseURL: "http://xtragpt-mcp-server:8080/paper-score",
65+
// baseURL: "http://xtragpt-mcp-server:8080/mcp",
66+
baseURL: "http://localhost:8080/mcp", // For local development
6967
client: &http.Client{},
7068
}
7169
}
7270

7371
type SearchPapersToolArgs struct {
74-
Limit int `json:"limit"`
75-
MatchMode string `json:"matchMode"`
76-
MatchThreshold float64 `json:"matchThreshold"`
77-
MinScore float64 `json:"minScore"`
78-
Query string `json:"query"`
79-
Venues []Venue `json:"venues"`
80-
SearchFields []string `json:"searchFields"`
72+
Query string `json:"query"`
73+
TopK int `json:"top_k"`
74+
DateMin *string `json:"date_min,omitempty"`
75+
DateMax *string `json:"date_max,omitempty"`
76+
Countries []string `json:"countries"`
77+
MinSimilarity float64 `json:"min_similarity"`
8178
}
8279

8380
func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) {
@@ -89,19 +86,18 @@ func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args jso
8986

9087
// Create function call record
9188
record, err := t.toolCallRecordDB.Create(ctx, toolCallId, *t.Description.GetName(), map[string]any{
92-
"limit": argsMap.Limit,
93-
"matchMode": argsMap.MatchMode,
94-
"matchThreshold": argsMap.MatchThreshold,
95-
"minScore": argsMap.MinScore,
9689
"query": argsMap.Query,
97-
"venues": argsMap.Venues,
98-
"searchFields": argsMap.SearchFields,
90+
"top_k": argsMap.TopK,
91+
"date_min": argsMap.DateMin,
92+
"date_max": argsMap.DateMax,
93+
"countries": argsMap.Countries,
94+
"min_similarity": argsMap.MinSimilarity,
9995
})
10096
if err != nil {
10197
return "", "", err
10298
}
10399

104-
respStr, err := t.SearchPaper(argsMap.Limit, argsMap.MatchMode, argsMap.MatchThreshold, argsMap.MinScore, argsMap.Query, argsMap.Venues, argsMap.SearchFields)
100+
respStr, err := t.SearchPaper(argsMap.Query, argsMap.TopK, argsMap.DateMin, argsMap.DateMax, argsMap.Countries, argsMap.MinSimilarity)
105101
if err != nil {
106102
err = fmt.Errorf("failed to search paper: %v", err)
107103
t.toolCallRecordDB.OnError(ctx, record, err)
@@ -119,7 +115,7 @@ func (t *SearchPapersTool) Call(ctx context.Context, toolCallId string, args jso
119115
return respStr, "", nil
120116
}
121117

122-
func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThreshold float64, minScore float64, query string, venues []Venue, searchFields []string) (string, error) {
118+
func (t *SearchPapersTool) SearchPaper(query string, topK int, dateMin *string, dateMax *string, countries []string, minSimilarity float64) (string, error) {
123119
sessionId, err := MCPInitialize(t.baseURL)
124120
if err != nil {
125121
fmt.Printf("Error initializing MCP: %v\n", err)
@@ -135,15 +131,14 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho
135131
Method: "tools/call",
136132
ID: 2,
137133
Params: MCPParams{
138-
Name: "search_papers",
134+
Name: "search_relevant_papers",
139135
Arguments: map[string]interface{}{
140-
"limit": limit,
141-
"match_mode": matchMode,
142-
"match_threshold": matchThreshold,
143-
"min_score": minScore,
144-
"query": query,
145-
"search_fields": searchFields,
146-
"venues": venues,
136+
"query": query,
137+
"top_k": topK,
138+
"date_min": dateMin,
139+
"date_max": dateMax,
140+
"countries": countries,
141+
"min_similarity": minSimilarity,
147142
},
148143
},
149144
}
@@ -155,7 +150,7 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho
155150
}
156151

157152
// Create HTTP request
158-
req, err := http.NewRequest("POST", "http://localhost:8080/mcp", bytes.NewBuffer(jsonData))
153+
req, err := http.NewRequest("POST", t.baseURL, bytes.NewBuffer(jsonData))
159154
if err != nil {
160155
return "", fmt.Errorf("failed to create HTTP request: %w", err)
161156
}
@@ -166,8 +161,7 @@ func (t *SearchPapersTool) SearchPaper(limit int, matchMode string, matchThresho
166161
req.Header.Set("mcp-session-id", sessionId)
167162

168163
// Make the request
169-
client := &http.Client{}
170-
resp, err := client.Do(req)
164+
resp, err := t.client.Do(req)
171165
if err != nil {
172166
return "", fmt.Errorf("failed to make request: %w", err)
173167
}

0 commit comments

Comments
 (0)