11import { Test , TestingModule } from '@nestjs/testing' ;
22import { HttpException } from '@nestjs/common' ;
3+ import type { Request } from 'express' ;
34import { OAuthController } from './oauth.controller' ;
45import { HybridAuthGuard } from '../../auth/hybrid-auth.guard' ;
56import { PermissionGuard } from '../../auth/permission.guard' ;
@@ -12,10 +13,21 @@ import { OAuthCredentialsService } from '../services/oauth-credentials.service';
1213import { AutoCheckRunnerService } from '../services/auto-check-runner.service' ;
1314import { CloudSecurityService } from '../../cloud-security/cloud-security.service' ;
1415
16+ jest . mock ( '@db' , ( ) => ( {
17+ ...jest . requireActual ( '@prisma/client' ) ,
18+ db : { } ,
19+ } ) ) ;
20+
1521jest . mock ( '../../auth/auth.server' , ( ) => ( {
1622 auth : { api : { getSession : jest . fn ( ) } } ,
1723} ) ) ;
1824
25+ import { auth } from '../../auth/auth.server' ;
26+
27+ const mockedGetSession = auth . api . getSession as jest . MockedFunction <
28+ typeof auth . api . getSession
29+ > ;
30+
1931jest . mock ( '../../auth/hybrid-auth.guard' , ( ) => ( {
2032 HybridAuthGuard : class HybridAuthGuard { } ,
2133} ) ) ;
@@ -306,8 +318,36 @@ describe('OAuthController', () => {
306318 redirect : jest . fn ( ) ,
307319 } as unknown as import ( 'express' ) . Response ;
308320
321+ const buildRequest = ( overrides ?: Partial < Request [ 'headers' ] > ) =>
322+ ( {
323+ headers : {
324+ cookie : 'better-auth.session_token=valid_cookie' ,
325+ ...overrides ,
326+ } ,
327+ } ) as unknown as Request ;
328+
329+ const mockRequest = buildRequest ( ) ;
330+
331+ const setMatchingSession = ( overrides ?: {
332+ userId ?: string ;
333+ activeOrganizationId ?: string | null ;
334+ } ) => {
335+ mockedGetSession . mockResolvedValue ( {
336+ user : { id : overrides ?. userId ?? 'user_1' } ,
337+ session : {
338+ id : 'sess_1' ,
339+ activeOrganizationId :
340+ overrides ?. activeOrganizationId === null
341+ ? undefined
342+ : ( overrides ?. activeOrganizationId ?? 'org_1' ) ,
343+ } ,
344+ } as never ) ;
345+ } ;
346+
309347 beforeEach ( ( ) => {
310348 ( mockResponse . redirect as jest . Mock ) . mockClear ( ) ;
349+ mockedGetSession . mockReset ( ) ;
350+ setMatchingSession ( ) ;
311351 } ) ;
312352
313353 it ( 'should redirect with error when OAuth error is present' , async ( ) => {
@@ -318,6 +358,7 @@ describe('OAuthController', () => {
318358 error : 'access_denied' ,
319359 error_description : 'User denied access' ,
320360 } ,
361+ mockRequest ,
321362 mockResponse ,
322363 ) ;
323364
@@ -327,7 +368,11 @@ describe('OAuthController', () => {
327368 } ) ;
328369
329370 it ( 'should redirect with error when code or state is missing' , async ( ) => {
330- await controller . oauthCallback ( { code : '' , state : '' } , mockResponse ) ;
371+ await controller . oauthCallback (
372+ { code : '' , state : '' } ,
373+ mockRequest ,
374+ mockResponse ,
375+ ) ;
331376
332377 expect ( mockResponse . redirect ) . toHaveBeenCalled ( ) ;
333378 const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock . calls [ 0 ] [ 0 ] ;
@@ -339,6 +384,7 @@ describe('OAuthController', () => {
339384
340385 await controller . oauthCallback (
341386 { code : 'auth_code' , state : 'invalid_state' } ,
387+ mockRequest ,
342388 mockResponse ,
343389 ) ;
344390
@@ -359,6 +405,7 @@ describe('OAuthController', () => {
359405
360406 await controller . oauthCallback (
361407 { code : 'auth_code' , state : 'expired_state' } ,
408+ mockRequest ,
362409 mockResponse ,
363410 ) ;
364411
@@ -385,6 +432,7 @@ describe('OAuthController', () => {
385432
386433 await controller . oauthCallback (
387434 { code : 'auth_code' , state : 'valid_state' } ,
435+ mockRequest ,
388436 mockResponse ,
389437 ) ;
390438
@@ -396,7 +444,7 @@ describe('OAuthController', () => {
396444 expect ( redirectUrl ) . toContain ( 'error=token_exchange_failed' ) ;
397445 } ) ;
398446
399- it ( 'should trigger initial GCP service discovery scan on successful first connect ' , async ( ) => {
447+ it ( 'should redirect to success URL for GCP without triggering service detection or scan (GCP auto-detection runs after project selection, not after OAuth) ' , async ( ) => {
400448 const futureDate = new Date ( Date . now ( ) + 600000 ) ;
401449 mockOAuthStateRepository . findByState . mockResolvedValue ( {
402450 state : 'valid_gcp_state' ,
@@ -455,23 +503,18 @@ describe('OAuthController', () => {
455503
456504 await controller . oauthCallback (
457505 { code : 'auth_code' , state : 'valid_gcp_state' } ,
506+ mockRequest ,
458507 mockResponse ,
459508 ) ;
460509
461510 await new Promise < void > ( ( resolve ) => setImmediate ( resolve ) ) ;
462511
463- expect ( mockCloudSecurityService . detectServices ) . toHaveBeenCalledWith (
464- 'conn_1' ,
465- 'org_1' ,
466- ) ;
467- expect ( mockedTriggerTask ) . toHaveBeenCalledWith (
512+ // GCP service detection / scan is now triggered AFTER the user picks
513+ // projects on the integrations page, not automatically after OAuth.
514+ expect ( mockCloudSecurityService . detectServices ) . not . toHaveBeenCalled ( ) ;
515+ expect ( mockedTriggerTask ) . not . toHaveBeenCalledWith (
468516 'run-cloud-security-scan' ,
469- {
470- connectionId : 'conn_1' ,
471- organizationId : 'org_1' ,
472- providerSlug : 'gcp' ,
473- connectionName : 'conn_1' ,
474- } ,
517+ expect . anything ( ) ,
475518 ) ;
476519 expect ( mockResponse . redirect ) . toHaveBeenCalled ( ) ;
477520 const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock . calls [ 0 ] [ 0 ] ;
@@ -546,6 +589,7 @@ describe('OAuthController', () => {
546589
547590 await controller . oauthCallback (
548591 { code : 'auth_code' , state : 'valid_gcp_state' } ,
592+ mockRequest ,
549593 mockResponse ,
550594 ) ;
551595
@@ -558,5 +602,119 @@ describe('OAuthController', () => {
558602
559603 fetchSpy . mockRestore ( ) ;
560604 } ) ;
605+
606+ describe ( 'session defense-in-depth' , ( ) => {
607+ const futureDate = new Date ( Date . now ( ) + 600000 ) ;
608+ const validState = {
609+ state : 'valid_state' ,
610+ providerSlug : 'github' ,
611+ organizationId : 'org_1' ,
612+ userId : 'user_1' ,
613+ codeVerifier : null ,
614+ redirectUrl : null ,
615+ expiresAt : futureDate ,
616+ } ;
617+
618+ it ( 'redirects with session_mismatch when no session cookie/auth header is present' , async ( ) => {
619+ mockOAuthStateRepository . findByState . mockResolvedValue ( validState ) ;
620+ const reqWithoutCookie = {
621+ headers : { } ,
622+ } as unknown as Request ;
623+
624+ await controller . oauthCallback (
625+ { code : 'auth_code' , state : 'valid_state' } ,
626+ reqWithoutCookie ,
627+ mockResponse ,
628+ ) ;
629+
630+ // getSession must not even be called when no auth headers are present
631+ expect ( mockedGetSession ) . not . toHaveBeenCalled ( ) ;
632+ expect ( mockOAuthStateRepository . delete ) . toHaveBeenCalledWith (
633+ 'valid_state' ,
634+ ) ;
635+ const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock
636+ . calls [ 0 ] [ 0 ] ;
637+ expect ( redirectUrl ) . toContain ( 'error=session_mismatch' ) ;
638+ } ) ;
639+
640+ it ( 'redirects with session_mismatch when getSession returns null' , async ( ) => {
641+ mockOAuthStateRepository . findByState . mockResolvedValue ( validState ) ;
642+ mockedGetSession . mockResolvedValue ( null ) ;
643+
644+ await controller . oauthCallback (
645+ { code : 'auth_code' , state : 'valid_state' } ,
646+ mockRequest ,
647+ mockResponse ,
648+ ) ;
649+
650+ expect ( mockOAuthStateRepository . delete ) . toHaveBeenCalledWith (
651+ 'valid_state' ,
652+ ) ;
653+ const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock
654+ . calls [ 0 ] [ 0 ] ;
655+ expect ( redirectUrl ) . toContain ( 'error=session_mismatch' ) ;
656+ } ) ;
657+
658+ it ( 'redirects with session_mismatch when session.user.id does not match oauthState.userId' , async ( ) => {
659+ mockOAuthStateRepository . findByState . mockResolvedValue ( validState ) ;
660+ setMatchingSession ( { userId : 'different_user' } ) ;
661+
662+ await controller . oauthCallback (
663+ { code : 'auth_code' , state : 'valid_state' } ,
664+ mockRequest ,
665+ mockResponse ,
666+ ) ;
667+
668+ expect ( mockOAuthStateRepository . delete ) . toHaveBeenCalledWith (
669+ 'valid_state' ,
670+ ) ;
671+ // We do NOT proceed to token exchange when session doesn't match
672+ expect (
673+ mockOAuthCredentialsService . getCredentials ,
674+ ) . not . toHaveBeenCalled ( ) ;
675+ const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock
676+ . calls [ 0 ] [ 0 ] ;
677+ expect ( redirectUrl ) . toContain ( 'error=session_mismatch' ) ;
678+ } ) ;
679+
680+ it ( 'redirects with session_mismatch when session.activeOrganizationId is set and does not match oauthState.organizationId' , async ( ) => {
681+ mockOAuthStateRepository . findByState . mockResolvedValue ( validState ) ;
682+ setMatchingSession ( { activeOrganizationId : 'org_other' } ) ;
683+
684+ await controller . oauthCallback (
685+ { code : 'auth_code' , state : 'valid_state' } ,
686+ mockRequest ,
687+ mockResponse ,
688+ ) ;
689+
690+ expect ( mockOAuthStateRepository . delete ) . toHaveBeenCalledWith (
691+ 'valid_state' ,
692+ ) ;
693+ const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock
694+ . calls [ 0 ] [ 0 ] ;
695+ expect ( redirectUrl ) . toContain ( 'error=session_mismatch' ) ;
696+ } ) ;
697+
698+ it ( 'proceeds when session.user.id matches and activeOrganizationId is absent' , async ( ) => {
699+ mockOAuthStateRepository . findByState . mockResolvedValue ( validState ) ;
700+ // Session with userId match but no activeOrganizationId — still allowed,
701+ // since the state itself already binds the organization.
702+ setMatchingSession ( { activeOrganizationId : null } ) ;
703+ mockedGetManifest . mockReturnValue ( undefined as never ) ;
704+
705+ await controller . oauthCallback (
706+ { code : 'auth_code' , state : 'valid_state' } ,
707+ mockRequest ,
708+ mockResponse ,
709+ ) ;
710+
711+ // Session check passed → we reach the manifest lookup, fail there,
712+ // redirect with token_exchange_failed (NOT session_mismatch).
713+ const redirectUrl = ( mockResponse . redirect as jest . Mock ) . mock
714+ . calls [ 0 ] [ 0 ] ;
715+ expect ( redirectUrl ) . toContain ( 'error=token_exchange_failed' ) ;
716+ expect ( redirectUrl ) . not . toContain ( 'error=session_mismatch' ) ;
717+ } ) ;
718+ } ) ;
561719 } ) ;
562720} ) ;
0 commit comments