diff --git a/src/SmartTransactionsController.test.ts b/src/SmartTransactionsController.test.ts index 1693560e..8f7e8f85 100644 --- a/src/SmartTransactionsController.test.ts +++ b/src/SmartTransactionsController.test.ts @@ -415,7 +415,7 @@ describe('SmartTransactionsController', () => { await withController( { options: { - supportedChainIds: [ChainId.mainnet], + getSupportedChainIds: () => [ChainId.mainnet], }, }, ({ controller, triggerNetworStateChange }) => { @@ -1141,6 +1141,26 @@ describe('SmartTransactionsController', () => { }); }); + it('fetches liveness using custom getSentinelUrl', async () => { + const customSentinelUrl = 'https://custom-sentinel.example.com'; + await withController( + { + options: { + getSentinelUrl: (_chainId: Hex) => customSentinelUrl, + }, + }, + async ({ controller }) => { + nock(customSentinelUrl) + .get(`/network`) + .reply(200, createSuccessLivenessApiResponse()); + + const liveness = await controller.fetchLiveness(); + + expect(liveness).toBe(true); + }, + ); + }); + it('fetches liveness and sets in feesByChainId state for the Smart Transactions API for the chainId of the networkClientId passed in', async () => { await withController(async ({ controller }) => { nock(SENTINEL_API_BASE_URL_MAP[sepoliaChainIdDec]) @@ -1869,7 +1889,7 @@ describe('SmartTransactionsController', () => { await withController( { options: { - // pending transactions in state are required to test polling + getSupportedChainIds: () => [ChainId.mainnet, ChainId.sepolia], state: { smartTransactionsState: { ...getDefaultSmartTransactionsControllerState() @@ -1989,7 +2009,7 @@ describe('SmartTransactionsController', () => { await withController( { options: { - // pending transactions in state are required to test polling + getSupportedChainIds: () => [ChainId.mainnet], state: { smartTransactionsState: { ...getDefaultSmartTransactionsControllerState() @@ -2166,12 +2186,11 @@ describe('SmartTransactionsController', () => { ); }); - it('removes transactions from the current chainId (even if it is not in supportedChainIds) if ignoreNetwork is false', async () => { + it('removes transactions from the current chainId (even if it is not returned by getSupportedChainIds) if ignoreNetwork is false', async () => { await withController( { options: { - supportedChainIds: [ChainId.sepolia], - chainId: ChainId.mainnet, + getSupportedChainIds: () => [ChainId.sepolia], state: { smartTransactionsState: { ...getDefaultSmartTransactionsControllerState() @@ -2218,11 +2237,11 @@ describe('SmartTransactionsController', () => { ); }); - it('removes transactions from all chains (even if they are not in supportedChainIds) if ignoreNetwork is true', async () => { + it('removes transactions from all chains (even if they are not returned by getSupportedChainIds) if ignoreNetwork is true', async () => { await withController( { options: { - supportedChainIds: [], + getSupportedChainIds: () => [], state: { smartTransactionsState: { ...getDefaultSmartTransactionsControllerState() diff --git a/src/SmartTransactionsController.ts b/src/SmartTransactionsController.ts index b0645eb6..5ced1aad 100644 --- a/src/SmartTransactionsController.ts +++ b/src/SmartTransactionsController.ts @@ -191,7 +191,6 @@ type SmartTransactionsControllerOptions = { interval?: number; clientId: ClientId; chainId?: Hex; - supportedChainIds?: Hex[]; getNonceLock: TransactionController['getNonceLock']; confirmExternalTransaction: TransactionController['confirmExternalTransaction']; trackMetaMetricsEvent: ( @@ -212,6 +211,8 @@ type SmartTransactionsControllerOptions = { getFeatureFlags: () => FeatureFlags; updateTransaction: (transaction: TransactionMeta, note: string) => void; trace?: TraceCallback; + getSentinelUrl?: (chainId: Hex) => string | undefined; + getSupportedChainIds?: () => Hex[]; }; export type SmartTransactionsControllerPollingInput = { @@ -229,8 +230,6 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo #chainId: Hex; - #supportedChainIds: Hex[]; - timeoutHandle?: NodeJS.Timeout; readonly #getNonceLock: SmartTransactionsControllerOptions['getNonceLock']; @@ -253,6 +252,10 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo #trace: TraceCallback; + #getSentinelUrl?: SmartTransactionsControllerOptions['getSentinelUrl']; + + #getSupportedChainIds: () => Hex[]; + /* istanbul ignore next */ async #fetch(request: string, options?: RequestInit) { const fetchOptions = { @@ -270,7 +273,6 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo interval = DEFAULT_INTERVAL, clientId, chainId: InitialChainId = ChainId.mainnet, - supportedChainIds = [ChainId.mainnet, ChainId.sepolia], getNonceLock, confirmExternalTransaction, trackMetaMetricsEvent, @@ -281,6 +283,8 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo getFeatureFlags, updateTransaction, trace, + getSentinelUrl, + getSupportedChainIds = () => [ChainId.mainnet, ChainId.sepolia], }: SmartTransactionsControllerOptions) { super({ name: controllerName, @@ -294,7 +298,6 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo this.#interval = interval; this.#clientId = clientId; this.#chainId = InitialChainId; - this.#supportedChainIds = supportedChainIds; this.setIntervalLength(interval); this.#getNonceLock = getNonceLock; this.#ethQuery = undefined; @@ -305,6 +308,8 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo this.#getFeatureFlags = getFeatureFlags; this.#updateTransaction = updateTransaction; this.#trace = trace ?? (((_request, fn) => fn?.()) as TraceCallback); + this.#getSentinelUrl = getSentinelUrl; + this.#getSupportedChainIds = getSupportedChainIds; this.initializeSmartTransactionsForChainId(); @@ -339,7 +344,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo // wondering if we should add some kind of predicate to the polling controller to check whether // we should poll or not const filteredChainIds = (chainIds ?? []).filter((chainId) => - this.#supportedChainIds.includes(chainId), + this.#getSupportedChainIds().includes(chainId), ); if (filteredChainIds.length === 0) { @@ -365,7 +370,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo } initializeSmartTransactionsForChainId() { - if (this.#supportedChainIds.includes(this.#chainId)) { + if (this.#getSupportedChainIds().includes(this.#chainId)) { this.update((state) => { state.smartTransactionsState.smartTransactions[this.#chainId] = state.smartTransactionsState.smartTransactions[this.#chainId] ?? []; @@ -380,7 +385,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo this.timeoutHandle && clearInterval(this.timeoutHandle); - if (!this.#supportedChainIds.includes(this.#chainId)) { + if (!this.#getSupportedChainIds().includes(this.#chainId)) { return; } @@ -1059,7 +1064,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo ); return Object.keys(networkConfigurationsByChainId).filter( (chainId): chainId is Hex => - this.#supportedChainIds.includes(chainId as Hex), + this.#getSupportedChainIds().includes(chainId as Hex), ); } @@ -1122,14 +1127,18 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo const chainId = this.#getChainId({ networkClientId }); let liveness = false; try { + const url = getAPIRequestURL( + APIType.LIVENESS, + chainId, + this.#getSentinelUrl?.(chainId), + ); const response = await this.#trace( { name: SmartTransactionsTraceName.FetchLiveness }, - async () => - await this.#fetch(getAPIRequestURL(APIType.LIVENESS, chainId)), + async () => await this.#fetch(url), ); liveness = Boolean(response.smartTransactions); } catch (error) { - console.log('"fetchLiveness" API call failed'); + console.error('"fetchLiveness" API call failed:', error); } this.update((state) => { diff --git a/src/utils.test.ts b/src/utils.test.ts index 42a8aa46..f13ba0ba 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -621,4 +621,28 @@ describe('src/utils.js', () => { expect(updateTransactionMock).not.toHaveBeenCalled(); }); }); + + describe('getSentinelBaseUrl', () => { + it('returns the correct base URL for Ethereum Mainnet', () => { + const chainId = ChainId.mainnet; + const result = utils.getSentinelBaseUrl(chainId); + expect(result).toBe( + SENTINEL_API_BASE_URL_MAP[parseInt(ChainId.mainnet, 16)], + ); + }); + + it('returns the correct base URL for Sepolia', () => { + const chainId = ChainId.sepolia; + const result = utils.getSentinelBaseUrl(chainId); + expect(result).toBe( + SENTINEL_API_BASE_URL_MAP[parseInt(ChainId.sepolia, 16)], + ); + }); + + it('returns undefined for unsupported chainId', () => { + const unsupportedChainId = '0x999'; // Arbitrary unsupported chain + const result = utils.getSentinelBaseUrl(unsupportedChainId); + expect(result).toBeUndefined(); + }); + }); }); diff --git a/src/utils.ts b/src/utils.ts index cabdc178..8c54c2a4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -36,39 +36,48 @@ export const isSmartTransactionStatusResolved = ( ) => stxStatus === 'uuid_not_found'; // TODO use actual url once API is defined -export function getAPIRequestURL(apiType: APIType, chainId: string): string { +export function getAPIRequestURL( + apiType: APIType, + chainId: string, + sentinelUrl?: string, +): string { const chainIdDec = parseInt(chainId, 16); switch (apiType) { + case APIType.LIVENESS: { + const effectiveSentinelUrl: string | undefined = + sentinelUrl ?? getSentinelBaseUrl(chainId); + + if (effectiveSentinelUrl === undefined) { + throw new Error(`No sentinel URL for chainId ${chainId}`); + } + return `${effectiveSentinelUrl}/network`; + } case APIType.GET_FEES: { return `${API_BASE_URL}/networks/${chainIdDec}/getFees`; } - case APIType.ESTIMATE_GAS: { return `${API_BASE_URL}/networks/${chainIdDec}/estimateGas`; } - case APIType.SUBMIT_TRANSACTIONS: { return `${API_BASE_URL}/networks/${chainIdDec}/submitTransactions?stxControllerVersion=${packageJson.version}`; } - case APIType.CANCEL: { return `${API_BASE_URL}/networks/${chainIdDec}/cancel`; } - case APIType.BATCH_STATUS: { return `${API_BASE_URL}/networks/${chainIdDec}/batchStatus`; } - - case APIType.LIVENESS: { - return `${SENTINEL_API_BASE_URL_MAP[chainIdDec]}/network`; - } - default: { - throw new Error(`Invalid APIType`); // It can never get here thanks to TypeScript. + throw new Error(`Invalid APIType`); } } } +export function getSentinelBaseUrl(chainId: string): string | undefined { + const chainIdDec = parseInt(chainId, 16); + return SENTINEL_API_BASE_URL_MAP[chainIdDec]; +} + export const calculateStatus = (stxStatus: SmartTransactionsStatus) => { if (isSmartTransactionStatusResolved(stxStatus)) { return SmartTransactionStatuses.RESOLVED;