Skip to content

Commit 1f658a2

Browse files
yroblataskbot
andauthored
Add RestoreHijackPrevention and RestoreSession interface stub (#4405)
* Add RestoreHijackPrevention and RestoreSession interface stub (RC-15, #4216) Add the infrastructure needed to reconstruct hijack-prevention state from Redis-persisted metadata, enabling cross-pod token validation in horizontal scaling scenarios. - security.go: Add RestoreHijackPrevention(), the restore counterpart to PreventSessionHijacking(). Rebuilds a hijackPreventionDecorator from persisted tokenHash + tokenSaltHex + hmacSecret without re-hashing a live token. Returns errors for nil session, missing salt on authenticated sessions, and invalid hex salt. - factory.go: Add RestoreSession() to the MultiSessionFactory interface with full doc comment (backend ID parsing, session hint lookup, routing-table rebuild, hijack-prevention re-application). Add a stub implementation on defaultMultiSessionFactory (returns "not yet implemented"); full reconnection logic is deferred. Document the cross-replica HMAC secret consistency requirement on defaultHMACSecret. - decorating_factory.go: Forward RestoreSession() to the base factory; decorators are not re-applied during restore. - mocks/mock_factory.go: Regenerate mock to include RestoreSession(). - restore_test.go: Unit tests covering nil session, missing salt, invalid hex, anonymous session round-trip, authenticated store→restore→validate round-trip, and cross-replica secret mismatch. Closes #4216. * changes from rebase --------- Co-authored-by: taskbot <taskbot@users.noreply.github.com>
1 parent 37e5bde commit 1f658a2

7 files changed

Lines changed: 407 additions & 34 deletions

File tree

pkg/vmcp/session/connector_integration_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
2222
"github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
2323
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
24+
"github.com/stacklok/toolhive/pkg/vmcp/session/internal/security"
2425
sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types"
2526
)
2627

@@ -347,6 +348,96 @@ func TestTokenBinding_DifferentSecretsProduceDifferentHashes(t *testing.T) {
347348
"different HMAC secrets must produce different token hashes for the same input token")
348349
}
349350

