-
Notifications
You must be signed in to change notification settings - Fork 222
Expand file tree
/
Copy pathsession_manager.go
More file actions
1016 lines (922 loc) · 43.4 KB
/
session_manager.go
File metadata and controls
1016 lines (922 loc) · 43.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0
// Package sessionmanager provides session lifecycle management.
//
// This package implements the two-phase session creation pattern that bridges
// the MCP SDK's session management with the vMCP server's backend lifecycle:
// - Phase 1 (Generate): Creates a placeholder session with no context
// - Phase 2 (CreateSession): Replaces placeholder with fully-initialized MultiSession
//
// The Manager type implements the server.SessionManager interface and is used by
// the server package.
package sessionmanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/cache"
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types"
)
const (
// MetadataKeyTerminated is the session metadata key that marks a placeholder
// session as explicitly terminated by the client.
MetadataKeyTerminated = "terminated"
// MetadataValTrue is the string value stored under MetadataKeyTerminated
// when a session has been terminated.
MetadataValTrue = "true"
)
// Manager bridges the domain session lifecycle (MultiSession / MultiSessionFactory)
// to the mark3labs SDK's SessionIdManager interface.
//
// It implements a two-phase session-creation pattern:
//
// - Generate(): called by SDK during initialize without context;
// stores an empty placeholder via storage.
// - CreateSession(): called from OnRegisterSession hook once
// context is available; calls factory.MakeSessionWithID(), then
// persists the session metadata to storage.
//
// # Storage split
//
// MultiSession holds live in-process state (backend HTTP connections, routing
// table) that cannot be serialized or recovered across processes. A separate
// in-process multiSessions map holds the authoritative MultiSession reference
// for this pod. The pluggable SessionDataStorage (LocalSessionDataStorage or
// RedisSessionDataStorage) carries only the lightweight, serialisable session
// metadata required for TTL management, Validate(), and cross-pod visibility.
//
// Because MultiSession objects are node-local, horizontal scaling requires
// sticky routing when session-affinity is desired. When Redis is used as the
// session-storage backend the metadata is durable across pod restarts, and the
// live MultiSession can be re-created via factory.RestoreSession() on a cache miss.
//
// TODO: Long-term, the cache and storage should be layered behind a single
// interface so the session manager does not need to coordinate between them.
// Reads would go through the cache (handling misses, singleflight, and liveness
// transparently); writes go to storage; caching is an implementation detail
// hidden from the caller.
type Manager struct {
storage transportsession.DataStorage
factory vmcpsession.MultiSessionFactory
backendReg vmcp.BackendRegistry
// sessions is a node-local cache of live MultiSession objects, separate
// from storage because MultiSession contains un-serialisable runtime state
// (HTTP connections, routing tables). On a cache miss it restores the
// session from stored metadata; on a cache hit it confirms liveness via
// storage.Load, which also refreshes the Redis TTL.
sessions *cache.ValidatingCache[string, vmcpsession.MultiSession]
}
// New creates a Manager backed by the given SessionDataStorage and backend
// registry. It builds the decorating session factory from cfg, wiring the
// optimizer and composite tool layers internally.
//
// The returned cleanup function releases any resources allocated during
// construction (e.g. the optimizer's SQLite store). Callers must invoke it
// on shutdown. If no cleanup is needed, a no-op function is returned.
func New(
storage transportsession.DataStorage,
cfg *FactoryConfig,
backendRegistry vmcp.BackendRegistry,
) (*Manager, func(context.Context) error, error) {
if cfg == nil || cfg.Base == nil {
return nil, nil, fmt.Errorf("sessionmanager.New: FactoryConfig.Base (SessionFactory) is required")
}
if cfg.CacheCapacity < 0 {
return nil, nil, fmt.Errorf("sessionmanager.New: CacheCapacity must be >= 0 (got %d)", cfg.CacheCapacity)
}
capacity := cfg.CacheCapacity
if capacity == 0 {
capacity = defaultCacheCapacity
}
if len(cfg.WorkflowDefs) > 0 && cfg.ComposerFactory == nil {
return nil, nil, fmt.Errorf("sessionmanager.New: ComposerFactory is required when WorkflowDefs are provided")
}
// Resolve optimizer factory from config, applying telemetry wrapping if needed.
optimizerFactory, optimizerCleanup, err := resolveOptimizer(cfg)
if err != nil {
return nil, nil, err
}
// Pre-create workflow telemetry instruments once so they are reused across
// all per-session executor wrappers without re-registering metrics.
var instruments *workflowExecutorInstruments
if cfg.TelemetryProvider != nil && len(cfg.WorkflowDefs) > 0 {
instruments, err = newWorkflowExecutorInstruments(
cfg.TelemetryProvider.MeterProvider(),
cfg.TelemetryProvider.TracerProvider(),
)
if err != nil {
if cleanupErr := optimizerCleanup(context.Background()); cleanupErr != nil {
slog.Warn("failed to clean up optimizer after instrument creation error", "error", cleanupErr)
}
return nil, nil, fmt.Errorf("failed to create workflow executor telemetry: %w", err)
}
}
// Build the Manager first so we can reference sm.Terminate and sm.sessions
// directly in closures, eliminating the forward-reference variable pattern.
sm := &Manager{
storage: storage,
backendReg: backendRegistry,
}
sm.sessions = cache.New(
capacity,
sm.loadSession,
sm.checkSession,
func(id string, sess vmcpsession.MultiSession) {
if closeErr := sess.Close(); closeErr != nil {
slog.Warn("session cache: error closing evicted session",
"session_id", id, "error", closeErr)
}
slog.Warn("session cache: session evicted from node-local cache",
"session_id", id)
},
)
sm.factory = buildDecoratingFactory(cfg, optimizerFactory, instruments, sm.Terminate)
cleanup := func(ctx context.Context) error {
return optimizerCleanup(ctx)
}
return sm, cleanup, nil
}
// generateTimeout is the context deadline applied to the storage operations
// inside Generate(). It provides a safety net in addition to the go-redis
// client-level read/write timeouts.
const generateTimeout = 5 * time.Second
// createSessionStorageTimeout bounds each individual storage operation inside
// CreateSession() (two Load checks and one final Store). The caller's ctx is
// used as the parent so auth values and request-level cancellation still
// propagate; this constant adds an upper bound so a slow or unreachable Redis
// cannot block session creation indefinitely. 5 s is consistent with
// generateTimeout and terminateTimeout — all are single-key Redis operations.
const createSessionStorageTimeout = 5 * time.Second
// validateTimeout is the context deadline applied to the storage Load inside
// Validate(). Validate() is called on every incoming HTTP request, so a tight
// timeout bounds how long a slow or unreachable Redis can stall a request goroutine.
const validateTimeout = 3 * time.Second
// restoreStorageTimeout bounds storage.Load calls (GETEX) in the
// GetMultiSession restore path (loadSession) and in the checkSession liveness
// check. Both are single-key Redis reads; 3 s is generous.
const restoreStorageTimeout = 3 * time.Second
// restoreMetadataWriteTimeout bounds the storage.Update call that persists
// the restored session's metadata back to Redis after a successful
// RestoreSession. Single-key Redis SET XX operation; 5 s is consistent with
// other write timeouts (createSessionStorageTimeout, terminateTimeout,
// decorateTimeout, notifyBackendExpiredTimeout).
const restoreMetadataWriteTimeout = 5 * time.Second
// restoreSessionTimeout bounds factory.RestoreSession in the GetMultiSession
// cache-miss path. RestoreSession opens HTTP connections to each backend, so
// we allow more time than a simple storage read. Aligned with discoveryTimeout
// (15 s) since both involve backend HTTP round-trips.
const restoreSessionTimeout = 15 * time.Second
// terminateTimeout is the context deadline applied to storage operations inside
// Terminate(). Terminate() is called on client DELETE requests and on auth
// failures, each of which performs at most one Delete + one Load + one Store
// (all single-key Redis operations). 5 s matches generateTimeout and is
// generous for these operations while still bounding slow/unreachable Redis.
const terminateTimeout = 5 * time.Second
// decorateTimeout bounds the storage.Store call inside DecorateSession().
// DecorateSession is called during session setup (OnRegisterSession hook) and
// performs a single Redis SET. 5 s is consistent with terminateTimeout.
const decorateTimeout = 5 * time.Second
// notifyBackendExpiredTimeout bounds the storage.Update call inside
// NotifyBackendExpired() — a single-key Redis operation, consistent with
// terminateTimeout and decorateTimeout.
const notifyBackendExpiredTimeout = 5 * time.Second
// Generate implements the SDK's SessionIdManager.Generate().
//
// Phase 1 of the two-phase creation pattern: creates a unique session ID,
// stores an empty placeholder via storage, and returns the ID to the SDK.
// No context is available at this point.
//
// The placeholder is replaced by CreateSession() in Phase 2 once context
// is available via the OnRegisterSession hook.
func (sm *Manager) Generate() string {
// Two attempts: the second handles both storage transients and the
// astronomically unlikely (but now correctly detected) UUID collision.
// Each attempt gets its own context so an expired deadline on attempt 0
// does not immediately abort attempt 1.
for attempt := range 2 {
ctx, cancel := context.WithTimeout(context.Background(), generateTimeout)
sessionID := uuid.New().String()
// Create is an atomic SET NX on Redis, eliminating the TOCTOU
// race that a Load+Upsert would have in a multi-pod deployment.
stored, err := sm.storage.Create(ctx, sessionID, map[string]string{})
cancel()
if err != nil {
slog.Error("Manager: failed to store placeholder session",
"session_id", sessionID, "attempt", attempt+1, "error", err)
continue
}
if !stored {
slog.Warn("Manager: UUID collision detected; retrying", "session_id", sessionID)
continue
}
slog.Debug("Manager: generated placeholder session", "session_id", sessionID)
return sessionID
}
slog.Error("Manager: failed to generate unique session ID after 2 attempts")
return ""
}
// CreateSession is Phase 2 of the two-phase creation pattern.
//
// It is called from the OnRegisterSession hook once the request context is
// available. It:
// 1. Resolves the caller identity from the context.
// 2. Lists available backends from the registry.
// 3. Calls MultiSessionFactory.MakeSessionWithID() to build a fully-formed
// MultiSession (which opens real HTTP connections to each backend).
// 4. Persists session metadata to storage and caches the live MultiSession
// in the node-local map.
//
// The returned MultiSession can be retrieved later via GetMultiSession().
func (sm *Manager) CreateSession(
ctx context.Context,
sessionID string,
) (vmcpsession.MultiSession, error) {
if sessionID == "" {
return nil, fmt.Errorf("Manager.CreateSession: session ID must not be empty")
}
// Fast-fail before opening any backend connections: verify the phase-1
// placeholder still exists and has not been marked terminated. A client
// DELETE between Generate() and this hook sets terminated=true on the
// placeholder (or removes it entirely). Opening backend connections first
// and checking afterwards would waste those resources and could silently
// resurrect a session the client intentionally ended.
loadCtx1, loadCancel1 := context.WithTimeout(ctx, createSessionStorageTimeout)
placeholder, err := sm.storage.Load(loadCtx1, sessionID)
loadCancel1()
if errors.Is(err, transportsession.ErrSessionNotFound) {
return nil, fmt.Errorf(
"Manager.CreateSession: placeholder for session %q not found (terminated concurrently?)",
sessionID,
)
}
if err != nil {
return nil, fmt.Errorf("Manager.CreateSession: failed to load placeholder for session %q: %w", sessionID, err)
}
if placeholder[MetadataKeyTerminated] == MetadataValTrue {
return nil, fmt.Errorf(
"Manager.CreateSession: session %q was terminated before backend connections could be opened",
sessionID,
)
}
// Resolve the caller identity (may be nil for anonymous access).
identity, _ := auth.IdentityFromContext(ctx)
// List all available backends from the registry.
backends := sm.listAllBackends(ctx)
// Build the fully-formed MultiSession using the SDK-assigned session ID.
sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, backends)
if err != nil {
sm.cleanupFailedPlaceholder(sessionID, placeholder)
return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err)
}
// Re-check that the placeholder is still present AND not terminated after
// the (potentially slow) MakeSessionWithID call. A concurrent DELETE could:
// 1. Delete the placeholder entirely (caught by ErrSessionNotFound), OR
// 2. Mark it terminated=true (caught by terminated flag check)
// Without this second check, storage.Store would silently resurrect a
// session the client already terminated, wasting backend connections.
loadCtx2, loadCancel2 := context.WithTimeout(ctx, createSessionStorageTimeout)
placeholder2, err := sm.storage.Load(loadCtx2, sessionID)
loadCancel2()
if errors.Is(err, transportsession.ErrSessionNotFound) {
_ = sess.Close()
return nil, fmt.Errorf(
"Manager.CreateSession: placeholder for session %q disappeared during backend init (terminated concurrently)",
sessionID,
)
}
if err != nil {
_ = sess.Close()
sm.cleanupFailedPlaceholder(sessionID, placeholder)
return nil, fmt.Errorf(
"Manager.CreateSession: failed to re-check placeholder for session %q after backend init: %w",
sessionID, err,
)
}
if placeholder2[MetadataKeyTerminated] == MetadataValTrue {
_ = sess.Close()
return nil, fmt.Errorf(
"Manager.CreateSession: session %q was terminated during backend init (marked after first check)",
sessionID,
)
}
// Persist the serialisable session metadata to the pluggable backend (e.g.
// Redis) so that Validate() and TTL management work correctly. The live
// MultiSession itself is cached in the node-local multiSessions map below.
//
// Use Update (SET XX) rather than Upsert to close the TOCTOU window between
// the second placeholder check above and this write. If Terminate deleted the
// key in that window, Update returns (false, nil) and we bail without
// resurrecting the deleted session.
storeCtx, storeCancel := context.WithTimeout(ctx, createSessionStorageTimeout)
defer storeCancel()
stored, err := sm.storage.Update(storeCtx, sessionID, sess.GetMetadata())
if err != nil {
_ = sess.Close()
sm.cleanupFailedPlaceholder(sessionID, placeholder2)
return nil, fmt.Errorf("Manager.CreateSession: failed to store session metadata: %w", err)
}
if !stored {
_ = sess.Close()
return nil, fmt.Errorf(
"Manager.CreateSession: session %q was terminated between placeholder check and metadata store",
sessionID,
)
}
// Cache the live MultiSession so that GetMultiSession can retrieve it.
sm.sessions.Set(sessionID, sess)
slog.Debug("Manager: created multi-session",
"session_id", sessionID,
"backend_count", len(backends))
return sess, nil
}
// cleanupFailedPlaceholder marks a placeholder session as terminated in storage
// after a CreateSession failure. This prevents Validate() from returning
// (false, nil) for an orphaned placeholder (which would make the SDK treat it
// as a valid session), and prevents repeated Validate() calls from refreshing
// the Redis TTL and keeping the placeholder alive indefinitely.
//
// Uses Update (SET XX) so that a Terminate() that already deleted the key is
// not inadvertently resurrected as a terminated entry.
//
// Cleanup is best-effort: errors are logged but not returned, since the caller
// already has an error to report.
func (sm *Manager) cleanupFailedPlaceholder(sessionID string, metadata map[string]string) {
// Copy before mutating so the caller's map is not modified.
terminated := make(map[string]string, len(metadata)+1)
for k, v := range metadata {
terminated[k] = v
}
terminated[MetadataKeyTerminated] = MetadataValTrue
cleanupCtx, cancel := context.WithTimeout(context.Background(), createSessionStorageTimeout)
defer cancel()
if _, err := sm.storage.Update(cleanupCtx, sessionID, terminated); err != nil {
slog.Warn("Manager.CreateSession: failed to mark failed placeholder as terminated; it will linger until TTL expires",
"session_id", sessionID, "error", err)
}
}
// Validate implements the SDK's SessionIdManager.Validate().
//
// Returns (isTerminated=true, nil) for explicitly terminated sessions.
// Returns (false, error) for unknown sessions — per the SDK interface contract,
// a lookup failure is signalled via err, not via isTerminated.
// Returns (false, nil) for valid, active sessions.
func (sm *Manager) Validate(sessionID string) (isTerminated bool, err error) {
if sessionID == "" {
return false, fmt.Errorf("Manager.Validate: empty session ID")
}
ctx, cancel := context.WithTimeout(context.Background(), validateTimeout)
defer cancel()
metadata, err := sm.storage.Load(ctx, sessionID)
if errors.Is(err, transportsession.ErrSessionNotFound) {
slog.Debug("Manager.Validate: session not found", "session_id", sessionID)
return false, fmt.Errorf("session not found")
}
if err != nil {
return false, fmt.Errorf("Manager.Validate: storage error for session %q: %w", sessionID, err)
}
if metadata[MetadataKeyTerminated] == MetadataValTrue {
slog.Debug("Manager.Validate: session is terminated", "session_id", sessionID)
return true, nil
}
return false, nil
}
// Terminate implements the SDK's SessionIdManager.Terminate().
//
// The two session types are handled asymmetrically to prevent a race condition
// where client termination during the Phase 1→Phase 2 window could resurrect
// sessions with open backend connections:
//
// - MultiSession (Phase 2): the storage key is deleted. The node-local cache
// self-heals on the next Get: checkSession detects ErrSessionNotFound,
// evicts the entry, and onEvict closes backend connections. After deletion
// Validate() returns (false, error) — the same response as "never existed".
//
// - Placeholder (Phase 1): the session is marked terminated=true and left
// for TTL cleanup. This prevents CreateSession() from opening backend
// connections for an already-terminated session (see fast-fail check in
// CreateSession). The terminated flag also lets Validate() return
// (isTerminated=true, nil) during the window between termination and TTL
// expiry, allowing the SDK to distinguish "actively terminated" from
// "never existed".
//
// Returns (isNotAllowed=false, nil) on success; client termination is always permitted.
func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) {
if sessionID == "" {
return false, fmt.Errorf("Manager.Terminate: empty session ID")
}
ctx, cancel := context.WithTimeout(context.Background(), terminateTimeout)
defer cancel()
// Load current metadata to determine session phase.
metadata, loadErr := sm.storage.Load(ctx, sessionID)
if errors.Is(loadErr, transportsession.ErrSessionNotFound) {
// Already gone (concurrent termination or TTL expiry).
slog.Debug("Manager.Terminate: session not found (already expired?)", "session_id", sessionID)
return false, nil
}
if loadErr != nil {
return false, fmt.Errorf("Manager.Terminate: failed to load session %q: %w", sessionID, loadErr)
}
if _, isFullSession := metadata[sessiontypes.MetadataKeyIdentityBinding]; isFullSession {
// Phase 2 (full MultiSession): delete from storage. The cache entry will be
// evicted lazily on the next Get when checkSession finds the session gone.
if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil {
return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr)
}
slog.Info("Manager.Terminate: session terminated", "session_id", sessionID)
return false, nil
}
// Phase 1 (placeholder): mark terminated so CreateSession fast-fails and
// Validate returns isTerminated=true during the TTL window.
// Use Update (SET XX) rather than Upsert so we never resurrect a key that
// was concurrently deleted or expired between the Load above and this write.
// (false, nil) means already gone — treat as success.
metadata[MetadataKeyTerminated] = MetadataValTrue
updated, storeErr := sm.storage.Update(ctx, sessionID, metadata)
if storeErr != nil {
slog.Warn("Manager.Terminate: failed to persist terminated flag for placeholder; attempting delete fallback",
"session_id", sessionID, "error", storeErr)
deleteCtx, deleteCancel := context.WithTimeout(context.Background(), terminateTimeout)
if deleteErr := sm.storage.Delete(deleteCtx, sessionID); deleteErr != nil {
deleteCancel()
return false, fmt.Errorf(
"Manager.Terminate: failed to persist terminated flag and delete placeholder: storeErr=%v, deleteErr=%w",
storeErr, deleteErr)
}
deleteCancel()
} else if !updated {
// Session expired or was concurrently deleted between Load and Update — already gone.
slog.Debug("Manager.Terminate: placeholder already gone before terminated flag could be set", "session_id", sessionID)
}
slog.Info("Manager.Terminate: session terminated", "session_id", sessionID)
return false, nil
}
// NotifyBackendExpired updates session metadata in storage to reflect that the
// backend identified by workloadID is no longer connected. It removes the
// per-backend session ID key and rebuilds MetadataKeyBackendIDs so that a
// cross-pod RestoreSession call does not attempt to reconnect to the expired
// backend session.
//
// The caller supplies the session metadata it already holds (e.g. from
// MultiSession.GetMetadata). Passing nil metadata is treated as "no metadata
// available" and is a silent no-op, avoiding a redundant storage round-trip.
//
// After a successful storage update, the cached entry is not immediately evicted.
// On the next GetMultiSession call, checkSession detects that the stored
// MetadataKeyBackendIDs differs from the cached session's value, evicts the stale
// entry via onEvict, and triggers RestoreSession with the updated metadata.
// On storage error, no eviction occurs and the caller retries on the next access.
//
// This is a best-effort operation. If the session key is absent from storage
// (terminated or expired), updateMetadata's SET XX is a no-op. Storage errors
// are logged but not returned.
func (sm *Manager) NotifyBackendExpired(sessionID, workloadID string, metadata map[string]string) {
if metadata == nil {
return
}
if metadata[MetadataKeyTerminated] == MetadataValTrue {
return
}
// MetadataKeyBackendIDs must be present. An absent key means the metadata
// is corrupted or was never fully initialised; clobbering it with "" would
// silently drop all remaining backends from subsequent restores.
backendIDs, backendIDsPresent := metadata[vmcpsession.MetadataKeyBackendIDs]
if !backendIDsPresent {
slog.Warn("NotifyBackendExpired: MetadataKeyBackendIDs absent from session metadata; skipping update",
"session_id", sessionID,
"workload_id", workloadID)
return
}
// Build updated metadata: remove the expired backend's session-ID key and
// rebuild MetadataKeyBackendIDs. Always write the key (even as "") to match
// populateBackendMetadata, which uses key presence to distinguish an
// explicit zero-backend state from absent/corrupted metadata in
// RestoreSession. Trim spaces and drop empty parts for robustness.
//
// Copy before mutating so the caller's map is not modified. Mutating the
// caller's map would silently corrupt the in-memory session state, which
// would defeat lazy eviction: checkSession compares stored vs cached
// MetadataKeyBackendIDs to detect drift, so the values must differ after
// this update for eviction to trigger on the next GetMultiSession call.
updated := make(map[string]string, len(metadata))
for k, v := range metadata {
updated[k] = v
}
delete(updated, vmcpsession.MetadataKeyBackendSessionPrefix+workloadID)
var remaining []string
for _, p := range strings.Split(backendIDs, ",") {
if t := strings.TrimSpace(p); t != "" && t != workloadID {
remaining = append(remaining, t)
}
}
updated[vmcpsession.MetadataKeyBackendIDs] = strings.Join(remaining, ",")
if err := sm.updateMetadata(sessionID, updated); err != nil {
slog.Warn("NotifyBackendExpired: failed to persist backend expiry to storage",
"session_id", sessionID,
"workload_id", workloadID,
"error", err)
}
}
// updateMetadata writes a complete metadata snapshot to storage using a
// conditional Update (SET XX). If the key is absent at update time (concurrent
// Delete), the call is a no-op. The cache self-heals on the next GetMultiSession
// call: checkSession detects metadata drift, evicts the stale entry, and
// RestoreSession reloads with fresh state.
func (sm *Manager) updateMetadata(sessionID string, metadata map[string]string) error {
ctx, cancel := context.WithTimeout(context.Background(), notifyBackendExpiredTimeout)
defer cancel()
// Update only succeeds if the key still exists. A concurrent Delete (same
// pod or cross-pod) returns (false, nil), and we bail without resurrecting.
updated, err := sm.storage.Update(ctx, sessionID, metadata)
if err != nil {
return err
}
if !updated {
return nil // session was terminated; nothing to update
}
// The cache self-heals lazily: on the next GetMultiSession, checkSession detects
// either the absent storage key or stale MetadataKeyBackendIDs and evicts the
// entry, triggering a fresh RestoreSession.
return nil
}
// GetMultiSession retrieves the fully-formed MultiSession for a given SDK session ID.
// Returns (nil, false) if the session does not exist or has not yet been
// upgraded from placeholder to MultiSession.
//
// On a cache hit, liveness is confirmed via storage.Load (which also refreshes
// the Redis TTL). On a cache miss, the session is restored from storage via
// factory.RestoreSession, enabling cross-pod session recovery when Redis is
// used as the storage backend.
//
// Known limitation: GetMultiSession's signature is fixed by the
// MultiSessionGetter interface and carries no context. Both the liveness
// check and the restore path use context.Background() with per-operation
// timeouts (restoreStorageTimeout / restoreSessionTimeout), so they are
// bounded independently of any caller deadline. The caller's HTTP request
// cancellation cannot propagate here.
// TODO: add context propagation through MultiSessionGetter so the caller's
// deadline can further bound these operations.
func (sm *Manager) GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) {
return sm.sessions.Get(sessionID)
}
// checkSession is the liveness check supplied to sessions. It confirms the
// storage entry is still alive and refreshes the Redis TTL as a side effect.
// It returns ErrExpired when the session has been deleted or terminated
// (including termination by another pod), so the cache evicts the entry and
// onEvict closes backend connections.
//
// Cross-pod propagation: if the stored backend list differs from the cached
// session's, ErrExpired is returned to evict the stale entry. The next
// GetMultiSession call triggers RestoreSession with the up-to-date metadata,
// replacing the old session and its backend connections. This ensures that a
// backend-expiry update written by pod A propagates to pod B on the next
// cache access rather than waiting for natural TTL expiry.
func (sm *Manager) checkSession(sessionID string, sess vmcpsession.MultiSession) error {
checkCtx, cancel := context.WithTimeout(context.Background(), restoreStorageTimeout)
defer cancel()
metadata, err := sm.storage.Load(checkCtx, sessionID)
if errors.Is(err, transportsession.ErrSessionNotFound) {
return cache.ErrExpired
}
if err != nil {
return err // transient storage error — keep cached
}
if metadata[MetadataKeyTerminated] == MetadataValTrue {
return cache.ErrExpired
}
// Evict if the backend ID list has drifted (e.g. NotifyBackendExpired removed a
// backend), so the next Get calls RestoreSession with the updated backend list.
//
// We intentionally compare only MetadataKeyBackendIDs rather than the full
// metadata map. Per-backend session IDs (MetadataKeyBackendSessionPrefix+*)
// are the session IDs negotiated by each pod's independent RestoreSession call.
// Backends that do not honor Mcp-Session-Id hints (e.g. SSE transports, some
// StreamableHTTP backends) assign a fresh ID on every restore, so different pods
// legitimately hold different per-backend IDs for the same session. Comparing
// the full map would cause each pod's loadSession write-back to invalidate all
// other pods' cached sessions, creating an infinite eviction storm that prevents
// tools from ever being served in multi-pod deployments.
sessBackendIDs := sess.GetMetadata()[vmcpsession.MetadataKeyBackendIDs]
if sessBackendIDs != metadata[vmcpsession.MetadataKeyBackendIDs] {
return cache.ErrExpired
}
return nil
}
// loadSession is the restore function supplied to sessions. It loads session
// metadata from storage and calls factory.RestoreSession to reconnect to
// backends, returning the fully-formed MultiSession on success.
func (sm *Manager) loadSession(sessionID string) (vmcpsession.MultiSession, error) {
loadCtx, loadCancel := context.WithTimeout(context.Background(), restoreStorageTimeout)
defer loadCancel()
metadata, loadErr := sm.storage.Load(loadCtx, sessionID)
if loadErr != nil {
if !errors.Is(loadErr, transportsession.ErrSessionNotFound) {
slog.Warn("Manager.loadSession: storage error; treating as not found",
"session_id", sessionID, "error", loadErr)
}
return nil, loadErr
}
// Don't restore terminated sessions.
if metadata[MetadataKeyTerminated] == MetadataValTrue {
return nil, transportsession.ErrSessionNotFound
}
// Don't restore placeholder sessions (Phase 2 never ran).
// BindSession always writes MetadataKeyIdentityBinding during Phase 2
// (the unauthenticated sentinel for anonymous sessions, a bound (iss, sub)
// binding for authenticated ones). Its absence means Generate() stored
// this record but CreateSession() never completed — treat it as "not
// found" rather than "corrupted".
//
// Note: this is intentionally different from RestoreSession's fail-closed
// check (absent key → error). Here we know a placeholder's empty metadata
// is valid storage state produced by Generate(), so we return the
// SDK-standard ErrSessionNotFound instead of an error.
if _, bindingPresent := metadata[sessiontypes.MetadataKeyIdentityBinding]; !bindingPresent {
return nil, transportsession.ErrSessionNotFound
}
restoreCtx, restoreCancel := context.WithTimeout(context.Background(), restoreSessionTimeout)
defer restoreCancel()
restored, restoreErr := sm.factory.RestoreSession(restoreCtx, sessionID, metadata, sm.listAllBackends(restoreCtx))
if restoreErr != nil {
slog.Warn("Manager.loadSession: failed to restore session from storage",
"session_id", sessionID, "error", restoreErr)
return nil, restoreErr
}
// Persist the restored session's metadata back to Redis so that
// per-backend session IDs are kept current. Backends that do not honor
// Mcp-Session-Id hints (e.g. SSE transports) assign a fresh ID on every
// restore; without this write the stale IDs would persist in Redis
// indefinitely.
//
// We use Update (SET XX) rather than Upsert so we never resurrect a key
// that was concurrently deleted (Terminate / TTL expiry). A (false, nil)
// result means the key is already gone — treat it as not found so the
// cache never serves a session that no longer exists in storage.
updateCtx, updateCancel := context.WithTimeout(context.Background(), restoreMetadataWriteTimeout)
defer updateCancel()
updated, updateErr := sm.storage.Update(updateCtx, sessionID, restored.GetMetadata())
if updateErr != nil {
slog.Warn("Manager.loadSession: failed to persist restored session metadata",
"session_id", sessionID, "error", updateErr)
// Non-fatal: the session is still usable on this pod. checkSession
// will detect metadata drift on the next liveness check and evict,
// triggering a fresh restore that will retry the write.
} else if !updated {
// Session was concurrently deleted (Terminate / TTL expiry) between
// RestoreSession and this write — do not cache the restored session.
slog.Debug("Manager.loadSession: session already gone before metadata could be persisted; treating as not found",
"session_id", sessionID)
if closeErr := restored.Close(); closeErr != nil {
slog.Warn("Manager.loadSession: failed to close restored session after concurrent deletion",
"session_id", sessionID, "error", closeErr)
}
return nil, transportsession.ErrSessionNotFound
}
slog.Debug("Manager.loadSession: restored session from storage", "session_id", sessionID)
return restored, nil
}
// DecorateSession retrieves the MultiSession for sessionID, applies fn to it,
// and stores the result back. Returns an error if the session is not found or
// has not yet been upgraded from placeholder to MultiSession.
//
// storage.Update is the concurrency guard. If it returns (false, nil), the
// session was deleted; the cache entry will be evicted on the next Get when
// checkSession detects ErrSessionNotFound.
func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error {
sess, ok := sm.GetMultiSession(sessionID)
if !ok {
return fmt.Errorf("DecorateSession: session %q not found or not a multi-session", sessionID)
}
decorated := fn(sess)
if decorated == nil {
return fmt.Errorf("DecorateSession: decorator returned nil session")
}
if decorated.ID() != sessionID {
return fmt.Errorf("DecorateSession: decorator changed session ID from %q to %q", sessionID, decorated.ID())
}
// Persist metadata to storage first via conditional Update (SET XX).
// Only update the node-local cache after a successful write so that a
// storage error or a concurrent delete never leaves a decorated (but
// unpersisted) value in the cache where retries could stack decorations.
decorateCtx, decorateCancel := context.WithTimeout(context.Background(), decorateTimeout)
defer decorateCancel()
updated, err := sm.storage.Update(decorateCtx, sessionID, decorated.GetMetadata())
if err != nil {
return fmt.Errorf("DecorateSession: failed to store decorated session metadata: %w", err)
}
if !updated {
// Session was deleted (by Terminate or TTL) between Get and Update.
// The cache entry will be evicted lazily on the next Get when checkSession
// finds the session gone from storage.
return fmt.Errorf("DecorateSession: session %q was deleted during decoration", sessionID)
}
sm.sessions.Set(sessionID, decorated)
return nil
}
// GetAdaptedTools returns SDK-format tools for the given session, with handlers
// that delegate tool invocations directly to the session's CallTool() method.
//
// When the session factory is configured with an aggregator (WithAggregator),
// tools are in their final resolved form — overrides and conflict resolution
// applied via ProcessPreQueriedCapabilities. Each handler passes the resolved
// tool name to CallTool, which translates it back to the original backend name
// via GetBackendCapabilityName.
//
// Without an aggregator, raw backend tool names are used as-is (no overrides
// or conflict resolution applied).
func (sm *Manager) GetAdaptedTools(sessionID string) ([]mcpserver.ServerTool, error) {
multiSess, ok := sm.GetMultiSession(sessionID)
if !ok {
return nil, fmt.Errorf("Manager.GetAdaptedTools: session %q not found or not a multi-session", sessionID)
}
domainTools := multiSess.Tools()
sdkTools := make([]mcpserver.ServerTool, 0, len(domainTools))
for _, domainTool := range domainTools {
schemaJSON, err := json.Marshal(domainTool.InputSchema)
if err != nil {
return nil, fmt.Errorf("Manager.GetAdaptedTools: failed to marshal schema for tool %s: %w", domainTool.Name, err)
}
tool := mcp.Tool{
Name: domainTool.Name,
Description: domainTool.Description,
RawInputSchema: schemaJSON,
Annotations: conversion.ToMCPToolAnnotations(domainTool.Annotations),
}
if domainTool.OutputSchema != nil {
outputSchemaJSON, marshalErr := json.Marshal(domainTool.OutputSchema)
if marshalErr != nil {
slog.Warn("failed to marshal tool output schema",
"tool", domainTool.Name, "error", marshalErr)
} else {
tool.RawOutputSchema = outputSchemaJSON
}
}
capturedSess := multiSess
capturedSessionID := sessionID
capturedToolName := domainTool.Name
handler := func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args, ok := req.Params.Arguments.(map[string]any)
if !ok {
wrappedErr := fmt.Errorf("%w: arguments must be object, got %T", vmcp.ErrInvalidInput, req.Params.Arguments)
slog.Warn("invalid arguments for tool", "tool", capturedToolName, "error", wrappedErr)
return mcp.NewToolResultError(wrappedErr.Error()), nil
}
meta := conversion.FromMCPMeta(req.Params.Meta)
caller, _ := auth.IdentityFromContext(ctx)
result, callErr := capturedSess.CallTool(ctx, caller, capturedToolName, args, meta)
if callErr != nil {
if errors.Is(callErr, sessiontypes.ErrUnauthorizedCaller) || errors.Is(callErr, sessiontypes.ErrNilCaller) {
slog.Warn("caller authorization failed, terminating session",
"session_id", capturedSessionID, "tool", capturedToolName, "error", callErr)
if _, termErr := sm.Terminate(capturedSessionID); termErr != nil {
slog.Error("failed to terminate session after auth failure",
"session_id", capturedSessionID, "error", termErr)
}
return mcp.NewToolResultError(fmt.Sprintf("Unauthorized: %v", callErr)), nil
}
return mcp.NewToolResultError(callErr.Error()), nil
}
return &mcp.CallToolResult{
Result: mcp.Result{
Meta: conversion.ToMCPMeta(result.Meta),
},
Content: conversion.ToMCPContents(result.Content),
StructuredContent: result.StructuredContent,
IsError: result.IsError,
}, nil
}
sdkTools = append(sdkTools, mcpserver.ServerTool{
Tool: tool,
Handler: handler,
})
slog.Debug("Manager.GetAdaptedTools: adapted tool", "session_id", sessionID, "tool", domainTool.Name)
}
return sdkTools, nil
}
// GetAdaptedResources returns SDK-format resources for the given session, with handlers
// that delegate read requests directly to the session's ReadResource() method.
func (sm *Manager) GetAdaptedResources(sessionID string) ([]mcpserver.ServerResource, error) {
multiSess, ok := sm.GetMultiSession(sessionID)
if !ok {
return nil, fmt.Errorf("Manager.GetAdaptedResources: session %q not found or not a multi-session", sessionID)
}
domainResources := multiSess.Resources()
sdkResources := make([]mcpserver.ServerResource, 0, len(domainResources))
for _, domainResource := range domainResources {
resource := mcp.Resource{
Name: domainResource.Name,
URI: domainResource.URI,
Description: domainResource.Description,
MIMEType: domainResource.MimeType,
}
capturedSess := multiSess
capturedSessionID := sessionID
capturedResourceURI := domainResource.URI
handler := func(ctx context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
caller, _ := auth.IdentityFromContext(ctx)
result, readErr := capturedSess.ReadResource(ctx, caller, capturedResourceURI)
if readErr != nil {
if errors.Is(readErr, sessiontypes.ErrUnauthorizedCaller) || errors.Is(readErr, sessiontypes.ErrNilCaller) {
slog.Warn("caller authorization failed, terminating session",
"session_id", capturedSessionID, "resource", capturedResourceURI, "error", readErr)
if _, termErr := sm.Terminate(capturedSessionID); termErr != nil {
slog.Error("failed to terminate session after auth failure",
"session_id", capturedSessionID, "error", termErr)
}
return nil, fmt.Errorf("unauthorized: %w", readErr)
}
return nil, readErr
}
return conversion.ToMCPResourceContents(result.Contents), nil
}
sdkResources = append(sdkResources, mcpserver.ServerResource{
Resource: resource,
Handler: handler,
})
slog.Debug("Manager.GetAdaptedResources: adapted resource", "session_id", sessionID, "uri", domainResource.URI)
}
return sdkResources, nil
}
// GetAdaptedPrompts returns SDK-format prompts for the given session, with handlers
// that delegate prompt requests directly to the session's GetPrompt() method.
func (sm *Manager) GetAdaptedPrompts(sessionID string) ([]mcpserver.ServerPrompt, error) {
multiSess, ok := sm.GetMultiSession(sessionID)
if !ok {
return nil, fmt.Errorf("Manager.GetAdaptedPrompts: session %q not found or not a multi-session", sessionID)
}
domainPrompts := multiSess.Prompts()
sdkPrompts := make([]mcpserver.ServerPrompt, 0, len(domainPrompts))
for _, domainPrompt := range domainPrompts {
prompt := mcp.Prompt{
Name: domainPrompt.Name,
Description: domainPrompt.Description,
}
for _, arg := range domainPrompt.Arguments {
prompt.Arguments = append(prompt.Arguments, mcp.PromptArgument{
Name: arg.Name,
Description: arg.Description,
Required: arg.Required,
})
}
capturedSess := multiSess
capturedSessionID := sessionID
capturedPromptName := domainPrompt.Name
handler := func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
caller, _ := auth.IdentityFromContext(ctx)
args := make(map[string]any, len(req.Params.Arguments))
for k, v := range req.Params.Arguments {
args[k] = v
}
result, getErr := capturedSess.GetPrompt(ctx, caller, capturedPromptName, args)
if getErr != nil {
if errors.Is(getErr, sessiontypes.ErrUnauthorizedCaller) || errors.Is(getErr, sessiontypes.ErrNilCaller) {
slog.Warn("caller authorization failed, terminating session",
"session_id", capturedSessionID, "prompt", capturedPromptName, "error", getErr)
if _, termErr := sm.Terminate(capturedSessionID); termErr != nil {
slog.Error("failed to terminate session after auth failure",
"session_id", capturedSessionID, "error", termErr)
}
return nil, fmt.Errorf("unauthorized: %w", getErr)
}
return nil, getErr
}
mcpMessages := make([]mcp.PromptMessage, 0, len(result.Messages))
for _, msg := range result.Messages {
mcpMessages = append(mcpMessages, mcp.PromptMessage{
Role: mcp.Role(msg.Role),
Content: conversion.ToMCPContent(msg.Content),
})
}
return &mcp.GetPromptResult{
Description: result.Description,
Messages: mcpMessages,
}, nil
}
sdkPrompts = append(sdkPrompts, mcpserver.ServerPrompt{
Prompt: prompt,
Handler: handler,