Skip to content

Commit 5a4ecd0

Browse files
csg-pr-botDev Agent
authored andcommitted
Fix mcp gateway session affinity, mcp server update issue and add inspection cron job (#960)
Co-authored-by: Dev Agent <dev-agent@example.com>
1 parent 2d5b822 commit 5a4ecd0

17 files changed

Lines changed: 809 additions & 10 deletions

File tree

_mocks/opencsg.com/csghub-server/builder/store/cache/mock_RedisClient.go

Lines changed: 59 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//go:build ee || saas
2+
3+
package handler
4+
5+
import (
6+
"fmt"
7+
"log/slog"
8+
"net"
9+
"net/http"
10+
"net/http/httputil"
11+
"net/url"
12+
"sync/atomic"
13+
"time"
14+
15+
"opencsg.com/csghub-server/common/utils/trace"
16+
)
17+
18+
const (
19+
headerInternalProxy = "X-Internal-Proxy"
20+
proxyConnectTimeout = 3 * time.Second
21+
)
22+
23+
// MCPProxyAwareHandler wraps the local SDK handler with Redis-based session
24+
// routing. Requests for sessions owned by a remote instance are transparently
25+
// proxied; new sessions are registered in Redis after the SDK assigns an ID.
26+
type MCPProxyAwareHandler struct {
27+
sdkHandler http.Handler
28+
registry MCPSessionRegistry
29+
selfAddr string
30+
}
31+
32+
func NewMCPProxyAwareHandler(sdkHandler http.Handler, registry MCPSessionRegistry, selfAddr string) *MCPProxyAwareHandler {
33+
return &MCPProxyAwareHandler{
34+
sdkHandler: sdkHandler,
35+
registry: registry,
36+
selfAddr: selfAddr,
37+
}
38+
}
39+
40+
func (h *MCPProxyAwareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
41+
sessionID := r.Header.Get(trace.HeaderMcpSessionID)
42+
43+
// No session ID: new session (initialize). Capture the session ID from the
44+
// response and register it in Redis.
45+
if sessionID == "" {
46+
cw := &sessionCapturingWriter{
47+
ResponseWriter: w,
48+
onSessionCreated: func(sid string) {
49+
if err := h.registry.Register(r.Context(), sid, h.selfAddr); err != nil {
50+
slog.ErrorContext(r.Context(), "failed to register new mcp session in redis", "session_id", sid, "error", err)
51+
} else {
52+
slog.InfoContext(r.Context(), "registered new mcp session", "session_id", sid, "instance", h.selfAddr)
53+
}
54+
},
55+
}
56+
h.sdkHandler.ServeHTTP(cw, r)
57+
return
58+
}
59+
60+
// Internal proxy request: already routed by another instance, handle locally.
61+
if r.Header.Get(headerInternalProxy) == "true" {
62+
h.sdkHandler.ServeHTTP(w, r)
63+
return
64+
}
65+
66+
// Existing session: look up which instance owns it.
67+
targetAddr, err := h.registry.Lookup(r.Context(), sessionID)
68+
if err != nil {
69+
slog.DebugContext(r.Context(), "mcp session not found in redis, trying local", "session_id", sessionID, "error", err)
70+
h.sdkHandler.ServeHTTP(w, r)
71+
return
72+
}
73+
74+
// Session is local.
75+
if targetAddr == h.selfAddr {
76+
h.sdkHandler.ServeHTTP(w, r)
77+
return
78+
}
79+
80+
// Session is on a remote instance: proxy the request.
81+
slog.InfoContext(r.Context(), "proxy to remote instance", "session_id", sessionID, "target", targetAddr)
82+
proxyErr := h.proxyToInstance(w, r, targetAddr)
83+
if proxyErr != nil {
84+
slog.WarnContext(r.Context(), "proxy to remote instance failed, cleaning stale session",
85+
"session_id", sessionID, "target", targetAddr, "error", proxyErr)
86+
_ = h.registry.Delete(r.Context(), sessionID)
87+
http.Error(w, "session not found", http.StatusNotFound)
88+
}
89+
}
90+
91+
func (h *MCPProxyAwareHandler) proxyToInstance(w http.ResponseWriter, r *http.Request, targetAddr string) error {
92+
targetURL, err := url.Parse(fmt.Sprintf("http://%s", targetAddr))
93+
if err != nil {
94+
return fmt.Errorf("parse target url: %w", err)
95+
}
96+
97+
var proxyErr atomic.Value
98+
proxy := httputil.NewSingleHostReverseProxy(targetURL)
99+
proxy.FlushInterval = -1 // immediate flush for SSE streaming
100+
proxy.Transport = &http.Transport{
101+
DialContext: (&net.Dialer{Timeout: proxyConnectTimeout}).DialContext,
102+
}
103+
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, e error) {
104+
proxyErr.Store(e)
105+
}
106+
107+
r.Header.Set(headerInternalProxy, "true")
108+
proxy.ServeHTTP(w, r)
109+
110+
if stored := proxyErr.Load(); stored != nil {
111+
return stored.(error)
112+
}
113+
return nil
114+
}
115+
116+
// sessionCapturingWriter intercepts WriteHeader to capture the Mcp-Session-Id
117+
// set by the SDK on initialize responses, then calls onSessionCreated.
118+
type sessionCapturingWriter struct {
119+
http.ResponseWriter
120+
onSessionCreated func(sessionID string)
121+
captured bool
122+
}
123+
124+
func (w *sessionCapturingWriter) WriteHeader(code int) {
125+
if !w.captured {
126+
w.captured = true
127+
if sid := w.Header().Get(trace.HeaderMcpSessionID); sid != "" {
128+
w.onSessionCreated(sid)
129+
}
130+
}
131+
w.ResponseWriter.WriteHeader(code)
132+
}
133+
134+
func (w *sessionCapturingWriter) Write(b []byte) (int, error) {
135+
if !w.captured {
136+
w.captured = true
137+
if sid := w.Header().Get(trace.HeaderMcpSessionID); sid != "" {
138+
w.onSessionCreated(sid)
139+
}
140+
}
141+
return w.ResponseWriter.Write(b)
142+
}
143+
144+
// Flush implements http.Flusher for SSE streaming compatibility.
145+
func (w *sessionCapturingWriter) Flush() {
146+
if f, ok := w.ResponseWriter.(http.Flusher); ok {
147+
f.Flush()
148+
}
149+
}
150+
151+
// Unwrap allows http.ResponseController to reach the underlying writer.
152+
func (w *sessionCapturingWriter) Unwrap() http.ResponseWriter {
153+
return w.ResponseWriter
154+
}

0 commit comments

Comments
 (0)