351+
// TestRestoreHijackPrevention_Integration_RoundTrip verifies the full
352+
// store-then-restore flow across a real factory-created session:
353+
//
354+
// 1. Create a session via the factory (writes tokenHash + tokenSalt to metadata).
355+
// 2. Extract the persisted values.
356+
// 3. Wrap a fresh base session with RestoreHijackPrevention using those values.
357+
// 4. Confirm the restored decorator accepts the original token and rejects others.
358+
func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) {
359+
t.Parallel()
360+
361+
const rawToken = "integration-token"
362+
hmacSecret := []byte("test-hmac-secret-exactly-32bytes")
363+
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken}
364+
365+
factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(hmacSecret))
366+
sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil)
367+
require.NoError(t, err)
368+
t.Cleanup(func() { _ = sess.Close() })
369+
370+
// Extract persisted values — these simulate what would be read back from Redis.
371+
meta := sess.GetMetadata()
372+
persistedHash := meta[MetadataKeyTokenHash]
373+
persistedSalt := meta[sessiontypes.MetadataKeyTokenSalt]
374+
require.NotEmpty(t, persistedHash, "factory must write tokenHash to metadata")
375+
require.NotEmpty(t, persistedSalt, "factory must write tokenSalt to metadata")
376+
377+
// Simulate "Pod B": restore the decorator from persisted metadata.
378+
// We use a nil-connector session as the inner session (no real backend needed
379+
// to test auth path).
380+
innerSess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil)
381+
require.NoError(t, err)
382+
t.Cleanup(func() { _ = innerSess.Close() })
383+
384+
restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, hmacSecret)
385+
require.NoError(t, err)
386+
387+
ctx := context.Background()
388+
389+
// Original caller is accepted.
390+
_, err = restored.CallTool(ctx, identity, "any-tool", nil, nil)
391+
// ErrToolNotFound is expected (no backends), not an auth error.
392+
require.NotErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller)
393+
require.NotErrorIs(t, err, sessiontypes.ErrNilCaller)
394+
395+
// A different caller is rejected at the auth layer — before any backend routing.
396+
wrongCaller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, Token: "eve-token"}
397+
_, err = restored.CallTool(ctx, wrongCaller, "any-tool", nil, nil)
398+
require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller)
399+
400+
// Nil caller is rejected at the auth layer.
401+
_, err = restored.CallTool(ctx, nil, "any-tool", nil, nil)
402+
require.ErrorIs(t, err, sessiontypes.ErrNilCaller)
403+
}
404+
405+
// TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch verifies
406+
// that a session restored on a replica with a different HMAC secret rejects
407+
// the original caller's token, documenting the operational requirement that
408+
// all replicas must share the same secret.
409+
func TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch(t *testing.T) {
410+
t.Parallel()
411+
412+
secretA := []byte("secret-A-exactly-32-bytes-long!!")
413+
secretB := []byte("secret-B-exactly-32-bytes-long!!")
414+
415+
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: "alice-token"}
416+
417+
// Pod A creates the session with secretA, persisting the hash.
418+
factoryA := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretA))
419+
sessA, err := factoryA.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil)
420+
require.NoError(t, err)
421+
t.Cleanup(func() { _ = sessA.Close() })
422+
423+
persistedHash := sessA.GetMetadata()[MetadataKeyTokenHash]
424+
persistedSalt := sessA.GetMetadata()[sessiontypes.MetadataKeyTokenSalt]
425+
426+
// Pod B restores with secretB — the persisted hash was computed with secretA,
427+
// so validation will produce a different HMAC and reject the caller.
428+
factoryB := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretB))
429+
innerSess, err := factoryB.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil)
430+
require.NoError(t, err)
431+
t.Cleanup(func() { _ = innerSess.Close() })
432+
433+
restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, secretB)
434+
require.NoError(t, err)
435+
436+
_, err = restored.CallTool(context.Background(), identity, "any-tool", nil, nil)
437+
require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller,
438+
"cross-replica secret mismatch must reject the original caller")
439+
}
440+
350441
// TestTokenBinding_MetadataEncoding verifies that the token hash and salt stored
351442
// in session metadata are valid hex strings of the expected lengths:
352443
// - token hash: 64 hex chars (32-byte HMAC-SHA256)

pkg/vmcp/session/decorating_factory.go

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,7 @@ func (f *decoratingMultiSessionFactory) RestoreSession(
4444
if err != nil {
4545
return nil, err
4646
}
47-
for _, dec := range f.decorators {
48-
var decorated MultiSession
49-
decorated, err = dec(ctx, sess)
50-
if err != nil {
51-
if closeErr := sess.Close(); closeErr != nil {
52-
slog.Warn("failed to close session after decorator error", "error", closeErr)
53-
}
54-
return nil, err
55-
}
56-
if decorated == nil {
57-
if closeErr := sess.Close(); closeErr != nil {
58-
slog.Warn("failed to close session after decorator returned nil", "error", closeErr)
59-
}
60-
return nil, fmt.Errorf("decorator returned nil session without error")
61-
}
62-
sess = decorated
63-
}
64-
return sess, nil
47+
return f.applyDecorators(ctx, sess)
6548
}
6649

