Skip to content

Commit 457999a

Browse files
4ndrelimJunyi-99
andauthored
feat: XtraMCP Integration (#11)
* fix: CORS * chore: tested any schema * Fix search_papers tool * feat: Add XtraMCP loader to handle init and ack * feat: Add DynamicTool to represent generic XtraMCP tool * nit: Refactor for better structure * nit: remove unnecessary files and prints --------- Co-authored-by: Junyi Hou <hji200914@gmail.com>
1 parent b9a80a1 commit 457999a

5 files changed

Lines changed: 424 additions & 10 deletions

File tree

hack/values-dev.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
namespace: paperdebugger-dev
2+
3+
mongo:
4+
in_cluster: false

internal/api/gin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func NewGinServer(cfg *cfg.Cfg, oauthHandler *auth.OAuthHandler) *GinServer {
2121
ginServer := &GinServer{Engine: gin.New(), cfg: cfg}
2222
ginServer.Use(ginServer.ginLogMiddleware(), gin.Recovery())
2323
ginServer.Use(cors.New(cors.Config{
24-
AllowOrigins: []string{"https://overleaf.com", "https://*.overleaf.com", "https://*.paperdebugger.com", "http://localhost:3000", "http://127.0.0.1:3000"},
24+
AllowOrigins: []string{"*"},
2525
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
2626
AllowHeaders: []string{"*"},
2727
ExposeHeaders: []string{"*"},

internal/services/toolkit/client/client.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"paperdebugger/internal/services"
1010
"paperdebugger/internal/services/toolkit/handler"
1111
"paperdebugger/internal/services/toolkit/registry"
12-
"paperdebugger/internal/services/toolkit/tools"
12+
"paperdebugger/internal/services/toolkit/tools/xtramcp"
1313

1414
"github.com/openai/openai-go/v2"
1515
"github.com/openai/openai-go/v2/option"
@@ -42,18 +42,37 @@ func NewAIClient(
4242
option.WithAPIKey(cfg.OpenAIAPIKey),
4343
)
4444
CheckOpenAIWorks(oaiClient, logger)
45-
46-
toolPaperScore := tools.NewPaperScoreTool(db, projectService)
47-
toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService)
45+
// toolPaperScore := tools.NewPaperScoreTool(db, projectService)
46+
// toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService)
4847

4948
toolRegistry := registry.NewToolRegistry()
50-
toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool)
51-
toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool)
52-
toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call)
53-
toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call)
5449

55-
toolCallHandler := handler.NewToolCallHandler(toolRegistry)
50+
// toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool)
51+
// toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool)
52+
// toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call)
53+
// toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call)
54+
55+
// Load tools dynamically from backend (TODO: Make URL configurable / Xtramcp url)
56+
xtraMCPLoader := xtramcp.NewXtraMCPLoader(db, projectService, "http://localhost:8080/mcp")
57+
58+
// initialize MCP session first and log session ID
59+
sessionID, err := xtraMCPLoader.InitializeMCP()
60+
if err != nil {
61+
logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err)
62+
// TODO: Fallback to static tools or exit?
63+
} else {
64+
logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID)
65+
66+
// dynamically load all tools from XtraMCP backend
67+
err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
68+
if err != nil {
69+
logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err)
70+
} else {
71+
logger.Info("[AI Client] Successfully loaded XtraMCP tools")
72+
}
73+
}
5674

