diff --git a/.github/workflows/build-ios-e2e.yml b/.github/workflows/build-ios-e2e.yml
index 44a894f2f8e..b15e262b24b 100644
--- a/.github/workflows/build-ios-e2e.yml
+++ b/.github/workflows/build-ios-e2e.yml
@@ -2,6 +2,21 @@ name: Build iOS E2E Apps
on:
workflow_call:
+ outputs:
+ app-uploaded:
+ description: 'Whether the app was successfully uploaded'
+ value: ${{ jobs.build-ios-apps.outputs.app-uploaded }}
+ inputs:
+ build_type:
+ description: 'The type of build to perform'
+ required: false
+ default: 'main'
+ type: string
+ metamask_environment:
+ description: 'The environment to build for'
+ required: false
+ default: 'qa'
+ type: string
permissions:
contents: read
@@ -20,8 +35,8 @@ jobs:
XCODE_BUILD_SETTINGS: 'COMPILER_INDEX_STORE_ENABLE=NO'
GITHUB_CI: 'true' # This ensures it's available during pod install
PLATFORM: ios
- METAMASK_ENVIRONMENT: qa
- METAMASK_BUILD_TYPE: main
+ METAMASK_ENVIRONMENT: ${{ inputs.metamask_environment }}
+ METAMASK_BUILD_TYPE: ${{ inputs.build_type }}
IS_TEST: true
E2E: 'true'
IGNORE_BOXLOGS_DEVELOPMENT: true
@@ -226,7 +241,7 @@ jobs:
id: upload-app
uses: actions/upload-artifact@v4
with:
- name: MetaMask.app
+ name: ${{ inputs.build_type }}-${{ inputs.metamask_environment }}-MetaMask.app
path: ios/build/Build/Products/Release-iphonesimulator/MetaMask.app
retention-days: 7
if-no-files-found: error
@@ -239,7 +254,7 @@ jobs:
if: ${{ steps.cache-restore.outputs.cache-hit == 'true' || steps.cache-restore-main.outputs.cache-hit == 'true' }}
uses: actions/upload-artifact@v4
with:
- name: index.js.map
+ name: ${{ inputs.build_type }}-${{ inputs.metamask_environment }}-index.js.map
path: sourcemaps/ios/index.js.map
retention-days: 7
if-no-files-found: error
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 9e0e3a0faec..cd671f1a84f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -382,6 +382,23 @@ jobs:
}}
secrets: inherit
+ e2e-smoke-tests-ios-flask:
+ name: 'iOS Flask E2E Smoke Tests'
+ if: ${{ github.event_name != 'merge_group' && needs.needs_e2e_build.outputs.ios_changed == 'true' }}
+ permissions:
+ contents: read
+ id-token: write
+ needs: [needs_e2e_build, ios-tests-ready, smart-e2e-selection]
+ uses: ./.github/workflows/run-e2e-smoke-tests-ios-flask.yml
+ with:
+ changed_files: ${{ needs.needs_e2e_build.outputs.changed_files }}
+ selected_tags: >-
+ ${{
+ (needs.smart-e2e-selection.outputs.ai_confidence >= 80 && needs.smart-e2e-selection.outputs.ai_e2e_test_tags) ||
+ '["ALL"]'
+ }}
+ secrets: inherit
+
js-bundle-size-check:
runs-on: ubuntu-latest
steps:
@@ -604,6 +621,7 @@ jobs:
- e2e-smoke-tests-android
- e2e-smoke-tests-ios
- e2e-smoke-tests-android-flask
+ - e2e-smoke-tests-ios-flask
steps:
- run: |
# Check if all non-E2E jobs passed
@@ -629,9 +647,15 @@ jobs:
exit 1
fi
- FLASK_RESULT="${{ needs.e2e-smoke-tests-android-flask.result }}"
- if [[ "$FLASK_RESULT" == "failure" ]] || [[ "$FLASK_RESULT" == "cancelled" ]]; then
- echo "Android Flask E2E tests failed (result: $FLASK_RESULT)"
+ ANDROID_FLASK_RESULT="${{ needs.e2e-smoke-tests-android-flask.result }}"
+ if [[ "$ANDROID_FLASK_RESULT" == "failure" ]] || [[ "$ANDROID_FLASK_RESULT" == "cancelled" ]]; then
+ echo "Android Flask E2E tests failed (result: $ANDROID_FLASK_RESULT)"
+ exit 1
+ fi
+
+ IOS_FLASK_RESULT="${{ needs.e2e-smoke-tests-ios-flask.result }}"
+ if [[ "$IOS_FLASK_RESULT" == "failure" ]] || [[ "$IOS_FLASK_RESULT" == "cancelled" ]]; then
+ echo "iOS Flask E2E tests failed (result: $IOS_FLASK_RESULT)"
exit 1
fi
fi
diff --git a/.github/workflows/run-e2e-smoke-tests-ios-flask.yml b/.github/workflows/run-e2e-smoke-tests-ios-flask.yml
new file mode 100644
index 00000000000..96fbb2cd1ed
--- /dev/null
+++ b/.github/workflows/run-e2e-smoke-tests-ios-flask.yml
@@ -0,0 +1,167 @@
+name: iOS Flask E2E Smoke Tests
+
+on:
+ workflow_call:
+ inputs:
+ selected_tags:
+ description: 'JSON array of selected tags from Smart E2E selection'
+ required: false
+ type: string
+ default: '["ALL"]'
+ changed_files:
+ description: 'Changed files'
+ required: false
+ type: string
+ default: ''
+
+permissions:
+ contents: read
+ id-token: write
+
+jobs:
+ repack-ios-flask-apps:
+ if: contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests')
+ name: 'Repack iOS Flask Apps'
+ runs-on: ghcr.io/cirruslabs/macos-runner:sequoia
+ permissions:
+ contents: read
+ id-token: write
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Setup iOS Environment
+ timeout-minutes: 15
+ uses: MetaMask/github-tools/.github/actions/setup-e2e-env@v1
+ with:
+ platform: ios
+ setup-simulator: false
+
+ - name: Install dependencies
+ uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 #v3.0.2
+ with:
+ timeout_minutes: 10
+ max_attempts: 3
+ retry_wait_seconds: 30
+ command: yarn install --immutable
+
+ - name: Setup project
+ run: yarn setup:github-ci --no-build-android
+
+ - name: Download Main iOS App artifacts
+ uses: actions/download-artifact@v4
+ with:
+ path: artifacts/
+ pattern: main-*-MetaMask.app
+
+ - name: Setup Main iOS App artifacts
+ run: |
+ mkdir -p ios/build/Build/Products/Release-iphonesimulator/
+ cp -R artifacts/main-qa-MetaMask.app ios/build/Build/Products/Release-iphonesimulator/MetaMask.app
+
+ - name: Repack Main iOS App
+ run: node scripts/repack.js
+ env:
+ PLATFORM: ios
+ METAMASK_ENVIRONMENT: e2e
+ METAMASK_BUILD_TYPE: flask
+ IS_TEST: 'true'
+ E2E: 'true'
+ IGNORE_BOXLOGS_DEVELOPMENT: 'true'
+ GITHUB_CI: 'true'
+ CI: 'true'
+ NODE_OPTIONS: '--max-old-space-size=8192'
+ BRIDGE_USE_DEV_APIS: 'true'
+ RAMP_INTERNAL_BUILD: 'true'
+ SEEDLESS_ONBOARDING_ENABLED: 'true'
+ MM_NOTIFICATIONS_UI_ENABLED: 'true'
+ MM_SECURITY_ALERTS_API_ENABLED: 'true'
+ FEATURES_ANNOUNCEMENTS_ACCESS_TOKEN: ${{ secrets.FEATURES_ANNOUNCEMENTS_ACCESS_TOKEN }}
+ FEATURES_ANNOUNCEMENTS_SPACE_ID: ${{ secrets.FEATURES_ANNOUNCEMENTS_SPACE_ID }}
+ SEGMENT_WRITE_KEY_QA: ${{ secrets.SEGMENT_WRITE_KEY_QA }}
+ SEGMENT_WRITE_KEY_FLASK: ${{ secrets.SEGMENT_WRITE_KEY_FLASK }}
+ SEGMENT_PROXY_URL_QA: ${{ secrets.SEGMENT_PROXY_URL_QA }}
+ SEGMENT_PROXY_URL_FLASK: ${{ secrets.SEGMENT_PROXY_URL_FLASK }}
+ SEGMENT_DELETE_API_SOURCE_ID_QA: ${{ secrets.SEGMENT_DELETE_API_SOURCE_ID_QA }}
+ SEGMENT_DELETE_API_SOURCE_ID_FLASK: ${{ secrets.SEGMENT_DELETE_API_SOURCE_ID_FLASK }}
+ SEGMENT_REGULATIONS_ENDPOINT_QA: ${{ secrets.SEGMENT_REGULATIONS_ENDPOINT_QA }}
+ SEGMENT_REGULATIONS_ENDPOINT_FLASK: ${{ secrets.SEGMENT_REGULATIONS_ENDPOINT_FLASK }}
+ MM_SENTRY_DSN_TEST: ${{ secrets.MM_SENTRY_DSN_TEST }}
+ MM_SENTRY_AUTH_TOKEN: ${{ secrets.MM_SENTRY_AUTH_TOKEN }}
+ MAIN_IOS_GOOGLE_CLIENT_ID_UAT: ${{ secrets.MAIN_IOS_GOOGLE_CLIENT_ID_UAT }}
+ FLASK_IOS_GOOGLE_CLIENT_ID_PROD: ${{ secrets.FLASK_IOS_GOOGLE_CLIENT_ID_PROD }}
+ MAIN_IOS_GOOGLE_REDIRECT_URI_UAT: ${{ secrets.MAIN_IOS_GOOGLE_REDIRECT_URI_UAT }}
+ FLASK_IOS_GOOGLE_REDIRECT_URI_PROD: ${{ secrets.FLASK_IOS_GOOGLE_REDIRECT_URI_PROD }}
+ MAIN_ANDROID_APPLE_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_APPLE_CLIENT_ID_UAT }}
+ FLASK_ANDROID_APPLE_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_APPLE_CLIENT_ID_PROD }}
+ MAIN_ANDROID_GOOGLE_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_GOOGLE_CLIENT_ID_UAT }}
+ FLASK_ANDROID_GOOGLE_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_GOOGLE_CLIENT_ID_PROD }}
+ MAIN_ANDROID_GOOGLE_SERVER_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_GOOGLE_SERVER_CLIENT_ID_UAT }}
+ FLASK_ANDROID_GOOGLE_SERVER_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_GOOGLE_SERVER_CLIENT_ID_PROD }}
+ GOOGLE_SERVICES_B64_IOS: ${{ secrets.GOOGLE_SERVICES_B64_IOS }}
+ GOOGLE_SERVICES_B64_ANDROID: ${{ secrets.GOOGLE_SERVICES_B64_ANDROID }}
+ MM_INFURA_PROJECT_ID: ${{ secrets.MM_INFURA_PROJECT_ID }}
+
+ - name: Upload Repacked iOS Flask App
+ uses: actions/upload-artifact@v4
+ with:
+ name: flask-e2e-MetaMask.app
+ path: ios/build/Build/Products/Release-iphonesimulator/MetaMask.app
+ retention-days: 1
+
+ flask-ios-smoke:
+ if: contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests')
+ needs: [repack-ios-flask-apps]
+ strategy:
+ matrix:
+ split: [1, 2, 3]
+ fail-fast: false
+ uses: ./.github/workflows/run-e2e-workflow.yml
+ with:
+ test-suite-name: flask-ios-smoke-${{ matrix.split }}
+ platform: ios
+ test_suite_tag: 'FlaskBuildTests'
+ split_number: ${{ matrix.split }}
+ total_splits: 3
+ changed_files: ${{ inputs.changed_files }}
+ build_type: 'flask'
+ metamask_environment: 'e2e'
+ secrets: inherit
+
+ report-ios-smoke-tests:
+ name: Report iOS Flask Smoke Tests
+ runs-on: ubuntu-latest
+ if: ${{ !cancelled() && (contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests')) }}
+ needs:
+ - flask-ios-smoke
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Setup Node.js
+ uses: actions/setup-node@v4
+ with:
+ node-version-file: '.nvmrc'
+
+ - name: Download shards test artifacts (XMLs + Screenshots)
+ uses: actions/download-artifact@v4
+ continue-on-error: true
+ with:
+ path: all-test-artifacts/
+ pattern: 'test-e2e-*-ios-*'
+
+ - name: Post Test Report
+ uses: dorny/test-reporter@dc3a92680fcc15842eef52e8c4606ea7ce6bd3f3
+ with:
+ name: 'iOS Flask E2E Smoke Test Results'
+ path: 'all-test-artifacts/**/junit.xml'
+ reporter: 'jest-junit'
+ fail-on-error: false
+ list-suites: 'failed'
+ list-tests: 'failed'
+
+ - name: Upload all test artifacts (XMLs + Screenshots)
+ uses: actions/upload-artifact@v4
+ with:
+ name: e2e-smoke-ios-flask-all-test-artifacts
+ path: all-test-artifacts/
diff --git a/.github/workflows/run-e2e-smoke-tests-ios.yml b/.github/workflows/run-e2e-smoke-tests-ios.yml
index 4bff0ac6607..841334841cf 100644
--- a/.github/workflows/run-e2e-smoke-tests-ios.yml
+++ b/.github/workflows/run-e2e-smoke-tests-ios.yml
@@ -33,6 +33,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 4
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
trade-ios-smoke:
@@ -49,6 +51,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
perps-ios-smoke:
@@ -65,6 +69,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
wallet-platform-ios-smoke:
@@ -81,6 +87,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 2
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
identity-ios-smoke:
@@ -97,6 +105,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 2
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
accounts-ios-smoke:
@@ -113,6 +123,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
network-abstraction-ios-smoke:
@@ -129,6 +141,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 2
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
network-expansion-ios-smoke:
@@ -145,6 +159,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 2
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
prediction-market-ios-smoke:
@@ -161,6 +177,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
card-ios-smoke:
@@ -177,6 +195,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
ramps-ios-smoke:
@@ -193,6 +213,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
multichain-api-ios-smoke:
@@ -209,6 +231,8 @@ jobs:
split_number: ${{ matrix.split }}
total_splits: 1
changed_files: ${{ inputs.changed_files }}
+ build_type: 'main'
+ metamask_environment: 'qa'
secrets: inherit
report-ios-smoke-tests:
diff --git a/.github/workflows/run-e2e-workflow.yml b/.github/workflows/run-e2e-workflow.yml
index d973d640e5c..b09e35fb3e0 100644
--- a/.github/workflows/run-e2e-workflow.yml
+++ b/.github/workflows/run-e2e-workflow.yml
@@ -58,7 +58,7 @@ jobs:
test-apk-target-path: ${{ steps.determine-target-paths.outputs.test-apk-target-path }}
env:
- PREBUILT_IOS_APP_PATH: artifacts/MetaMask.app
+ PREBUILT_IOS_APP_PATH: artifacts/${{ inputs.build_type }}-${{ inputs.metamask_environment }}-MetaMask.app
METAMASK_ENVIRONMENT: ${{ inputs.metamask_environment }}
METAMASK_BUILD_TYPE: ${{ inputs.build_type }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/android/app/build.gradle b/android/app/build.gradle
index 38bfba42ffd..a206ec5276d 100644
--- a/android/app/build.gradle
+++ b/android/app/build.gradle
@@ -188,7 +188,7 @@ android {
minSdkVersion rootProject.ext.minSdkVersion
targetSdkVersion rootProject.ext.targetSdkVersion
versionName "7.65.0"
- versionCode 3418
+ versionCode 3607
testBuildType System.getProperty('testBuildType', 'debug')
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
manifestPlaceholders.MM_BRANCH_KEY_TEST = "$System.env.MM_BRANCH_KEY_TEST"
diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml
index bc632e82252..501368a500c 100644
--- a/android/app/src/main/AndroidManifest.xml
+++ b/android/app/src/main/AndroidManifest.xml
@@ -184,5 +184,12 @@
android:resource="@xml/filepaths"
/>
+
+
+
+
diff --git a/app.config.js b/app.config.js
index de9362b11e2..f7baabdd8ed 100644
--- a/app.config.js
+++ b/app.config.js
@@ -89,7 +89,10 @@ module.exports = {
: 'io.metamask', // Required for @expo/repack-app Android repacking
},
ios: {
- bundleIdentifier: 'io.metamask.MetaMask',
+ bundleIdentifier:
+ process.env.METAMASK_BUILD_TYPE === 'flask'
+ ? 'io.metamask.MetaMask-Flask'
+ : 'io.metamask.MetaMask', // Required for @expo/repack-app iOS repacking
usesAppleSignIn: true,
jsEngine: 'hermes',
},
diff --git a/app/components/Base/RemoteImage/index.test.tsx b/app/components/Base/RemoteImage/index.test.tsx
index adbbc29b18d..64eb4941b75 100644
--- a/app/components/Base/RemoteImage/index.test.tsx
+++ b/app/components/Base/RemoteImage/index.test.tsx
@@ -95,7 +95,9 @@ describe('RemoteImage', () => {
// Wait for IPFS URL resolution
});
- expect(wrapper).toMatchSnapshot();
+ await waitFor(() => {
+ expect(wrapper).toMatchSnapshot();
+ });
});
it('renders with Solana network badge when on Solana network', async () => {
@@ -129,7 +131,9 @@ describe('RemoteImage', () => {
// Wait for component to render
});
- expect(wrapper).toMatchSnapshot();
+ await waitFor(() => {
+ expect(wrapper).toMatchSnapshot();
+ });
});
describe('Error State Reset', () => {
@@ -154,7 +158,9 @@ describe('RemoteImage', () => {
jest.runAllTimers();
});
- expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ await waitFor(() => {
+ expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ });
await act(async () => {
rerender(
@@ -166,7 +172,9 @@ describe('RemoteImage', () => {
jest.runAllTimers();
});
- expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ await waitFor(() => {
+ expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ });
});
it('renders Identicon when address is provided', async () => {
@@ -182,7 +190,9 @@ describe('RemoteImage', () => {
jest.runAllTimers();
});
- expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ await waitFor(() => {
+ expect(queryByTestId('remote-image')).toBeOnTheScreen();
+ });
});
it('renders new image after source changes', async () => {
@@ -197,7 +207,9 @@ describe('RemoteImage', () => {
jest.runAllTimers();
});
- expect(queryByTestId('remote-image-1')).toBeOnTheScreen();
+ await waitFor(() => {
+ expect(queryByTestId('remote-image-1')).toBeOnTheScreen();
+ });
await act(async () => {
rerender(
@@ -209,7 +221,9 @@ describe('RemoteImage', () => {
jest.runAllTimers();
});
- expect(queryByTestId('remote-image-2')).toBeOnTheScreen();
+ await waitFor(() => {
+ expect(queryByTestId('remote-image-2')).toBeOnTheScreen();
+ });
});
});
@@ -228,8 +242,10 @@ describe('RemoteImage', () => {
image.props.onError({ error: 'Failed to load image' });
});
- const identicon = await findByTestId('identicon');
- expect(identicon).toBeOnTheScreen();
+ await waitFor(async () => {
+ const identicon = await findByTestId('identicon');
+ expect(identicon).toBeOnTheScreen();
+ });
});
it('calls onError callback when image fails to load', async () => {
@@ -247,7 +263,9 @@ describe('RemoteImage', () => {
image.props.onError({ error: 'Failed to load image' });
});
- expect(mockOnError).toHaveBeenCalledTimes(1);
+ await waitFor(() => {
+ expect(mockOnError).toHaveBeenCalledTimes(1);
+ });
});
it('resets error state when source URI changes', async () => {
@@ -265,8 +283,10 @@ describe('RemoteImage', () => {
});
// After error, Identicon should be rendered
- const identicon = await findByTestId('identicon');
- expect(identicon).toBeOnTheScreen();
+ await waitFor(async () => {
+ const identicon = await findByTestId('identicon');
+ expect(identicon).toBeOnTheScreen();
+ });
await act(async () => {
rerender(
@@ -278,9 +298,11 @@ describe('RemoteImage', () => {
});
// After source change, error should be reset and Image should render
- expect(queryByTestId('identicon')).not.toBeOnTheScreen();
- const image = UNSAFE_getByType(Image);
- expect(image).toBeDefined();
+ await waitFor(() => {
+ expect(queryByTestId('identicon')).not.toBeOnTheScreen();
+ const image = UNSAFE_getByType(Image);
+ expect(image).toBeDefined();
+ });
});
});
@@ -296,18 +318,14 @@ describe('RemoteImage', () => {
);
await waitFor(() => {
- expect(mockGetFormattedIpfsUrl).toHaveBeenCalled();
+ expect(mockGetFormattedIpfsUrl).toHaveBeenCalledWith(
+ expect.any(String),
+ ipfsUri,
+ false,
+ );
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.source.uri).toBe(resolvedUrl);
});
-
- // Verify the function was called with the IPFS URI
- expect(mockGetFormattedIpfsUrl).toHaveBeenCalledWith(
- expect.any(String),
- ipfsUri,
- false,
- );
-
- const image = UNSAFE_getByType(Image);
- expect(image.props.source.uri).toBe(resolvedUrl);
});
it('handles IPFS URL resolution failure', async () => {
@@ -330,8 +348,10 @@ describe('RemoteImage', () => {
// Wait for component to render
});
- const image = UNSAFE_getByType(Image);
- expect(image.props.source.uri).toBe('');
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.source.uri).toBe('');
+ });
});
});
@@ -339,6 +359,7 @@ describe('RemoteImage', () => {
let dimensionsSpy: jest.SpyInstance;
beforeEach(() => {
+ jest.useFakeTimers();
dimensionsSpy = jest.spyOn(Dimensions, 'get').mockReturnValue({
width: 400,
height: 800,
@@ -348,6 +369,9 @@ describe('RemoteImage', () => {
});
afterEach(() => {
+ jest.runOnlyPendingTimers();
+ jest.useRealTimers();
+ jest.restoreAllMocks();
dimensionsSpy.mockRestore();
});
@@ -361,6 +385,9 @@ describe('RemoteImage', () => {
/>,
);
+ // Clear any pending timers from the mock's automatic onLoad
+ jest.clearAllTimers();
+
await act(async () => {
const image = UNSAFE_getByType(Image);
image.props.onLoad({
@@ -368,9 +395,11 @@ describe('RemoteImage', () => {
});
});
- const image = UNSAFE_getByType(Image);
- expect(image.props.style.width).toBe(368);
- expect(image.props.style.height).toBe(184);
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(368);
+ expect(image.props.style.height).toBe(184);
+ });
});
it('calculates dimensions for vertical image', async () => {
@@ -383,6 +412,9 @@ describe('RemoteImage', () => {
/>,
);
+ // Clear any pending timers from the mock's automatic onLoad
+ jest.clearAllTimers();
+
await act(async () => {
const image = UNSAFE_getByType(Image);
image.props.onLoad({
@@ -390,9 +422,11 @@ describe('RemoteImage', () => {
});
});
- const image = UNSAFE_getByType(Image);
- expect(image.props.style.width).toBe(138);
- expect(image.props.style.height).toBe(276);
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(138);
+ expect(image.props.style.height).toBe(276);
+ });
});
it('calculates dimensions for square image', async () => {
@@ -405,6 +439,9 @@ describe('RemoteImage', () => {
/>,
);
+ // Clear any pending timers from the mock's automatic onLoad
+ jest.clearAllTimers();
+
await act(async () => {
const image = UNSAFE_getByType(Image);
image.props.onLoad({
@@ -412,9 +449,11 @@ describe('RemoteImage', () => {
});
});
- const image = UNSAFE_getByType(Image);
- expect(image.props.style.width).toBe(276);
- expect(image.props.style.height).toBe(276);
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(276);
+ expect(image.props.style.height).toBe(276);
+ });
});
it('does not update dimensions when they remain the same', async () => {
@@ -427,6 +466,9 @@ describe('RemoteImage', () => {
/>,
);
+ // Clear any pending timers from the mock's automatic onLoad
+ jest.clearAllTimers();
+
await act(async () => {
const image = UNSAFE_getByType(Image);
image.props.onLoad({
@@ -434,6 +476,11 @@ describe('RemoteImage', () => {
});
});
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(276);
+ });
+
const firstImage = UNSAFE_getByType(Image);
const firstStyle = firstImage.props.style;
@@ -444,11 +491,16 @@ describe('RemoteImage', () => {
});
});
- const secondImage = UNSAFE_getByType(Image);
- const secondStyle = secondImage.props.style;
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(276);
+
+ const secondImage = UNSAFE_getByType(Image);
+ const secondStyle = secondImage.props.style;
- expect(firstStyle.width).toBe(secondStyle.width);
- expect(firstStyle.height).toBe(secondStyle.height);
+ expect(firstStyle.width).toBe(secondStyle.width);
+ expect(firstStyle.height).toBe(secondStyle.height);
+ });
});
it('handles onLoad without width and height', async () => {
@@ -461,17 +513,26 @@ describe('RemoteImage', () => {
/>,
);
+ // Clear any pending timers from the mock's automatic onLoad
+ jest.clearAllTimers();
+
await act(async () => {
const image = UNSAFE_getByType(Image);
image.props.onLoad({ source: {} });
});
- const image = UNSAFE_getByType(Image);
- expect(image).toBeDefined();
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image).toBeDefined();
+ });
});
});
describe('Rendering Modes', () => {
+ afterEach(() => {
+ jest.restoreAllMocks();
+ });
+
it('renders default image without fadeIn', () => {
const { UNSAFE_getByType } = render(
,
@@ -496,9 +557,11 @@ describe('RemoteImage', () => {
// Wait for component to render
});
- const image = UNSAFE_getByType(Image);
- expect(image).toBeDefined();
- expect(image.props.source.uri).toBe('https://example.com/image.png');
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image).toBeDefined();
+ expect(image.props.source.uri).toBe('https://example.com/image.png');
+ });
});
it('renders token image without full ratio', async () => {
@@ -516,9 +579,11 @@ describe('RemoteImage', () => {
// Wait for component to render
});
- const image = UNSAFE_getByType(Image);
- expect(image).toBeDefined();
- expect(image.props.source.uri).toBe('https://example.com/token.png');
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image).toBeDefined();
+ expect(image.props.source.uri).toBe('https://example.com/token.png');
+ });
});
it('renders token image with full ratio and dimensions', async () => {
@@ -545,9 +610,11 @@ describe('RemoteImage', () => {
});
});
- const image = UNSAFE_getByType(Image);
- expect(image.props.style.width).toBe(368);
- expect(image.props.style.height).toBeCloseTo(245.33, 1);
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image.props.style.width).toBe(368);
+ expect(image.props.style.height).toBeCloseTo(245.33, 1);
+ });
});
it('renders token image with chainId prop', async () => {
@@ -564,9 +631,11 @@ describe('RemoteImage', () => {
// Wait for component to render
});
- const image = UNSAFE_getByType(Image);
- expect(image).toBeDefined();
- expect(image.props.source.uri).toBe('https://example.com/token.png');
+ await waitFor(() => {
+ const image = UNSAFE_getByType(Image);
+ expect(image).toBeDefined();
+ expect(image.props.source.uri).toBe('https://example.com/token.png');
+ });
});
});
});
diff --git a/app/components/UI/Bridge/Views/BridgeView/index.tsx b/app/components/UI/Bridge/Views/BridgeView/index.tsx
index ffad90dad16..0ed5d0b315e 100644
--- a/app/components/UI/Bridge/Views/BridgeView/index.tsx
+++ b/app/components/UI/Bridge/Views/BridgeView/index.tsx
@@ -600,6 +600,7 @@ const BridgeView = () => {
token={sourceToken}
tokenBalance={latestSourceBalance}
onMaxPress={handleSourceMaxPress}
+ isQuoteSponsored={isQuoteSponsored}
/>
) : null}
diff --git a/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx b/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx
index 596043042cb..51cfa9b48e4 100644
--- a/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx
+++ b/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx
@@ -5,34 +5,17 @@ import { Keys } from '../../../../Base/Keypad';
import { BridgeToken } from '../../types';
import { CHAIN_IDS } from '@metamask/transaction-controller';
import { BigNumber } from 'ethers';
-import { useSelector } from 'react-redux';
-import { useTokenAddress } from '../../hooks/useTokenAddress';
-import { isNativeAddress } from '@metamask/bridge-controller';
// Mock dependencies
-jest.mock('react-redux', () => ({
- useSelector: jest.fn(),
+jest.mock('../../hooks/useShouldRenderMaxOption', () => ({
+ useShouldRenderMaxOption: jest.fn(() => true),
}));
-jest.mock('../../hooks/useTokenAddress', () => ({
- useTokenAddress: jest.fn(),
-}));
-
-jest.mock('@metamask/bridge-controller', () => ({
- isNativeAddress: jest.fn(),
-}));
-
-jest.mock('../../../../../core/redux/slices/bridge', () => ({
- selectIsGaslessSwapEnabled: jest.fn(),
-}));
-
-const mockUseSelector = useSelector as jest.MockedFunction;
-const mockUseTokenAddress = useTokenAddress as jest.MockedFunction<
- typeof useTokenAddress
->;
-const mockIsNativeAddress = isNativeAddress as jest.MockedFunction<
- typeof isNativeAddress
->;
+import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption';
+const mockUseShouldRenderMaxOption =
+ useShouldRenderMaxOption as jest.MockedFunction<
+ typeof useShouldRenderMaxOption
+ >;
describe('SwapsKeypad', () => {
const mockOnChange = jest.fn();
@@ -51,10 +34,8 @@ describe('SwapsKeypad', () => {
};
beforeEach(() => {
- jest.clearAllMocks();
- mockUseSelector.mockReturnValue(false);
- mockUseTokenAddress.mockReturnValue(mockToken.address);
- mockIsNativeAddress.mockReturnValue(false);
+ mockUseShouldRenderMaxOption.mockReset();
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
});
afterEach(() => {
@@ -119,8 +100,8 @@ describe('SwapsKeypad', () => {
expect(queryByText('Max')).toBeNull();
});
- it('renders Max button for gasless swap enabled', () => {
- mockUseSelector.mockReturnValue(true);
+ it('renders Max button when useShouldRenderMaxOption returns true', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
const { getByText, queryByText } = render(
{
expect(getByText('Max')).toBeTruthy();
expect(queryByText('90%')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ undefined,
+ );
});
- it('renders Max button for non-native token', () => {
- mockIsNativeAddress.mockReturnValue(false);
+ it('renders 90% button when useShouldRenderMaxOption returns false', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
const { getByText, queryByText } = render(
{
/>,
);
- expect(getByText('Max')).toBeTruthy();
- expect(queryByText('90%')).toBeNull();
+ expect(getByText('90%')).toBeTruthy();
+ expect(queryByText('Max')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ undefined,
+ );
});
- it('renders 90% button for native token without gasless swap', () => {
- mockIsNativeAddress.mockReturnValue(true);
- mockUseSelector.mockImplementation(() => false);
+ it('passes isQuoteSponsored to useShouldRenderMaxOption', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
- const { getByText, queryByText } = render(
+ const { getByText } = render(
{
token={mockToken}
tokenBalance={mockTokenBalance}
onMaxPress={mockOnMaxPress}
+ isQuoteSponsored
/>,
);
- expect(getByText('90%')).toBeTruthy();
- expect(queryByText('Max')).toBeNull();
+ expect(getByText('Max')).toBeTruthy();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ true,
+ );
});
});
@@ -280,7 +275,7 @@ describe('SwapsKeypad', () => {
describe('Max button functionality', () => {
it('calls onMaxPress when Max button is clicked', () => {
- mockIsNativeAddress.mockReturnValue(false);
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
const { getByText } = render(
{
expect(mockOnMaxPress).toHaveBeenCalledTimes(1);
});
-
- it('renders Max button when gasless swap is enabled for native token', () => {
- mockIsNativeAddress.mockReturnValue(true);
- mockUseSelector.mockImplementation((selector) => {
- if (typeof selector === 'function') {
- return true;
- }
- return undefined;
- });
-
- const { getByText } = render(
- ,
- );
-
- expect(getByText('Max')).toBeTruthy();
- });
});
describe('edge cases', () => {
@@ -463,12 +434,11 @@ describe('SwapsKeypad', () => {
});
});
- describe('quick pick button selection logic', () => {
- it('selects gasless quick pick options when not native asset', () => {
- mockIsNativeAddress.mockReturnValue(false);
- mockUseSelector.mockImplementation(() => false);
+ describe('Quick pick options with useShouldRenderMaxOption hook', () => {
+ it('shows Max button when useShouldRenderMaxOption returns true', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
- const { getByText } = render(
+ const { getByText, queryByText } = render(
{
);
expect(getByText('Max')).toBeTruthy();
+ expect(queryByText('90%')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ undefined,
+ );
});
- it('selects standard quick pick options when native asset and gasless disabled', () => {
- mockIsNativeAddress.mockReturnValue(true);
- mockUseSelector.mockImplementation(() => false);
+ it('shows 90% button when useShouldRenderMaxOption returns false', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
- const { getByText } = render(
+ const { getByText, queryByText } = render(
{
);
expect(getByText('90%')).toBeTruthy();
+ expect(queryByText('Max')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ undefined,
+ );
});
- it('selects gasless quick pick options when native asset but gasless enabled', () => {
- mockIsNativeAddress.mockReturnValue(true);
- mockUseSelector.mockImplementation((selector) => {
- if (typeof selector === 'function') {
- return true;
- }
- return undefined;
- });
+ it('passes isQuoteSponsored to useShouldRenderMaxOption', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
const { getByText } = render(
{
token={mockToken}
tokenBalance={mockTokenBalance}
onMaxPress={mockOnMaxPress}
+ isQuoteSponsored
/>,
);
expect(getByText('Max')).toBeTruthy();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ mockTokenBalance.displayBalance,
+ true,
+ );
});
- });
- describe('token address handling', () => {
- it('uses correct token address from useTokenAddress hook', () => {
- const customAddress = '0xabcdef1234567890abcdef1234567890abcdef12';
- mockUseTokenAddress.mockReturnValue(customAddress);
+ it('hides quick pick buttons when displayBalance is zero', () => {
+ const zeroBalance = {
+ displayBalance: '0',
+ atomicBalance: BigNumber.from('0'),
+ };
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
- render(
+ const { queryByText } = render(
,
);
- expect(mockUseTokenAddress).toHaveBeenCalledWith(mockToken);
- expect(mockIsNativeAddress).toHaveBeenCalledWith(customAddress);
+ expect(queryByText('25%')).toBeNull();
+ expect(queryByText('50%')).toBeNull();
+ expect(queryByText('75%')).toBeNull();
+ expect(queryByText('Max')).toBeNull();
+ expect(queryByText('90%')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ mockToken,
+ zeroBalance.displayBalance,
+ undefined,
+ );
});
- it('handles token address changes correctly', () => {
- const { rerender } = render(
+ it('quick pick buttons calculate correct percentages with Max button', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
+
+ const { getByText } = render(
{
/>,
);
- const newToken = { ...mockToken, address: '0xnewaddress' };
- mockUseTokenAddress.mockReturnValue(newToken.address);
+ // Test 25% button
+ act(() => {
+ fireEvent.press(getByText('25%'));
+ });
- rerender(
+ expect(mockOnChange).toHaveBeenCalledWith(
+ expect.objectContaining({
+ value: '25.125', // 25% of 100.5
+ valueAsNumber: 25.125,
+ pressedKey: Keys.Initial,
+ }),
+ );
+
+ // Test 50% button
+ act(() => {
+ fireEvent.press(getByText('50%'));
+ });
+
+ expect(mockOnChange).toHaveBeenCalledWith(
+ expect.objectContaining({
+ value: '50.25', // 50% of 100.5
+ valueAsNumber: 50.25,
+ pressedKey: Keys.Initial,
+ }),
+ );
+
+ // Test 75% button
+ act(() => {
+ fireEvent.press(getByText('75%'));
+ });
+
+ expect(mockOnChange).toHaveBeenCalledWith(
+ expect.objectContaining({
+ value: '75.375', // 75% of 100.5
+ valueAsNumber: 75.375,
+ pressedKey: Keys.Initial,
+ }),
+ );
+
+ // Test Max button calls onMaxPress
+ act(() => {
+ fireEvent.press(getByText('Max'));
+ });
+
+ expect(mockOnMaxPress).toHaveBeenCalledTimes(1);
+ });
+
+ it('quick pick buttons calculate correct percentages with 90% button', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
+
+ const { getByText } = render(
,
);
- expect(mockUseTokenAddress).toHaveBeenCalledWith(newToken);
+ // Test 90% button
+ act(() => {
+ fireEvent.press(getByText('90%'));
+ });
+
+ expect(mockOnChange).toHaveBeenCalledWith(
+ expect.objectContaining({
+ value: '90.45', // 90% of 100.5
+ valueAsNumber: 90.45,
+ pressedKey: Keys.Initial,
+ }),
+ );
});
});
});
diff --git a/app/components/UI/Bridge/components/SwapsKeypad/index.tsx b/app/components/UI/Bridge/components/SwapsKeypad/index.tsx
index 1dfc0afb457..bc2b9aa28c6 100644
--- a/app/components/UI/Bridge/components/SwapsKeypad/index.tsx
+++ b/app/components/UI/Bridge/components/SwapsKeypad/index.tsx
@@ -3,15 +3,11 @@ import Keypad, { KeypadChangeData, Keys } from '../../../../Base/Keypad';
import { Box } from '../../../Box/Box';
import { swapsKeypadStyles as styles } from './styles';
import { QuickPickButtons } from './QuickPickButtons';
-import { useSelector } from 'react-redux';
-import { RootState } from '../../../../../reducers';
import { BridgeToken } from '../../types';
-import { selectIsGaslessSwapEnabled } from '../../../../../core/redux/slices/bridge';
import { QuickPickButtonOption } from './types';
-import { useTokenAddress } from '../../hooks/useTokenAddress';
-import { isNativeAddress } from '@metamask/bridge-controller';
import { useLatestBalance } from '../../hooks/useLatestBalance';
import { BigNumber } from 'bignumber.js';
+import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption';
interface SwapsKeypadProps {
value: string;
@@ -21,6 +17,7 @@ interface SwapsKeypadProps {
token?: BridgeToken;
tokenBalance: ReturnType;
onMaxPress: () => void;
+ isQuoteSponsored?: boolean;
}
export const SwapsKeypad = ({
@@ -31,17 +28,8 @@ export const SwapsKeypad = ({
token,
tokenBalance,
onMaxPress,
+ isQuoteSponsored,
}: SwapsKeypadProps) => {
- const tokenAddress = useTokenAddress(token);
- const isNativeAsset = useMemo(
- () => isNativeAddress(tokenAddress),
- [tokenAddress],
- );
-
- const isGaslessSwapEnabled = useSelector((state: RootState) =>
- token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false,
- );
-
const onQuickOptionPress = useCallback(
(percentage: number) => () => {
if (!tokenBalance?.displayBalance) return '0';
@@ -105,14 +93,18 @@ export const SwapsKeypad = ({
);
const shouldHideQuickPickOptions = useMemo(
- () => new BigNumber(tokenBalance?.displayBalance ?? 0).eq(0),
+ () => new BigNumber(tokenBalance?.displayBalance || 0).eq(0),
[tokenBalance],
);
- const quickPickOptions =
- !isNativeAsset || isGaslessSwapEnabled
- ? gasslessQuickPickOptions
- : standardQuickPickOptions;
+ const shouldRenderMaxOption = useShouldRenderMaxOption(
+ token,
+ tokenBalance?.displayBalance,
+ isQuoteSponsored,
+ );
+ const quickPickOptions = shouldRenderMaxOption
+ ? gasslessQuickPickOptions
+ : standardQuickPickOptions;
return (
diff --git a/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx b/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx
index d463fd7a2f4..a2405b9e6e8 100644
--- a/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx
+++ b/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx
@@ -11,6 +11,16 @@ jest.mock('../../hooks/useLatestBalance', () => ({
useLatestBalance: jest.fn(),
}));
+jest.mock('../../hooks/useShouldRenderMaxOption', () => ({
+ useShouldRenderMaxOption: jest.fn(() => true),
+}));
+
+import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption';
+const mockUseShouldRenderMaxOption =
+ useShouldRenderMaxOption as jest.MockedFunction<
+ typeof useShouldRenderMaxOption
+ >;
+
const mockOnTokenPress = jest.fn();
const mockOnFocus = jest.fn();
const mockOnBlur = jest.fn();
@@ -20,6 +30,7 @@ const mockOnMaxPress = jest.fn();
describe('TokenInputArea', () => {
beforeEach(() => {
jest.clearAllMocks();
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
});
it('renders with initial state', () => {
@@ -174,6 +185,9 @@ describe('TokenInputArea', () => {
},
};
+ // Mock hook to return false since gasless is disabled for native token
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
+
const { queryByText } = renderScreen(
() => (
{
},
};
+ // Mock hook to return false since gasless is disabled for native Polygon token
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
+
const { queryByText } = renderScreen(
() => (
{
// Native tokens show Max button when gasless swaps are enabled
expect(getByText('Max')).toBeTruthy();
});
+
+ describe('Max button visibility with useShouldRenderMaxOption hook', () => {
+ const nativeToken: BridgeToken = {
+ address: '0x0000000000000000000000000000000000000000',
+ symbol: 'ETH',
+ decimals: 18,
+ chainId: '0x1' as `0x${string}`,
+ };
+ const tokenBalance = '1.5';
+
+ beforeEach(() => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
+ });
+
+ afterEach(() => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
+ });
+
+ it('does not display max button when useShouldRenderMaxOption returns false', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(false);
+
+ const { queryByText } = renderScreen(
+ () => (
+
+ ),
+ {
+ name: 'TokenInputArea',
+ },
+ { state: initialState },
+ );
+
+ expect(queryByText('Max')).toBeNull();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ nativeToken,
+ tokenBalance,
+ false,
+ );
+ });
+
+ it('displays max button when useShouldRenderMaxOption returns true', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
+
+ const { getByText } = renderScreen(
+ () => (
+
+ ),
+ {
+ name: 'TokenInputArea',
+ },
+ { state: initialState },
+ );
+
+ expect(getByText('Max')).toBeTruthy();
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ nativeToken,
+ tokenBalance,
+ false,
+ );
+ });
+
+ it('passes isQuoteSponsored to useShouldRenderMaxOption', () => {
+ mockUseShouldRenderMaxOption.mockReturnValue(true);
+
+ renderScreen(
+ () => (
+
+ ),
+ {
+ name: 'TokenInputArea',
+ },
+ { state: initialState },
+ );
+
+ expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith(
+ nativeToken,
+ tokenBalance,
+ true,
+ );
+ });
+
+ it('does not display max button for destination token', () => {
+ const { queryByText } = renderScreen(
+ () => (
+
+ ),
+ {
+ name: 'TokenInputArea',
+ },
+ { state: initialState },
+ );
+
+ // Destination tokens never show max button
+ expect(queryByText('Max')).toBeNull();
+ });
+ });
});
describe('getDisplayAmount', () => {
diff --git a/app/components/UI/Bridge/components/TokenInputArea/index.tsx b/app/components/UI/Bridge/components/TokenInputArea/index.tsx
index c2cd91d30f2..209559b6ae0 100644
--- a/app/components/UI/Bridge/components/TokenInputArea/index.tsx
+++ b/app/components/UI/Bridge/components/TokenInputArea/index.tsx
@@ -1,4 +1,4 @@
-import React, { forwardRef, useImperativeHandle, useRef, useMemo } from 'react';
+import React, { forwardRef, useImperativeHandle, useRef } from 'react';
import {
StyleSheet,
ImageSourcePropType,
@@ -32,9 +32,7 @@ import { useNavigation } from '@react-navigation/native';
import {
setDestTokenExchangeRate,
setSourceTokenExchangeRate,
- selectIsGaslessSwapEnabled,
} from '../../../../../core/redux/slices/bridge';
-import { RootState } from '../../../../../reducers';
///: BEGIN:ONLY_INCLUDE_IF(keyring-snaps)
import { selectMultichainAssetsRates } from '../../../../../selectors/multichain';
///: END:ONLY_INCLUDE_IF(keyring-snaps)
@@ -48,6 +46,7 @@ import { isNativeAddress } from '@metamask/bridge-controller';
import { Theme } from '../../../../../util/theme/models';
import parseAmount from '../../../../../util/parseAmount';
import { useTokenAddress } from '../../hooks/useTokenAddress';
+import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption';
import { calculateInputFontSize } from '../../utils/calculateInputFontSize';
import { formatAmountWithLocaleSeparators } from '../../utils/formatAmountWithLocaleSeparators';
@@ -184,10 +183,6 @@ export const TokenInputArea = forwardRef<
) => {
const currentCurrency = useSelector(selectCurrentCurrency);
- const isGaslessSwapEnabled = useSelector((state: RootState) =>
- token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false,
- );
-
// Need to fetch the exchange rate for the token if we don't have it already
useBridgeExchangeRates({
token,
@@ -266,11 +261,12 @@ export const TokenInputArea = forwardRef<
const isNativeAsset = isNativeAddress(tokenAddress);
- // Show max button for native tokens if gasless swap is enabled OR quote is sponsored
- const shouldShowMaxButton = useMemo(() => {
- if (!isNativeAsset) return true; // Always show for non-native tokens
- return isGaslessSwapEnabled || isQuoteSponsored;
- }, [isNativeAsset, isGaslessSwapEnabled, isQuoteSponsored]);
+ const shouldShowMaxButton = useShouldRenderMaxOption(
+ token,
+ tokenBalance,
+ isQuoteSponsored,
+ );
+
const formattedAddress =
tokenAddress && !isNativeAsset ? formatAddress(tokenAddress) : undefined;
@@ -359,7 +355,6 @@ export const TokenInputArea = forwardRef<
{
+ const isGaslessSwapEnabled = useSelector((state: RootState) =>
+ token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false,
+ );
+ const stxEnabled = useSelector((state: RootState) =>
+ token?.chainId && !isNonEvmChainId(token.chainId)
+ ? selectShouldUseSmartTransaction(
+ state,
+ formatChainIdToHex(token.chainId),
+ )
+ : false,
+ );
+ const tokenAddress = useTokenAddress(token);
+ const isNativeAsset = isNativeAddress(tokenAddress);
+ const isZeroDisplayBalance = new BigNumber(displayBalance || 0).eq(0);
+
+ // Do not render on zero balance or undefined token
+ if (isZeroDisplayBalance || !token) {
+ return false;
+ }
+
+ // Always show for non-native tokens
+ if (!isNativeAsset) {
+ return true;
+ }
+
+ // Show for EVM native tokens if gasless swap is enabled OR quote is sponsored
+ // while smart transactions is enabled.
+ // For non-EVM native tokens stxEnabled will be false evaluating the whole
+ // expression to false. We do not know the fees beforehand so we cannot
+ // max out the input amount.
+ return (isGaslessSwapEnabled || isQuoteSponsored) && stxEnabled;
+};
diff --git a/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts b/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts
new file mode 100644
index 00000000000..68d0c3cd747
--- /dev/null
+++ b/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts
@@ -0,0 +1,594 @@
+import { renderHook } from '@testing-library/react-hooks';
+import { useShouldRenderMaxOption } from '.';
+import { BridgeToken } from '../../types';
+import { CHAIN_IDS } from '@metamask/transaction-controller';
+import { useSelector } from 'react-redux';
+import { useTokenAddress } from '../useTokenAddress';
+import { isNativeAddress } from '@metamask/bridge-controller';
+
+// Mock dependencies
+jest.mock('react-redux', () => ({
+ useSelector: jest.fn(),
+}));
+
+jest.mock('../useTokenAddress', () => ({
+ useTokenAddress: jest.fn(),
+}));
+
+jest.mock('@metamask/bridge-controller', () => ({
+ isNativeAddress: jest.fn(),
+}));
+
+jest.mock('../../../../../core/redux/slices/bridge', () => ({
+ selectIsGaslessSwapEnabled: jest.fn(),
+}));
+
+jest.mock('../../../../../selectors/smartTransactionsController', () => ({
+ selectShouldUseSmartTransaction: jest.fn(),
+}));
+
+const mockUseSelector = useSelector as jest.MockedFunction;
+const mockUseTokenAddress = useTokenAddress as jest.MockedFunction<
+ typeof useTokenAddress
+>;
+const mockIsNativeAddress = isNativeAddress as jest.MockedFunction<
+ typeof isNativeAddress
+>;
+
+/**
+ * IMPORTANT: useSelector call order in the hook:
+ * 1. First call: isGaslessSwapEnabled (line 12 in hook)
+ * 2. Second call: stxEnabled (line 15 in hook)
+ */
+describe('useShouldRenderMaxOption', () => {
+ const mockToken: BridgeToken = {
+ address: '0x1234567890123456789012345678901234567890',
+ symbol: 'TEST',
+ decimals: 18,
+ chainId: CHAIN_IDS.MAINNET,
+ };
+
+ const nativeToken: BridgeToken = {
+ address: '0x0000000000000000000000000000000000000000',
+ symbol: 'ETH',
+ decimals: 18,
+ chainId: CHAIN_IDS.MAINNET,
+ };
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+ mockIsNativeAddress.mockReturnValue(false);
+ // Default: isGaslessSwapEnabled = false, stxEnabled = true
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = false
+ // Second call: stxEnabled = true
+ return callCount === 2;
+ });
+ });
+
+ describe('Zero balance scenarios', () => {
+ it('returns false when displayBalance is undefined', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, undefined, false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false when displayBalance is "0"', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false when displayBalance is "0.0"', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0.0', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false when displayBalance is empty string', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Non-native token scenarios', () => {
+ beforeEach(() => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+ });
+
+ it('returns true for non-native token with balance regardless of gasless', () => {
+ mockUseSelector.mockImplementation(() => false); // Both selectors false
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '100.5', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns true for non-native token with balance regardless of stxEnabled', () => {
+ mockUseSelector.mockImplementation(() => false); // stxEnabled = false
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '50', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns true for non-native token with balance regardless of isQuoteSponsored', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '25.75', true),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns true for non-native token with very small balance', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0.000001', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns false for non-native token with zero balance', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Native token scenarios', () => {
+ beforeEach(() => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ });
+
+ it('returns false when native token has zero balance', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '0', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false when native token has zero balance even with all conditions favorable', () => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ mockUseSelector.mockReturnValue(true); // stxEnabled=true, gasless=true
+
+ const { result } = renderHook(
+ () => useShouldRenderMaxOption(nativeToken, '0', true), // sponsored=true
+ );
+
+ // Zero balance always returns false
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Edge cases', () => {
+ it('returns false when token is undefined', () => {
+ mockUseTokenAddress.mockReturnValue(undefined);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(undefined, '100', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false when token is undefined but has balance', () => {
+ mockUseTokenAddress.mockReturnValue(undefined);
+ mockIsNativeAddress.mockReturnValue(false);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(undefined, '500', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('handles large balance values correctly for non-native tokens', () => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '1000000.123456789', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('handles very small but non-zero balance for non-native tokens', () => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0.000000001', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('correctly identifies native token without chainId in token object', () => {
+ const tokenWithoutChainId = {
+ address: '0x0000000000000000000000000000000000000000',
+ symbol: 'ETH',
+ decimals: 18,
+ } as BridgeToken;
+
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(tokenWithoutChainId.address);
+ mockUseSelector.mockReturnValue(false); // stxEnabled = false, isGaslessSwapEnabled = false
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(tokenWithoutChainId, '100', false),
+ );
+
+ // Should return false (native + no gasless + no sponsored + no stx)
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Hook parameter validation', () => {
+ it('uses useTokenAddress hook to get token address', () => {
+ const customToken = {
+ ...mockToken,
+ address: '0xabcdef1234567890abcdef1234567890abcdef12',
+ };
+ mockUseTokenAddress.mockReturnValue(customToken.address);
+
+ renderHook(() => useShouldRenderMaxOption(customToken, '100', false));
+
+ expect(mockUseTokenAddress).toHaveBeenCalledWith(customToken);
+ });
+
+ it('checks if token address is native using isNativeAddress', () => {
+ const tokenAddress = '0x1234567890123456789012345678901234567890';
+ mockUseTokenAddress.mockReturnValue(tokenAddress);
+ mockIsNativeAddress.mockReturnValue(false);
+
+ renderHook(() => useShouldRenderMaxOption(mockToken, '100', false));
+
+ expect(mockIsNativeAddress).toHaveBeenCalledWith(tokenAddress);
+ });
+
+ it('calls selectIsGaslessSwapEnabled with correct chainId', () => {
+ const tokenWithChainId = {
+ ...mockToken,
+ chainId: '0xa' as `0x${string}`, // Optimism
+ };
+ mockUseSelector.mockReturnValue(true);
+
+ renderHook(() =>
+ useShouldRenderMaxOption(tokenWithChainId, '100', false),
+ );
+
+ // Verify useSelector was called (it uses selectIsGaslessSwapEnabled)
+ expect(mockUseSelector).toHaveBeenCalled();
+ });
+ });
+
+ describe('Default parameter values', () => {
+ it('uses false as default for isQuoteSponsored when not provided', () => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = false
+ // Second call: stxEnabled = true
+ return callCount === 2;
+ });
+
+ // Call without isQuoteSponsored parameter
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '100'),
+ );
+
+ // Should return false (native + stx=true + gasless=false + sponsored=false)
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Integration scenarios', () => {
+ it('returns correct value for typical non-native ERC20 token', () => {
+ const usdcToken: BridgeToken = {
+ address: '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48',
+ symbol: 'USDC',
+ decimals: 6,
+ chainId: CHAIN_IDS.MAINNET,
+ };
+
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(usdcToken.address);
+ mockUseSelector.mockReturnValue(false); // Everything disabled
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(usdcToken, '1000', false),
+ );
+
+ // Non-native tokens always show max (even with everything disabled)
+ expect(result.current).toBe(true);
+ });
+
+ it('returns correct value for native ETH with gasless swap enabled', () => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ mockUseSelector.mockReturnValue(true); // stxEnabled=true, gasless=true
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '5.5', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns correct value for native ETH with sponsored quote', () => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = false
+ // Second call: stxEnabled = true
+ return callCount === 2;
+ });
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '2.25', true),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('returns correct value for native ETH without gasless or sponsored but stx enabled', () => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = false
+ // Second call: stxEnabled = true
+ return callCount === 2;
+ });
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '10', false),
+ );
+
+ // Should be false (native + stx=true + gasless=false + sponsored=false)
+ expect(result.current).toBe(false);
+ });
+ });
+
+ describe('Boundary conditions', () => {
+ it('handles balance exactly equal to zero', () => {
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0.00000', false),
+ );
+
+ expect(result.current).toBe(false);
+ });
+
+ it('handles extremely small but non-zero balance', () => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(mockToken, '0.00000001', false),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('handles extremely large balance', () => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(
+ mockToken,
+ '999999999999999999999999999.999999',
+ false,
+ ),
+ );
+
+ expect(result.current).toBe(true);
+ });
+
+ it('handles balance with many decimal places', () => {
+ mockIsNativeAddress.mockReturnValue(false);
+ mockUseTokenAddress.mockReturnValue(mockToken.address);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(
+ mockToken,
+ '123.456789012345678901234567',
+ false,
+ ),
+ );
+
+ expect(result.current).toBe(true);
+ });
+ });
+
+ describe('Truth table - Native token with all combinations', () => {
+ beforeEach(() => {
+ mockIsNativeAddress.mockReturnValue(true);
+ mockUseTokenAddress.mockReturnValue(nativeToken.address);
+ });
+
+ const truthTable = [
+ { stxEnabled: true, gasless: true, sponsored: false, expected: true },
+ { stxEnabled: true, gasless: false, sponsored: true, expected: true },
+ { stxEnabled: true, gasless: true, sponsored: true, expected: true },
+ { stxEnabled: true, gasless: false, sponsored: false, expected: false },
+ { stxEnabled: false, gasless: true, sponsored: false, expected: false },
+ { stxEnabled: false, gasless: false, sponsored: true, expected: false },
+ { stxEnabled: false, gasless: true, sponsored: true, expected: false },
+ { stxEnabled: false, gasless: false, sponsored: false, expected: false },
+ ];
+
+ truthTable.forEach(({ stxEnabled, gasless, sponsored, expected }) => {
+ it(`stxEnabled=${stxEnabled}, gasless=${gasless}, sponsored=${sponsored} → returns ${expected}`, () => {
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled (line 12 in hook)
+ // Second call: stxEnabled (line 15 in hook)
+ if (callCount === 1) {
+ return gasless;
+ }
+ return stxEnabled;
+ });
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(nativeToken, '100', sponsored),
+ );
+
+ expect(result.current).toBe(expected);
+ });
+ });
+ });
+
+ describe('Non-EVM native token scenarios', () => {
+ const solanaToken: BridgeToken = {
+ address: '0x0000000000000000000000000000000000000000',
+ symbol: 'SOL',
+ decimals: 9,
+ chainId: 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp', // Solana mainnet CAIP-2
+ };
+
+ const bitcoinToken: BridgeToken = {
+ address: '0x0000000000000000000000000000000000000000',
+ symbol: 'BTC',
+ decimals: 8,
+ chainId: 'bip122:000000000019d6689c085ae165831e93', // Bitcoin mainnet CAIP-2
+ };
+
+ beforeEach(() => {
+ mockIsNativeAddress.mockReturnValue(true);
+ });
+
+ it('returns false for Solana native token even with gasless enabled', () => {
+ mockUseTokenAddress.mockReturnValue(solanaToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = true
+ // Second call: stxEnabled = false (non-EVM chain)
+ if (callCount === 1) {
+ return true; // gasless enabled
+ }
+ return false; // stxEnabled is false for non-EVM
+ });
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(solanaToken, '100', false),
+ );
+
+ // Should return false because stxEnabled is false for non-EVM chains
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false for Solana native token even with sponsored quote', () => {
+ mockUseTokenAddress.mockReturnValue(solanaToken.address);
+ mockUseSelector.mockImplementation(
+ () =>
+ // First call: isGaslessSwapEnabled = false
+ // Second call: stxEnabled = false (non-EVM chain)
+ false,
+ );
+
+ const { result } = renderHook(
+ () => useShouldRenderMaxOption(solanaToken, '50', true), // sponsored = true
+ );
+
+ // Should return false because stxEnabled is false for non-EVM chains
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false for Solana native token with both gasless and sponsored enabled', () => {
+ mockUseTokenAddress.mockReturnValue(solanaToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = true
+ // Second call: stxEnabled = false (non-EVM chain)
+ if (callCount === 1) {
+ return true; // gasless enabled
+ }
+ return false; // stxEnabled is false for non-EVM
+ });
+
+ const { result } = renderHook(
+ () => useShouldRenderMaxOption(solanaToken, '25.5', true), // sponsored = true
+ );
+
+ // Should return false because stxEnabled is false for non-EVM chains
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false for Bitcoin native token with gasless enabled', () => {
+ mockUseTokenAddress.mockReturnValue(bitcoinToken.address);
+ let callCount = 0;
+ mockUseSelector.mockImplementation(() => {
+ callCount++;
+ // First call: isGaslessSwapEnabled = true
+ // Second call: stxEnabled = false (non-EVM chain)
+ if (callCount === 1) {
+ return true; // gasless enabled
+ }
+ return false; // stxEnabled is false for non-EVM
+ });
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(bitcoinToken, '1.5', false),
+ );
+
+ // Should return false because stxEnabled is false for non-EVM chains
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false for Bitcoin native token without any flags enabled', () => {
+ mockUseTokenAddress.mockReturnValue(bitcoinToken.address);
+ mockUseSelector.mockReturnValue(false); // All flags disabled
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(bitcoinToken, '0.5', false),
+ );
+
+ // Should return false because stxEnabled is false for non-EVM chains
+ expect(result.current).toBe(false);
+ });
+
+ it('returns false for non-EVM native token with zero balance', () => {
+ mockUseTokenAddress.mockReturnValue(solanaToken.address);
+ mockUseSelector.mockReturnValue(false);
+
+ const { result } = renderHook(() =>
+ useShouldRenderMaxOption(solanaToken, '0', false),
+ );
+
+ // Should return false due to zero balance (checked before non-EVM logic)
+ expect(result.current).toBe(false);
+ });
+ });
+});
diff --git a/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx b/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx
index 815c030a967..4e89e9f1d01 100644
--- a/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx
+++ b/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx
@@ -143,7 +143,7 @@ describe('DeepLinkModal', () => {
}),
build: jest.fn().mockImplementation(function (this: MockBuilder) {
return {
- name: 'Deep link Used',
+ name: 'Deep Link Used',
properties: {
route: 'invalid',
was_app_installed: true,
diff --git a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx
index 003e0dacd12..37e6c063e7f 100644
--- a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx
+++ b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx
@@ -119,6 +119,22 @@ jest.mock('../../hooks/stream/usePerpsLiveOrderBook', () => ({
usePerpsLiveOrderBook: (params: unknown) => mockUsePerpsLiveOrderBook(params),
}));
+// Mock usePerpsLivePrices for live price updates in header (TAT-2441)
+const mockUsePerpsLivePrices = jest.fn<
+ Record,
+ [unknown]
+>(() => ({
+ BTC: {
+ price: '50000',
+ percentChange24h: '2.5',
+ symbol: 'BTC',
+ },
+}));
+
+jest.mock('../../hooks/stream/usePerpsLivePrices', () => ({
+ usePerpsLivePrices: (params: unknown) => mockUsePerpsLivePrices(params),
+}));
+
// Mock usePerpsMeasurement
jest.mock('../../hooks/usePerpsMeasurement', () => ({
usePerpsMeasurement: jest.fn(),
@@ -246,7 +262,7 @@ jest.mock('../../components/PerpsMarketHeader', () => {
jest.mock(
'../../../../../component-library/components/BottomSheets/BottomSheet',
() => {
- const { View } = jest.requireActual('react-native');
+ const { View, TouchableOpacity, Text } = jest.requireActual('react-native');
const ReactMock = jest.requireActual('react');
return {
__esModule: true,
@@ -257,7 +273,17 @@ jest.mock(
onClose?: () => void;
},
_ref: unknown,
- ) => {props.children},
+ ) => (
+
+ {props.children}
+
+ Backdrop
+
+
+ ),
),
};
},
@@ -286,12 +312,23 @@ jest.mock('../PerpsSelectModifyActionView', () => {
jest.mock(
'../../components/PerpsBottomSheetTooltip/PerpsBottomSheetTooltip',
() => {
- const { View } = jest.requireActual('react-native');
+ const { View, TouchableOpacity, Text } = jest.requireActual('react-native');
return {
__esModule: true,
- default: (props: { testID?: string; isVisible?: boolean }) =>
+ default: (props: {
+ testID?: string;
+ isVisible?: boolean;
+ onClose?: () => void;
+ }) =>
props.isVisible ? (
-
+
+
+ Close
+
+
) : null,
};
},
@@ -321,6 +358,13 @@ describe('PerpsOrderBookView', () => {
bestAsk: '50001',
spread: '1.00000',
});
+ mockUsePerpsLivePrices.mockReturnValue({
+ BTC: {
+ price: '50000',
+ percentChange24h: '2.5',
+ symbol: 'BTC',
+ },
+ });
});
describe('rendering', () => {
@@ -1070,6 +1114,17 @@ describe('PerpsOrderBookView', () => {
);
});
+ it('subscribes to live prices for header price display (TAT-2441)', () => {
+ renderWithProvider(, { state: initialState });
+
+ expect(mockUsePerpsLivePrices).toHaveBeenCalledWith(
+ expect.objectContaining({
+ symbols: ['BTC'],
+ throttleMs: 1000,
+ }),
+ );
+ });
+
it('subscribes to live order book with MAX_ORDER_BOOK_LEVELS (20) for server-side aggregation', () => {
renderWithProvider(, { state: initialState });
@@ -1139,5 +1194,293 @@ describe('PerpsOrderBookView', () => {
getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
).toBeOnTheScreen();
});
+
+ it('falls back to static market price when live price is not available', () => {
+ mockUsePerpsLivePrices.mockReturnValue({});
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('falls back to static market price when live price is invalid (NaN)', () => {
+ mockUsePerpsLivePrices.mockReturnValue({
+ BTC: {
+ price: 'invalid-price',
+ percentChange24h: '2.5',
+ symbol: 'BTC',
+ },
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('falls back to static market price when live price is zero or negative', () => {
+ mockUsePerpsLivePrices.mockReturnValue({
+ BTC: {
+ price: '0',
+ percentChange24h: '2.5',
+ symbol: 'BTC',
+ },
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('handles invalid topOfBook values gracefully', () => {
+ mockUsePerpsTopOfBook.mockReturnValue({
+ bestBid: '0',
+ bestAsk: '-1',
+ spread: '0',
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('handles non-finite topOfBook values', () => {
+ mockUsePerpsTopOfBook.mockReturnValue({
+ bestBid: 'NaN',
+ bestAsk: 'Infinity',
+ spread: '0',
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('syncs saved grouping when it loads after mount', async () => {
+ const { usePerpsOrderBookGrouping } = jest.requireMock(
+ '../../hooks/usePerpsOrderBookGrouping',
+ );
+
+ // First render with undefined savedGrouping
+ usePerpsOrderBookGrouping.mockReturnValue({
+ savedGrouping: undefined,
+ saveGrouping: mockSaveGrouping,
+ });
+
+ const { rerender, getByTestId } = renderWithProvider(
+ ,
+ {
+ state: initialState,
+ },
+ );
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+
+ // Simulate savedGrouping loading
+ usePerpsOrderBookGrouping.mockReturnValue({
+ savedGrouping: 5,
+ saveGrouping: mockSaveGrouping,
+ });
+
+ rerender();
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+ });
+
+ describe('spread tooltip', () => {
+ it('opens spread tooltip when info button is pressed', async () => {
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ const spreadInfoButton = getByTestId(
+ PerpsOrderBookViewSelectorsIDs.SPREAD_INFO_BUTTON,
+ );
+
+ fireEvent.press(spreadInfoButton);
+
+ await waitFor(() => {
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP),
+ ).toBeOnTheScreen();
+ });
+ });
+
+ it('closes spread tooltip when onClose is called', async () => {
+ const { getByTestId, queryByTestId } = renderWithProvider(
+ ,
+ {
+ state: initialState,
+ },
+ );
+
+ // Open tooltip
+ const spreadInfoButton = getByTestId(
+ PerpsOrderBookViewSelectorsIDs.SPREAD_INFO_BUTTON,
+ );
+ fireEvent.press(spreadInfoButton);
+
+ await waitFor(() => {
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP),
+ ).toBeOnTheScreen();
+ });
+
+ // Close the tooltip via the close button in the mock
+ const closeButton = getByTestId(
+ `${PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP}-close`,
+ );
+ fireEvent.press(closeButton);
+
+ await waitFor(() => {
+ expect(
+ queryByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP),
+ ).toBeNull();
+ });
+ });
+ });
+
+ describe('depth band sheet close', () => {
+ it('closes depth band sheet via onClose callback', async () => {
+ const { getByTestId, queryByText } = renderWithProvider(
+ ,
+ {
+ state: initialState,
+ },
+ );
+
+ // Open the depth band sheet
+ const depthBandButton = getByTestId(
+ PerpsOrderBookViewSelectorsIDs.DEPTH_BAND_BUTTON,
+ );
+ fireEvent.press(depthBandButton);
+
+ await waitFor(() => {
+ expect(queryByText('Depth Band')).toBeOnTheScreen();
+ });
+
+ // Close via backdrop (onClose callback)
+ const backdrop = getByTestId('bottom-sheet-backdrop');
+ fireEvent.press(backdrop);
+
+ await waitFor(() => {
+ expect(queryByText('Depth Band')).toBeNull();
+ });
+ });
+ });
+
+ describe('geo-block modal interactions', () => {
+ it('closes geo-block modal when onClose is called', async () => {
+ const { selectPerpsEligibility } = jest.requireMock(
+ '../../selectors/perpsController',
+ );
+ selectPerpsEligibility.mockReturnValue(false);
+
+ // Ensure no existing position so Long/Short buttons are shown
+ const { useHasExistingPosition } = jest.requireMock(
+ '../../hooks/useHasExistingPosition',
+ );
+ useHasExistingPosition.mockReturnValue({
+ isLoading: false,
+ existingPosition: null,
+ });
+
+ const { getByTestId, queryByTestId } = renderWithProvider(
+ ,
+ {
+ state: initialState,
+ },
+ );
+
+ // Trigger geo-block modal by pressing Long button when not eligible
+ const longButton = getByTestId(
+ PerpsOrderBookViewSelectorsIDs.LONG_BUTTON,
+ );
+ fireEvent.press(longButton);
+
+ // Verify modal is shown
+ const geoBlockTooltipId = `${PerpsOrderBookViewSelectorsIDs.CONTAINER}-geo-block-tooltip`;
+ expect(queryByTestId(geoBlockTooltipId)).toBeOnTheScreen();
+
+ // Close the modal via the close button
+ const closeButton = getByTestId(`${geoBlockTooltipId}-close`);
+ fireEvent.press(closeButton);
+
+ await waitFor(() => {
+ expect(queryByTestId(geoBlockTooltipId)).toBeNull();
+ });
+ });
+ });
+
+ describe('grouping with no market price', () => {
+ it('returns null grouping when market price is unavailable', () => {
+ const { usePerpsMarkets } = jest.requireMock('../../hooks');
+ usePerpsMarkets.mockReturnValue({
+ markets: [
+ {
+ symbol: 'BTC',
+ price: null,
+ leverage: 50,
+ },
+ ],
+ isLoading: false,
+ error: null,
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
+
+ it('handles invalid market price format', () => {
+ const { usePerpsMarkets } = jest.requireMock('../../hooks');
+ usePerpsMarkets.mockReturnValue({
+ markets: [
+ {
+ symbol: 'BTC',
+ price: 'invalid',
+ leverage: 50,
+ },
+ ],
+ isLoading: false,
+ error: null,
+ });
+
+ const { getByTestId } = renderWithProvider(, {
+ state: initialState,
+ });
+
+ expect(
+ getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER),
+ ).toBeOnTheScreen();
+ });
});
});
diff --git a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx
index e420cbffe1d..a170292f743 100644
--- a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx
+++ b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx
@@ -66,6 +66,7 @@ import {
} from '../../hooks';
import { useHasExistingPosition } from '../../hooks/useHasExistingPosition';
import { usePerpsLiveOrderBook } from '../../hooks/stream/usePerpsLiveOrderBook';
+import { usePerpsLivePrices } from '../../hooks/stream/usePerpsLivePrices';
import { usePerpsTopOfBook } from '../../hooks/stream/usePerpsTopOfBook';
import { usePerpsEventTracking } from '../../hooks/usePerpsEventTracking';
import { usePerpsMeasurement } from '../../hooks/usePerpsMeasurement';
@@ -203,6 +204,27 @@ const PerpsOrderBookView: React.FC = ({
// This is intentionally independent from order book aggregation/grouping.
const topOfBook = usePerpsTopOfBook({ symbol: symbol || '' });
+ // Subscribe to live price updates for header display (TAT-2441)
+ // This ensures the price in the header updates in real-time
+ const livePrices = usePerpsLivePrices({
+ symbols: symbol ? [symbol] : [],
+ throttleMs: 1000,
+ });
+
+ // Current price for header - use live price with fallback to static market price
+ const currentPrice = useMemo(() => {
+ const priceData = livePrices[symbol || ''];
+ if (priceData?.price) {
+ const parsed = parseFloat(priceData.price);
+ // Validate parsed value - fallback to marketPrice if invalid
+ if (Number.isFinite(parsed) && parsed > 0) {
+ return parsed;
+ }
+ }
+ // Fallback to static market price if live price not available or invalid
+ return marketPrice ?? 0;
+ }, [livePrices, symbol, marketPrice]);
+
const spreadMetrics = useMemo(() => {
const bidStr = topOfBook?.bestBid;
const askStr = topOfBook?.bestAsk;
@@ -470,7 +492,7 @@ const PerpsOrderBookView: React.FC = ({
) : (
@@ -504,7 +526,7 @@ const PerpsOrderBookView: React.FC = ({
)}
diff --git a/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx b/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx
index 435c084d28c..5d6bdd70783 100644
--- a/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx
+++ b/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx
@@ -84,6 +84,9 @@ describe('PerpsStreamManager', () => {
subscribeToPositions: mockSubscribeToPositions,
subscribeToAccount: mockSubscribeToAccount,
isCurrentlyReinitializing: jest.fn().mockReturnValue(false),
+ getMarkets: jest
+ .fn()
+ .mockResolvedValue([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]),
} as unknown as typeof mockEngine.context.PerpsController;
// Mock AccountTreeController for getEvmAccountFromSelectedAccountGroup
@@ -940,6 +943,241 @@ describe('PerpsStreamManager', () => {
});
});
+ describe('PriceStreamChannel.prewarm non-blocking behavior', () => {
+ it('returns immediately without waiting for getMarkets', async () => {
+ // Create a promise that we can control
+ let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn();
+ const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => {
+ resolveGetMarkets = resolve;
+ });
+
+ mockEngine.context.PerpsController.getMarkets = jest
+ .fn()
+ .mockReturnValue(getMarketsPromise);
+
+ // prewarm should return immediately
+ const cleanupPromise = testStreamManager.prices.prewarm();
+
+ // The promise should resolve immediately (before getMarkets completes)
+ const cleanup = await cleanupPromise;
+ expect(typeof cleanup).toBe('function');
+
+ // getMarkets was called but not awaited
+ expect(mockEngine.context.PerpsController.getMarkets).toHaveBeenCalled();
+
+ // subscribeToPrices should NOT have been called yet
+ expect(mockSubscribeToPrices).not.toHaveBeenCalled();
+
+ // Now resolve getMarkets
+ resolveGetMarkets([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]);
+
+ // Wait for the promise chain to complete
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ // Now subscribeToPrices should have been called
+ expect(mockSubscribeToPrices).toHaveBeenCalledWith({
+ symbols: ['BTC-PERP', 'ETH-PERP'],
+ callback: expect.any(Function),
+ });
+
+ cleanup();
+ });
+
+ it('skips subscription when cleanup occurs before getMarkets completes', async () => {
+ // Create a promise that we can control
+ let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn();
+ const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => {
+ resolveGetMarkets = resolve;
+ });
+
+ mockEngine.context.PerpsController.getMarkets = jest
+ .fn()
+ .mockReturnValue(getMarketsPromise);
+
+ // prewarm should return immediately
+ const cleanup = await testStreamManager.prices.prewarm();
+
+ // Call cleanup before getMarkets resolves
+ cleanup();
+
+ // Now resolve getMarkets
+ resolveGetMarkets([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]);
+
+ // Wait for the promise chain to complete
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ // subscribeToPrices should NOT have been called (cleaned up before it could subscribe)
+ expect(mockSubscribeToPrices).not.toHaveBeenCalled();
+ });
+
+ it('logs error and returns cleanup function when getMarkets fails', async () => {
+ mockEngine.context.PerpsController.getMarkets = jest
+ .fn()
+ .mockRejectedValue(new Error('Network error'));
+
+ // prewarm should return immediately
+ const cleanup = await testStreamManager.prices.prewarm();
+ expect(typeof cleanup).toBe('function');
+
+ // Wait for the promise chain to complete
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ // subscribeToPrices should NOT have been called due to error
+ expect(mockSubscribeToPrices).not.toHaveBeenCalled();
+
+ // Logger.error should have been called
+ expect(mockLogger.error).toHaveBeenCalled();
+
+ cleanup();
+ });
+
+ it('calls actual unsubscribe when cleanupPrewarm runs after subscription is established', async () => {
+ const mockActualUnsubscribe = jest.fn();
+ mockSubscribeToPrices.mockReturnValue(mockActualUnsubscribe);
+
+ mockEngine.context.PerpsController.getMarkets = jest
+ .fn()
+ .mockResolvedValue([{ name: 'BTC-PERP' }]);
+
+ // prewarm and wait for subscription to be established
+ const cleanup = await testStreamManager.prices.prewarm();
+
+ // Wait for the background subscription to be set up
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(mockSubscribeToPrices).toHaveBeenCalled();
+
+ // Call cleanup
+ cleanup();
+
+ // The actual unsubscribe should have been called
+ expect(mockActualUnsubscribe).toHaveBeenCalled();
+ });
+
+ it('returns same cleanup when prewarm called twice without cleanup (already pre-warmed guard)', async () => {
+ const mockUnsubscribe = jest.fn();
+ mockSubscribeToPrices.mockReturnValue(mockUnsubscribe);
+
+ // Create controlled promise
+ let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn();
+ const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => {
+ resolveGetMarkets = resolve;
+ });
+
+ mockEngine.context.PerpsController.getMarkets = jest
+ .fn()
+ .mockReturnValue(getMarketsPromise);
+
+ // Cycle 1: User enters Perps
+ const cleanup1 = await testStreamManager.prices.prewarm();
+
+ // Cycle 2: Called again without cleanup (rapid navigation)
+ // The guard at line 362 returns the existing prewarmUnsubscribe
+ const cleanup2 = await testStreamManager.prices.prewarm();
+
+ // Both cleanups should be the same function (guarded by "already pre-warmed" check)
+ expect(cleanup1).toBe(cleanup2);
+
+ // getMarkets only called once (second call was short-circuited)
+ expect(
+ mockEngine.context.PerpsController.getMarkets,
+ ).toHaveBeenCalledTimes(1);
+
+ // Resolve and cleanup
+ resolveGetMarkets([{ name: 'BTC-PERP' }]);
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(mockSubscribeToPrices).toHaveBeenCalledTimes(1);
+ cleanup1();
+ expect(mockUnsubscribe).toHaveBeenCalled();
+ });
+
+ it('prevents stale promise from creating subscription after cleanup and new prewarm', async () => {
+ const mockUnsubscribe1 = jest.fn();
+ const mockUnsubscribe2 = jest.fn();
+ let subscribeCallCount = 0;
+
+ mockSubscribeToPrices.mockImplementation(() => {
+ subscribeCallCount++;
+ return subscribeCallCount === 1 ? mockUnsubscribe1 : mockUnsubscribe2;
+ });
+
+ // Create controlled promises for each cycle
+ let resolveGetMarkets1: (value: { name: string }[]) => void = jest.fn();
+ let resolveGetMarkets2: (value: { name: string }[]) => void = jest.fn();
+
+ const getMarketsPromise1 = new Promise<{ name: string }[]>((resolve) => {
+ resolveGetMarkets1 = resolve;
+ });
+ const getMarketsPromise2 = new Promise<{ name: string }[]>((resolve) => {
+ resolveGetMarkets2 = resolve;
+ });
+
+ let getMarketsCallCount = 0;
+ (mockEngine.context.PerpsController.getMarkets as jest.Mock) = jest.fn(
+ () => {
+ getMarketsCallCount++;
+ return getMarketsCallCount === 1
+ ? getMarketsPromise1
+ : getMarketsPromise2;
+ },
+ );
+
+ // Cycle 1: User enters Perps
+ const cleanup1 = await testStreamManager.prices.prewarm();
+
+ // User leaves before markets load - this resets prewarmUnsubscribe to undefined
+ cleanup1();
+
+ // Cycle 2: User enters Perps again
+ const cleanup2 = await testStreamManager.prices.prewarm();
+
+ // Now cycle 1's promise resolves (STALE - should be ignored due to cycle ID mismatch)
+ resolveGetMarkets1([{ name: 'BTC-PERP' }]);
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ // Stale promise should NOT create subscription (cycle ID mismatch)
+ expect(mockSubscribeToPrices).not.toHaveBeenCalled();
+
+ // Cycle 2's promise resolves (active)
+ resolveGetMarkets2([{ name: 'ETH-PERP' }]);
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ // Only one subscription should be created (from cycle 2)
+ expect(mockSubscribeToPrices).toHaveBeenCalledTimes(1);
+ expect(mockSubscribeToPrices).toHaveBeenCalledWith({
+ symbols: ['ETH-PERP'],
+ callback: expect.any(Function),
+ });
+
+ // Cleanup cycle 2 - since this is the first subscription call, it uses mockUnsubscribe1
+ cleanup2();
+ expect(mockUnsubscribe1).toHaveBeenCalled();
+ // mockUnsubscribe2 was never created because only one subscription was made
+ });
+ });
+
it('should throttle subsequent updates', async () => {
const onUpdate = jest.fn();
let priceCallback: (data: PriceUpdate[]) => void = jest.fn();
diff --git a/app/components/UI/Perps/providers/PerpsStreamManager.tsx b/app/components/UI/Perps/providers/PerpsStreamManager.tsx
index 0cc19de3876..eaec6098be5 100644
--- a/app/components/UI/Perps/providers/PerpsStreamManager.tsx
+++ b/app/components/UI/Perps/providers/PerpsStreamManager.tsx
@@ -236,7 +236,10 @@ abstract class StreamChannel {
class PriceStreamChannel extends StreamChannel> {
private symbols = new Set();
private prewarmUnsubscribe?: () => void;
+ private actualPriceUnsubscribe?: () => void;
private allMarketSymbols: string[] = [];
+ // Unique ID per prewarm cycle to detect stale promises and prevent subscription leaks
+ private prewarmCycleId: number = 0;
// Override cache to store individual PriceUpdate objects
protected priceCache = new Map();
@@ -355,6 +358,7 @@ class PriceStreamChannel extends StreamChannel> {
/**
* Pre-warm the channel by subscribing to all market prices
* This keeps a single WebSocket connection alive with all price updates
+ * Non-blocking: Returns immediately while market fetch happens in background
* @returns Cleanup function to call when leaving Perps environment
*/
public async prewarm(): Promise<() => void> {
@@ -364,57 +368,104 @@ class PriceStreamChannel extends StreamChannel> {
}
try {
- // Get all available market symbols
const controller = Engine.context.PerpsController;
- const markets = await controller.getMarkets();
- this.allMarketSymbols = markets.map((market) => market.name);
- DevLogger.log('PriceStreamChannel: Pre-warming with all market symbols', {
- symbolCount: this.allMarketSymbols.length,
- symbols: this.allMarketSymbols.slice(0, 10), // Log first 10 for debugging
- });
+ // Increment cycle ID to detect stale promises from previous prewarm cycles
+ // This prevents subscription leaks when user navigates: Perps → away → back quickly
+ this.prewarmCycleId++;
+ const currentCycleId = this.prewarmCycleId;
+
+ // Start market fetch in background (non-blocking)
+ // We need the symbols to register subscribers, but we can return immediately
+ const marketsPromise = controller.getMarkets();
+
+ // Set up subscription once markets arrive (fire-and-forget)
+ marketsPromise
+ .then((markets) => {
+ // If this promise is from a stale cycle, don't set up subscription
+ // This prevents leaks when prewarm is called multiple times rapidly
+ if (currentCycleId !== this.prewarmCycleId) {
+ DevLogger.log('PriceStreamChannel: Skipping stale prewarm cycle', {
+ currentCycleId,
+ activeCycleId: this.prewarmCycleId,
+ });
+ return;
+ }
+
+ // If already cleaned up, don't set up subscription
+ if (this.prewarmUnsubscribe === undefined) {
+ return;
+ }
+
+ this.allMarketSymbols = markets.map((market) => market.name);
+
+ DevLogger.log(
+ 'PriceStreamChannel: Pre-warming with all market symbols',
+ {
+ symbolCount: this.allMarketSymbols.length,
+ symbols: this.allMarketSymbols.slice(0, 10),
+ },
+ );
- // Subscribe to all market prices
- this.prewarmUnsubscribe = controller.subscribeToPrices({
- symbols: this.allMarketSymbols,
- callback: (updates: PriceUpdate[]) => {
- // Update cache and build price map
- const priceMap: Record = {};
- updates.forEach((update) => {
- const priceUpdate: PriceUpdate = {
- symbol: update.symbol,
- price: update.price,
- timestamp: Date.now(),
- percentChange24h: update.percentChange24h,
- bestBid: update.bestBid,
- bestAsk: update.bestAsk,
- spread: update.spread,
- markPrice: update.markPrice,
- funding: update.funding,
- openInterest: update.openInterest,
- volume24h: update.volume24h,
- };
- this.priceCache.set(update.symbol, priceUpdate);
- priceMap[update.symbol] = priceUpdate;
+ // Subscribe to all market prices
+ const unsub = controller.subscribeToPrices({
+ symbols: this.allMarketSymbols,
+ callback: (updates: PriceUpdate[]) => {
+ const priceMap: Record = {};
+ updates.forEach((update) => {
+ const priceUpdate: PriceUpdate = {
+ symbol: update.symbol,
+ price: update.price,
+ timestamp: Date.now(),
+ percentChange24h: update.percentChange24h,
+ bestBid: update.bestBid,
+ bestAsk: update.bestAsk,
+ spread: update.spread,
+ markPrice: update.markPrice,
+ funding: update.funding,
+ openInterest: update.openInterest,
+ volume24h: update.volume24h,
+ };
+ this.priceCache.set(update.symbol, priceUpdate);
+ priceMap[update.symbol] = priceUpdate;
+ });
+
+ if (this.subscribers.size > 0) {
+ this.notifySubscribers(priceMap);
+ }
+ },
});
- // Notify any active subscribers with all updates
+ // Store the actual unsubscribe function
+ this.actualPriceUnsubscribe = unsub;
+ })
+ .catch((error) => {
+ Logger.error(
+ ensureError(error, 'PriceStreamChannel.prewarm.backgroundFetch'),
+ {
+ context: 'PriceStreamChannel.prewarm.backgroundFetch',
+ },
+ );
+ // Reset state so subsequent prewarm/connect calls can recover
+ this.prewarmUnsubscribe = undefined;
+ this.allMarketSymbols = [];
+ // Reconnect waiting subscribers that were skipped because prewarm was pending
if (this.subscribers.size > 0) {
- this.notifySubscribers(priceMap);
+ this.connect();
}
- },
- });
+ });
- // Return a cleanup function that properly clears internal state
- return () => {
+ // Return cleanup function immediately (before markets load)
+ this.prewarmUnsubscribe = () => {
DevLogger.log('PriceStreamChannel: Cleaning up prewarm subscription');
this.cleanupPrewarm();
};
+
+ return this.prewarmUnsubscribe;
} catch (error) {
Logger.error(ensureError(error, 'PriceStreamChannel.prewarm'), {
context: 'PriceStreamChannel.prewarm',
});
- // Return no-op cleanup function
return () => {
// No-op
};
@@ -425,11 +476,12 @@ class PriceStreamChannel extends StreamChannel> {
* Cleanup pre-warm subscription
*/
public cleanupPrewarm(): void {
- if (this.prewarmUnsubscribe) {
- this.prewarmUnsubscribe();
- this.prewarmUnsubscribe = undefined;
- this.allMarketSymbols = [];
+ if (this.actualPriceUnsubscribe) {
+ this.actualPriceUnsubscribe();
+ this.actualPriceUnsubscribe = undefined;
}
+ this.prewarmUnsubscribe = undefined;
+ this.allMarketSymbols = [];
}
}
@@ -1285,7 +1337,7 @@ class MarketDataChannel extends StreamChannel {
public prewarm(): () => void {
// Fetch data immediately to populate cache
this.fetchMarketData().catch((error) => {
- Logger.error(error instanceof Error ? error : new Error(String(error)), {
+ Logger.error(ensureError(error, 'MarketDataChannel.prewarm'), {
context: 'MarketDataChannel.prewarm',
});
});
diff --git a/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts b/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts
index 22f2a990f26..fe3e6984b6f 100644
--- a/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts
+++ b/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts
@@ -1,31 +1,5 @@
import { PredictController } from './PredictController';
-// Mock Engine AccountsController for selected address
-jest.mock('../../../../core/Engine', () => ({
- context: {
- AccountsController: {
- getSelectedAccount: jest.fn(() => ({ address: '0xselected' })),
- },
- AccountTreeController: {
- getAccountsFromSelectedAccountGroup: jest.fn().mockReturnValue([
- {
- address: '0xselected',
- id: 'mock-account-id',
- type: 'eip155:eoa',
- options: {},
- metadata: {
- name: 'Test Account',
- importTime: Date.now(),
- keyring: { type: 'HD Key Tree' },
- },
- scopes: ['eip155:1'],
- methods: ['eth_sendTransaction'],
- },
- ]),
- },
- },
-}));
-
interface ActivityEntry {
id: string;
providerId: string;
@@ -55,6 +29,11 @@ describe('PredictController.getActivity', () => {
) as MockPredictController;
controller.providers = new Map(Object.entries(providers));
controller.update = jest.fn();
+ (
+ controller as unknown as { getSigner: () => { address: string } }
+ ).getSigner = jest.fn(() => ({
+ address: '0xselected',
+ }));
return controller;
};
diff --git a/app/components/UI/Predict/controllers/PredictController.test.ts b/app/components/UI/Predict/controllers/PredictController.test.ts
index 1a4a9300876..03a41eb07c6 100644
--- a/app/components/UI/Predict/controllers/PredictController.test.ts
+++ b/app/components/UI/Predict/controllers/PredictController.test.ts
@@ -7,14 +7,10 @@ import {
type MessengerEvents,
type MockAnyNamespace,
} from '@metamask/messenger';
-import {
- GasFeeEstimateLevel,
- GasFeeEstimateType,
-} from '@metamask/transaction-controller';
+
import type { NetworkState } from '@metamask/network-controller';
import type { InternalAccount } from '@metamask/keyring-internal-api';
-import Engine from '../../../../core/Engine';
import DevLogger from '../../../../core/SDKConnect/utils/DevLogger';
import {
addTransaction,
@@ -46,71 +42,33 @@ jest.mock('../../../../util/transaction-controller', () => ({
addTransactionBatch: jest.fn(),
}));
-// Mock Engine
-jest.mock('../../../../core/Engine', () => ({
- context: {
- KeyringController: {
- signTypedMessage: jest.fn(),
+// Default mock values for messenger actions
+const DEFAULT_REMOTE_FEATURE_FLAG_STATE = {
+ remoteFeatureFlags: {
+ predictFeeCollection: {
+ enabled: true,
+ collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7',
+ metamaskFee: 0.02,
+ providerFee: 0.02,
+ waiveList: [],
},
- AccountsController: {
- getSelectedAccount: jest.fn().mockReturnValue({
- id: 'mock-account-id',
- address: '0x1234567890123456789012345678901234567890',
- metadata: { name: 'Test Account' },
- }),
+ predictLiveSports: {
+ enabled: false,
+ leagues: [],
},
- AccountTreeController: {
- getAccountsFromSelectedAccountGroup: jest.fn(() => [
- {
- id: 'mock-account-id',
- address: '0x1234567890123456789012345678901234567890',
- type: 'eip155:eoa',
- name: 'Test Account',
- metadata: {
- lastSelected: 0,
- },
- },
- ]),
- },
- NetworkController: {
- getState: jest.fn().mockReturnValue({
- selectedNetworkClientId: 'mainnet',
- }),
- findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'),
- getNetworkClientById: jest.fn().mockReturnValue({
- blockTracker: {
- checkForLatestBlock: jest.fn().mockResolvedValue(undefined),
- },
- }),
- },
- TransactionController: {
- estimateGas: jest.fn(),
- estimateGasFee: jest.fn(),
- },
- AccountTrackerController: {
- state: {
- accountsByChainId: {},
- },
- },
- RemoteFeatureFlagController: {
- state: {
- remoteFeatureFlags: {
- predictFeeCollection: {
- enabled: true,
- collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7',
- metamaskFee: 0.02,
- providerFee: 0.02,
- waiveList: [],
- },
- predictLiveSports: {
- enabled: false,
- leagues: [],
- },
- },
- },
+ predictMarketHighlights: {
+ enabled: false,
+ highlights: [],
},
},
-}));
+ cacheTimestamp: Date.now(),
+};
+
+const DEFAULT_NETWORK_CLIENT = {
+ blockTracker: {
+ checkForLatestBlock: jest.fn().mockResolvedValue(undefined),
+ },
+};
// Mock DevLogger (default export)
jest.mock('../../../../core/SDKConnect/utils/DevLogger', () => ({
@@ -266,6 +224,11 @@ describe('PredictController', () => {
mocks?: {
getSelectedAccount?: jest.MockedFunction<() => InternalAccount>;
getNetworkState?: jest.MockedFunction<() => NetworkState>;
+ getRemoteFeatureFlagState?: jest.MockedFunction<() => any>;
+ findNetworkClientIdByChainId?: jest.MockedFunction<
+ (chainId: string) => string
+ >;
+ getNetworkClientById?: jest.MockedFunction<(clientId: string) => any>;
};
} = {},
): ReturnValue {
@@ -284,6 +247,19 @@ describe('PredictController', () => {
}),
);
+ rootMessenger.registerActionHandler(
+ 'AccountTreeController:getAccountsFromSelectedAccountGroup',
+ jest.fn().mockReturnValue([
+ {
+ id: 'mock-account-id',
+ address: '0x1234567890123456789012345678901234567890',
+ type: 'eip155:eoa',
+ name: 'Test Account',
+ metadata: { lastSelected: 0 },
+ },
+ ]),
+ );
+
rootMessenger.registerActionHandler(
'NetworkController:getState',
mocks.getNetworkState ??
@@ -299,6 +275,34 @@ describe('PredictController', () => {
}),
);
+ rootMessenger.registerActionHandler(
+ 'KeyringController:signTypedMessage',
+ jest.fn().mockResolvedValue('0xmocksignature'),
+ );
+
+ rootMessenger.registerActionHandler(
+ 'KeyringController:signPersonalMessage',
+ jest.fn().mockResolvedValue('0xmocksignature'),
+ );
+
+ rootMessenger.registerActionHandler(
+ 'NetworkController:findNetworkClientIdByChainId',
+ mocks.findNetworkClientIdByChainId ??
+ jest.fn().mockReturnValue('polygon-mainnet'),
+ );
+
+ rootMessenger.registerActionHandler(
+ 'NetworkController:getNetworkClientById',
+ mocks.getNetworkClientById ??
+ jest.fn().mockReturnValue(DEFAULT_NETWORK_CLIENT),
+ );
+
+ rootMessenger.registerActionHandler(
+ 'RemoteFeatureFlagController:getState',
+ mocks.getRemoteFeatureFlagState ??
+ jest.fn().mockReturnValue(DEFAULT_REMOTE_FEATURE_FLAG_STATE),
+ );
+
const messenger = new Messenger<
'PredictController',
AllPredictControllerMessengerActions,
@@ -312,14 +316,21 @@ describe('PredictController', () => {
rootMessenger.delegate({
actions: [
'AccountsController:getSelectedAccount',
+ 'AccountTreeController:getAccountsFromSelectedAccountGroup',
'NetworkController:getState',
+ 'NetworkController:findNetworkClientIdByChainId',
+ 'NetworkController:getNetworkClientById',
'TransactionController:estimateGas',
+ 'KeyringController:signTypedMessage',
+ 'KeyringController:signPersonalMessage',
+ 'RemoteFeatureFlagController:getState',
],
events: [
'TransactionController:transactionSubmitted',
'TransactionController:transactionConfirmed',
'TransactionController:transactionFailed',
'TransactionController:transactionRejected',
+ 'RemoteFeatureFlagController:stateChange',
],
messenger,
});
@@ -1794,32 +1805,28 @@ describe('PredictController', () => {
outcomes: ['YES', 'NO'],
});
- const setMarketHighlightsFlag = (flag: {
+ const createFlagState = (flag: {
enabled: boolean;
highlights: { category: string; markets: string[] }[];
- }) => {
- (
- Engine.context.RemoteFeatureFlagController as any
- ).state.remoteFeatureFlags.predictMarketHighlights = flag;
- };
-
- const clearMarketHighlightsFlag = () => {
- delete (Engine.context.RemoteFeatureFlagController as any).state
- .remoteFeatureFlags.predictMarketHighlights;
- };
-
- afterEach(() => {
- clearMarketHighlightsFlag();
+ }) => ({
+ remoteFeatureFlags: {
+ predictFeeCollection: {
+ enabled: true,
+ collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7',
+ metamaskFee: 0.02,
+ providerFee: 0.02,
+ waiveList: [],
+ },
+ predictLiveSports: {
+ enabled: false,
+ leagues: [],
+ },
+ predictMarketHighlights: flag,
+ },
+ cacheTimestamp: Date.now(),
});
it('prepends highlighted markets when flag is enabled and offset is 0', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [
- { category: 'trending', markets: ['highlight-1', 'highlight-2'] },
- ],
- });
-
const regularMarkets = [
createMockMarket('regular-1'),
createMockMarket('regular-2'),
@@ -1829,134 +1836,180 @@ describe('PredictController', () => {
createMockMarket('highlight-2'),
];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
- highlightedMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
+ highlightedMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
-
- expect(result).toHaveLength(4);
- expect(result[0].id).toBe('highlight-1');
- expect(result[1].id).toBe('highlight-2');
- expect(result[2].id).toBe('regular-1');
- expect(result[3].id).toBe('regular-2');
- expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalledWith(
- ['highlight-1', 'highlight-2'],
- [],
- );
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
+
+ expect(result).toHaveLength(4);
+ expect(result[0].id).toBe('highlight-1');
+ expect(result[1].id).toBe('highlight-2');
+ expect(result[2].id).toBe('regular-1');
+ expect(result[3].id).toBe('regular-2');
+ expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalledWith(
+ ['highlight-1', 'highlight-2'],
+ [],
+ );
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ {
+ category: 'trending',
+ markets: ['highlight-1', 'highlight-2'],
+ },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('skips highlights when offset is greater than 0', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 10,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 10,
+ });
- expect(result).toHaveLength(1);
- expect(result[0].id).toBe('regular-1');
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(result[0].id).toBe('regular-1');
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ { category: 'trending', markets: ['highlight-1'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('skips highlights when flag is disabled', async () => {
- setMarketHighlightsFlag({
- enabled: false,
- highlights: [{ category: 'trending', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(result[0].id).toBe('regular-1');
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(result[0].id).toBe('regular-1');
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: false,
+ highlights: [
+ { category: 'trending', markets: ['highlight-1'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('skips highlights when category is not provided', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ });
- expect(result).toHaveLength(1);
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ { category: 'trending', markets: ['highlight-1'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('skips highlights when category has no configured highlights', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'crypto', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [{ category: 'crypto', markets: ['highlight-1'] }],
+ }),
+ ),
+ },
+ },
+ );
});
it('filters duplicates from regular results when highlighted market appears in both', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: ['duplicate-market'] }],
- });
-
const regularMarkets = [
createMockMarket('regular-1'),
createMockMarket('duplicate-market'),
@@ -1964,135 +2017,183 @@ describe('PredictController', () => {
];
const highlightedMarkets = [createMockMarket('duplicate-market')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
- highlightedMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
+ highlightedMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(3);
- expect(result[0].id).toBe('duplicate-market');
- expect(result[1].id).toBe('regular-1');
- expect(result[2].id).toBe('regular-2');
- });
+ expect(result).toHaveLength(3);
+ expect(result[0].id).toBe('duplicate-market');
+ expect(result[1].id).toBe('regular-1');
+ expect(result[2].id).toBe('regular-2');
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ { category: 'trending', markets: ['duplicate-market'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('handles empty highlights array gracefully', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: [] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(result[0].id).toBe('regular-1');
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(result[0].id).toBe('regular-1');
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [{ category: 'trending', markets: [] }],
+ }),
+ ),
+ },
+ },
+ );
});
it('handles getMarketsByIds failure gracefully', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue([]);
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue([]);
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(result[0].id).toBe('regular-1');
- });
+ expect(result).toHaveLength(1);
+ expect(result[0].id).toBe('regular-1');
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ { category: 'trending', markets: ['highlight-1'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('uses default flag when predictMarketHighlights is not in remote config', async () => {
- clearMarketHighlightsFlag();
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue({
+ remoteFeatureFlags: {
+ predictFeeCollection:
+ DEFAULT_REMOTE_FEATURE_FLAG_STATE.remoteFeatureFlags
+ .predictFeeCollection,
+ predictLiveSports:
+ DEFAULT_REMOTE_FEATURE_FLAG_STATE.remoteFeatureFlags
+ .predictLiveSports,
+ // predictMarketHighlights intentionally omitted to test fallback
+ },
+ cacheTimestamp: Date.now(),
+ }),
+ },
+ },
+ );
});
it('fetches highlights for first page when offset is undefined', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [{ category: 'trending', markets: ['highlight-1'] }],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
const highlightedMarkets = [createMockMarket('highlight-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
- highlightedMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
+ highlightedMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ });
- expect(result).toHaveLength(2);
- expect(result[0].id).toBe('highlight-1');
- expect(result[1].id).toBe('regular-1');
- expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalled();
- });
+ expect(result).toHaveLength(2);
+ expect(result[0].id).toBe('highlight-1');
+ expect(result[1].id).toBe('regular-1');
+ expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ { category: 'trending', markets: ['highlight-1'] },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('preserves order of highlighted markets from config', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [
- { category: 'trending', markets: ['third', 'first', 'second'] },
- ],
- });
-
const regularMarkets = [createMockMarket('regular-1')];
const highlightedMarkets = [
createMockMarket('third'),
@@ -2100,85 +2201,131 @@ describe('PredictController', () => {
createMockMarket('second'),
];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
- highlightedMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
+ highlightedMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result[0].id).toBe('third');
- expect(result[1].id).toBe('first');
- expect(result[2].id).toBe('second');
- });
+ expect(result[0].id).toBe('third');
+ expect(result[1].id).toBe('first');
+ expect(result[2].id).toBe('second');
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ {
+ category: 'trending',
+ markets: ['third', 'first', 'second'],
+ },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
it('handles missing highlights array in flag gracefully', async () => {
- (
- Engine.context.RemoteFeatureFlagController as any
- ).state.remoteFeatureFlags.predictMarketHighlights = { enabled: true };
-
const regularMarkets = [createMockMarket('regular-1')];
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarkets as any,
- );
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarkets as any,
+ );
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
- expect(result).toHaveLength(1);
- expect(result[0].id).toBe('regular-1');
- expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
- });
+ expect(result).toHaveLength(1);
+ expect(result[0].id).toBe('regular-1');
+ expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue({
+ remoteFeatureFlags: {
+ predictFeeCollection: {
+ enabled: true,
+ collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7',
+ metamaskFee: 0.02,
+ providerFee: 0.02,
+ waiveList: [],
+ },
+ predictLiveSports: {
+ enabled: false,
+ leagues: [],
+ },
+ predictMarketHighlights: { enabled: true },
+ },
+ cacheTimestamp: Date.now(),
+ }),
+ },
+ },
+ );
});
it('keeps market in regular results when highlight fetch fails for that market', async () => {
- setMarketHighlightsFlag({
- enabled: true,
- highlights: [
- { category: 'trending', markets: ['highlight-1', 'highlight-2'] },
- ],
- });
-
const regularMarketsIncludingFailedHighlight = [
createMockMarket('highlight-1'),
createMockMarket('regular-1'),
];
const onlySuccessfullyFetchedHighlights = [
createMockMarket('highlight-2'),
- ];
-
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getMarkets.mockResolvedValue(
- regularMarketsIncludingFailedHighlight as any,
- );
- mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
- onlySuccessfullyFetchedHighlights as any,
- );
+ ];
- const result = await controller.getMarkets({
- providerId: 'polymarket',
- category: 'trending',
- offset: 0,
- });
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getMarkets.mockResolvedValue(
+ regularMarketsIncludingFailedHighlight as any,
+ );
+ mockPolymarketProvider.getMarketsByIds.mockResolvedValue(
+ onlySuccessfullyFetchedHighlights as any,
+ );
- expect(result).toHaveLength(3);
- expect(result[0].id).toBe('highlight-2');
- expect(result[1].id).toBe('highlight-1');
- expect(result[2].id).toBe('regular-1');
- });
+ const result = await controller.getMarkets({
+ providerId: 'polymarket',
+ category: 'trending',
+ offset: 0,
+ });
+
+ expect(result).toHaveLength(3);
+ expect(result[0].id).toBe('highlight-2');
+ expect(result[1].id).toBe('highlight-1');
+ expect(result[2].id).toBe('regular-1');
+ },
+ {
+ mocks: {
+ getRemoteFeatureFlagState: jest.fn().mockReturnValue(
+ createFlagState({
+ enabled: true,
+ highlights: [
+ {
+ category: 'trending',
+ markets: ['highlight-1', 'highlight-2'],
+ },
+ ],
+ }),
+ ),
+ },
+ },
+ );
});
});
@@ -2546,64 +2693,71 @@ describe('PredictController', () => {
it('throws error when network client not found', async () => {
// Arrange
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
- {
- marketId: 'test-market',
- providerId: 'polymarket',
- outcomeId: 'test-outcome',
- balance: '100',
- },
- ]);
- const MockedEngine = jest.requireMock('../../../../core/Engine');
- MockedEngine.context.NetworkController.findNetworkClientIdByChainId =
- jest.fn().mockReturnValue(undefined);
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
+ {
+ marketId: 'test-market',
+ providerId: 'polymarket',
+ outcomeId: 'test-outcome',
+ balance: '100',
+ },
+ ]);
- mockPolymarketProvider.prepareClaim = jest
- .fn()
- .mockResolvedValue(mockClaim);
- await controller.getPositions({ claimable: true });
+ mockPolymarketProvider.prepareClaim = jest
+ .fn()
+ .mockResolvedValue(mockClaim);
+ await controller.getPositions({ claimable: true });
- // Act & Assert
- await expect(
- controller.claimWithConfirmation({
- providerId: 'polymarket',
- }),
- ).rejects.toThrow('Network client not found for chain ID');
- });
+ // Act & Assert
+ await expect(
+ controller.claimWithConfirmation({
+ providerId: 'polymarket',
+ }),
+ ).rejects.toThrow('Network client not found for chain ID');
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest.fn().mockReturnValue(undefined),
+ },
+ },
+ );
});
it('throws error when transaction batch returns no batchId', async () => {
// Arrange
- await withController(async ({ controller }) => {
- mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
- {
- marketId: 'test-market',
- providerId: 'polymarket',
- outcomeId: 'test-outcome',
- balance: '100',
- },
- ]);
- mockPolymarketProvider.prepareClaim = jest
- .fn()
- .mockResolvedValue(mockClaim);
-
- const MockedEngine = jest.requireMock('../../../../core/Engine');
- MockedEngine.context.NetworkController.findNetworkClientIdByChainId =
- jest.fn().mockReturnValue('mainnet');
+ await withController(
+ async ({ controller }) => {
+ mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
+ {
+ marketId: 'test-market',
+ providerId: 'polymarket',
+ outcomeId: 'test-outcome',
+ balance: '100',
+ },
+ ]);
+ mockPolymarketProvider.prepareClaim = jest
+ .fn()
+ .mockResolvedValue(mockClaim);
- (addTransactionBatch as jest.Mock).mockResolvedValue({});
- await controller.getPositions({ claimable: true });
+ (addTransactionBatch as jest.Mock).mockResolvedValue({});
+ await controller.getPositions({ claimable: true });
- // Act & Assert
- await expect(
- controller.claimWithConfirmation({
- providerId: 'polymarket',
- }),
- ).rejects.toThrow(
- 'Failed to get batch ID from claim transaction submission',
- );
- });
+ // Act & Assert
+ await expect(
+ controller.claimWithConfirmation({
+ providerId: 'polymarket',
+ }),
+ ).rejects.toThrow(
+ 'Failed to get batch ID from claim transaction submission',
+ );
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'),
+ },
+ },
+ );
});
it('throws error when prepareClaim returns no transactions', async () => {
@@ -2667,42 +2821,45 @@ describe('PredictController', () => {
it('clears error state on successful claim', async () => {
// Arrange
const mockBatchId = 'claim-batch-1';
- await withController(async ({ controller }) => {
- controller.updateStateForTesting((state) => {
- state.lastError = 'Previous error';
- });
-
- mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
- {
- marketId: 'test-market',
- providerId: 'polymarket',
- outcomeId: 'test-outcome',
- balance: '100',
- },
- ]);
- mockPolymarketProvider.prepareClaim = jest
- .fn()
- .mockResolvedValue(mockClaim);
+ await withController(
+ async ({ controller }) => {
+ controller.updateStateForTesting((state) => {
+ state.lastError = 'Previous error';
+ });
- const MockedEngine = jest.requireMock('../../../../core/Engine');
- MockedEngine.context.NetworkController.findNetworkClientIdByChainId =
- jest.fn().mockReturnValue('mainnet');
+ mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([
+ {
+ marketId: 'test-market',
+ providerId: 'polymarket',
+ outcomeId: 'test-outcome',
+ balance: '100',
+ },
+ ]);
+ mockPolymarketProvider.prepareClaim = jest
+ .fn()
+ .mockResolvedValue(mockClaim);
- (addTransactionBatch as jest.Mock).mockResolvedValue({
- batchId: mockBatchId,
- });
- await controller.getPositions({ claimable: true });
+ (addTransactionBatch as jest.Mock).mockResolvedValue({
+ batchId: mockBatchId,
+ });
+ await controller.getPositions({ claimable: true });
- // Act
- const result = await controller.claimWithConfirmation({
- providerId: 'polymarket',
- });
+ // Act
+ const result = await controller.claimWithConfirmation({
+ providerId: 'polymarket',
+ });
- // Assert
- expect(result.batchId).toBe(mockBatchId);
- expect(controller.state.lastError).toBeNull();
- expect(controller.state.lastUpdateTimestamp).toBeGreaterThan(0);
- });
+ // Assert
+ expect(result.batchId).toBe(mockBatchId);
+ expect(controller.state.lastError).toBeNull();
+ expect(controller.state.lastUpdateTimestamp).toBeGreaterThan(0);
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'),
+ },
+ },
+ );
});
});
@@ -3009,37 +3166,6 @@ describe('PredictController', () => {
batchId: mockBatchId,
});
- Engine.context.AccountTrackerController.state.accountsByChainId = {
- [mockChainId]: {
- '0x1234567890123456789012345678901234567890': {
- balance: '0x0',
- },
- },
- };
-
- Engine.context.TransactionController.estimateGas = jest
- .fn()
- .mockResolvedValue({ gas: '0x5208' });
- Engine.context.TransactionController.estimateGasFee = jest
- .fn()
- .mockResolvedValue({
- estimates: {
- type: GasFeeEstimateType.FeeMarket,
- [GasFeeEstimateLevel.Low]: {
- maxFeePerGas: '0x3b9aca00',
- maxPriorityFeePerGas: '0x1',
- },
- [GasFeeEstimateLevel.Medium]: {
- maxFeePerGas: '0x3b9aca00',
- maxPriorityFeePerGas: '0x1',
- },
- [GasFeeEstimateLevel.High]: {
- maxFeePerGas: '0x3b9aca00',
- maxPriorityFeePerGas: '0x1',
- },
- },
- });
-
await withController(async ({ controller }) => {
// When calling depositWithConfirmation
const result = await controller.depositWithConfirmation({
@@ -3064,7 +3190,7 @@ describe('PredictController', () => {
expect(addTransactionBatch).toHaveBeenCalledWith({
from: '0x1234567890123456789012345678901234567890',
origin: 'metamask',
- networkClientId: 'mainnet',
+ networkClientId: 'polygon-mainnet',
disableHook: true,
disableSequential: true,
disableUpgrade: true,
@@ -3175,6 +3301,9 @@ describe('PredictController', () => {
getNetworkState: jest.fn().mockReturnValue({
selectedNetworkClientId: mockNetworkClientId,
}),
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue(mockNetworkClientId),
},
},
);
@@ -4826,108 +4955,156 @@ describe('PredictController', () => {
});
it('calls NetworkController.findNetworkClientIdByChainId with hex chain ID', async () => {
- await withController(async ({ controller }) => {
- // Arrange
- const chainId = 137;
+ const mockFindNetworkClientIdByChainId = jest
+ .fn()
+ .mockReturnValue('polygon-mainnet');
+ await withController(
+ async ({ controller }) => {
+ // Arrange
+ const chainId = 137;
- // Act
- // eslint-disable-next-line dot-notation
- await controller['invalidateQueryCache'](chainId);
+ // Act
+ // eslint-disable-next-line dot-notation
+ await controller['invalidateQueryCache'](chainId);
- // Assert
- expect(
- Engine.context.NetworkController.findNetworkClientIdByChainId,
- ).toHaveBeenCalledWith('0x89');
- });
+ // Assert
+ expect(mockFindNetworkClientIdByChainId).toHaveBeenCalledWith('0x89');
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: mockFindNetworkClientIdByChainId,
+ },
+ },
+ );
});
it('calls NetworkController.getNetworkClientById with network client ID', async () => {
- await withController(async ({ controller }) => {
- // Arrange
- const chainId = 137;
+ const mockGetNetworkClientById = jest
+ .fn()
+ .mockReturnValue(DEFAULT_NETWORK_CLIENT);
+ await withController(
+ async ({ controller }) => {
+ // Arrange
+ const chainId = 137;
- // Act
- // eslint-disable-next-line dot-notation
- await controller['invalidateQueryCache'](chainId);
+ // Act
+ // eslint-disable-next-line dot-notation
+ await controller['invalidateQueryCache'](chainId);
- // Assert
- expect(
- Engine.context.NetworkController.getNetworkClientById,
- ).toHaveBeenCalledWith('mainnet');
- });
+ // Assert
+ expect(mockGetNetworkClientById).toHaveBeenCalledWith(
+ 'polygon-mainnet',
+ );
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue('polygon-mainnet'),
+ getNetworkClientById: mockGetNetworkClientById,
+ },
+ },
+ );
});
it('calls blockTracker.checkForLatestBlock to invalidate cache', async () => {
const mockCheckForLatestBlock = jest.fn().mockResolvedValue(undefined);
- // @ts-expect-error - Mocking Engine for test
- Engine.context.NetworkController.getNetworkClientById.mockReturnValue({
+ const mockGetNetworkClientById = jest.fn().mockReturnValue({
blockTracker: {
checkForLatestBlock: mockCheckForLatestBlock,
},
});
- await withController(async ({ controller }) => {
- // Arrange
- const chainId = 137;
+ await withController(
+ async ({ controller }) => {
+ // Arrange
+ const chainId = 137;
- // Act
- // eslint-disable-next-line dot-notation
- await controller['invalidateQueryCache'](chainId);
+ // Act
+ // eslint-disable-next-line dot-notation
+ await controller['invalidateQueryCache'](chainId);
- // Assert
- expect(mockCheckForLatestBlock).toHaveBeenCalledWith();
- });
+ // Assert
+ expect(mockCheckForLatestBlock).toHaveBeenCalledWith();
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue('polygon-mainnet'),
+ getNetworkClientById: mockGetNetworkClientById,
+ },
+ },
+ );
});
it('logs error when blockTracker.checkForLatestBlock fails', async () => {
const mockError = new Error('Block tracker error');
const mockCheckForLatestBlock = jest.fn().mockRejectedValue(mockError);
- // @ts-expect-error - Mocking Engine for test
- Engine.context.NetworkController.getNetworkClientById.mockReturnValue({
+ const mockGetNetworkClientById = jest.fn().mockReturnValue({
blockTracker: {
checkForLatestBlock: mockCheckForLatestBlock,
},
});
- await withController(async ({ controller }) => {
- // Arrange
- const chainId = 137;
+ await withController(
+ async ({ controller }) => {
+ // Arrange
+ const chainId = 137;
- // Act
- // eslint-disable-next-line dot-notation
- await controller['invalidateQueryCache'](chainId);
+ // Act
+ // eslint-disable-next-line dot-notation
+ await controller['invalidateQueryCache'](chainId);
- // Assert
- expect(DevLogger.log).toHaveBeenCalledWith(
- 'PredictController: Error invalidating query cache',
- expect.objectContaining({
- error: 'Block tracker error',
- timestamp: expect.any(String),
- }),
- );
- });
+ // Assert
+ expect(DevLogger.log).toHaveBeenCalledWith(
+ 'PredictController: Error invalidating query cache',
+ expect.objectContaining({
+ error: 'Block tracker error',
+ timestamp: expect.any(String),
+ }),
+ );
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue('polygon-mainnet'),
+ getNetworkClientById: mockGetNetworkClientById,
+ },
+ },
+ );
});
it('continues execution when invalidation fails', async () => {
const mockError = new Error('Block tracker error');
const mockCheckForLatestBlock = jest.fn().mockRejectedValue(mockError);
- // @ts-expect-error - Mocking Engine for test
- Engine.context.NetworkController.getNetworkClientById.mockReturnValue({
+ const mockGetNetworkClientById = jest.fn().mockReturnValue({
blockTracker: {
checkForLatestBlock: mockCheckForLatestBlock,
},
});
- await withController(async ({ controller }) => {
- // Arrange
- const chainId = 137;
+ await withController(
+ async ({ controller }) => {
+ // Arrange
+ const chainId = 137;
- // Act & Assert - should not throw
- await expect(
- // eslint-disable-next-line dot-notation
- controller['invalidateQueryCache'](chainId),
- ).resolves.not.toThrow();
- });
+ // Act & Assert - should not throw
+ await expect(
+ // eslint-disable-next-line dot-notation
+ controller['invalidateQueryCache'](chainId),
+ ).resolves.not.toThrow();
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue('polygon-mainnet'),
+ getNetworkClientById: mockGetNetworkClientById,
+ },
+ },
+ );
});
});
@@ -5371,24 +5548,33 @@ describe('PredictController', () => {
it('calls invalidateQueryCache before fetching fresh balance', async () => {
const mockCheckForLatestBlock = jest.fn().mockResolvedValue(undefined);
- // @ts-expect-error - Mocking Engine for test
- Engine.context.NetworkController.getNetworkClientById.mockReturnValue({
+ const mockGetNetworkClientById = jest.fn().mockReturnValue({
blockTracker: {
checkForLatestBlock: mockCheckForLatestBlock,
},
});
mockPolymarketProvider.getBalance.mockResolvedValue(1000);
- await withController(async ({ controller }) => {
- // Act
- await controller.getBalance({
- providerId: 'polymarket',
- });
+ await withController(
+ async ({ controller }) => {
+ // Act
+ await controller.getBalance({
+ providerId: 'polymarket',
+ });
- // Assert
- expect(mockCheckForLatestBlock).toHaveBeenCalled();
- expect(mockPolymarketProvider.getBalance).toHaveBeenCalled();
- });
+ // Assert
+ expect(mockCheckForLatestBlock).toHaveBeenCalled();
+ expect(mockPolymarketProvider.getBalance).toHaveBeenCalled();
+ },
+ {
+ mocks: {
+ findNetworkClientIdByChainId: jest
+ .fn()
+ .mockReturnValue('polygon-mainnet'),
+ getNetworkClientById: mockGetNetworkClientById,
+ },
+ },
+ );
});
it('fetches balance when validUntil equals current time', async () => {
diff --git a/app/components/UI/Predict/controllers/PredictController.ts b/app/components/UI/Predict/controllers/PredictController.ts
index 3ec84f8273d..4e9c5f23a3a 100644
--- a/app/components/UI/Predict/controllers/PredictController.ts
+++ b/app/components/UI/Predict/controllers/PredictController.ts
@@ -1,4 +1,6 @@
import { AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller';
+import { AccountTreeControllerGetAccountsFromSelectedAccountGroupAction } from '@metamask/account-tree-controller';
+import { isEvmAccountType } from '@metamask/keyring-api';
import {
BaseController,
ControllerGetStateAction,
@@ -11,8 +13,14 @@ import {
PersonalMessageParams,
SignTypedDataVersion,
TypedMessageParams,
+ KeyringControllerSignTypedMessageAction,
+ KeyringControllerSignPersonalMessageAction,
} from '@metamask/keyring-controller';
-import { NetworkControllerGetStateAction } from '@metamask/network-controller';
+import {
+ NetworkControllerGetStateAction,
+ NetworkControllerFindNetworkClientIdByChainIdAction,
+ NetworkControllerGetNetworkClientByIdAction,
+} from '@metamask/network-controller';
import {
TransactionControllerEstimateGasAction,
TransactionControllerTransactionConfirmedEvent,
@@ -22,11 +30,14 @@ import {
TransactionMeta,
TransactionType,
} from '@metamask/transaction-controller';
+import {
+ RemoteFeatureFlagControllerGetStateAction,
+ RemoteFeatureFlagControllerStateChangeEvent,
+} from '@metamask/remote-feature-flag-controller';
import { Hex, hexToNumber, numberToHex } from '@metamask/utils';
import performance from 'react-native-performance';
import { MetaMetrics, MetaMetricsEvents } from '../../../../core/Analytics';
import { MetricsEventBuilder } from '../../../../core/Analytics/MetricsEventBuilder';
-import Engine from '../../../../core/Engine';
import DevLogger from '../../../../core/SDKConnect/utils/DevLogger';
import Logger, { type LoggerErrorOptions } from '../../../../util/Logger';
import {
@@ -82,7 +93,6 @@ import {
} from '../types';
import { ensureError } from '../utils/predictErrorHandler';
import { PREDICT_CONSTANTS, PREDICT_ERROR_CODES } from '../constants/errors';
-import { getEvmAccountFromSelectedAccountGroup } from '../utils/accounts';
import { GEO_BLOCKED_COUNTRIES } from '../constants/geoblock';
import { MATIC_CONTRACTS } from '../providers/polymarket/constants';
import {
@@ -228,8 +238,14 @@ export type PredictControllerActions =
*/
type AllowedActions =
| AccountsControllerGetSelectedAccountAction
+ | AccountTreeControllerGetAccountsFromSelectedAccountGroupAction
| NetworkControllerGetStateAction
- | TransactionControllerEstimateGasAction;
+ | NetworkControllerFindNetworkClientIdByChainIdAction
+ | NetworkControllerGetNetworkClientByIdAction
+ | TransactionControllerEstimateGasAction
+ | KeyringControllerSignTypedMessageAction
+ | KeyringControllerSignPersonalMessageAction
+ | RemoteFeatureFlagControllerGetStateAction;
/**
* External events the PredictController can subscribe to
@@ -238,7 +254,8 @@ type AllowedEvents =
| TransactionControllerTransactionSubmittedEvent
| TransactionControllerTransactionConfirmedEvent
| TransactionControllerTransactionFailedEvent
- | TransactionControllerTransactionRejectedEvent;
+ | TransactionControllerTransactionRejectedEvent
+ | RemoteFeatureFlagControllerStateChangeEvent;
/**
* PredictController messenger constraints
@@ -407,27 +424,42 @@ export class PredictController extends BaseController<
* @private
*/
private getSigner(address?: string): Signer {
- const { KeyringController } = Engine.context;
- const selectedAddress =
- address ?? getEvmAccountFromSelectedAccountGroup()?.address ?? '0x0';
+ const selectedAddress = address ?? this.getEvmAccountAddress();
return {
address: selectedAddress,
signTypedMessage: (
_params: TypedMessageParams,
_version: SignTypedDataVersion,
- ) => KeyringController.signTypedMessage(_params, _version),
+ ) =>
+ this.messenger.call(
+ 'KeyringController:signTypedMessage',
+ _params,
+ _version,
+ ),
signPersonalMessage: (_params: PersonalMessageParams) =>
- KeyringController.signPersonalMessage(_params),
+ this.messenger.call('KeyringController:signPersonalMessage', _params),
};
}
+ private getEvmAccountAddress(): string {
+ const accounts = this.messenger.call(
+ 'AccountTreeController:getAccountsFromSelectedAccountGroup',
+ );
+ const evmAccount = accounts.find(
+ (account) => account && isEvmAccountType(account.type),
+ );
+ return evmAccount?.address ?? '0x0';
+ }
+
private async invalidateQueryCache(chainId: number) {
- const { NetworkController } = Engine.context;
- const networkClientId = NetworkController.findNetworkClientIdByChainId(
+ const networkClientId = this.messenger.call(
+ 'NetworkController:findNetworkClientIdByChainId',
numberToHex(chainId),
);
- const networkClient =
- NetworkController.getNetworkClientById(networkClientId);
+ const networkClient = this.messenger.call(
+ 'NetworkController:getNetworkClientById',
+ networkClientId,
+ );
try {
await networkClient.blockTracker.checkForLatestBlock();
} catch (error) {
@@ -474,9 +506,11 @@ export class PredictController extends BaseController<
throw new Error('Provider not available');
}
- const { RemoteFeatureFlagController } = Engine.context;
+ const remoteFeatureFlagState = this.messenger.call(
+ 'RemoteFeatureFlagController:getState',
+ );
const liveSportsFlag =
- (RemoteFeatureFlagController.state.remoteFeatureFlags
+ (remoteFeatureFlagState.remoteFeatureFlags
.predictLiveSports as unknown as PredictLiveSportsFlag | undefined) ??
DEFAULT_LIVE_SPORTS_FLAG;
const liveSportsLeagues = liveSportsFlag.enabled
@@ -484,7 +518,7 @@ export class PredictController extends BaseController<
: [];
const marketHighlightsFlag =
- (RemoteFeatureFlagController.state.remoteFeatureFlags
+ (remoteFeatureFlagState.remoteFeatureFlags
.predictMarketHighlights as unknown as
| PredictMarketHighlightsFlag
| undefined) ?? DEFAULT_MARKET_HIGHLIGHTS_FLAG;
@@ -619,9 +653,11 @@ export class PredictController extends BaseController<
throw new Error('Provider not available');
}
- const { RemoteFeatureFlagController } = Engine.context;
+ const remoteFeatureFlagState = this.messenger.call(
+ 'RemoteFeatureFlagController:getState',
+ );
const liveSportsFlag =
- (RemoteFeatureFlagController.state.remoteFeatureFlags
+ (remoteFeatureFlagState.remoteFeatureFlags
.predictLiveSports as unknown as PredictLiveSportsFlag | undefined) ??
DEFAULT_LIVE_SPORTS_FLAG;
const liveSportsLeagues = liveSportsFlag.enabled
@@ -1427,9 +1463,11 @@ export class PredictController extends BaseController<
throw new Error('Provider not available');
}
- const { RemoteFeatureFlagController } = Engine.context;
+ const remoteFeatureFlagState = this.messenger.call(
+ 'RemoteFeatureFlagController:getState',
+ );
const feeCollection =
- (RemoteFeatureFlagController.state.remoteFeatureFlags
+ (remoteFeatureFlagState.remoteFeatureFlags
.predictFeeCollection as unknown as
| PredictFeeCollection
| undefined) ?? DEFAULT_FEE_COLLECTION_FLAG;
@@ -1683,8 +1721,8 @@ export class PredictController extends BaseController<
}
// Find network client - can fail if chain is not supported
- const { NetworkController } = Engine.context;
- const networkClientId = NetworkController.findNetworkClientIdByChainId(
+ const networkClientId = this.messenger.call(
+ 'NetworkController:findNetworkClientIdByChainId',
numberToHex(chainId),
);
@@ -1977,9 +2015,10 @@ export class PredictController extends BaseController<
providerId: params.providerId,
});
- const { NetworkController } = Engine.context;
- const networkClientId =
- NetworkController.findNetworkClientIdByChainId(chainId);
+ const networkClientId = this.messenger.call(
+ 'NetworkController:findNetworkClientIdByChainId',
+ chainId,
+ );
if (!networkClientId) {
throw new Error(`Network client not found for chain ID: ${chainId}`);
@@ -2222,13 +2261,13 @@ export class PredictController extends BaseController<
};
});
- const { NetworkController } = Engine.context;
-
const { batchId } = await addTransactionBatch({
from: signer.address as Hex,
origin: ORIGIN_METAMASK,
- networkClientId:
- NetworkController.findNetworkClientIdByChainId(chainId),
+ networkClientId: this.messenger.call(
+ 'NetworkController:findNetworkClientIdByChainId',
+ chainId,
+ ),
disableHook: true,
disableSequential: true,
requireApproval: true,
@@ -2324,8 +2363,8 @@ export class PredictController extends BaseController<
const chainId = this.state.withdrawTransaction.chainId;
- const { NetworkController } = Engine.context;
- const networkClientId = NetworkController.findNetworkClientIdByChainId(
+ const networkClientId = this.messenger.call(
+ 'NetworkController:findNetworkClientIdByChainId',
numberToHex(chainId),
);
diff --git a/app/components/UI/Ramp/hooks/useAnalytics.ts b/app/components/UI/Ramp/hooks/useAnalytics.ts
index 90de87141f6..b5e35c48969 100644
--- a/app/components/UI/Ramp/hooks/useAnalytics.ts
+++ b/app/components/UI/Ramp/hooks/useAnalytics.ts
@@ -5,6 +5,7 @@ import { AnalyticsEvents as DepositEvents } from '../Deposit/types';
import { MetaMetricsEvents } from '../../../../core/Analytics';
import { analytics } from '../../../../util/analytics/analytics';
import { AnalyticsEventBuilder } from '../../../../util/analytics/AnalyticsEventBuilder';
+import type { AnalyticsUnfilteredProperties } from '../../../../util/analytics/analytics.types';
interface MergedRampEvents extends AggregatorEvents, DepositEvents {}
@@ -14,7 +15,7 @@ export function trackEvent(
) {
analytics.trackEvent(
AnalyticsEventBuilder.createEventBuilder(MetaMetricsEvents[eventType])
- .addProperties({ ...params })
+ .addProperties(params as AnalyticsUnfilteredProperties)
.build(),
);
}
diff --git a/app/components/UI/Ramp/index.test.tsx b/app/components/UI/Ramp/index.test.tsx
index b54e56368b8..ee26683de5c 100644
--- a/app/components/UI/Ramp/index.test.tsx
+++ b/app/components/UI/Ramp/index.test.tsx
@@ -14,6 +14,11 @@ import getAggregatorAnalyticsPayload from './Aggregator/utils/getAggregatorAnaly
const mockNavigate = jest.fn();
+jest.mock('./hooks/useHydrateRampsController', () => ({
+ __esModule: true,
+ default: jest.fn(),
+}));
+
jest.mock('@react-navigation/native', () => {
const actual = jest.requireActual('@react-navigation/native');
return {
diff --git a/app/components/UI/Ramp/index.tsx b/app/components/UI/Ramp/index.tsx
index fc175659349..bd0f817a3dd 100644
--- a/app/components/UI/Ramp/index.tsx
+++ b/app/components/UI/Ramp/index.tsx
@@ -34,6 +34,7 @@ import Routes from '../../../constants/navigation/Routes';
import getOrderAnalyticsPayload from './utils/getOrderAnalyticsPayload';
import { NativeRampsSdk } from '@consensys/native-ramps-sdk';
import useDetectGeolocation from './hooks/useDetectGeolocation';
+import useHydrateRampsController from './hooks/useHydrateRampsController';
import useRampsSmartRouting from './hooks/useRampsSmartRouting';
const POLLING_FREQUENCY = AppConstants.FIAT_ORDERS.POLLING_FREQUENCY;
@@ -117,6 +118,7 @@ const styles = StyleSheet.create({
});
function FiatOrders() {
+ useHydrateRampsController();
useFetchRampNetworks();
useDetectGeolocation();
useRampsSmartRouting();
diff --git a/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx b/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx
index 3a4f45d6965..5142d41c985 100644
--- a/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx
+++ b/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx
@@ -53,6 +53,10 @@ jest.mock('../../../../selectors/rewards', () => ({
selectRewardsSubscriptionId: jest.fn(),
}));
+jest.mock('../../../../selectors/featureFlagController/rewards', () => ({
+ selectSnapshotsRewardsEnabledFlag: jest.fn(),
+}));
+
jest.mock(
'../../../../selectors/multichainAccounts/accountTreeController',
() => ({
@@ -69,6 +73,7 @@ import {
} from '../../../../reducers/rewards/selectors';
import { selectRewardsSubscriptionId } from '../../../../selectors/rewards';
import { selectSelectedAccountGroup } from '../../../../selectors/multichainAccounts/accountTreeController';
+import { selectSnapshotsRewardsEnabledFlag } from '../../../../selectors/featureFlagController/rewards';
const mockSelectActiveTab = selectActiveTab as jest.MockedFunction<
typeof selectActiveTab
@@ -95,6 +100,10 @@ const mockSelectSelectedAccountGroup =
selectSelectedAccountGroup as jest.MockedFunction<
typeof selectSelectedAccountGroup
>;
+const mockSelectSnapshotsRewardsEnabledFlag =
+ selectSnapshotsRewardsEnabledFlag as jest.MockedFunction<
+ typeof selectSnapshotsRewardsEnabledFlag
+ >;
// Mock theme
jest.mock('../../../../util/theme', () => ({
@@ -168,7 +177,7 @@ jest.mock('../../../../../locales/i18n', () => ({
const translations: Record = {
'rewards.main_title': 'Rewards',
'rewards.tab_overview_title': 'Overview',
- 'rewards.tab_levels_title': 'Levels',
+ 'rewards.tab_snapshots_title': 'Snapshots',
'rewards.tab_activity_title': 'Activity',
'rewards.not_implemented': 'Not implemented yet',
};
@@ -235,16 +244,16 @@ jest.mock('../components/Tabs/RewardsOverview', () => ({
},
}));
-jest.mock('../components/Tabs/RewardsLevels', () => ({
+jest.mock('../components/Tabs/RewardsSnapshots', () => ({
__esModule: true,
- default: function MockRewardsLevels({ tabLabel }: { tabLabel: string }) {
+ default: function MockRewardsSnapshots({ tabLabel }: { tabLabel: string }) {
const ReactActual = jest.requireActual('react');
const { View, Text } = jest.requireActual('react-native');
return ReactActual.createElement(
View,
- { testID: 'rewards-levels-tab' },
- ReactActual.createElement(Text, null, tabLabel || 'Levels'),
+ { testID: 'rewards-snapshots-tab' },
+ ReactActual.createElement(Text, null, tabLabel || 'Snapshots'),
);
},
}));
@@ -359,6 +368,10 @@ jest.mock('../hooks/useRewardDashboardModals', () => ({
useRewardDashboardModals: jest.fn(),
}));
+jest.mock('../hooks/useBulkLinkState', () => ({
+ useBulkLinkState: jest.fn(),
+}));
+
jest.mock('../utils', () => ({
convertInternalAccountToCaipAccountId: jest.fn(),
}));
@@ -544,6 +557,7 @@ jest.spyOn(Alert, 'alert').mockImplementation(mockAlert);
import { useRewardOptinSummary } from '../hooks/useRewardOptinSummary';
import { useLinkAccountGroup } from '../hooks/useLinkAccountGroup';
import { useRewardDashboardModals } from '../hooks/useRewardDashboardModals';
+import { useBulkLinkState } from '../hooks/useBulkLinkState';
import { convertInternalAccountToCaipAccountId } from '../utils';
import { InternalAccount } from '@metamask/keyring-internal-api';
import { AccountGroupType, AccountWalletType } from '@metamask/account-api';
@@ -558,6 +572,9 @@ const mockUseRewardDashboardModals =
useRewardDashboardModals as jest.MockedFunction<
typeof useRewardDashboardModals
>;
+const mockUseBulkLinkState = useBulkLinkState as jest.MockedFunction<
+ typeof useBulkLinkState
+>;
const mockConvertInternalAccountToCaipAccountId =
convertInternalAccountToCaipAccountId as jest.MockedFunction<
typeof convertInternalAccountToCaipAccountId
@@ -571,6 +588,10 @@ describe('RewardsDashboard', () => {
const mockShowNotSupportedModal = jest.fn();
const mockHasShownModal = jest.fn();
const mockResetSessionTracking = jest.fn();
+ const mockResumeBulkLink = jest.fn();
+ const mockStartBulkLink = jest.fn();
+ const mockCancelBulkLink = jest.fn();
+ const mockResetBulkLink = jest.fn();
const mockSelectedAccount = {
id: 'account-1',
@@ -610,6 +631,7 @@ describe('RewardsDashboard', () => {
hideCurrentAccountNotOptedInBannerArray: [],
selectedAccount: mockSelectedAccount,
selectedAccountGroup: mockSelectedAccountGroup,
+ isSnapshotsEnabled: true, // Enable snapshots by default in tests
};
const defaultHookValues = {
@@ -637,6 +659,22 @@ describe('RewardsDashboard', () => {
resetSessionTrackingForCurrentAccountGroup: jest.fn(),
resetAllSessionTracking: jest.fn(),
},
+ useBulkLinkState: {
+ startBulkLink: mockStartBulkLink,
+ cancelBulkLink: mockCancelBulkLink,
+ resetBulkLink: mockResetBulkLink,
+ resumeBulkLink: mockResumeBulkLink,
+ isRunning: false,
+ wasInterrupted: false,
+ isCompleted: false,
+ hasFailures: false,
+ isFullySuccessful: false,
+ totalAccounts: 0,
+ linkedAccounts: 0,
+ failedAccounts: 0,
+ accountProgress: 0,
+ processedAccounts: 0,
+ },
};
beforeEach(() => {
@@ -648,6 +686,10 @@ describe('RewardsDashboard', () => {
mockShowNotSupportedModal.mockClear();
mockHasShownModal.mockClear();
mockResetSessionTracking.mockClear();
+ mockResumeBulkLink.mockClear();
+ mockStartBulkLink.mockClear();
+ mockCancelBulkLink.mockClear();
+ mockResetBulkLink.mockClear();
mockTrackEvent.mockClear();
mockCreateEventBuilder.mockClear();
mockBuild.mockClear();
@@ -682,6 +724,9 @@ describe('RewardsDashboard', () => {
mockSelectSelectedAccountGroup.mockReturnValue(
defaultSelectorValues.selectedAccountGroup,
);
+ mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(
+ defaultSelectorValues.isSnapshotsEnabled,
+ );
// Setup hook mocks
mockUseRewardOptinSummary.mockReturnValue(
@@ -693,6 +738,7 @@ describe('RewardsDashboard', () => {
mockUseRewardDashboardModals.mockReturnValue(
defaultHookValues.useRewardDashboardModals,
);
+ mockUseBulkLinkState.mockReturnValue(defaultHookValues.useBulkLinkState);
mockConvertInternalAccountToCaipAccountId.mockReturnValue('eip155:1:0x123');
// Setup default modal hook behavior - return false for all modal types by default
@@ -711,6 +757,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
});
@@ -764,6 +812,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -796,6 +846,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -828,6 +880,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -858,6 +912,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -900,11 +956,11 @@ describe('RewardsDashboard', () => {
it('should handle tab change when user selects different tab', () => {
// Act
const { getByTestId } = render();
- const levelsTab = getByTestId('tab-1');
- fireEvent.press(levelsTab);
+ const snapshotsTab = getByTestId('tab-1');
+ fireEvent.press(snapshotsTab);
// Assert
- expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('levels'));
+ expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('snapshots'));
});
it('should render all tab options', () => {
@@ -926,14 +982,14 @@ describe('RewardsDashboard', () => {
expect(getByTestId('rewards-overview-tab')).toBeTruthy();
});
- it('should switch to levels tab when levels tab is pressed', () => {
+ it('switches to snapshots tab when snapshots tab is pressed', () => {
// Act
const { getByTestId } = render();
- const levelsTab = getByTestId('tab-1');
- fireEvent.press(levelsTab);
+ const snapshotsTab = getByTestId('tab-1');
+ fireEvent.press(snapshotsTab);
// Assert
- expect(getByTestId('rewards-levels-tab')).toBeTruthy();
+ expect(getByTestId('rewards-snapshots-tab')).toBeTruthy();
});
it('should switch to activity tab when activity tab is pressed', () => {
@@ -946,7 +1002,7 @@ describe('RewardsDashboard', () => {
expect(getByTestId('rewards-activity-tab')).toBeTruthy();
});
- it('should not allow tab switching when user is not opted in', () => {
+ it('allows tab switching when user is not opted in', () => {
// Arrange
const futureDateObj = new Date(futureDate);
mockSelectRewardsSubscriptionId.mockReturnValue(null);
@@ -957,16 +1013,145 @@ describe('RewardsDashboard', () => {
if (selector === selectRewardsSubscriptionId) return null;
if (selector === selectSeasonId) return currentSeasonId;
if (selector === selectSeasonEndDate) return futureDateObj;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
// Act
const { getByTestId } = render();
- const levelsTab = getByTestId('tab-1');
- fireEvent.press(levelsTab);
+ const snapshotsTab = getByTestId('tab-1');
+ fireEvent.press(snapshotsTab);
- // Assert - should show levels tab (tab change occurred)
- expect(getByTestId('rewards-levels-tab')).toBeTruthy();
+ // Assert - tab change occurred
+ expect(getByTestId('rewards-snapshots-tab')).toBeTruthy();
+ });
+ });
+
+ describe('tabComponents when isSnapshotsEnabled is false', () => {
+ beforeEach(() => {
+ mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false);
+ mockUseSelector.mockImplementation((selector) => {
+ if (selector === selectActiveTab)
+ return defaultSelectorValues.activeTab;
+ if (selector === selectRewardsSubscriptionId)
+ return defaultSelectorValues.subscriptionId;
+ if (selector === selectSeasonId) return currentSeasonId;
+ if (selector === selectSeasonEndDate)
+ return defaultSelectorValues.seasonEndDate;
+ if (selector === selectHideUnlinkedAccountsBanner)
+ return defaultSelectorValues.hideUnlinkedAccountsBanner;
+ if (selector === selectHideCurrentAccountNotOptedInBannerArray)
+ return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
+ if (selector === selectSelectedAccountGroup)
+ return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag) return false;
+ return undefined;
+ });
+ });
+
+ it('renders only overview and activity tabs when snapshots is disabled', () => {
+ // Act
+ const { getByTestId, queryByTestId } = render();
+
+ // Assert - verify only 2 tabs are rendered
+ expect(getByTestId('tab-headers')).toBeTruthy();
+ expect(getByTestId('tab-0')).toBeTruthy();
+ expect(getByTestId('tab-1')).toBeTruthy();
+ expect(queryByTestId('tab-2')).toBeNull();
+ });
+
+ it('does not render snapshots tab when snapshots is disabled', () => {
+ // Act
+ const { queryByTestId } = render();
+
+ // Assert - snapshots tab should not be visible by default
+ expect(queryByTestId('rewards-snapshots-tab')).toBeNull();
+ });
+
+ it('renders overview tab as first tab when snapshots is disabled', () => {
+ // Act
+ const { getByTestId } = render();
+
+ // Assert
+ expect(getByTestId('rewards-overview-tab')).toBeTruthy();
+ });
+
+ it('switches directly to activity tab at index 1 when snapshots is disabled', () => {
+ // Act
+ const { getByTestId } = render();
+ const activityTab = getByTestId('tab-1');
+ fireEvent.press(activityTab);
+
+ // Assert - activity tab is now at index 1 instead of index 2
+ expect(getByTestId('rewards-activity-tab')).toBeTruthy();
+ });
+
+ it('dispatches setActiveTab with activity when tab-1 is pressed and snapshots is disabled', () => {
+ // Act
+ const { getByTestId } = render();
+ const activityTab = getByTestId('tab-1');
+ fireEvent.press(activityTab);
+
+ // Assert - tab-1 should now be activity, not snapshots
+ expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('activity'));
+ });
+
+ it('resets activeTab to overview when snapshots tab becomes unavailable', () => {
+ // Arrange - activeTab is 'snapshots' but isSnapshotsEnabled is false
+ mockSelectActiveTab.mockReturnValue('snapshots');
+ mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false);
+ mockUseSelector.mockImplementation((selector) => {
+ if (selector === selectActiveTab) return 'snapshots';
+ if (selector === selectRewardsSubscriptionId)
+ return defaultSelectorValues.subscriptionId;
+ if (selector === selectSeasonId) return currentSeasonId;
+ if (selector === selectSeasonEndDate)
+ return defaultSelectorValues.seasonEndDate;
+ if (selector === selectHideUnlinkedAccountsBanner)
+ return defaultSelectorValues.hideUnlinkedAccountsBanner;
+ if (selector === selectHideCurrentAccountNotOptedInBannerArray)
+ return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
+ if (selector === selectSelectedAccountGroup)
+ return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag) return false;
+ return undefined;
+ });
+
+ // Act
+ render();
+
+ // Assert - should dispatch setActiveTab('overview') to reset the invalid tab
+ expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('overview'));
+ });
+
+ it('does not reset activeTab when current tab is still available', () => {
+ // Arrange - activeTab is 'activity' and isSnapshotsEnabled is false
+ // activity tab should still be available
+ mockSelectActiveTab.mockReturnValue('activity');
+ mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false);
+ mockUseSelector.mockImplementation((selector) => {
+ if (selector === selectActiveTab) return 'activity';
+ if (selector === selectRewardsSubscriptionId)
+ return defaultSelectorValues.subscriptionId;
+ if (selector === selectSeasonId) return currentSeasonId;
+ if (selector === selectSeasonEndDate)
+ return defaultSelectorValues.seasonEndDate;
+ if (selector === selectHideUnlinkedAccountsBanner)
+ return defaultSelectorValues.hideUnlinkedAccountsBanner;
+ if (selector === selectHideCurrentAccountNotOptedInBannerArray)
+ return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
+ if (selector === selectSelectedAccountGroup)
+ return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag) return false;
+ return undefined;
+ });
+
+ // Act
+ render();
+
+ // Assert - should NOT dispatch setActiveTab since 'activity' is still valid
+ expect(mockDispatch).not.toHaveBeenCalledWith(setActiveTab('overview'));
});
});
@@ -989,6 +1174,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1020,6 +1207,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1051,6 +1240,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1079,6 +1270,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1192,6 +1385,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1251,6 +1446,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1286,6 +1483,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1393,6 +1592,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1467,6 +1668,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1497,6 +1700,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1577,6 +1782,8 @@ describe('RewardsDashboard', () => {
return [{ accountGroupId: 'keyring:wallet1/1', hide: true }];
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1607,6 +1814,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1686,6 +1895,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1758,6 +1969,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1926,6 +2139,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -1987,6 +2202,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -2048,6 +2265,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
@@ -2136,7 +2355,7 @@ describe('RewardsDashboard', () => {
expect(mockCreateEventBuilder).not.toHaveBeenCalled();
});
- it('should track tab viewed event when activeTab changes', () => {
+ it('tracks tab viewed event when activeTab changes', () => {
// Arrange
const { rerender } = render();
mockTrackEvent.mockClear();
@@ -2144,9 +2363,9 @@ describe('RewardsDashboard', () => {
mockBuild.mockClear();
// Act - change active tab
- mockSelectActiveTab.mockReturnValue('levels');
+ mockSelectActiveTab.mockReturnValue('snapshots');
mockUseSelector.mockImplementation((selector) => {
- if (selector === selectActiveTab) return 'levels';
+ if (selector === selectActiveTab) return 'snapshots';
if (selector === selectRewardsSubscriptionId)
return defaultSelectorValues.subscriptionId;
if (selector === selectSeasonId) return currentSeasonId;
@@ -2158,6 +2377,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
rerender();
@@ -2166,12 +2387,12 @@ describe('RewardsDashboard', () => {
expect(mockCreateEventBuilder).toHaveBeenCalledWith(
'rewards_dashboard_tab_viewed',
);
- expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'levels' });
+ expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'snapshots' });
expect(mockBuild).toHaveBeenCalled();
expect(mockTrackEvent).toHaveBeenCalledWith({ event: 'mock-event' });
});
- it('should track tab viewed event for each tab change', () => {
+ it('tracks tab viewed event for each tab change', () => {
// Arrange
const { rerender } = render();
mockTrackEvent.mockClear();
@@ -2179,10 +2400,10 @@ describe('RewardsDashboard', () => {
mockBuild.mockClear();
mockAddProperties.mockClear();
- // Act - change to levels tab
- mockSelectActiveTab.mockReturnValue('levels');
+ // Act - change to snapshots tab
+ mockSelectActiveTab.mockReturnValue('snapshots');
mockUseSelector.mockImplementation((selector) => {
- if (selector === selectActiveTab) return 'levels';
+ if (selector === selectActiveTab) return 'snapshots';
if (selector === selectRewardsSubscriptionId)
return defaultSelectorValues.subscriptionId;
if (selector === selectSeasonId) return currentSeasonId;
@@ -2194,12 +2415,14 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
rerender();
- // Assert - levels tab
- expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'levels' });
+ // Assert - snapshots tab
+ expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'snapshots' });
// Act - change to activity tab
mockSelectActiveTab.mockReturnValue('activity');
@@ -2216,6 +2439,8 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
rerender();
@@ -2226,15 +2451,15 @@ describe('RewardsDashboard', () => {
});
describe('TabsList ref functionality', () => {
- it('should handle Redux state changes for activeTab without crashing', () => {
+ it('handles Redux state changes for activeTab without crashing', () => {
// Arrange
mockSelectActiveTab.mockReturnValue('overview');
const { rerender } = render();
- // Act - change activeTab in Redux to levels
- mockSelectActiveTab.mockReturnValue('levels');
+ // Act - change activeTab in Redux to snapshots
+ mockSelectActiveTab.mockReturnValue('snapshots');
mockUseSelector.mockImplementation((selector) => {
- if (selector === selectActiveTab) return 'levels';
+ if (selector === selectActiveTab) return 'snapshots';
if (selector === selectRewardsSubscriptionId)
return defaultSelectorValues.subscriptionId;
if (selector === selectSeasonId) return currentSeasonId;
@@ -2246,21 +2471,23 @@ describe('RewardsDashboard', () => {
return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray;
if (selector === selectSelectedAccountGroup)
return defaultSelectorValues.selectedAccountGroup;
+ if (selector === selectSnapshotsRewardsEnabledFlag)
+ return defaultSelectorValues.isSnapshotsEnabled;
return undefined;
});
- // Assert - should not crash when activeTab changes
+ // Assert - does not crash when activeTab changes
expect(() => rerender()).not.toThrow();
});
});
describe('component lifecycle', () => {
- it('should render without crashing', () => {
+ it('renders without crashing', () => {
// Act & Assert
expect(() => render()).not.toThrow();
});
- it('should cleanup properly when unmounted', () => {
+ it('cleans up properly when unmounted', () => {
// Act
const { unmount } = render();
@@ -2268,4 +2495,81 @@ describe('RewardsDashboard', () => {
expect(() => unmount()).not.toThrow();
});
});
+
+ describe('bulk link auto-resume', () => {
+ it('calls resumeBulkLink when wasInterrupted is true and isRunning is false', () => {
+ // Arrange
+ mockUseBulkLinkState.mockReturnValue({
+ ...defaultHookValues.useBulkLinkState,
+ wasInterrupted: true,
+ isRunning: false,
+ });
+
+ // Act
+ render();
+
+ // Assert
+ expect(mockResumeBulkLink).toHaveBeenCalled();
+ });
+
+ it('does not call resumeBulkLink when wasInterrupted is false', () => {
+ // Arrange
+ mockUseBulkLinkState.mockReturnValue({
+ ...defaultHookValues.useBulkLinkState,
+ wasInterrupted: false,
+ isRunning: false,
+ });
+
+ // Act
+ render();
+
+ // Assert
+ expect(mockResumeBulkLink).not.toHaveBeenCalled();
+ });
+
+ it('does not call resumeBulkLink when isRunning is true', () => {
+ // Arrange
+ mockUseBulkLinkState.mockReturnValue({
+ ...defaultHookValues.useBulkLinkState,
+ wasInterrupted: true,
+ isRunning: true,
+ });
+
+ // Act
+ render();
+
+ // Assert
+ expect(mockResumeBulkLink).not.toHaveBeenCalled();
+ });
+
+ it('does not call resumeBulkLink when both wasInterrupted and isRunning are false', () => {
+ // Arrange
+ mockUseBulkLinkState.mockReturnValue({
+ ...defaultHookValues.useBulkLinkState,
+ wasInterrupted: false,
+ isRunning: false,
+ });
+
+ // Act
+ render();
+
+ // Assert
+ expect(mockResumeBulkLink).not.toHaveBeenCalled();
+ });
+
+ it('does not call resumeBulkLink when both wasInterrupted and isRunning are true', () => {
+ // Arrange
+ mockUseBulkLinkState.mockReturnValue({
+ ...defaultHookValues.useBulkLinkState,
+ wasInterrupted: true,
+ isRunning: true,
+ });
+
+ // Act
+ render();
+
+ // Assert
+ expect(mockResumeBulkLink).not.toHaveBeenCalled();
+ });
+ });
});
diff --git a/app/components/UI/Rewards/Views/RewardsDashboard.tsx b/app/components/UI/Rewards/Views/RewardsDashboard.tsx
index 47d242b53cc..d359ba230e8 100644
--- a/app/components/UI/Rewards/Views/RewardsDashboard.tsx
+++ b/app/components/UI/Rewards/Views/RewardsDashboard.tsx
@@ -34,13 +34,15 @@ import {
} from '../../../../reducers/rewards/selectors';
import SeasonStatus from '../components/SeasonStatus/SeasonStatus';
import { selectRewardsSubscriptionId } from '../../../../selectors/rewards';
+import { selectSnapshotsRewardsEnabledFlag } from '../../../../selectors/featureFlagController/rewards';
import { useRewardOptinSummary } from '../hooks/useRewardOptinSummary';
import {
useRewardDashboardModals,
RewardsDashboardModalType,
} from '../hooks/useRewardDashboardModals';
+import { useBulkLinkState } from '../hooks/useBulkLinkState';
import RewardsOverview from '../components/Tabs/RewardsOverview';
-import RewardsLevels from '../components/Tabs/RewardsLevels';
+import RewardsSnapshots from '../components/Tabs/RewardsSnapshots';
import RewardsActivity from '../components/Tabs/RewardsActivity';
import { TabsList } from '../../../../component-library/components-temp/Tabs';
import { TabsListRef } from '../../../../component-library/components-temp/Tabs/TabsList/TabsList.types';
@@ -65,6 +67,7 @@ const RewardsDashboard: React.FC = () => {
);
const seasonId = useSelector(selectSeasonId);
const seasonEndDate = useSelector(selectSeasonEndDate);
+ const isSnapshotsEnabled = useSelector(selectSnapshotsRewardsEnabledFlag);
const hideCurrentAccountNotOptedInBannerMap = useSelector(
selectHideCurrentAccountNotOptedInBannerArray,
);
@@ -100,6 +103,9 @@ const RewardsDashboard: React.FC = () => {
currentAccountGroupOptedInStatus,
} = useRewardOptinSummary();
+ // Use the bulk link state hook for resuming interrupted opt-in processes
+ const { wasInterrupted, isRunning, resumeBulkLink } = useBulkLinkState();
+
const totalOptedInAccountsSelectedGroup = useMemo(
() => optInBySelectedAccountGroup?.optedInAccounts?.length,
[optInBySelectedAccountGroup],
@@ -131,29 +137,48 @@ const RewardsDashboard: React.FC = () => {
);
}, [colors, navigation]);
- const tabOptions = useMemo(
- () => [
+ const tabOptions = useMemo(() => {
+ const options: {
+ value: 'overview' | 'snapshots' | 'activity';
+ label: string;
+ }[] = [
{
value: 'overview' as const,
label: strings('rewards.tab_overview_title'),
},
- {
- value: 'levels' as const,
- label: strings('rewards.tab_levels_title'),
- },
- {
- value: 'activity' as const,
- label: strings('rewards.tab_activity_title'),
- },
- ],
- [],
- );
+ ];
+
+ if (isSnapshotsEnabled) {
+ options.push({
+ value: 'snapshots' as const,
+ label: strings('rewards.tab_snapshots_title'),
+ });
+ }
+
+ options.push({
+ value: 'activity' as const,
+ label: strings('rewards.tab_activity_title'),
+ });
+
+ return options;
+ }, [isSnapshotsEnabled]);
const getActiveIndex = useCallback(
() => tabOptions.findIndex((tab) => tab.value === activeTab),
[tabOptions, activeTab],
);
+ // Reset activeTab to 'overview' if current tab becomes unavailable (e.g., snapshots disabled)
+ // This ensures Redux state stays in sync with the visible tab and analytics events are accurate
+ useEffect(() => {
+ const isCurrentTabAvailable = tabOptions.some(
+ (tab) => tab.value === activeTab,
+ );
+ if (!isCurrentTabAvailable) {
+ dispatch(setActiveTab('overview'));
+ }
+ }, [tabOptions, activeTab, dispatch]);
+
// Sync TabsList with Redux state changes
useEffect(() => {
const activeIndex = tabOptions.findIndex((tab) => tab.value === activeTab);
@@ -190,10 +215,48 @@ const RewardsDashboard: React.FC = () => {
[getActiveIndex, handleTabChange],
);
+ const tabComponents = useMemo(() => {
+ const tabs: React.ReactElement[] = [
+ ,
+ ];
+
+ if (isSnapshotsEnabled) {
+ tabs.push(
+ ,
+ );
+ }
+
+ tabs.push(
+ ,
+ );
+
+ return tabs;
+ }, [isSnapshotsEnabled]);
+
const [showPreviousSeasonSummary, setShowPreviousSeasonSummary] = useState<
boolean | null
>(null);
+ // Auto-resume interrupted bulk link process when screen comes into focus.
+ // This handles the case where the app was closed during a bulk opt-in process.
+ // The saga is idempotent - it re-fetches opt-in status to skip already-linked accounts.
+ useFocusEffect(
+ useCallback(() => {
+ if (wasInterrupted && !isRunning) {
+ resumeBulkLink();
+ }
+ }, [wasInterrupted, isRunning, resumeBulkLink]),
+ );
+
// Evaluate showPreviousSeasonSummary when screen comes into focus
useFocusEffect(
useCallback(() => {
@@ -328,23 +391,12 @@ const RewardsDashboard: React.FC = () => {
) : (
<>
-
+
+
+
{/* Tab View */}
-
-
-
-
-
+ {tabComponents}
>
)}
diff --git a/app/components/UI/Rewards/Views/RewardsView.constants.ts b/app/components/UI/Rewards/Views/RewardsView.constants.ts
index f104c8c4a8a..51a21845453 100644
--- a/app/components/UI/Rewards/Views/RewardsView.constants.ts
+++ b/app/components/UI/Rewards/Views/RewardsView.constants.ts
@@ -61,4 +61,10 @@ export const REWARDS_VIEW_SELECTORS = {
ACTIVITY_EVENT_ROW_DETAILS: 'activity-event-row-details',
ACTIVITY_EVENT_ROW_DATE: 'activity-event-row-date',
ACTIVITY_EVENT_ROW_BONUS: 'activity-event-row-bonus',
+ // Snapshots
+ TAB_CONTENT_SNAPSHOTS: 'rewards-view-tab-content-snapshots',
+ SNAPSHOTS_SECTION: 'rewards-view-snapshots-section',
+ SNAPSHOTS_ACTIVE_SECTION: 'rewards-view-snapshots-active-section',
+ SNAPSHOTS_UPCOMING_SECTION: 'rewards-view-snapshots-upcoming-section',
+ SNAPSHOTS_PREVIOUS_SECTION: 'rewards-view-snapshots-previous-section',
} as const;
diff --git a/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx b/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx
index b8f13910c0c..cf7314b30bc 100644
--- a/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx
+++ b/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx
@@ -1,9 +1,10 @@
-import React, { useCallback } from 'react';
+import React, { useCallback, useState } from 'react';
import { Image, useWindowDimensions } from 'react-native';
import { useSelector } from 'react-redux';
import { useTailwind } from '@metamask/design-system-twrnc-preset';
import { useOptin } from '../../hooks/useOptIn';
import { Box, Text, TextVariant } from '@metamask/design-system-react-native';
+import Checkbox from '../../../../../component-library/components/Checkbox';
import step1Img from '../../../../../images/rewards/rewards-onboarding-step1.png';
import Step1BgImg from '../../../../../images/rewards/rewards-onboarding-step1-bg.svg';
import { strings } from '../../../../../../locales/i18n';
@@ -35,39 +36,62 @@ const OnboardingNoActiveSeasonStep: React.FC<
const navigation = useNavigation();
const { width: screenWidth, height: screenHeight } = useWindowDimensions();
const { optin, optinError, optinLoading } = useOptin();
+ const [bulkLink, setBulkLink] = useState(false);
+
+ const handleBulkLinkToggle = useCallback(() => {
+ setBulkLink((prev) => !prev);
+ }, []);
const handleNext = useCallback(() => {
if (!canContinue()) {
return;
}
- optin({});
- }, [optin, canContinue]);
+ optin({ bulkLink });
+ }, [optin, canContinue, bulkLink]);
const renderStepInfo = () => (
-
- {/* Opt in error message */}
- {optinError && (
-
- )}
+ <>
+
+ {/* Opt in error message */}
+ {optinError && (
+
+ )}
- {/* Title and Description */}
-
-
- {strings('rewards.onboarding.no_active_season.title')}
-
-
-
- {strings('rewards.onboarding.no_active_season.description')}
-
+ {/* Title and Description */}
+
+
+ {strings('rewards.onboarding.no_active_season.title')}
+
+
+
+ {strings('rewards.onboarding.no_active_season.description')}
+
+
+
+ {/* Opt-in all accounts checkbox */}
+
+
+ {strings('rewards.onboarding.step4_bulk_link_checkbox')}
+
+ }
+ />
-
+ >
);
const renderLegalDisclaimer = () => (
diff --git a/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx b/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx
index da36b3870f4..5bb15817359 100644
--- a/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx
+++ b/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx
@@ -140,7 +140,9 @@ const OnboardingStepComponent: React.FC = ({
- {renderStepInfo()}
+
+ {renderStepInfo()}
+