6750
func (f *decoratingMultiSessionFactory) MakeSessionWithID(
@@ -75,6 +58,13 @@ func (f *decoratingMultiSessionFactory) MakeSessionWithID(
7558
if err != nil {
7659
return nil, err
7760
}
61+
return f.applyDecorators(ctx, sess)
62+
}
63+
64+
// applyDecorators runs the decorator chain over sess in order, closing sess on
65+
// any error and returning the fully-decorated session on success.
66+
func (f *decoratingMultiSessionFactory) applyDecorators(ctx context.Context, sess MultiSession) (MultiSession, error) {
67+
var err error
7868
for _, dec := range f.decorators {
7969
var decorated MultiSession
8070
decorated, err = dec(ctx, sess)

pkg/vmcp/session/decorating_factory_test.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,158 @@ func TestNewDecoratingFactory_HappyPath_ReturnsFinalSession(t *testing.T) {
160160
require.NoError(t, err)
161161
assert.Equal(t, finalSess, got)
162162
}
163+
164+
// ---------------------------------------------------------------------------
165+
// RestoreSession — mirrors the MakeSessionWithID tests above
166+
// ---------------------------------------------------------------------------
167+
168+
func TestRestoreSession_BaseError_Propagated(t *testing.T) {
169+
t.Parallel()
170+
171+
ctrl := gomock.NewController(t)
172+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
173+
baseErr := errors.New("restore failed")
174+
base.EXPECT().
175+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
176+
Return(nil, baseErr)
177+
178+
factory := session.NewDecoratingFactory(base,
179+
func(_ context.Context, s session.MultiSession) (session.MultiSession, error) { return s, nil },
180+
)
181+
182+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
183+
require.ErrorIs(t, err, baseErr)
184+
}
185+
186+
func TestRestoreSession_DecoratorsAppliedInOrder(t *testing.T) {
187+
t.Parallel()
188+
189+
ctrl := gomock.NewController(t)
190+
sess := sessionmocks.NewMockMultiSession(ctrl)
191+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
192+
base.EXPECT().
193+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
194+
Return(sess, nil)
195+
196+
var order []int
197+
dec1 := func(_ context.Context, s session.MultiSession) (session.MultiSession, error) {
198+
order = append(order, 1)
199+
return s, nil
200+
}
201+
dec2 := func(_ context.Context, s session.MultiSession) (session.MultiSession, error) {
202+
order = append(order, 2)
203+
return s, nil
204+
}
205+
206+
factory := session.NewDecoratingFactory(base, dec1, dec2)
207+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
208+
require.NoError(t, err)
209+
assert.Equal(t, []int{1, 2}, order)
210+
}
211+
212+
func TestRestoreSession_DecoratorError_ClosesSession(t *testing.T) {
213+
t.Parallel()
214+
215+
ctrl := gomock.NewController(t)
216+
sess := sessionmocks.NewMockMultiSession(ctrl)
217+
sess.EXPECT().Close().Return(nil)
218+
219+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
220+
base.EXPECT().
221+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
222+
Return(sess, nil)
223+
224+
decErr := errors.New("decorator boom")
225+
factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) {
226+
return nil, decErr
227+
})
228+
229+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
230+
require.ErrorIs(t, err, decErr)
231+
}
232+
233+
func TestRestoreSession_SecondDecoratorError_ClosesCurrentSession(t *testing.T) {
234+
t.Parallel()
235+
236+
ctrl := gomock.NewController(t)
237+
sess := sessionmocks.NewMockMultiSession(ctrl)
238+
wrappedSess := sessionmocks.NewMockMultiSession(ctrl)
239+
// Only the session that is current at the time of failure should be closed.
240+
wrappedSess.EXPECT().Close().Return(nil)
241+
242+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
243+
base.EXPECT().
244+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
245+
Return(sess, nil)
246+
247+
decErr := errors.New("second decorator boom")
248+
dec1 := func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return wrappedSess, nil }
249+
dec2 := func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return nil, decErr }
250+
251+
factory := session.NewDecoratingFactory(base, dec1, dec2)
252+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
253+
require.ErrorIs(t, err, decErr)
254+
}
255+
256+
func TestRestoreSession_NilReturnWithNoError_ClosesSession(t *testing.T) {
257+
t.Parallel()
258+
259+
ctrl := gomock.NewController(t)
260+
sess := sessionmocks.NewMockMultiSession(ctrl)
261+
sess.EXPECT().Close().Return(nil)
262+
263+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
264+
base.EXPECT().
265+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
266+
Return(sess, nil)
267+
268+
factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) {
269+
return nil, nil // buggy decorator
270+
})
271+
272+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
273+
require.Error(t, err)
274+
assert.Contains(t, err.Error(), "nil session")
275+
}
276+
277+
func TestRestoreSession_CloseErrorDoesNotSuppressOriginalError(t *testing.T) {
278+
t.Parallel()
279+
280+
ctrl := gomock.NewController(t)
281+
sess := sessionmocks.NewMockMultiSession(ctrl)
282+
sess.EXPECT().Close().Return(errors.New("close failed"))
283+
284+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
285+
base.EXPECT().
286+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
287+
Return(sess, nil)
288+
289+
decErr := errors.New("decorator error")
290+
factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) {
291+
return nil, decErr
292+
})
293+
294+
_, err := factory.RestoreSession(context.Background(), "id", nil, nil)
295+
require.ErrorIs(t, err, decErr)
296+
}
297+
298+
func TestRestoreSession_HappyPath_ReturnsFinalSession(t *testing.T) {
299+
t.Parallel()
300+
301+
ctrl := gomock.NewController(t)
302+
sess := sessionmocks.NewMockMultiSession(ctrl)
303+
finalSess := sessionmocks.NewMockMultiSession(ctrl)
304+
305+
base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl)
306+
base.EXPECT().
307+
RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
308+
Return(sess, nil)
309+
310+
factory := session.NewDecoratingFactory(base,
311+
func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return finalSess, nil },
312+
)
313+
314+
got, err := factory.RestoreSession(context.Background(), "id", nil, nil)
315+
require.NoError(t, err)
316+
assert.Equal(t, finalSess, got)
317+
}