75+
toolCallHandler := handler.NewToolCallHandler(toolRegistry)
5776
client := &AIClient{
5877
openaiClient: &oaiClient,
5978
toolCallHandler: toolCallHandler,
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
package xtramcp
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"strings"
10+
"paperdebugger/internal/libs/db"
11+
"paperdebugger/internal/services"
12+
"paperdebugger/internal/services/toolkit/registry"
13+
)
14+
15+
// MCPListToolsResponse represents the JSON-RPC response from tools/list method
16+
type MCPListToolsResponse struct {
17+
JSONRPC string `json:"jsonrpc"`
18+
ID int `json:"id"`
19+
Result struct {
20+
Tools []ToolSchema `json:"tools"`
21+
} `json:"result"`
22+
}
23+
24+
// loads tools dynamically from backend
25+
type XtraMCPLoader struct {
26+
db *db.DB
27+
projectService *services.ProjectService
28+
baseURL string
29+
client *http.Client
30+
sessionID string // Store the MCP session ID after initialization for re-use
31+
}
32+
33+
// NewXtraMCPLoader creates a new dynamic XtraMCP loader
34+
func NewXtraMCPLoader(db *db.DB, projectService *services.ProjectService, baseURL string) *XtraMCPLoader {
35+
return &XtraMCPLoader{
36+
db: db,
37+
projectService: projectService,
38+
baseURL: baseURL,
39+
client: &http.Client{},
40+
}
41+
}
42+
43+
// LoadToolsFromBackend fetches tool schemas from backend and registers them
44+
func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolRegistry) error {
45+
if loader.sessionID == "" {
46+
return fmt.Errorf("MCP session not initialized - call InitializeMCP first")
47+
}
48+
49+
// Fetch tools from backend using the established session
50+
toolSchemas, err := loader.fetchAvailableTools()
51+
if err != nil {
52+
return fmt.Errorf("failed to fetch tools from backend: %w", err)
53+
}
54+
55+
// Register each tool dynamically, passing the session ID
56+
for _, toolSchema := range toolSchemas {
57+
dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
58+
59+
// Register the tool with the registry
60+
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)
61+
62+
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
63+
}
64+
65+
return nil
66+
}
67+
68+
// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
69+
func (loader *XtraMCPLoader) InitializeMCP() (string, error) {
70+
// Step 1: Initialize
71+
sessionID, err := loader.performInitialize()
72+
if err != nil {
73+
return "", fmt.Errorf("step 1 - initialize failed: %w", err)
74+
}
75+
76+
// Step 2: Send notifications/initialized
77+
err = loader.sendInitializedNotification(sessionID)
78+
if err != nil {
79+
return "", fmt.Errorf("step 2 - notifications/initialized failed: %w", err)
80+
}
81+
82+
// Store session ID for future use and return it
83+
loader.sessionID = sessionID
84+
85+
return sessionID, nil
86+
}
87+
88+
// performInitialize performs MCP initialization (1. establish connection)
89+
func (loader *XtraMCPLoader) performInitialize() (string, error) {
90+
initReq := map[string]interface{}{
91+
"jsonrpc": "2.0",
92+
"method": "initialize",
93+
"id": 1,
94+
"params": map[string]interface{}{
95+
"protocolVersion": "2024-11-05",
96+
"capabilities": map[string]interface{}{},
97+
"clientInfo": map[string]interface{}{
98+
"name": "paperdebugger-client",
99+
"version": "1.0.0",
100+
},
101+
},
102+
}
103+
104+
jsonData, err := json.Marshal(initReq)
105+
if err != nil {
106+
return "", fmt.Errorf("failed to marshal initialize request: %w", err)
107+
}
108+
109+
req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
110+
if err != nil {
111+
return "", fmt.Errorf("failed to create initialize request: %w", err)
112+
}
113+
114+
req.Header.Set("Content-Type", "application/json")
115+
req.Header.Set("Accept", "application/json, text/event-stream")
116+
117+
resp, err := loader.client.Do(req)
118+
if err != nil {
119+
return "", fmt.Errorf("failed to make initialize request: %w", err)
120+
}
121+
defer resp.Body.Close()
122+
123+
// Extract session ID from response headers
124+
sessionID := resp.Header.Get("mcp-session-id")
125+
if sessionID == "" {
126+
return "", fmt.Errorf("no session ID returned from initialize")
127+
}
128+
129+
return sessionID, nil
130+
}
131+
132+
// sendInitializedNotification completes MCP initialization (acknowledges initialization)
133+
func (loader *XtraMCPLoader) sendInitializedNotification(sessionID string) error {
134+
notifyReq := map[string]interface{}{
135+
"jsonrpc": "2.0",
136+
"method": "notifications/initialized",
137+
"params": map[string]interface{}{},
138+
}
139+
140+
jsonData, err := json.Marshal(notifyReq)
141+
if err != nil {
142+
return fmt.Errorf("failed to marshal notification: %w", err)
143+
}
144+
145+
req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
146+
if err != nil {
147+
return fmt.Errorf("failed to create notification request: %w", err)
148+
}
149+
150+
req.Header.Set("Content-Type", "application/json")
151+
req.Header.Set("Accept", "application/json, text/event-stream")
152+
req.Header.Set("mcp-session-id", sessionID)
153+
154+
resp, err := loader.client.Do(req)
155+
if err != nil {
156+
return fmt.Errorf("failed to send notification: %w", err)
157+
}
158+
defer resp.Body.Close()
159+
160+
return nil
161+
}
162+
163+
// fetchAvailableTools makes a request to get available tools from backend
164+
func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) {
165+
// List all tools using the established session
166+
requestBody := map[string]interface{}{
167+
"jsonrpc": "2.0",
168+
"method": "tools/list",
169+
"params": map[string]interface{}{},
170+
"id": 2,
171+
}
172+
173+
jsonData, err := json.Marshal(requestBody)
174+
if err != nil {
175+
return nil, fmt.Errorf("failed to marshal request: %w", err)
176+
}
177+
178+
req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
179+
if err != nil {
180+
return nil, fmt.Errorf("failed to create request: %w", err)
181+
}
182+
183+
req.Header.Set("Content-Type", "application/json")
184+
req.Header.Set("Accept", "application/json, text/event-stream")
185+
req.Header.Set("mcp-session-id", loader.sessionID)
186+
187+
resp, err := loader.client.Do(req)
188+
if err != nil {
189+
return nil, fmt.Errorf("failed to make request: %w", err)
190+
}
191+
defer resp.Body.Close()
192+
193+
// Read the raw response body (SSE format) for debugging
194+
bodyBytes, err := io.ReadAll(resp.Body)
195+
if err != nil {
196+
return nil, fmt.Errorf("failed to read response body: %w", err)
197+
}
198+
199+
// Parse SSE format - extract JSON from "data: " lines
200+
lines := strings.Split(string(bodyBytes), "\n")
201+
var extractedJSON string
202+
for _, line := range lines {
203+
if strings.HasPrefix(line, "data: ") {
204+
extractedJSON = strings.TrimPrefix(line, "data: ")
205+
break
206+
}
207+
}
208+
209+
if extractedJSON == "" {
210+
return nil, fmt.Errorf("no data line found in SSE response")
211+
}
212+
213+
// Parse the extracted JSON
214+
var mcpResponse MCPListToolsResponse
215+
err = json.Unmarshal([]byte(extractedJSON), &mcpResponse)
216+
if err != nil {
217+
return nil, fmt.Errorf("failed to parse JSON from SSE data: %w. JSON data: %s", err, extractedJSON)
218+
}
219+
220+
return mcpResponse.Result.Tools, nil
221+
}

0 commit comments

Comments
 (0)