Skip to content

Commit 565a6fa

Browse files
Session store is not a singleton
Will allow us to swap in different session stores GitOrigin-RevId: 1c2470b
1 parent e6ad93a commit 565a6fa

3 files changed

Lines changed: 33 additions & 11 deletions

File tree

cmd/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ func Serve(transport string) *server.MCPServer {
3636
s.AddTools(logs.Tools(c)...)
3737

3838
if transport == "http" {
39-
if err := server.NewStreamableHTTPServer(s, server.WithHTTPContextFunc(session.ContextWithHTTPSession)).Start(":10000"); err != nil {
39+
sessionStore := session.NewInMemoryStore()
40+
if err := server.NewStreamableHTTPServer(s, server.WithHTTPContextFunc(session.ContextWithHTTPSession(sessionStore))).Start(":10000"); err != nil {
4041
log.Fatalf("Starting Streamable server: %v\n:", err)
4142
}
4243
} else {

pkg/session/http.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@ type sessionCtxKeyType struct{}
1212

1313
var sessionCtxKey sessionCtxKeyType
1414

15-
func ContextWithHTTPSession(ctx context.Context, _ *http.Request) context.Context {
16-
cs := server.ClientSessionFromContext(ctx)
17-
if _, ok := inMemoryDataSingleton[cs.SessionID()]; !ok {
18-
inMemoryDataSingleton[cs.SessionID()] = &HTTPSession{}
15+
func ContextWithHTTPSession(store Store) func(ctx context.Context, _ *http.Request) context.Context {
16+
return func(ctx context.Context, _ *http.Request) context.Context {
17+
return context.WithValue(ctx, sessionCtxKey, store.Get(server.ClientSessionFromContext(ctx).SessionID()))
1918
}
20-
return context.WithValue(ctx, sessionCtxKey, inMemoryDataSingleton[cs.SessionID()])
2119
}
2220

2321
type HTTPSession struct {
@@ -38,4 +36,25 @@ func (h *HTTPSession) SetWorkspace(s string) error {
3836
return nil
3937
}
4038

41-
var inMemoryDataSingleton = map[string]*HTTPSession{}
39+
type Store interface {
40+
Get(sessionID string) *HTTPSession
41+
}
42+
43+
type inMemoryStore struct {
44+
sessions map[string]*HTTPSession
45+
}
46+
47+
var _ Store = (*inMemoryStore)(nil)
48+
49+
func NewInMemoryStore() Store {
50+
return &inMemoryStore{
51+
sessions: make(map[string]*HTTPSession),
52+
}
53+
}
54+
55+
func (i *inMemoryStore) Get(sessionID string) *HTTPSession {
56+
if _, ok := i.sessions[sessionID]; !ok {
57+
i.sessions[sessionID] = &HTTPSession{}
58+
}
59+
return i.sessions[sessionID]
60+
}

pkg/session/http_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ import (
1010
)
1111

1212
func TestHTTPSession(t *testing.T) {
13+
sessionStore := session.NewInMemoryStore()
14+
contextWithHTTPSession := session.ContextWithHTTPSession(sessionStore)
1315
{
1416
ctxOne := (&server.MCPServer{}).WithContext(context.Background(), fakeSession{sessionID: "one"})
15-
ctxOne = session.ContextWithHTTPSession(ctxOne, nil)
17+
ctxOne = contextWithHTTPSession(ctxOne, nil)
1618

1719
sessionOne := session.FromContext(ctxOne)
1820

@@ -36,7 +38,7 @@ func TestHTTPSession(t *testing.T) {
3638

3739
{
3840
ctxOneAgain := (&server.MCPServer{}).WithContext(context.Background(), fakeSession{sessionID: "one"})
39-
ctxOneAgain = session.ContextWithHTTPSession(ctxOneAgain, nil)
41+
ctxOneAgain = contextWithHTTPSession(ctxOneAgain, nil)
4042
sessionOneAgain := session.FromContext(ctxOneAgain)
4143

4244
workspace, err := sessionOneAgain.GetWorkspace()
@@ -50,7 +52,7 @@ func TestHTTPSession(t *testing.T) {
5052

5153
{
5254
ctxTwo := (&server.MCPServer{}).WithContext(context.Background(), fakeSession{sessionID: "two"})
53-
ctxTwo = session.ContextWithHTTPSession(ctxTwo, nil)
55+
ctxTwo = contextWithHTTPSession(ctxTwo, nil)
5456

5557
sessionTwo := session.FromContext(ctxTwo)
5658

@@ -67,7 +69,7 @@ func TestHTTPSession(t *testing.T) {
6769
{
6870

6971
ctxOneFinal := (&server.MCPServer{}).WithContext(context.Background(), fakeSession{sessionID: "one"})
70-
ctxOneFinal = session.ContextWithHTTPSession(ctxOneFinal, nil)
72+
ctxOneFinal = contextWithHTTPSession(ctxOneFinal, nil)
7173
sessionOneFinal := session.FromContext(ctxOneFinal)
7274

7375
workspace, err := sessionOneFinal.GetWorkspace()

0 commit comments

Comments
 (0)