pkg/vmcp/session/factory.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ var (
5050
// defaultHMACSecret is the fallback HMAC secret used when WithHMACSecret is not provided.
5151
// WARNING: This is INSECURE and should ONLY be used for testing/development.
5252
// Production deployments MUST provide a secure secret via WithHMACSecret option.
53+
//
54+
// NOTE: In multi-replica deployments, all replicas must use the same HMAC secret,
55+
// injected via the VMCP_SESSION_HMAC_SECRET environment variable. If replicas use
56+
// different secrets, cross-pod token validation will silently reject legitimate
57+
// callers. The default insecure secret must NOT be used in production.
5358
defaultHMACSecret = []byte("insecure-default-for-testing-only-change-in-production")
5459
)
5560

@@ -582,7 +587,7 @@ func (f *defaultMultiSessionFactory) RestoreSession(
582587
return nil, fmt.Errorf("RestoreSession: token hash metadata key absent (corrupted session metadata)")
583588
}
584589
storedSalt := storedMetadata[sessiontypes.MetadataKeyTokenSalt]
585-
restored, err := security.RestoreHijackPrevention(baseSession, f.hmacSecret, storedHash, storedSalt)
590+
restored, err := security.RestoreHijackPrevention(baseSession, storedHash, storedSalt, f.hmacSecret)
586591
if err != nil {
587592
_ = baseSession.Close()
588593
return nil, fmt.Errorf("RestoreSession: failed to restore hijack prevention: %w", err)

pkg/vmcp/session/internal/security/hijack_prevention_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ func TestRestoreHijackPrevention(t *testing.T) {
227227
t.Parallel()
228228

229229
base := newMockSession("s1")
230-
restored, err := RestoreHijackPrevention(base, testSecret, "", "")
230+
restored, err := RestoreHijackPrevention(base, "", "", testSecret)
231231
require.NoError(t, err)
232232
require.NotNil(t, restored)
233233
})
@@ -236,7 +236,7 @@ func TestRestoreHijackPrevention(t *testing.T) {
236236
t.Parallel()
237237

238238
base := newMockSession("s2")
239-
_, err := RestoreHijackPrevention(base, testSecret, "somehash", "")
239+
_, err := RestoreHijackPrevention(base, "somehash", "", testSecret)
240240
require.Error(t, err)
241241
assert.Contains(t, err.Error(), "salt is missing")
242242
})
@@ -245,15 +245,15 @@ func TestRestoreHijackPrevention(t *testing.T) {
245245
t.Parallel()
246246

247247
base := newMockSession("s3")
248-
_, err := RestoreHijackPrevention(base, testSecret, "", "deadbeef")
248+
_, err := RestoreHijackPrevention(base, "", "deadbeef", testSecret)
249249
require.Error(t, err)
250250
assert.Contains(t, err.Error(), "hash is missing")
251251
})
252252

253253
t.Run("nil session is rejected", func(t *testing.T) {
254254
t.Parallel()
255255

256-
_, err := RestoreHijackPrevention(nil, testSecret, "", "")
256+
_, err := RestoreHijackPrevention(nil, "", "", testSecret)
257257
require.Error(t, err)
258258
})
259259
}

0 commit comments

Comments
 (0)