Skip to content

Commit 348f5da

Browse files
authored
Merge pull request #2564 from trycompai/fix/azure-oauth-token-refresh
fix(cloud-tests): add OAuth token auto-refresh to Azure remediation
2 parents d1295d4 + 3f17765 commit 348f5da

2 files changed

Lines changed: 185 additions & 115 deletions

File tree

apps/api/src/cloud-security/azure-remediation.service.ts

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import { Injectable, Logger } from '@nestjs/common';
22
import { db, Prisma } from '@db';
3+
import { getManifest } from '@trycompai/integration-platform';
34
import { CredentialVaultService } from '../integration-platform/services/credential-vault.service';
5+
import { OAuthCredentialsService } from '../integration-platform/services/oauth-credentials.service';
46
import { AiRemediationService } from './ai-remediation.service';
57
import { AzureSecurityService } from './providers/azure-security.service';
68
import { parseAzurePermissionError } from './remediation-error.utils';
@@ -36,6 +38,7 @@ export class AzureRemediationService {
3638

3739
constructor(
3840
private readonly credentialVaultService: CredentialVaultService,
41+
private readonly oauthCredentialsService: OAuthCredentialsService,
3942
private readonly aiRemediationService: AiRemediationService,
4043
private readonly azureSecurityService: AzureSecurityService,
4144
) {}
@@ -455,31 +458,13 @@ export class AzureRemediationService {
455458
throw new Error('No rollback steps available for this action.');
456459
}
457460

458-
// Get fresh access token
459-
const credentials = await this.resolveCredentials(
461+
// Get fresh access token (auto-refreshes if expired)
462+
const accessToken = await this.getValidAzureToken(
460463
action.connectionId,
461464
action.organizationId,
462465
);
463-
if (!credentials) {
464-
throw new Error('Cannot retrieve Azure credentials for rollback.');
465-
}
466-
467-
// OAuth flow: token from vault; legacy: SP client credentials
468-
let accessToken = credentials.access_token as string | undefined;
469-
if (
470-
!accessToken &&
471-
credentials.tenantId &&
472-
credentials.clientId &&
473-
credentials.clientSecret
474-
) {
475-
accessToken = await this.azureSecurityService.getAccessToken(
476-
credentials.tenantId as string,
477-
credentials.clientId as string,
478-
credentials.clientSecret as string,
479-
);
480-
}
481466
if (!accessToken) {
482-
throw new Error('Cannot obtain Azure access token for rollback.');
467+
throw new Error('Cannot obtain Azure access token for rollback. Please reconnect the integration.');
483468
}
484469

485470
this.logger.log(
@@ -638,6 +623,56 @@ export class AzureRemediationService {
638623

639624
// --- Private helpers ---
640625

626+
/**
627+
* Get a valid Azure access token, refreshing if expired.
628+
*/
629+
private async getValidAzureToken(
630+
connectionId: string,
631+
organizationId: string,
632+
): Promise<string | null> {
633+
const manifest = getManifest('azure');
634+
const oauthConfig = manifest?.auth?.type === 'oauth2' ? manifest.auth.config : null;
635+
636+
if (oauthConfig) {
637+
const oauthCreds = await this.oauthCredentialsService.getCredentials(
638+
'azure',
639+
organizationId,
640+
);
641+
if (oauthCreds) {
642+
const token = await this.credentialVaultService.getValidAccessToken(
643+
connectionId,
644+
{
645+
tokenUrl: oauthConfig.tokenUrl,
646+
clientId: oauthCreds.clientId,
647+
clientSecret: oauthCreds.clientSecret,
648+
clientAuthMethod: oauthConfig.clientAuthMethod,
649+
},
650+
);
651+
if (token) return token;
652+
}
653+
}
654+
655+
// Fallback: try raw credentials (legacy SP or expired token)
656+
const credentials =
657+
await this.credentialVaultService.getDecryptedCredentials(connectionId);
658+
if (!credentials) return null;
659+
660+
if (credentials.access_token) {
661+
return credentials.access_token as string;
662+
}
663+
664+
// Legacy service principal flow
665+
if (credentials.tenantId && credentials.clientId && credentials.clientSecret) {
666+
return this.azureSecurityService.getAccessToken(
667+
credentials.tenantId as string,
668+
credentials.clientId as string,
669+
credentials.clientSecret as string,
670+
);
671+
}
672+
673+
return null;
674+
}
675+
641676
private async resolveCredentials(
642677
connectionId: string,
643678
organizationId: string,
@@ -655,30 +690,16 @@ export class AzureRemediationService {
655690
organizationId: string,
656691
checkResultId: string,
657692
) {
658-
const credentials = await this.resolveCredentials(
659-
connectionId,
660-
organizationId,
661-
);
662-
663-
let accessToken: string | null = null;
664-
// OAuth flow: token from vault
665-
if (credentials?.access_token) {
666-
accessToken = credentials.access_token as string;
667-
}
668-
// Legacy SP flow fallback
669-
if (
670-
!accessToken &&
671-
credentials?.tenantId &&
672-
credentials?.clientId &&
673-
credentials?.clientSecret
674-
) {
675-
accessToken = await this.azureSecurityService.getAccessToken(
676-
credentials.tenantId as string,
677-
credentials.clientId as string,
678-
credentials.clientSecret as string,
679-
);
693+
const connection = await db.integrationConnection.findFirst({
694+
where: { id: connectionId, organizationId, status: 'active' },
695+
include: { provider: true },
696+
});
697+
if (!connection || connection.provider.slug !== 'azure') {
698+
throw new Error('Azure connection not found or not active');
680699
}
681700

701+
const accessToken = await this.getValidAzureToken(connectionId, organizationId);
702+
682703
const checkResult = await db.integrationCheckResult.findFirst({
683704
where: {
684705
id: checkResultId,

apps/api/src/integration-platform/services/credential-vault.service.ts

Lines changed: 121 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,28 @@ export class CredentialVaultService {
197197
}
198198

199199
/**
200-
* Get decrypted credentials for a connection
200+
* Get decrypted credentials for a connection.
201+
* Prefers the explicitly marked active version, falls back to latest by version number.
201202
*/
202203
async getDecryptedCredentials(
203204
connectionId: string,
204205
): Promise<Record<string, string | string[]> | null> {
205-
const latestVersion =
206-
await this.credentialRepository.findLatestByConnection(connectionId);
207-
if (!latestVersion) return null;
206+
// Prefer the active credential version set during token storage/refresh
207+
const connection = await this.connectionRepository.findById(connectionId);
208+
let version = connection?.activeCredentialVersionId
209+
? await this.credentialRepository.findById(
210+
connection.activeCredentialVersionId,
211+
)
212+
: null;
213+
214+
// Fall back to latest version by version number
215+
if (!version) {
216+
version =
217+
await this.credentialRepository.findLatestByConnection(connectionId);
218+
}
219+
if (!version) return null;
220+
221+
const latestVersion = version;
208222

209223
const encryptedPayload = latestVersion.encryptedPayload as Record<
210224
string,
@@ -297,8 +311,66 @@ export class CredentialVaultService {
297311
}
298312

299313
/**
300-
* Refresh OAuth tokens using the refresh token
301-
* Returns the new access token, or null if refresh failed
314+
* Attempt a single token refresh request to the OAuth provider.
315+
* Returns the new access token on success, or null on failure.
316+
*/
317+
private async attemptTokenRefresh(
318+
connectionId: string,
319+
refreshToken: string,
320+
config: TokenRefreshConfig,
321+
): Promise<{ token?: string; status?: number; errorBody?: string }> {
322+
const body = new URLSearchParams({
323+
grant_type: 'refresh_token',
324+
refresh_token: refreshToken,
325+
});
326+
327+
const headers: Record<string, string> = {
328+
'Content-Type': 'application/x-www-form-urlencoded',
329+
Accept: 'application/json',
330+
};
331+
332+
// Per OAuth 2.0 RFC 6749 Section 2.3.1, when using HTTP Basic auth (header),
333+
// client credentials should NOT be included in the request body
334+
if (config.clientAuthMethod === 'header') {
335+
const credentials = Buffer.from(
336+
`${config.clientId}:${config.clientSecret}`,
337+
).toString('base64');
338+
headers['Authorization'] = `Basic ${credentials}`;
339+
} else {
340+
body.set('client_id', config.clientId);
341+
body.set('client_secret', config.clientSecret);
342+
}
343+
344+
const refreshEndpoint = config.refreshUrl || config.tokenUrl;
345+
const response = await fetch(refreshEndpoint, {
346+
method: 'POST',
347+
headers,
348+
body: body.toString(),
349+
});
350+
351+
if (!response.ok) {
352+
const errorBody = await response.text();
353+
return { status: response.status, errorBody };
354+
}
355+
356+
const tokens: OAuthTokens = await response.json();
357+
358+
const tokensToStore: OAuthTokens = {
359+
access_token: tokens.access_token,
360+
refresh_token: tokens.refresh_token || refreshToken,
361+
token_type: tokens.token_type,
362+
expires_in: tokens.expires_in,
363+
scope: tokens.scope,
364+
};
365+
366+
await this.storeOAuthTokens(connectionId, tokensToStore);
367+
return { token: tokens.access_token };
368+
}
369+
370+
/**
371+
* Refresh OAuth tokens using the refresh token.
372+
* Retries once after a short delay before marking the connection as error.
373+
* Returns the new access token, or null if refresh failed.
302374
*/
303375
async refreshOAuthTokens(
304376
connectionId: string,
@@ -315,76 +387,55 @@ export class CredentialVaultService {
315387
try {
316388
this.logger.log(`Refreshing OAuth tokens for connection ${connectionId}`);
317389

318-
// Build the token request
319-
const body = new URLSearchParams({
320-
grant_type: 'refresh_token',
321-
refresh_token: refreshToken,
322-
});
323-
324-
const headers: Record<string, string> = {
325-
'Content-Type': 'application/x-www-form-urlencoded',
326-
Accept: 'application/json',
327-
};
328-
329-
// Add client credentials based on auth method
330-
// Per OAuth 2.0 RFC 6749 Section 2.3.1, when using HTTP Basic auth (header),
331-
// client credentials should NOT be included in the request body
332-
if (config.clientAuthMethod === 'header') {
333-
const credentials = Buffer.from(
334-
`${config.clientId}:${config.clientSecret}`,
335-
).toString('base64');
336-
headers['Authorization'] = `Basic ${credentials}`;
337-
} else {
338-
// Default: send in body
339-
body.set('client_id', config.clientId);
340-
body.set('client_secret', config.clientSecret);
390+
// First attempt
391+
const first = await this.attemptTokenRefresh(
392+
connectionId,
393+
refreshToken,
394+
config,
395+
);
396+
if (first.token) {
397+
this.logger.log(
398+
`Successfully refreshed OAuth tokens for connection ${connectionId}`,
399+
);
400+
return first.token;
341401
}
342402

343-
// Use refreshUrl if provided, otherwise fall back to tokenUrl
344-
const refreshEndpoint = config.refreshUrl || config.tokenUrl;
345-
346-
const response = await fetch(refreshEndpoint, {
347-
method: 'POST',
348-
headers,
349-
body: body.toString(),
350-
});
403+
// Retry once after 2 seconds for transient failures (rate limits, network blips)
404+
this.logger.warn(
405+
`Token refresh attempt 1 failed for connection ${connectionId}: HTTP ${first.status}${first.errorBody ?? '(no body)'}. Retrying in 2s...`,
406+
);
407+
await new Promise((r) => setTimeout(r, 2000));
351408

352-
if (!response.ok) {
353-
await response.text(); // consume body
354-
this.logger.error(
355-
`Token refresh failed for connection ${connectionId}: ${response.status}`,
409+
const second = await this.attemptTokenRefresh(
410+
connectionId,
411+
refreshToken,
412+
config,
413+
);
414+
if (second.token) {
415+
this.logger.log(
416+
`Successfully refreshed OAuth tokens for connection ${connectionId} on retry`,
356417
);
357-
358-
// If refresh token is invalid/expired, mark connection as error
359-
if (response.status === 400 || response.status === 401) {
360-
await this.connectionRepository.update(connectionId, {
361-
status: 'error',
362-
errorMessage:
363-
'OAuth token expired. Please reconnect the integration.',
364-
});
365-
}
366-
367-
return null;
418+
return second.token;
368419
}
369420

370-
const tokens: OAuthTokens = await response.json();
371-
372-
// Store the new tokens
373-
// Note: Some providers return a new refresh token, some don't
374-
const tokensToStore: OAuthTokens = {
375-
access_token: tokens.access_token,
376-
refresh_token: tokens.refresh_token || refreshToken, // Keep old refresh token if not provided
377-
token_type: tokens.token_type,
378-
expires_in: tokens.expires_in,
379-
scope: tokens.scope,
380-
};
421+
// Both attempts failed — log the full error and mark connection
422+
this.logger.error(
423+
`Token refresh failed for connection ${connectionId} after 2 attempts: HTTP ${second.status}${second.errorBody ?? '(no body)'}`,
424+
);
381425

382-
await this.storeOAuthTokens(connectionId, tokensToStore);
426+
if (
427+
second.status === 400 ||
428+
second.status === 401 ||
429+
second.status === 403
430+
) {
431+
await this.connectionRepository.update(connectionId, {
432+
status: 'error',
433+
errorMessage:
434+
'OAuth token expired. Please reconnect the integration.',
435+
});
436+
}
383437

384-
this.logger.log(
385-
`Successfully refreshed OAuth tokens for connection ${connectionId}`,
386-
);
387-
return tokens.access_token;
438+
return null;
388439
} catch (error) {
389440
this.logger.error(
390441
`Error refreshing tokens for connection ${connectionId}:`,
@@ -402,10 +453,9 @@ export class CredentialVaultService {
402453
connectionId: string,
403454
refreshConfig?: TokenRefreshConfig,
404455
): Promise<string | null> {
405-
// Check if we need to refresh
406-
const needsRefresh = await this.needsRefresh(connectionId);
456+
const shouldRefresh = await this.needsRefresh(connectionId);
407457

408-
if (needsRefresh && refreshConfig) {
458+
if (shouldRefresh && refreshConfig) {
409459
const newToken = await this.refreshOAuthTokens(
410460
connectionId,
411461
refreshConfig,
@@ -416,7 +466,6 @@ export class CredentialVaultService {
416466
// If refresh failed, try to use existing token (might still work briefly)
417467
}
418468

419-
// Get current credentials
420469
const credentials = await this.getDecryptedCredentials(connectionId);
421470
return typeof credentials?.access_token === 'string'
422471
? credentials.access_token

0 commit comments

Comments
 (0)