diff --git a/.github/workflows/build-ios-e2e.yml b/.github/workflows/build-ios-e2e.yml index 44a894f2f8e..b15e262b24b 100644 --- a/.github/workflows/build-ios-e2e.yml +++ b/.github/workflows/build-ios-e2e.yml @@ -2,6 +2,21 @@ name: Build iOS E2E Apps on: workflow_call: + outputs: + app-uploaded: + description: 'Whether the app was successfully uploaded' + value: ${{ jobs.build-ios-apps.outputs.app-uploaded }} + inputs: + build_type: + description: 'The type of build to perform' + required: false + default: 'main' + type: string + metamask_environment: + description: 'The environment to build for' + required: false + default: 'qa' + type: string permissions: contents: read @@ -20,8 +35,8 @@ jobs: XCODE_BUILD_SETTINGS: 'COMPILER_INDEX_STORE_ENABLE=NO' GITHUB_CI: 'true' # This ensures it's available during pod install PLATFORM: ios - METAMASK_ENVIRONMENT: qa - METAMASK_BUILD_TYPE: main + METAMASK_ENVIRONMENT: ${{ inputs.metamask_environment }} + METAMASK_BUILD_TYPE: ${{ inputs.build_type }} IS_TEST: true E2E: 'true' IGNORE_BOXLOGS_DEVELOPMENT: true @@ -226,7 +241,7 @@ jobs: id: upload-app uses: actions/upload-artifact@v4 with: - name: MetaMask.app + name: ${{ inputs.build_type }}-${{ inputs.metamask_environment }}-MetaMask.app path: ios/build/Build/Products/Release-iphonesimulator/MetaMask.app retention-days: 7 if-no-files-found: error @@ -239,7 +254,7 @@ jobs: if: ${{ steps.cache-restore.outputs.cache-hit == 'true' || steps.cache-restore-main.outputs.cache-hit == 'true' }} uses: actions/upload-artifact@v4 with: - name: index.js.map + name: ${{ inputs.build_type }}-${{ inputs.metamask_environment }}-index.js.map path: sourcemaps/ios/index.js.map retention-days: 7 if-no-files-found: error diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e0e3a0faec..cd671f1a84f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -382,6 +382,23 @@ jobs: }} secrets: inherit + e2e-smoke-tests-ios-flask: + name: 'iOS Flask E2E Smoke Tests' + if: ${{ github.event_name != 'merge_group' && needs.needs_e2e_build.outputs.ios_changed == 'true' }} + permissions: + contents: read + id-token: write + needs: [needs_e2e_build, ios-tests-ready, smart-e2e-selection] + uses: ./.github/workflows/run-e2e-smoke-tests-ios-flask.yml + with: + changed_files: ${{ needs.needs_e2e_build.outputs.changed_files }} + selected_tags: >- + ${{ + (needs.smart-e2e-selection.outputs.ai_confidence >= 80 && needs.smart-e2e-selection.outputs.ai_e2e_test_tags) || + '["ALL"]' + }} + secrets: inherit + js-bundle-size-check: runs-on: ubuntu-latest steps: @@ -604,6 +621,7 @@ jobs: - e2e-smoke-tests-android - e2e-smoke-tests-ios - e2e-smoke-tests-android-flask + - e2e-smoke-tests-ios-flask steps: - run: | # Check if all non-E2E jobs passed @@ -629,9 +647,15 @@ jobs: exit 1 fi - FLASK_RESULT="${{ needs.e2e-smoke-tests-android-flask.result }}" - if [[ "$FLASK_RESULT" == "failure" ]] || [[ "$FLASK_RESULT" == "cancelled" ]]; then - echo "Android Flask E2E tests failed (result: $FLASK_RESULT)" + ANDROID_FLASK_RESULT="${{ needs.e2e-smoke-tests-android-flask.result }}" + if [[ "$ANDROID_FLASK_RESULT" == "failure" ]] || [[ "$ANDROID_FLASK_RESULT" == "cancelled" ]]; then + echo "Android Flask E2E tests failed (result: $ANDROID_FLASK_RESULT)" + exit 1 + fi + + IOS_FLASK_RESULT="${{ needs.e2e-smoke-tests-ios-flask.result }}" + if [[ "$IOS_FLASK_RESULT" == "failure" ]] || [[ "$IOS_FLASK_RESULT" == "cancelled" ]]; then + echo "iOS Flask E2E tests failed (result: $IOS_FLASK_RESULT)" exit 1 fi fi diff --git a/.github/workflows/run-e2e-smoke-tests-ios-flask.yml b/.github/workflows/run-e2e-smoke-tests-ios-flask.yml new file mode 100644 index 00000000000..96fbb2cd1ed --- /dev/null +++ b/.github/workflows/run-e2e-smoke-tests-ios-flask.yml @@ -0,0 +1,167 @@ +name: iOS Flask E2E Smoke Tests + +on: + workflow_call: + inputs: + selected_tags: + description: 'JSON array of selected tags from Smart E2E selection' + required: false + type: string + default: '["ALL"]' + changed_files: + description: 'Changed files' + required: false + type: string + default: '' + +permissions: + contents: read + id-token: write + +jobs: + repack-ios-flask-apps: + if: contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests') + name: 'Repack iOS Flask Apps' + runs-on: ghcr.io/cirruslabs/macos-runner:sequoia + permissions: + contents: read + id-token: write + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup iOS Environment + timeout-minutes: 15 + uses: MetaMask/github-tools/.github/actions/setup-e2e-env@v1 + with: + platform: ios + setup-simulator: false + + - name: Install dependencies + uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 #v3.0.2 + with: + timeout_minutes: 10 + max_attempts: 3 + retry_wait_seconds: 30 + command: yarn install --immutable + + - name: Setup project + run: yarn setup:github-ci --no-build-android + + - name: Download Main iOS App artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts/ + pattern: main-*-MetaMask.app + + - name: Setup Main iOS App artifacts + run: | + mkdir -p ios/build/Build/Products/Release-iphonesimulator/ + cp -R artifacts/main-qa-MetaMask.app ios/build/Build/Products/Release-iphonesimulator/MetaMask.app + + - name: Repack Main iOS App + run: node scripts/repack.js + env: + PLATFORM: ios + METAMASK_ENVIRONMENT: e2e + METAMASK_BUILD_TYPE: flask + IS_TEST: 'true' + E2E: 'true' + IGNORE_BOXLOGS_DEVELOPMENT: 'true' + GITHUB_CI: 'true' + CI: 'true' + NODE_OPTIONS: '--max-old-space-size=8192' + BRIDGE_USE_DEV_APIS: 'true' + RAMP_INTERNAL_BUILD: 'true' + SEEDLESS_ONBOARDING_ENABLED: 'true' + MM_NOTIFICATIONS_UI_ENABLED: 'true' + MM_SECURITY_ALERTS_API_ENABLED: 'true' + FEATURES_ANNOUNCEMENTS_ACCESS_TOKEN: ${{ secrets.FEATURES_ANNOUNCEMENTS_ACCESS_TOKEN }} + FEATURES_ANNOUNCEMENTS_SPACE_ID: ${{ secrets.FEATURES_ANNOUNCEMENTS_SPACE_ID }} + SEGMENT_WRITE_KEY_QA: ${{ secrets.SEGMENT_WRITE_KEY_QA }} + SEGMENT_WRITE_KEY_FLASK: ${{ secrets.SEGMENT_WRITE_KEY_FLASK }} + SEGMENT_PROXY_URL_QA: ${{ secrets.SEGMENT_PROXY_URL_QA }} + SEGMENT_PROXY_URL_FLASK: ${{ secrets.SEGMENT_PROXY_URL_FLASK }} + SEGMENT_DELETE_API_SOURCE_ID_QA: ${{ secrets.SEGMENT_DELETE_API_SOURCE_ID_QA }} + SEGMENT_DELETE_API_SOURCE_ID_FLASK: ${{ secrets.SEGMENT_DELETE_API_SOURCE_ID_FLASK }} + SEGMENT_REGULATIONS_ENDPOINT_QA: ${{ secrets.SEGMENT_REGULATIONS_ENDPOINT_QA }} + SEGMENT_REGULATIONS_ENDPOINT_FLASK: ${{ secrets.SEGMENT_REGULATIONS_ENDPOINT_FLASK }} + MM_SENTRY_DSN_TEST: ${{ secrets.MM_SENTRY_DSN_TEST }} + MM_SENTRY_AUTH_TOKEN: ${{ secrets.MM_SENTRY_AUTH_TOKEN }} + MAIN_IOS_GOOGLE_CLIENT_ID_UAT: ${{ secrets.MAIN_IOS_GOOGLE_CLIENT_ID_UAT }} + FLASK_IOS_GOOGLE_CLIENT_ID_PROD: ${{ secrets.FLASK_IOS_GOOGLE_CLIENT_ID_PROD }} + MAIN_IOS_GOOGLE_REDIRECT_URI_UAT: ${{ secrets.MAIN_IOS_GOOGLE_REDIRECT_URI_UAT }} + FLASK_IOS_GOOGLE_REDIRECT_URI_PROD: ${{ secrets.FLASK_IOS_GOOGLE_REDIRECT_URI_PROD }} + MAIN_ANDROID_APPLE_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_APPLE_CLIENT_ID_UAT }} + FLASK_ANDROID_APPLE_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_APPLE_CLIENT_ID_PROD }} + MAIN_ANDROID_GOOGLE_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_GOOGLE_CLIENT_ID_UAT }} + FLASK_ANDROID_GOOGLE_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_GOOGLE_CLIENT_ID_PROD }} + MAIN_ANDROID_GOOGLE_SERVER_CLIENT_ID_UAT: ${{ secrets.MAIN_ANDROID_GOOGLE_SERVER_CLIENT_ID_UAT }} + FLASK_ANDROID_GOOGLE_SERVER_CLIENT_ID_PROD: ${{ secrets.FLASK_ANDROID_GOOGLE_SERVER_CLIENT_ID_PROD }} + GOOGLE_SERVICES_B64_IOS: ${{ secrets.GOOGLE_SERVICES_B64_IOS }} + GOOGLE_SERVICES_B64_ANDROID: ${{ secrets.GOOGLE_SERVICES_B64_ANDROID }} + MM_INFURA_PROJECT_ID: ${{ secrets.MM_INFURA_PROJECT_ID }} + + - name: Upload Repacked iOS Flask App + uses: actions/upload-artifact@v4 + with: + name: flask-e2e-MetaMask.app + path: ios/build/Build/Products/Release-iphonesimulator/MetaMask.app + retention-days: 1 + + flask-ios-smoke: + if: contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests') + needs: [repack-ios-flask-apps] + strategy: + matrix: + split: [1, 2, 3] + fail-fast: false + uses: ./.github/workflows/run-e2e-workflow.yml + with: + test-suite-name: flask-ios-smoke-${{ matrix.split }} + platform: ios + test_suite_tag: 'FlaskBuildTests' + split_number: ${{ matrix.split }} + total_splits: 3 + changed_files: ${{ inputs.changed_files }} + build_type: 'flask' + metamask_environment: 'e2e' + secrets: inherit + + report-ios-smoke-tests: + name: Report iOS Flask Smoke Tests + runs-on: ubuntu-latest + if: ${{ !cancelled() && (contains(fromJson(inputs.selected_tags), 'ALL') || contains(fromJson(inputs.selected_tags), 'FlaskBuildTests')) }} + needs: + - flask-ios-smoke + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version-file: '.nvmrc' + + - name: Download shards test artifacts (XMLs + Screenshots) + uses: actions/download-artifact@v4 + continue-on-error: true + with: + path: all-test-artifacts/ + pattern: 'test-e2e-*-ios-*' + + - name: Post Test Report + uses: dorny/test-reporter@dc3a92680fcc15842eef52e8c4606ea7ce6bd3f3 + with: + name: 'iOS Flask E2E Smoke Test Results' + path: 'all-test-artifacts/**/junit.xml' + reporter: 'jest-junit' + fail-on-error: false + list-suites: 'failed' + list-tests: 'failed' + + - name: Upload all test artifacts (XMLs + Screenshots) + uses: actions/upload-artifact@v4 + with: + name: e2e-smoke-ios-flask-all-test-artifacts + path: all-test-artifacts/ diff --git a/.github/workflows/run-e2e-smoke-tests-ios.yml b/.github/workflows/run-e2e-smoke-tests-ios.yml index 4bff0ac6607..841334841cf 100644 --- a/.github/workflows/run-e2e-smoke-tests-ios.yml +++ b/.github/workflows/run-e2e-smoke-tests-ios.yml @@ -33,6 +33,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 4 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit trade-ios-smoke: @@ -49,6 +51,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit perps-ios-smoke: @@ -65,6 +69,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit wallet-platform-ios-smoke: @@ -81,6 +87,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 2 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit identity-ios-smoke: @@ -97,6 +105,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 2 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit accounts-ios-smoke: @@ -113,6 +123,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit network-abstraction-ios-smoke: @@ -129,6 +141,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 2 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit network-expansion-ios-smoke: @@ -145,6 +159,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 2 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit prediction-market-ios-smoke: @@ -161,6 +177,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit card-ios-smoke: @@ -177,6 +195,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit ramps-ios-smoke: @@ -193,6 +213,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit multichain-api-ios-smoke: @@ -209,6 +231,8 @@ jobs: split_number: ${{ matrix.split }} total_splits: 1 changed_files: ${{ inputs.changed_files }} + build_type: 'main' + metamask_environment: 'qa' secrets: inherit report-ios-smoke-tests: diff --git a/.github/workflows/run-e2e-workflow.yml b/.github/workflows/run-e2e-workflow.yml index d973d640e5c..b09e35fb3e0 100644 --- a/.github/workflows/run-e2e-workflow.yml +++ b/.github/workflows/run-e2e-workflow.yml @@ -58,7 +58,7 @@ jobs: test-apk-target-path: ${{ steps.determine-target-paths.outputs.test-apk-target-path }} env: - PREBUILT_IOS_APP_PATH: artifacts/MetaMask.app + PREBUILT_IOS_APP_PATH: artifacts/${{ inputs.build_type }}-${{ inputs.metamask_environment }}-MetaMask.app METAMASK_ENVIRONMENT: ${{ inputs.metamask_environment }} METAMASK_BUILD_TYPE: ${{ inputs.build_type }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/android/app/build.gradle b/android/app/build.gradle index 38bfba42ffd..a206ec5276d 100644 --- a/android/app/build.gradle +++ b/android/app/build.gradle @@ -188,7 +188,7 @@ android { minSdkVersion rootProject.ext.minSdkVersion targetSdkVersion rootProject.ext.targetSdkVersion versionName "7.65.0" - versionCode 3418 + versionCode 3607 testBuildType System.getProperty('testBuildType', 'debug') testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" manifestPlaceholders.MM_BRANCH_KEY_TEST = "$System.env.MM_BRANCH_KEY_TEST" diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index bc632e82252..501368a500c 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -184,5 +184,12 @@ android:resource="@xml/filepaths" /> + + + + diff --git a/app.config.js b/app.config.js index de9362b11e2..f7baabdd8ed 100644 --- a/app.config.js +++ b/app.config.js @@ -89,7 +89,10 @@ module.exports = { : 'io.metamask', // Required for @expo/repack-app Android repacking }, ios: { - bundleIdentifier: 'io.metamask.MetaMask', + bundleIdentifier: + process.env.METAMASK_BUILD_TYPE === 'flask' + ? 'io.metamask.MetaMask-Flask' + : 'io.metamask.MetaMask', // Required for @expo/repack-app iOS repacking usesAppleSignIn: true, jsEngine: 'hermes', }, diff --git a/app/components/Base/RemoteImage/index.test.tsx b/app/components/Base/RemoteImage/index.test.tsx index adbbc29b18d..64eb4941b75 100644 --- a/app/components/Base/RemoteImage/index.test.tsx +++ b/app/components/Base/RemoteImage/index.test.tsx @@ -95,7 +95,9 @@ describe('RemoteImage', () => { // Wait for IPFS URL resolution }); - expect(wrapper).toMatchSnapshot(); + await waitFor(() => { + expect(wrapper).toMatchSnapshot(); + }); }); it('renders with Solana network badge when on Solana network', async () => { @@ -129,7 +131,9 @@ describe('RemoteImage', () => { // Wait for component to render }); - expect(wrapper).toMatchSnapshot(); + await waitFor(() => { + expect(wrapper).toMatchSnapshot(); + }); }); describe('Error State Reset', () => { @@ -154,7 +158,9 @@ describe('RemoteImage', () => { jest.runAllTimers(); }); - expect(queryByTestId('remote-image')).toBeOnTheScreen(); + await waitFor(() => { + expect(queryByTestId('remote-image')).toBeOnTheScreen(); + }); await act(async () => { rerender( @@ -166,7 +172,9 @@ describe('RemoteImage', () => { jest.runAllTimers(); }); - expect(queryByTestId('remote-image')).toBeOnTheScreen(); + await waitFor(() => { + expect(queryByTestId('remote-image')).toBeOnTheScreen(); + }); }); it('renders Identicon when address is provided', async () => { @@ -182,7 +190,9 @@ describe('RemoteImage', () => { jest.runAllTimers(); }); - expect(queryByTestId('remote-image')).toBeOnTheScreen(); + await waitFor(() => { + expect(queryByTestId('remote-image')).toBeOnTheScreen(); + }); }); it('renders new image after source changes', async () => { @@ -197,7 +207,9 @@ describe('RemoteImage', () => { jest.runAllTimers(); }); - expect(queryByTestId('remote-image-1')).toBeOnTheScreen(); + await waitFor(() => { + expect(queryByTestId('remote-image-1')).toBeOnTheScreen(); + }); await act(async () => { rerender( @@ -209,7 +221,9 @@ describe('RemoteImage', () => { jest.runAllTimers(); }); - expect(queryByTestId('remote-image-2')).toBeOnTheScreen(); + await waitFor(() => { + expect(queryByTestId('remote-image-2')).toBeOnTheScreen(); + }); }); }); @@ -228,8 +242,10 @@ describe('RemoteImage', () => { image.props.onError({ error: 'Failed to load image' }); }); - const identicon = await findByTestId('identicon'); - expect(identicon).toBeOnTheScreen(); + await waitFor(async () => { + const identicon = await findByTestId('identicon'); + expect(identicon).toBeOnTheScreen(); + }); }); it('calls onError callback when image fails to load', async () => { @@ -247,7 +263,9 @@ describe('RemoteImage', () => { image.props.onError({ error: 'Failed to load image' }); }); - expect(mockOnError).toHaveBeenCalledTimes(1); + await waitFor(() => { + expect(mockOnError).toHaveBeenCalledTimes(1); + }); }); it('resets error state when source URI changes', async () => { @@ -265,8 +283,10 @@ describe('RemoteImage', () => { }); // After error, Identicon should be rendered - const identicon = await findByTestId('identicon'); - expect(identicon).toBeOnTheScreen(); + await waitFor(async () => { + const identicon = await findByTestId('identicon'); + expect(identicon).toBeOnTheScreen(); + }); await act(async () => { rerender( @@ -278,9 +298,11 @@ describe('RemoteImage', () => { }); // After source change, error should be reset and Image should render - expect(queryByTestId('identicon')).not.toBeOnTheScreen(); - const image = UNSAFE_getByType(Image); - expect(image).toBeDefined(); + await waitFor(() => { + expect(queryByTestId('identicon')).not.toBeOnTheScreen(); + const image = UNSAFE_getByType(Image); + expect(image).toBeDefined(); + }); }); }); @@ -296,18 +318,14 @@ describe('RemoteImage', () => { ); await waitFor(() => { - expect(mockGetFormattedIpfsUrl).toHaveBeenCalled(); + expect(mockGetFormattedIpfsUrl).toHaveBeenCalledWith( + expect.any(String), + ipfsUri, + false, + ); + const image = UNSAFE_getByType(Image); + expect(image.props.source.uri).toBe(resolvedUrl); }); - - // Verify the function was called with the IPFS URI - expect(mockGetFormattedIpfsUrl).toHaveBeenCalledWith( - expect.any(String), - ipfsUri, - false, - ); - - const image = UNSAFE_getByType(Image); - expect(image.props.source.uri).toBe(resolvedUrl); }); it('handles IPFS URL resolution failure', async () => { @@ -330,8 +348,10 @@ describe('RemoteImage', () => { // Wait for component to render }); - const image = UNSAFE_getByType(Image); - expect(image.props.source.uri).toBe(''); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.source.uri).toBe(''); + }); }); }); @@ -339,6 +359,7 @@ describe('RemoteImage', () => { let dimensionsSpy: jest.SpyInstance; beforeEach(() => { + jest.useFakeTimers(); dimensionsSpy = jest.spyOn(Dimensions, 'get').mockReturnValue({ width: 400, height: 800, @@ -348,6 +369,9 @@ describe('RemoteImage', () => { }); afterEach(() => { + jest.runOnlyPendingTimers(); + jest.useRealTimers(); + jest.restoreAllMocks(); dimensionsSpy.mockRestore(); }); @@ -361,6 +385,9 @@ describe('RemoteImage', () => { />, ); + // Clear any pending timers from the mock's automatic onLoad + jest.clearAllTimers(); + await act(async () => { const image = UNSAFE_getByType(Image); image.props.onLoad({ @@ -368,9 +395,11 @@ describe('RemoteImage', () => { }); }); - const image = UNSAFE_getByType(Image); - expect(image.props.style.width).toBe(368); - expect(image.props.style.height).toBe(184); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(368); + expect(image.props.style.height).toBe(184); + }); }); it('calculates dimensions for vertical image', async () => { @@ -383,6 +412,9 @@ describe('RemoteImage', () => { />, ); + // Clear any pending timers from the mock's automatic onLoad + jest.clearAllTimers(); + await act(async () => { const image = UNSAFE_getByType(Image); image.props.onLoad({ @@ -390,9 +422,11 @@ describe('RemoteImage', () => { }); }); - const image = UNSAFE_getByType(Image); - expect(image.props.style.width).toBe(138); - expect(image.props.style.height).toBe(276); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(138); + expect(image.props.style.height).toBe(276); + }); }); it('calculates dimensions for square image', async () => { @@ -405,6 +439,9 @@ describe('RemoteImage', () => { />, ); + // Clear any pending timers from the mock's automatic onLoad + jest.clearAllTimers(); + await act(async () => { const image = UNSAFE_getByType(Image); image.props.onLoad({ @@ -412,9 +449,11 @@ describe('RemoteImage', () => { }); }); - const image = UNSAFE_getByType(Image); - expect(image.props.style.width).toBe(276); - expect(image.props.style.height).toBe(276); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(276); + expect(image.props.style.height).toBe(276); + }); }); it('does not update dimensions when they remain the same', async () => { @@ -427,6 +466,9 @@ describe('RemoteImage', () => { />, ); + // Clear any pending timers from the mock's automatic onLoad + jest.clearAllTimers(); + await act(async () => { const image = UNSAFE_getByType(Image); image.props.onLoad({ @@ -434,6 +476,11 @@ describe('RemoteImage', () => { }); }); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(276); + }); + const firstImage = UNSAFE_getByType(Image); const firstStyle = firstImage.props.style; @@ -444,11 +491,16 @@ describe('RemoteImage', () => { }); }); - const secondImage = UNSAFE_getByType(Image); - const secondStyle = secondImage.props.style; + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(276); + + const secondImage = UNSAFE_getByType(Image); + const secondStyle = secondImage.props.style; - expect(firstStyle.width).toBe(secondStyle.width); - expect(firstStyle.height).toBe(secondStyle.height); + expect(firstStyle.width).toBe(secondStyle.width); + expect(firstStyle.height).toBe(secondStyle.height); + }); }); it('handles onLoad without width and height', async () => { @@ -461,17 +513,26 @@ describe('RemoteImage', () => { />, ); + // Clear any pending timers from the mock's automatic onLoad + jest.clearAllTimers(); + await act(async () => { const image = UNSAFE_getByType(Image); image.props.onLoad({ source: {} }); }); - const image = UNSAFE_getByType(Image); - expect(image).toBeDefined(); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image).toBeDefined(); + }); }); }); describe('Rendering Modes', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + it('renders default image without fadeIn', () => { const { UNSAFE_getByType } = render( , @@ -496,9 +557,11 @@ describe('RemoteImage', () => { // Wait for component to render }); - const image = UNSAFE_getByType(Image); - expect(image).toBeDefined(); - expect(image.props.source.uri).toBe('https://example.com/image.png'); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image).toBeDefined(); + expect(image.props.source.uri).toBe('https://example.com/image.png'); + }); }); it('renders token image without full ratio', async () => { @@ -516,9 +579,11 @@ describe('RemoteImage', () => { // Wait for component to render }); - const image = UNSAFE_getByType(Image); - expect(image).toBeDefined(); - expect(image.props.source.uri).toBe('https://example.com/token.png'); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image).toBeDefined(); + expect(image.props.source.uri).toBe('https://example.com/token.png'); + }); }); it('renders token image with full ratio and dimensions', async () => { @@ -545,9 +610,11 @@ describe('RemoteImage', () => { }); }); - const image = UNSAFE_getByType(Image); - expect(image.props.style.width).toBe(368); - expect(image.props.style.height).toBeCloseTo(245.33, 1); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image.props.style.width).toBe(368); + expect(image.props.style.height).toBeCloseTo(245.33, 1); + }); }); it('renders token image with chainId prop', async () => { @@ -564,9 +631,11 @@ describe('RemoteImage', () => { // Wait for component to render }); - const image = UNSAFE_getByType(Image); - expect(image).toBeDefined(); - expect(image.props.source.uri).toBe('https://example.com/token.png'); + await waitFor(() => { + const image = UNSAFE_getByType(Image); + expect(image).toBeDefined(); + expect(image.props.source.uri).toBe('https://example.com/token.png'); + }); }); }); }); diff --git a/app/components/UI/Bridge/Views/BridgeView/index.tsx b/app/components/UI/Bridge/Views/BridgeView/index.tsx index ffad90dad16..0ed5d0b315e 100644 --- a/app/components/UI/Bridge/Views/BridgeView/index.tsx +++ b/app/components/UI/Bridge/Views/BridgeView/index.tsx @@ -600,6 +600,7 @@ const BridgeView = () => { token={sourceToken} tokenBalance={latestSourceBalance} onMaxPress={handleSourceMaxPress} + isQuoteSponsored={isQuoteSponsored} /> ) : null} diff --git a/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx b/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx index 596043042cb..51cfa9b48e4 100644 --- a/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx +++ b/app/components/UI/Bridge/components/SwapsKeypad/index.test.tsx @@ -5,34 +5,17 @@ import { Keys } from '../../../../Base/Keypad'; import { BridgeToken } from '../../types'; import { CHAIN_IDS } from '@metamask/transaction-controller'; import { BigNumber } from 'ethers'; -import { useSelector } from 'react-redux'; -import { useTokenAddress } from '../../hooks/useTokenAddress'; -import { isNativeAddress } from '@metamask/bridge-controller'; // Mock dependencies -jest.mock('react-redux', () => ({ - useSelector: jest.fn(), +jest.mock('../../hooks/useShouldRenderMaxOption', () => ({ + useShouldRenderMaxOption: jest.fn(() => true), })); -jest.mock('../../hooks/useTokenAddress', () => ({ - useTokenAddress: jest.fn(), -})); - -jest.mock('@metamask/bridge-controller', () => ({ - isNativeAddress: jest.fn(), -})); - -jest.mock('../../../../../core/redux/slices/bridge', () => ({ - selectIsGaslessSwapEnabled: jest.fn(), -})); - -const mockUseSelector = useSelector as jest.MockedFunction; -const mockUseTokenAddress = useTokenAddress as jest.MockedFunction< - typeof useTokenAddress ->; -const mockIsNativeAddress = isNativeAddress as jest.MockedFunction< - typeof isNativeAddress ->; +import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption'; +const mockUseShouldRenderMaxOption = + useShouldRenderMaxOption as jest.MockedFunction< + typeof useShouldRenderMaxOption + >; describe('SwapsKeypad', () => { const mockOnChange = jest.fn(); @@ -51,10 +34,8 @@ describe('SwapsKeypad', () => { }; beforeEach(() => { - jest.clearAllMocks(); - mockUseSelector.mockReturnValue(false); - mockUseTokenAddress.mockReturnValue(mockToken.address); - mockIsNativeAddress.mockReturnValue(false); + mockUseShouldRenderMaxOption.mockReset(); + mockUseShouldRenderMaxOption.mockReturnValue(true); }); afterEach(() => { @@ -119,8 +100,8 @@ describe('SwapsKeypad', () => { expect(queryByText('Max')).toBeNull(); }); - it('renders Max button for gasless swap enabled', () => { - mockUseSelector.mockReturnValue(true); + it('renders Max button when useShouldRenderMaxOption returns true', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); const { getByText, queryByText } = render( { expect(getByText('Max')).toBeTruthy(); expect(queryByText('90%')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + undefined, + ); }); - it('renders Max button for non-native token', () => { - mockIsNativeAddress.mockReturnValue(false); + it('renders 90% button when useShouldRenderMaxOption returns false', () => { + mockUseShouldRenderMaxOption.mockReturnValue(false); const { getByText, queryByText } = render( { />, ); - expect(getByText('Max')).toBeTruthy(); - expect(queryByText('90%')).toBeNull(); + expect(getByText('90%')).toBeTruthy(); + expect(queryByText('Max')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + undefined, + ); }); - it('renders 90% button for native token without gasless swap', () => { - mockIsNativeAddress.mockReturnValue(true); - mockUseSelector.mockImplementation(() => false); + it('passes isQuoteSponsored to useShouldRenderMaxOption', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); - const { getByText, queryByText } = render( + const { getByText } = render( { token={mockToken} tokenBalance={mockTokenBalance} onMaxPress={mockOnMaxPress} + isQuoteSponsored />, ); - expect(getByText('90%')).toBeTruthy(); - expect(queryByText('Max')).toBeNull(); + expect(getByText('Max')).toBeTruthy(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + true, + ); }); }); @@ -280,7 +275,7 @@ describe('SwapsKeypad', () => { describe('Max button functionality', () => { it('calls onMaxPress when Max button is clicked', () => { - mockIsNativeAddress.mockReturnValue(false); + mockUseShouldRenderMaxOption.mockReturnValue(true); const { getByText } = render( { expect(mockOnMaxPress).toHaveBeenCalledTimes(1); }); - - it('renders Max button when gasless swap is enabled for native token', () => { - mockIsNativeAddress.mockReturnValue(true); - mockUseSelector.mockImplementation((selector) => { - if (typeof selector === 'function') { - return true; - } - return undefined; - }); - - const { getByText } = render( - , - ); - - expect(getByText('Max')).toBeTruthy(); - }); }); describe('edge cases', () => { @@ -463,12 +434,11 @@ describe('SwapsKeypad', () => { }); }); - describe('quick pick button selection logic', () => { - it('selects gasless quick pick options when not native asset', () => { - mockIsNativeAddress.mockReturnValue(false); - mockUseSelector.mockImplementation(() => false); + describe('Quick pick options with useShouldRenderMaxOption hook', () => { + it('shows Max button when useShouldRenderMaxOption returns true', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); - const { getByText } = render( + const { getByText, queryByText } = render( { ); expect(getByText('Max')).toBeTruthy(); + expect(queryByText('90%')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + undefined, + ); }); - it('selects standard quick pick options when native asset and gasless disabled', () => { - mockIsNativeAddress.mockReturnValue(true); - mockUseSelector.mockImplementation(() => false); + it('shows 90% button when useShouldRenderMaxOption returns false', () => { + mockUseShouldRenderMaxOption.mockReturnValue(false); - const { getByText } = render( + const { getByText, queryByText } = render( { ); expect(getByText('90%')).toBeTruthy(); + expect(queryByText('Max')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + undefined, + ); }); - it('selects gasless quick pick options when native asset but gasless enabled', () => { - mockIsNativeAddress.mockReturnValue(true); - mockUseSelector.mockImplementation((selector) => { - if (typeof selector === 'function') { - return true; - } - return undefined; - }); + it('passes isQuoteSponsored to useShouldRenderMaxOption', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); const { getByText } = render( { token={mockToken} tokenBalance={mockTokenBalance} onMaxPress={mockOnMaxPress} + isQuoteSponsored />, ); expect(getByText('Max')).toBeTruthy(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + mockTokenBalance.displayBalance, + true, + ); }); - }); - describe('token address handling', () => { - it('uses correct token address from useTokenAddress hook', () => { - const customAddress = '0xabcdef1234567890abcdef1234567890abcdef12'; - mockUseTokenAddress.mockReturnValue(customAddress); + it('hides quick pick buttons when displayBalance is zero', () => { + const zeroBalance = { + displayBalance: '0', + atomicBalance: BigNumber.from('0'), + }; + mockUseShouldRenderMaxOption.mockReturnValue(false); - render( + const { queryByText } = render( , ); - expect(mockUseTokenAddress).toHaveBeenCalledWith(mockToken); - expect(mockIsNativeAddress).toHaveBeenCalledWith(customAddress); + expect(queryByText('25%')).toBeNull(); + expect(queryByText('50%')).toBeNull(); + expect(queryByText('75%')).toBeNull(); + expect(queryByText('Max')).toBeNull(); + expect(queryByText('90%')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + mockToken, + zeroBalance.displayBalance, + undefined, + ); }); - it('handles token address changes correctly', () => { - const { rerender } = render( + it('quick pick buttons calculate correct percentages with Max button', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); + + const { getByText } = render( { />, ); - const newToken = { ...mockToken, address: '0xnewaddress' }; - mockUseTokenAddress.mockReturnValue(newToken.address); + // Test 25% button + act(() => { + fireEvent.press(getByText('25%')); + }); - rerender( + expect(mockOnChange).toHaveBeenCalledWith( + expect.objectContaining({ + value: '25.125', // 25% of 100.5 + valueAsNumber: 25.125, + pressedKey: Keys.Initial, + }), + ); + + // Test 50% button + act(() => { + fireEvent.press(getByText('50%')); + }); + + expect(mockOnChange).toHaveBeenCalledWith( + expect.objectContaining({ + value: '50.25', // 50% of 100.5 + valueAsNumber: 50.25, + pressedKey: Keys.Initial, + }), + ); + + // Test 75% button + act(() => { + fireEvent.press(getByText('75%')); + }); + + expect(mockOnChange).toHaveBeenCalledWith( + expect.objectContaining({ + value: '75.375', // 75% of 100.5 + valueAsNumber: 75.375, + pressedKey: Keys.Initial, + }), + ); + + // Test Max button calls onMaxPress + act(() => { + fireEvent.press(getByText('Max')); + }); + + expect(mockOnMaxPress).toHaveBeenCalledTimes(1); + }); + + it('quick pick buttons calculate correct percentages with 90% button', () => { + mockUseShouldRenderMaxOption.mockReturnValue(false); + + const { getByText } = render( , ); - expect(mockUseTokenAddress).toHaveBeenCalledWith(newToken); + // Test 90% button + act(() => { + fireEvent.press(getByText('90%')); + }); + + expect(mockOnChange).toHaveBeenCalledWith( + expect.objectContaining({ + value: '90.45', // 90% of 100.5 + valueAsNumber: 90.45, + pressedKey: Keys.Initial, + }), + ); }); }); }); diff --git a/app/components/UI/Bridge/components/SwapsKeypad/index.tsx b/app/components/UI/Bridge/components/SwapsKeypad/index.tsx index 1dfc0afb457..bc2b9aa28c6 100644 --- a/app/components/UI/Bridge/components/SwapsKeypad/index.tsx +++ b/app/components/UI/Bridge/components/SwapsKeypad/index.tsx @@ -3,15 +3,11 @@ import Keypad, { KeypadChangeData, Keys } from '../../../../Base/Keypad'; import { Box } from '../../../Box/Box'; import { swapsKeypadStyles as styles } from './styles'; import { QuickPickButtons } from './QuickPickButtons'; -import { useSelector } from 'react-redux'; -import { RootState } from '../../../../../reducers'; import { BridgeToken } from '../../types'; -import { selectIsGaslessSwapEnabled } from '../../../../../core/redux/slices/bridge'; import { QuickPickButtonOption } from './types'; -import { useTokenAddress } from '../../hooks/useTokenAddress'; -import { isNativeAddress } from '@metamask/bridge-controller'; import { useLatestBalance } from '../../hooks/useLatestBalance'; import { BigNumber } from 'bignumber.js'; +import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption'; interface SwapsKeypadProps { value: string; @@ -21,6 +17,7 @@ interface SwapsKeypadProps { token?: BridgeToken; tokenBalance: ReturnType; onMaxPress: () => void; + isQuoteSponsored?: boolean; } export const SwapsKeypad = ({ @@ -31,17 +28,8 @@ export const SwapsKeypad = ({ token, tokenBalance, onMaxPress, + isQuoteSponsored, }: SwapsKeypadProps) => { - const tokenAddress = useTokenAddress(token); - const isNativeAsset = useMemo( - () => isNativeAddress(tokenAddress), - [tokenAddress], - ); - - const isGaslessSwapEnabled = useSelector((state: RootState) => - token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false, - ); - const onQuickOptionPress = useCallback( (percentage: number) => () => { if (!tokenBalance?.displayBalance) return '0'; @@ -105,14 +93,18 @@ export const SwapsKeypad = ({ ); const shouldHideQuickPickOptions = useMemo( - () => new BigNumber(tokenBalance?.displayBalance ?? 0).eq(0), + () => new BigNumber(tokenBalance?.displayBalance || 0).eq(0), [tokenBalance], ); - const quickPickOptions = - !isNativeAsset || isGaslessSwapEnabled - ? gasslessQuickPickOptions - : standardQuickPickOptions; + const shouldRenderMaxOption = useShouldRenderMaxOption( + token, + tokenBalance?.displayBalance, + isQuoteSponsored, + ); + const quickPickOptions = shouldRenderMaxOption + ? gasslessQuickPickOptions + : standardQuickPickOptions; return ( diff --git a/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx b/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx index d463fd7a2f4..a2405b9e6e8 100644 --- a/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx +++ b/app/components/UI/Bridge/components/TokenInputArea/TokenInputArea.test.tsx @@ -11,6 +11,16 @@ jest.mock('../../hooks/useLatestBalance', () => ({ useLatestBalance: jest.fn(), })); +jest.mock('../../hooks/useShouldRenderMaxOption', () => ({ + useShouldRenderMaxOption: jest.fn(() => true), +})); + +import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption'; +const mockUseShouldRenderMaxOption = + useShouldRenderMaxOption as jest.MockedFunction< + typeof useShouldRenderMaxOption + >; + const mockOnTokenPress = jest.fn(); const mockOnFocus = jest.fn(); const mockOnBlur = jest.fn(); @@ -20,6 +30,7 @@ const mockOnMaxPress = jest.fn(); describe('TokenInputArea', () => { beforeEach(() => { jest.clearAllMocks(); + mockUseShouldRenderMaxOption.mockReturnValue(true); }); it('renders with initial state', () => { @@ -174,6 +185,9 @@ describe('TokenInputArea', () => { }, }; + // Mock hook to return false since gasless is disabled for native token + mockUseShouldRenderMaxOption.mockReturnValue(false); + const { queryByText } = renderScreen( () => ( { }, }; + // Mock hook to return false since gasless is disabled for native Polygon token + mockUseShouldRenderMaxOption.mockReturnValue(false); + const { queryByText } = renderScreen( () => ( { // Native tokens show Max button when gasless swaps are enabled expect(getByText('Max')).toBeTruthy(); }); + + describe('Max button visibility with useShouldRenderMaxOption hook', () => { + const nativeToken: BridgeToken = { + address: '0x0000000000000000000000000000000000000000', + symbol: 'ETH', + decimals: 18, + chainId: '0x1' as `0x${string}`, + }; + const tokenBalance = '1.5'; + + beforeEach(() => { + mockUseShouldRenderMaxOption.mockReturnValue(true); + }); + + afterEach(() => { + mockUseShouldRenderMaxOption.mockReturnValue(true); + }); + + it('does not display max button when useShouldRenderMaxOption returns false', () => { + mockUseShouldRenderMaxOption.mockReturnValue(false); + + const { queryByText } = renderScreen( + () => ( + + ), + { + name: 'TokenInputArea', + }, + { state: initialState }, + ); + + expect(queryByText('Max')).toBeNull(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + nativeToken, + tokenBalance, + false, + ); + }); + + it('displays max button when useShouldRenderMaxOption returns true', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); + + const { getByText } = renderScreen( + () => ( + + ), + { + name: 'TokenInputArea', + }, + { state: initialState }, + ); + + expect(getByText('Max')).toBeTruthy(); + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + nativeToken, + tokenBalance, + false, + ); + }); + + it('passes isQuoteSponsored to useShouldRenderMaxOption', () => { + mockUseShouldRenderMaxOption.mockReturnValue(true); + + renderScreen( + () => ( + + ), + { + name: 'TokenInputArea', + }, + { state: initialState }, + ); + + expect(mockUseShouldRenderMaxOption).toHaveBeenCalledWith( + nativeToken, + tokenBalance, + true, + ); + }); + + it('does not display max button for destination token', () => { + const { queryByText } = renderScreen( + () => ( + + ), + { + name: 'TokenInputArea', + }, + { state: initialState }, + ); + + // Destination tokens never show max button + expect(queryByText('Max')).toBeNull(); + }); + }); }); describe('getDisplayAmount', () => { diff --git a/app/components/UI/Bridge/components/TokenInputArea/index.tsx b/app/components/UI/Bridge/components/TokenInputArea/index.tsx index c2cd91d30f2..209559b6ae0 100644 --- a/app/components/UI/Bridge/components/TokenInputArea/index.tsx +++ b/app/components/UI/Bridge/components/TokenInputArea/index.tsx @@ -1,4 +1,4 @@ -import React, { forwardRef, useImperativeHandle, useRef, useMemo } from 'react'; +import React, { forwardRef, useImperativeHandle, useRef } from 'react'; import { StyleSheet, ImageSourcePropType, @@ -32,9 +32,7 @@ import { useNavigation } from '@react-navigation/native'; import { setDestTokenExchangeRate, setSourceTokenExchangeRate, - selectIsGaslessSwapEnabled, } from '../../../../../core/redux/slices/bridge'; -import { RootState } from '../../../../../reducers'; ///: BEGIN:ONLY_INCLUDE_IF(keyring-snaps) import { selectMultichainAssetsRates } from '../../../../../selectors/multichain'; ///: END:ONLY_INCLUDE_IF(keyring-snaps) @@ -48,6 +46,7 @@ import { isNativeAddress } from '@metamask/bridge-controller'; import { Theme } from '../../../../../util/theme/models'; import parseAmount from '../../../../../util/parseAmount'; import { useTokenAddress } from '../../hooks/useTokenAddress'; +import { useShouldRenderMaxOption } from '../../hooks/useShouldRenderMaxOption'; import { calculateInputFontSize } from '../../utils/calculateInputFontSize'; import { formatAmountWithLocaleSeparators } from '../../utils/formatAmountWithLocaleSeparators'; @@ -184,10 +183,6 @@ export const TokenInputArea = forwardRef< ) => { const currentCurrency = useSelector(selectCurrentCurrency); - const isGaslessSwapEnabled = useSelector((state: RootState) => - token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false, - ); - // Need to fetch the exchange rate for the token if we don't have it already useBridgeExchangeRates({ token, @@ -266,11 +261,12 @@ export const TokenInputArea = forwardRef< const isNativeAsset = isNativeAddress(tokenAddress); - // Show max button for native tokens if gasless swap is enabled OR quote is sponsored - const shouldShowMaxButton = useMemo(() => { - if (!isNativeAsset) return true; // Always show for non-native tokens - return isGaslessSwapEnabled || isQuoteSponsored; - }, [isNativeAsset, isGaslessSwapEnabled, isQuoteSponsored]); + const shouldShowMaxButton = useShouldRenderMaxOption( + token, + tokenBalance, + isQuoteSponsored, + ); + const formattedAddress = tokenAddress && !isNativeAsset ? formatAddress(tokenAddress) : undefined; @@ -359,7 +355,6 @@ export const TokenInputArea = forwardRef< { + const isGaslessSwapEnabled = useSelector((state: RootState) => + token?.chainId ? selectIsGaslessSwapEnabled(state, token.chainId) : false, + ); + const stxEnabled = useSelector((state: RootState) => + token?.chainId && !isNonEvmChainId(token.chainId) + ? selectShouldUseSmartTransaction( + state, + formatChainIdToHex(token.chainId), + ) + : false, + ); + const tokenAddress = useTokenAddress(token); + const isNativeAsset = isNativeAddress(tokenAddress); + const isZeroDisplayBalance = new BigNumber(displayBalance || 0).eq(0); + + // Do not render on zero balance or undefined token + if (isZeroDisplayBalance || !token) { + return false; + } + + // Always show for non-native tokens + if (!isNativeAsset) { + return true; + } + + // Show for EVM native tokens if gasless swap is enabled OR quote is sponsored + // while smart transactions is enabled. + // For non-EVM native tokens stxEnabled will be false evaluating the whole + // expression to false. We do not know the fees beforehand so we cannot + // max out the input amount. + return (isGaslessSwapEnabled || isQuoteSponsored) && stxEnabled; +}; diff --git a/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts b/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts new file mode 100644 index 00000000000..68d0c3cd747 --- /dev/null +++ b/app/components/UI/Bridge/hooks/useShouldRenderMaxOption/useShouldRenderMaxOption.test.ts @@ -0,0 +1,594 @@ +import { renderHook } from '@testing-library/react-hooks'; +import { useShouldRenderMaxOption } from '.'; +import { BridgeToken } from '../../types'; +import { CHAIN_IDS } from '@metamask/transaction-controller'; +import { useSelector } from 'react-redux'; +import { useTokenAddress } from '../useTokenAddress'; +import { isNativeAddress } from '@metamask/bridge-controller'; + +// Mock dependencies +jest.mock('react-redux', () => ({ + useSelector: jest.fn(), +})); + +jest.mock('../useTokenAddress', () => ({ + useTokenAddress: jest.fn(), +})); + +jest.mock('@metamask/bridge-controller', () => ({ + isNativeAddress: jest.fn(), +})); + +jest.mock('../../../../../core/redux/slices/bridge', () => ({ + selectIsGaslessSwapEnabled: jest.fn(), +})); + +jest.mock('../../../../../selectors/smartTransactionsController', () => ({ + selectShouldUseSmartTransaction: jest.fn(), +})); + +const mockUseSelector = useSelector as jest.MockedFunction; +const mockUseTokenAddress = useTokenAddress as jest.MockedFunction< + typeof useTokenAddress +>; +const mockIsNativeAddress = isNativeAddress as jest.MockedFunction< + typeof isNativeAddress +>; + +/** + * IMPORTANT: useSelector call order in the hook: + * 1. First call: isGaslessSwapEnabled (line 12 in hook) + * 2. Second call: stxEnabled (line 15 in hook) + */ +describe('useShouldRenderMaxOption', () => { + const mockToken: BridgeToken = { + address: '0x1234567890123456789012345678901234567890', + symbol: 'TEST', + decimals: 18, + chainId: CHAIN_IDS.MAINNET, + }; + + const nativeToken: BridgeToken = { + address: '0x0000000000000000000000000000000000000000', + symbol: 'ETH', + decimals: 18, + chainId: CHAIN_IDS.MAINNET, + }; + + beforeEach(() => { + jest.clearAllMocks(); + mockUseTokenAddress.mockReturnValue(mockToken.address); + mockIsNativeAddress.mockReturnValue(false); + // Default: isGaslessSwapEnabled = false, stxEnabled = true + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = false + // Second call: stxEnabled = true + return callCount === 2; + }); + }); + + describe('Zero balance scenarios', () => { + it('returns false when displayBalance is undefined', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, undefined, false), + ); + + expect(result.current).toBe(false); + }); + + it('returns false when displayBalance is "0"', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0', false), + ); + + expect(result.current).toBe(false); + }); + + it('returns false when displayBalance is "0.0"', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0.0', false), + ); + + expect(result.current).toBe(false); + }); + + it('returns false when displayBalance is empty string', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '', false), + ); + + expect(result.current).toBe(false); + }); + }); + + describe('Non-native token scenarios', () => { + beforeEach(() => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + }); + + it('returns true for non-native token with balance regardless of gasless', () => { + mockUseSelector.mockImplementation(() => false); // Both selectors false + + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '100.5', false), + ); + + expect(result.current).toBe(true); + }); + + it('returns true for non-native token with balance regardless of stxEnabled', () => { + mockUseSelector.mockImplementation(() => false); // stxEnabled = false + + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '50', false), + ); + + expect(result.current).toBe(true); + }); + + it('returns true for non-native token with balance regardless of isQuoteSponsored', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '25.75', true), + ); + + expect(result.current).toBe(true); + }); + + it('returns true for non-native token with very small balance', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0.000001', false), + ); + + expect(result.current).toBe(true); + }); + + it('returns false for non-native token with zero balance', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0', false), + ); + + expect(result.current).toBe(false); + }); + }); + + describe('Native token scenarios', () => { + beforeEach(() => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + }); + + it('returns false when native token has zero balance', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '0', false), + ); + + expect(result.current).toBe(false); + }); + + it('returns false when native token has zero balance even with all conditions favorable', () => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + mockUseSelector.mockReturnValue(true); // stxEnabled=true, gasless=true + + const { result } = renderHook( + () => useShouldRenderMaxOption(nativeToken, '0', true), // sponsored=true + ); + + // Zero balance always returns false + expect(result.current).toBe(false); + }); + }); + + describe('Edge cases', () => { + it('returns false when token is undefined', () => { + mockUseTokenAddress.mockReturnValue(undefined); + + const { result } = renderHook(() => + useShouldRenderMaxOption(undefined, '100', false), + ); + + expect(result.current).toBe(false); + }); + + it('returns false when token is undefined but has balance', () => { + mockUseTokenAddress.mockReturnValue(undefined); + mockIsNativeAddress.mockReturnValue(false); + + const { result } = renderHook(() => + useShouldRenderMaxOption(undefined, '500', false), + ); + + expect(result.current).toBe(false); + }); + + it('handles large balance values correctly for non-native tokens', () => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '1000000.123456789', false), + ); + + expect(result.current).toBe(true); + }); + + it('handles very small but non-zero balance for non-native tokens', () => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0.000000001', false), + ); + + expect(result.current).toBe(true); + }); + + it('correctly identifies native token without chainId in token object', () => { + const tokenWithoutChainId = { + address: '0x0000000000000000000000000000000000000000', + symbol: 'ETH', + decimals: 18, + } as BridgeToken; + + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(tokenWithoutChainId.address); + mockUseSelector.mockReturnValue(false); // stxEnabled = false, isGaslessSwapEnabled = false + + const { result } = renderHook(() => + useShouldRenderMaxOption(tokenWithoutChainId, '100', false), + ); + + // Should return false (native + no gasless + no sponsored + no stx) + expect(result.current).toBe(false); + }); + }); + + describe('Hook parameter validation', () => { + it('uses useTokenAddress hook to get token address', () => { + const customToken = { + ...mockToken, + address: '0xabcdef1234567890abcdef1234567890abcdef12', + }; + mockUseTokenAddress.mockReturnValue(customToken.address); + + renderHook(() => useShouldRenderMaxOption(customToken, '100', false)); + + expect(mockUseTokenAddress).toHaveBeenCalledWith(customToken); + }); + + it('checks if token address is native using isNativeAddress', () => { + const tokenAddress = '0x1234567890123456789012345678901234567890'; + mockUseTokenAddress.mockReturnValue(tokenAddress); + mockIsNativeAddress.mockReturnValue(false); + + renderHook(() => useShouldRenderMaxOption(mockToken, '100', false)); + + expect(mockIsNativeAddress).toHaveBeenCalledWith(tokenAddress); + }); + + it('calls selectIsGaslessSwapEnabled with correct chainId', () => { + const tokenWithChainId = { + ...mockToken, + chainId: '0xa' as `0x${string}`, // Optimism + }; + mockUseSelector.mockReturnValue(true); + + renderHook(() => + useShouldRenderMaxOption(tokenWithChainId, '100', false), + ); + + // Verify useSelector was called (it uses selectIsGaslessSwapEnabled) + expect(mockUseSelector).toHaveBeenCalled(); + }); + }); + + describe('Default parameter values', () => { + it('uses false as default for isQuoteSponsored when not provided', () => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = false + // Second call: stxEnabled = true + return callCount === 2; + }); + + // Call without isQuoteSponsored parameter + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '100'), + ); + + // Should return false (native + stx=true + gasless=false + sponsored=false) + expect(result.current).toBe(false); + }); + }); + + describe('Integration scenarios', () => { + it('returns correct value for typical non-native ERC20 token', () => { + const usdcToken: BridgeToken = { + address: '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', + symbol: 'USDC', + decimals: 6, + chainId: CHAIN_IDS.MAINNET, + }; + + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(usdcToken.address); + mockUseSelector.mockReturnValue(false); // Everything disabled + + const { result } = renderHook(() => + useShouldRenderMaxOption(usdcToken, '1000', false), + ); + + // Non-native tokens always show max (even with everything disabled) + expect(result.current).toBe(true); + }); + + it('returns correct value for native ETH with gasless swap enabled', () => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + mockUseSelector.mockReturnValue(true); // stxEnabled=true, gasless=true + + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '5.5', false), + ); + + expect(result.current).toBe(true); + }); + + it('returns correct value for native ETH with sponsored quote', () => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = false + // Second call: stxEnabled = true + return callCount === 2; + }); + + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '2.25', true), + ); + + expect(result.current).toBe(true); + }); + + it('returns correct value for native ETH without gasless or sponsored but stx enabled', () => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = false + // Second call: stxEnabled = true + return callCount === 2; + }); + + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '10', false), + ); + + // Should be false (native + stx=true + gasless=false + sponsored=false) + expect(result.current).toBe(false); + }); + }); + + describe('Boundary conditions', () => { + it('handles balance exactly equal to zero', () => { + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0.00000', false), + ); + + expect(result.current).toBe(false); + }); + + it('handles extremely small but non-zero balance', () => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + + const { result } = renderHook(() => + useShouldRenderMaxOption(mockToken, '0.00000001', false), + ); + + expect(result.current).toBe(true); + }); + + it('handles extremely large balance', () => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + + const { result } = renderHook(() => + useShouldRenderMaxOption( + mockToken, + '999999999999999999999999999.999999', + false, + ), + ); + + expect(result.current).toBe(true); + }); + + it('handles balance with many decimal places', () => { + mockIsNativeAddress.mockReturnValue(false); + mockUseTokenAddress.mockReturnValue(mockToken.address); + + const { result } = renderHook(() => + useShouldRenderMaxOption( + mockToken, + '123.456789012345678901234567', + false, + ), + ); + + expect(result.current).toBe(true); + }); + }); + + describe('Truth table - Native token with all combinations', () => { + beforeEach(() => { + mockIsNativeAddress.mockReturnValue(true); + mockUseTokenAddress.mockReturnValue(nativeToken.address); + }); + + const truthTable = [ + { stxEnabled: true, gasless: true, sponsored: false, expected: true }, + { stxEnabled: true, gasless: false, sponsored: true, expected: true }, + { stxEnabled: true, gasless: true, sponsored: true, expected: true }, + { stxEnabled: true, gasless: false, sponsored: false, expected: false }, + { stxEnabled: false, gasless: true, sponsored: false, expected: false }, + { stxEnabled: false, gasless: false, sponsored: true, expected: false }, + { stxEnabled: false, gasless: true, sponsored: true, expected: false }, + { stxEnabled: false, gasless: false, sponsored: false, expected: false }, + ]; + + truthTable.forEach(({ stxEnabled, gasless, sponsored, expected }) => { + it(`stxEnabled=${stxEnabled}, gasless=${gasless}, sponsored=${sponsored} → returns ${expected}`, () => { + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled (line 12 in hook) + // Second call: stxEnabled (line 15 in hook) + if (callCount === 1) { + return gasless; + } + return stxEnabled; + }); + + const { result } = renderHook(() => + useShouldRenderMaxOption(nativeToken, '100', sponsored), + ); + + expect(result.current).toBe(expected); + }); + }); + }); + + describe('Non-EVM native token scenarios', () => { + const solanaToken: BridgeToken = { + address: '0x0000000000000000000000000000000000000000', + symbol: 'SOL', + decimals: 9, + chainId: 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp', // Solana mainnet CAIP-2 + }; + + const bitcoinToken: BridgeToken = { + address: '0x0000000000000000000000000000000000000000', + symbol: 'BTC', + decimals: 8, + chainId: 'bip122:000000000019d6689c085ae165831e93', // Bitcoin mainnet CAIP-2 + }; + + beforeEach(() => { + mockIsNativeAddress.mockReturnValue(true); + }); + + it('returns false for Solana native token even with gasless enabled', () => { + mockUseTokenAddress.mockReturnValue(solanaToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = true + // Second call: stxEnabled = false (non-EVM chain) + if (callCount === 1) { + return true; // gasless enabled + } + return false; // stxEnabled is false for non-EVM + }); + + const { result } = renderHook(() => + useShouldRenderMaxOption(solanaToken, '100', false), + ); + + // Should return false because stxEnabled is false for non-EVM chains + expect(result.current).toBe(false); + }); + + it('returns false for Solana native token even with sponsored quote', () => { + mockUseTokenAddress.mockReturnValue(solanaToken.address); + mockUseSelector.mockImplementation( + () => + // First call: isGaslessSwapEnabled = false + // Second call: stxEnabled = false (non-EVM chain) + false, + ); + + const { result } = renderHook( + () => useShouldRenderMaxOption(solanaToken, '50', true), // sponsored = true + ); + + // Should return false because stxEnabled is false for non-EVM chains + expect(result.current).toBe(false); + }); + + it('returns false for Solana native token with both gasless and sponsored enabled', () => { + mockUseTokenAddress.mockReturnValue(solanaToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = true + // Second call: stxEnabled = false (non-EVM chain) + if (callCount === 1) { + return true; // gasless enabled + } + return false; // stxEnabled is false for non-EVM + }); + + const { result } = renderHook( + () => useShouldRenderMaxOption(solanaToken, '25.5', true), // sponsored = true + ); + + // Should return false because stxEnabled is false for non-EVM chains + expect(result.current).toBe(false); + }); + + it('returns false for Bitcoin native token with gasless enabled', () => { + mockUseTokenAddress.mockReturnValue(bitcoinToken.address); + let callCount = 0; + mockUseSelector.mockImplementation(() => { + callCount++; + // First call: isGaslessSwapEnabled = true + // Second call: stxEnabled = false (non-EVM chain) + if (callCount === 1) { + return true; // gasless enabled + } + return false; // stxEnabled is false for non-EVM + }); + + const { result } = renderHook(() => + useShouldRenderMaxOption(bitcoinToken, '1.5', false), + ); + + // Should return false because stxEnabled is false for non-EVM chains + expect(result.current).toBe(false); + }); + + it('returns false for Bitcoin native token without any flags enabled', () => { + mockUseTokenAddress.mockReturnValue(bitcoinToken.address); + mockUseSelector.mockReturnValue(false); // All flags disabled + + const { result } = renderHook(() => + useShouldRenderMaxOption(bitcoinToken, '0.5', false), + ); + + // Should return false because stxEnabled is false for non-EVM chains + expect(result.current).toBe(false); + }); + + it('returns false for non-EVM native token with zero balance', () => { + mockUseTokenAddress.mockReturnValue(solanaToken.address); + mockUseSelector.mockReturnValue(false); + + const { result } = renderHook(() => + useShouldRenderMaxOption(solanaToken, '0', false), + ); + + // Should return false due to zero balance (checked before non-EVM logic) + expect(result.current).toBe(false); + }); + }); +}); diff --git a/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx b/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx index 815c030a967..4e89e9f1d01 100644 --- a/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx +++ b/app/components/UI/DeepLinkModal/DeepLinkModal.test.tsx @@ -143,7 +143,7 @@ describe('DeepLinkModal', () => { }), build: jest.fn().mockImplementation(function (this: MockBuilder) { return { - name: 'Deep link Used', + name: 'Deep Link Used', properties: { route: 'invalid', was_app_installed: true, diff --git a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx index 003e0dacd12..37e6c063e7f 100644 --- a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx +++ b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.test.tsx @@ -119,6 +119,22 @@ jest.mock('../../hooks/stream/usePerpsLiveOrderBook', () => ({ usePerpsLiveOrderBook: (params: unknown) => mockUsePerpsLiveOrderBook(params), })); +// Mock usePerpsLivePrices for live price updates in header (TAT-2441) +const mockUsePerpsLivePrices = jest.fn< + Record, + [unknown] +>(() => ({ + BTC: { + price: '50000', + percentChange24h: '2.5', + symbol: 'BTC', + }, +})); + +jest.mock('../../hooks/stream/usePerpsLivePrices', () => ({ + usePerpsLivePrices: (params: unknown) => mockUsePerpsLivePrices(params), +})); + // Mock usePerpsMeasurement jest.mock('../../hooks/usePerpsMeasurement', () => ({ usePerpsMeasurement: jest.fn(), @@ -246,7 +262,7 @@ jest.mock('../../components/PerpsMarketHeader', () => { jest.mock( '../../../../../component-library/components/BottomSheets/BottomSheet', () => { - const { View } = jest.requireActual('react-native'); + const { View, TouchableOpacity, Text } = jest.requireActual('react-native'); const ReactMock = jest.requireActual('react'); return { __esModule: true, @@ -257,7 +273,17 @@ jest.mock( onClose?: () => void; }, _ref: unknown, - ) => {props.children}, + ) => ( + + {props.children} + + Backdrop + + + ), ), }; }, @@ -286,12 +312,23 @@ jest.mock('../PerpsSelectModifyActionView', () => { jest.mock( '../../components/PerpsBottomSheetTooltip/PerpsBottomSheetTooltip', () => { - const { View } = jest.requireActual('react-native'); + const { View, TouchableOpacity, Text } = jest.requireActual('react-native'); return { __esModule: true, - default: (props: { testID?: string; isVisible?: boolean }) => + default: (props: { + testID?: string; + isVisible?: boolean; + onClose?: () => void; + }) => props.isVisible ? ( - + + + Close + + ) : null, }; }, @@ -321,6 +358,13 @@ describe('PerpsOrderBookView', () => { bestAsk: '50001', spread: '1.00000', }); + mockUsePerpsLivePrices.mockReturnValue({ + BTC: { + price: '50000', + percentChange24h: '2.5', + symbol: 'BTC', + }, + }); }); describe('rendering', () => { @@ -1070,6 +1114,17 @@ describe('PerpsOrderBookView', () => { ); }); + it('subscribes to live prices for header price display (TAT-2441)', () => { + renderWithProvider(, { state: initialState }); + + expect(mockUsePerpsLivePrices).toHaveBeenCalledWith( + expect.objectContaining({ + symbols: ['BTC'], + throttleMs: 1000, + }), + ); + }); + it('subscribes to live order book with MAX_ORDER_BOOK_LEVELS (20) for server-side aggregation', () => { renderWithProvider(, { state: initialState }); @@ -1139,5 +1194,293 @@ describe('PerpsOrderBookView', () => { getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), ).toBeOnTheScreen(); }); + + it('falls back to static market price when live price is not available', () => { + mockUsePerpsLivePrices.mockReturnValue({}); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('falls back to static market price when live price is invalid (NaN)', () => { + mockUsePerpsLivePrices.mockReturnValue({ + BTC: { + price: 'invalid-price', + percentChange24h: '2.5', + symbol: 'BTC', + }, + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('falls back to static market price when live price is zero or negative', () => { + mockUsePerpsLivePrices.mockReturnValue({ + BTC: { + price: '0', + percentChange24h: '2.5', + symbol: 'BTC', + }, + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('handles invalid topOfBook values gracefully', () => { + mockUsePerpsTopOfBook.mockReturnValue({ + bestBid: '0', + bestAsk: '-1', + spread: '0', + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('handles non-finite topOfBook values', () => { + mockUsePerpsTopOfBook.mockReturnValue({ + bestBid: 'NaN', + bestAsk: 'Infinity', + spread: '0', + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('syncs saved grouping when it loads after mount', async () => { + const { usePerpsOrderBookGrouping } = jest.requireMock( + '../../hooks/usePerpsOrderBookGrouping', + ); + + // First render with undefined savedGrouping + usePerpsOrderBookGrouping.mockReturnValue({ + savedGrouping: undefined, + saveGrouping: mockSaveGrouping, + }); + + const { rerender, getByTestId } = renderWithProvider( + , + { + state: initialState, + }, + ); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + + // Simulate savedGrouping loading + usePerpsOrderBookGrouping.mockReturnValue({ + savedGrouping: 5, + saveGrouping: mockSaveGrouping, + }); + + rerender(); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + }); + + describe('spread tooltip', () => { + it('opens spread tooltip when info button is pressed', async () => { + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + const spreadInfoButton = getByTestId( + PerpsOrderBookViewSelectorsIDs.SPREAD_INFO_BUTTON, + ); + + fireEvent.press(spreadInfoButton); + + await waitFor(() => { + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP), + ).toBeOnTheScreen(); + }); + }); + + it('closes spread tooltip when onClose is called', async () => { + const { getByTestId, queryByTestId } = renderWithProvider( + , + { + state: initialState, + }, + ); + + // Open tooltip + const spreadInfoButton = getByTestId( + PerpsOrderBookViewSelectorsIDs.SPREAD_INFO_BUTTON, + ); + fireEvent.press(spreadInfoButton); + + await waitFor(() => { + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP), + ).toBeOnTheScreen(); + }); + + // Close the tooltip via the close button in the mock + const closeButton = getByTestId( + `${PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP}-close`, + ); + fireEvent.press(closeButton); + + await waitFor(() => { + expect( + queryByTestId(PerpsOrderBookViewSelectorsIDs.BOTTOM_SHEET_TOOLTIP), + ).toBeNull(); + }); + }); + }); + + describe('depth band sheet close', () => { + it('closes depth band sheet via onClose callback', async () => { + const { getByTestId, queryByText } = renderWithProvider( + , + { + state: initialState, + }, + ); + + // Open the depth band sheet + const depthBandButton = getByTestId( + PerpsOrderBookViewSelectorsIDs.DEPTH_BAND_BUTTON, + ); + fireEvent.press(depthBandButton); + + await waitFor(() => { + expect(queryByText('Depth Band')).toBeOnTheScreen(); + }); + + // Close via backdrop (onClose callback) + const backdrop = getByTestId('bottom-sheet-backdrop'); + fireEvent.press(backdrop); + + await waitFor(() => { + expect(queryByText('Depth Band')).toBeNull(); + }); + }); + }); + + describe('geo-block modal interactions', () => { + it('closes geo-block modal when onClose is called', async () => { + const { selectPerpsEligibility } = jest.requireMock( + '../../selectors/perpsController', + ); + selectPerpsEligibility.mockReturnValue(false); + + // Ensure no existing position so Long/Short buttons are shown + const { useHasExistingPosition } = jest.requireMock( + '../../hooks/useHasExistingPosition', + ); + useHasExistingPosition.mockReturnValue({ + isLoading: false, + existingPosition: null, + }); + + const { getByTestId, queryByTestId } = renderWithProvider( + , + { + state: initialState, + }, + ); + + // Trigger geo-block modal by pressing Long button when not eligible + const longButton = getByTestId( + PerpsOrderBookViewSelectorsIDs.LONG_BUTTON, + ); + fireEvent.press(longButton); + + // Verify modal is shown + const geoBlockTooltipId = `${PerpsOrderBookViewSelectorsIDs.CONTAINER}-geo-block-tooltip`; + expect(queryByTestId(geoBlockTooltipId)).toBeOnTheScreen(); + + // Close the modal via the close button + const closeButton = getByTestId(`${geoBlockTooltipId}-close`); + fireEvent.press(closeButton); + + await waitFor(() => { + expect(queryByTestId(geoBlockTooltipId)).toBeNull(); + }); + }); + }); + + describe('grouping with no market price', () => { + it('returns null grouping when market price is unavailable', () => { + const { usePerpsMarkets } = jest.requireMock('../../hooks'); + usePerpsMarkets.mockReturnValue({ + markets: [ + { + symbol: 'BTC', + price: null, + leverage: 50, + }, + ], + isLoading: false, + error: null, + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); + + it('handles invalid market price format', () => { + const { usePerpsMarkets } = jest.requireMock('../../hooks'); + usePerpsMarkets.mockReturnValue({ + markets: [ + { + symbol: 'BTC', + price: 'invalid', + leverage: 50, + }, + ], + isLoading: false, + error: null, + }); + + const { getByTestId } = renderWithProvider(, { + state: initialState, + }); + + expect( + getByTestId(PerpsOrderBookViewSelectorsIDs.CONTAINER), + ).toBeOnTheScreen(); + }); }); }); diff --git a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx index e420cbffe1d..a170292f743 100644 --- a/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx +++ b/app/components/UI/Perps/Views/PerpsOrderBookView/PerpsOrderBookView.tsx @@ -66,6 +66,7 @@ import { } from '../../hooks'; import { useHasExistingPosition } from '../../hooks/useHasExistingPosition'; import { usePerpsLiveOrderBook } from '../../hooks/stream/usePerpsLiveOrderBook'; +import { usePerpsLivePrices } from '../../hooks/stream/usePerpsLivePrices'; import { usePerpsTopOfBook } from '../../hooks/stream/usePerpsTopOfBook'; import { usePerpsEventTracking } from '../../hooks/usePerpsEventTracking'; import { usePerpsMeasurement } from '../../hooks/usePerpsMeasurement'; @@ -203,6 +204,27 @@ const PerpsOrderBookView: React.FC = ({ // This is intentionally independent from order book aggregation/grouping. const topOfBook = usePerpsTopOfBook({ symbol: symbol || '' }); + // Subscribe to live price updates for header display (TAT-2441) + // This ensures the price in the header updates in real-time + const livePrices = usePerpsLivePrices({ + symbols: symbol ? [symbol] : [], + throttleMs: 1000, + }); + + // Current price for header - use live price with fallback to static market price + const currentPrice = useMemo(() => { + const priceData = livePrices[symbol || '']; + if (priceData?.price) { + const parsed = parseFloat(priceData.price); + // Validate parsed value - fallback to marketPrice if invalid + if (Number.isFinite(parsed) && parsed > 0) { + return parsed; + } + } + // Fallback to static market price if live price not available or invalid + return marketPrice ?? 0; + }, [livePrices, symbol, marketPrice]); + const spreadMetrics = useMemo(() => { const bidStr = topOfBook?.bestBid; const askStr = topOfBook?.bestAsk; @@ -470,7 +492,7 @@ const PerpsOrderBookView: React.FC = ({ ) : ( @@ -504,7 +526,7 @@ const PerpsOrderBookView: React.FC = ({ )} diff --git a/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx b/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx index 435c084d28c..5d6bdd70783 100644 --- a/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx +++ b/app/components/UI/Perps/providers/PerpsStreamManager.test.tsx @@ -84,6 +84,9 @@ describe('PerpsStreamManager', () => { subscribeToPositions: mockSubscribeToPositions, subscribeToAccount: mockSubscribeToAccount, isCurrentlyReinitializing: jest.fn().mockReturnValue(false), + getMarkets: jest + .fn() + .mockResolvedValue([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]), } as unknown as typeof mockEngine.context.PerpsController; // Mock AccountTreeController for getEvmAccountFromSelectedAccountGroup @@ -940,6 +943,241 @@ describe('PerpsStreamManager', () => { }); }); + describe('PriceStreamChannel.prewarm non-blocking behavior', () => { + it('returns immediately without waiting for getMarkets', async () => { + // Create a promise that we can control + let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn(); + const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => { + resolveGetMarkets = resolve; + }); + + mockEngine.context.PerpsController.getMarkets = jest + .fn() + .mockReturnValue(getMarketsPromise); + + // prewarm should return immediately + const cleanupPromise = testStreamManager.prices.prewarm(); + + // The promise should resolve immediately (before getMarkets completes) + const cleanup = await cleanupPromise; + expect(typeof cleanup).toBe('function'); + + // getMarkets was called but not awaited + expect(mockEngine.context.PerpsController.getMarkets).toHaveBeenCalled(); + + // subscribeToPrices should NOT have been called yet + expect(mockSubscribeToPrices).not.toHaveBeenCalled(); + + // Now resolve getMarkets + resolveGetMarkets([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]); + + // Wait for the promise chain to complete + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + // Now subscribeToPrices should have been called + expect(mockSubscribeToPrices).toHaveBeenCalledWith({ + symbols: ['BTC-PERP', 'ETH-PERP'], + callback: expect.any(Function), + }); + + cleanup(); + }); + + it('skips subscription when cleanup occurs before getMarkets completes', async () => { + // Create a promise that we can control + let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn(); + const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => { + resolveGetMarkets = resolve; + }); + + mockEngine.context.PerpsController.getMarkets = jest + .fn() + .mockReturnValue(getMarketsPromise); + + // prewarm should return immediately + const cleanup = await testStreamManager.prices.prewarm(); + + // Call cleanup before getMarkets resolves + cleanup(); + + // Now resolve getMarkets + resolveGetMarkets([{ name: 'BTC-PERP' }, { name: 'ETH-PERP' }]); + + // Wait for the promise chain to complete + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + // subscribeToPrices should NOT have been called (cleaned up before it could subscribe) + expect(mockSubscribeToPrices).not.toHaveBeenCalled(); + }); + + it('logs error and returns cleanup function when getMarkets fails', async () => { + mockEngine.context.PerpsController.getMarkets = jest + .fn() + .mockRejectedValue(new Error('Network error')); + + // prewarm should return immediately + const cleanup = await testStreamManager.prices.prewarm(); + expect(typeof cleanup).toBe('function'); + + // Wait for the promise chain to complete + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + // subscribeToPrices should NOT have been called due to error + expect(mockSubscribeToPrices).not.toHaveBeenCalled(); + + // Logger.error should have been called + expect(mockLogger.error).toHaveBeenCalled(); + + cleanup(); + }); + + it('calls actual unsubscribe when cleanupPrewarm runs after subscription is established', async () => { + const mockActualUnsubscribe = jest.fn(); + mockSubscribeToPrices.mockReturnValue(mockActualUnsubscribe); + + mockEngine.context.PerpsController.getMarkets = jest + .fn() + .mockResolvedValue([{ name: 'BTC-PERP' }]); + + // prewarm and wait for subscription to be established + const cleanup = await testStreamManager.prices.prewarm(); + + // Wait for the background subscription to be set up + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(mockSubscribeToPrices).toHaveBeenCalled(); + + // Call cleanup + cleanup(); + + // The actual unsubscribe should have been called + expect(mockActualUnsubscribe).toHaveBeenCalled(); + }); + + it('returns same cleanup when prewarm called twice without cleanup (already pre-warmed guard)', async () => { + const mockUnsubscribe = jest.fn(); + mockSubscribeToPrices.mockReturnValue(mockUnsubscribe); + + // Create controlled promise + let resolveGetMarkets: (value: { name: string }[]) => void = jest.fn(); + const getMarketsPromise = new Promise<{ name: string }[]>((resolve) => { + resolveGetMarkets = resolve; + }); + + mockEngine.context.PerpsController.getMarkets = jest + .fn() + .mockReturnValue(getMarketsPromise); + + // Cycle 1: User enters Perps + const cleanup1 = await testStreamManager.prices.prewarm(); + + // Cycle 2: Called again without cleanup (rapid navigation) + // The guard at line 362 returns the existing prewarmUnsubscribe + const cleanup2 = await testStreamManager.prices.prewarm(); + + // Both cleanups should be the same function (guarded by "already pre-warmed" check) + expect(cleanup1).toBe(cleanup2); + + // getMarkets only called once (second call was short-circuited) + expect( + mockEngine.context.PerpsController.getMarkets, + ).toHaveBeenCalledTimes(1); + + // Resolve and cleanup + resolveGetMarkets([{ name: 'BTC-PERP' }]); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(mockSubscribeToPrices).toHaveBeenCalledTimes(1); + cleanup1(); + expect(mockUnsubscribe).toHaveBeenCalled(); + }); + + it('prevents stale promise from creating subscription after cleanup and new prewarm', async () => { + const mockUnsubscribe1 = jest.fn(); + const mockUnsubscribe2 = jest.fn(); + let subscribeCallCount = 0; + + mockSubscribeToPrices.mockImplementation(() => { + subscribeCallCount++; + return subscribeCallCount === 1 ? mockUnsubscribe1 : mockUnsubscribe2; + }); + + // Create controlled promises for each cycle + let resolveGetMarkets1: (value: { name: string }[]) => void = jest.fn(); + let resolveGetMarkets2: (value: { name: string }[]) => void = jest.fn(); + + const getMarketsPromise1 = new Promise<{ name: string }[]>((resolve) => { + resolveGetMarkets1 = resolve; + }); + const getMarketsPromise2 = new Promise<{ name: string }[]>((resolve) => { + resolveGetMarkets2 = resolve; + }); + + let getMarketsCallCount = 0; + (mockEngine.context.PerpsController.getMarkets as jest.Mock) = jest.fn( + () => { + getMarketsCallCount++; + return getMarketsCallCount === 1 + ? getMarketsPromise1 + : getMarketsPromise2; + }, + ); + + // Cycle 1: User enters Perps + const cleanup1 = await testStreamManager.prices.prewarm(); + + // User leaves before markets load - this resets prewarmUnsubscribe to undefined + cleanup1(); + + // Cycle 2: User enters Perps again + const cleanup2 = await testStreamManager.prices.prewarm(); + + // Now cycle 1's promise resolves (STALE - should be ignored due to cycle ID mismatch) + resolveGetMarkets1([{ name: 'BTC-PERP' }]); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + // Stale promise should NOT create subscription (cycle ID mismatch) + expect(mockSubscribeToPrices).not.toHaveBeenCalled(); + + // Cycle 2's promise resolves (active) + resolveGetMarkets2([{ name: 'ETH-PERP' }]); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + // Only one subscription should be created (from cycle 2) + expect(mockSubscribeToPrices).toHaveBeenCalledTimes(1); + expect(mockSubscribeToPrices).toHaveBeenCalledWith({ + symbols: ['ETH-PERP'], + callback: expect.any(Function), + }); + + // Cleanup cycle 2 - since this is the first subscription call, it uses mockUnsubscribe1 + cleanup2(); + expect(mockUnsubscribe1).toHaveBeenCalled(); + // mockUnsubscribe2 was never created because only one subscription was made + }); + }); + it('should throttle subsequent updates', async () => { const onUpdate = jest.fn(); let priceCallback: (data: PriceUpdate[]) => void = jest.fn(); diff --git a/app/components/UI/Perps/providers/PerpsStreamManager.tsx b/app/components/UI/Perps/providers/PerpsStreamManager.tsx index 0cc19de3876..eaec6098be5 100644 --- a/app/components/UI/Perps/providers/PerpsStreamManager.tsx +++ b/app/components/UI/Perps/providers/PerpsStreamManager.tsx @@ -236,7 +236,10 @@ abstract class StreamChannel { class PriceStreamChannel extends StreamChannel> { private symbols = new Set(); private prewarmUnsubscribe?: () => void; + private actualPriceUnsubscribe?: () => void; private allMarketSymbols: string[] = []; + // Unique ID per prewarm cycle to detect stale promises and prevent subscription leaks + private prewarmCycleId: number = 0; // Override cache to store individual PriceUpdate objects protected priceCache = new Map(); @@ -355,6 +358,7 @@ class PriceStreamChannel extends StreamChannel> { /** * Pre-warm the channel by subscribing to all market prices * This keeps a single WebSocket connection alive with all price updates + * Non-blocking: Returns immediately while market fetch happens in background * @returns Cleanup function to call when leaving Perps environment */ public async prewarm(): Promise<() => void> { @@ -364,57 +368,104 @@ class PriceStreamChannel extends StreamChannel> { } try { - // Get all available market symbols const controller = Engine.context.PerpsController; - const markets = await controller.getMarkets(); - this.allMarketSymbols = markets.map((market) => market.name); - DevLogger.log('PriceStreamChannel: Pre-warming with all market symbols', { - symbolCount: this.allMarketSymbols.length, - symbols: this.allMarketSymbols.slice(0, 10), // Log first 10 for debugging - }); + // Increment cycle ID to detect stale promises from previous prewarm cycles + // This prevents subscription leaks when user navigates: Perps → away → back quickly + this.prewarmCycleId++; + const currentCycleId = this.prewarmCycleId; + + // Start market fetch in background (non-blocking) + // We need the symbols to register subscribers, but we can return immediately + const marketsPromise = controller.getMarkets(); + + // Set up subscription once markets arrive (fire-and-forget) + marketsPromise + .then((markets) => { + // If this promise is from a stale cycle, don't set up subscription + // This prevents leaks when prewarm is called multiple times rapidly + if (currentCycleId !== this.prewarmCycleId) { + DevLogger.log('PriceStreamChannel: Skipping stale prewarm cycle', { + currentCycleId, + activeCycleId: this.prewarmCycleId, + }); + return; + } + + // If already cleaned up, don't set up subscription + if (this.prewarmUnsubscribe === undefined) { + return; + } + + this.allMarketSymbols = markets.map((market) => market.name); + + DevLogger.log( + 'PriceStreamChannel: Pre-warming with all market symbols', + { + symbolCount: this.allMarketSymbols.length, + symbols: this.allMarketSymbols.slice(0, 10), + }, + ); - // Subscribe to all market prices - this.prewarmUnsubscribe = controller.subscribeToPrices({ - symbols: this.allMarketSymbols, - callback: (updates: PriceUpdate[]) => { - // Update cache and build price map - const priceMap: Record = {}; - updates.forEach((update) => { - const priceUpdate: PriceUpdate = { - symbol: update.symbol, - price: update.price, - timestamp: Date.now(), - percentChange24h: update.percentChange24h, - bestBid: update.bestBid, - bestAsk: update.bestAsk, - spread: update.spread, - markPrice: update.markPrice, - funding: update.funding, - openInterest: update.openInterest, - volume24h: update.volume24h, - }; - this.priceCache.set(update.symbol, priceUpdate); - priceMap[update.symbol] = priceUpdate; + // Subscribe to all market prices + const unsub = controller.subscribeToPrices({ + symbols: this.allMarketSymbols, + callback: (updates: PriceUpdate[]) => { + const priceMap: Record = {}; + updates.forEach((update) => { + const priceUpdate: PriceUpdate = { + symbol: update.symbol, + price: update.price, + timestamp: Date.now(), + percentChange24h: update.percentChange24h, + bestBid: update.bestBid, + bestAsk: update.bestAsk, + spread: update.spread, + markPrice: update.markPrice, + funding: update.funding, + openInterest: update.openInterest, + volume24h: update.volume24h, + }; + this.priceCache.set(update.symbol, priceUpdate); + priceMap[update.symbol] = priceUpdate; + }); + + if (this.subscribers.size > 0) { + this.notifySubscribers(priceMap); + } + }, }); - // Notify any active subscribers with all updates + // Store the actual unsubscribe function + this.actualPriceUnsubscribe = unsub; + }) + .catch((error) => { + Logger.error( + ensureError(error, 'PriceStreamChannel.prewarm.backgroundFetch'), + { + context: 'PriceStreamChannel.prewarm.backgroundFetch', + }, + ); + // Reset state so subsequent prewarm/connect calls can recover + this.prewarmUnsubscribe = undefined; + this.allMarketSymbols = []; + // Reconnect waiting subscribers that were skipped because prewarm was pending if (this.subscribers.size > 0) { - this.notifySubscribers(priceMap); + this.connect(); } - }, - }); + }); - // Return a cleanup function that properly clears internal state - return () => { + // Return cleanup function immediately (before markets load) + this.prewarmUnsubscribe = () => { DevLogger.log('PriceStreamChannel: Cleaning up prewarm subscription'); this.cleanupPrewarm(); }; + + return this.prewarmUnsubscribe; } catch (error) { Logger.error(ensureError(error, 'PriceStreamChannel.prewarm'), { context: 'PriceStreamChannel.prewarm', }); - // Return no-op cleanup function return () => { // No-op }; @@ -425,11 +476,12 @@ class PriceStreamChannel extends StreamChannel> { * Cleanup pre-warm subscription */ public cleanupPrewarm(): void { - if (this.prewarmUnsubscribe) { - this.prewarmUnsubscribe(); - this.prewarmUnsubscribe = undefined; - this.allMarketSymbols = []; + if (this.actualPriceUnsubscribe) { + this.actualPriceUnsubscribe(); + this.actualPriceUnsubscribe = undefined; } + this.prewarmUnsubscribe = undefined; + this.allMarketSymbols = []; } } @@ -1285,7 +1337,7 @@ class MarketDataChannel extends StreamChannel { public prewarm(): () => void { // Fetch data immediately to populate cache this.fetchMarketData().catch((error) => { - Logger.error(error instanceof Error ? error : new Error(String(error)), { + Logger.error(ensureError(error, 'MarketDataChannel.prewarm'), { context: 'MarketDataChannel.prewarm', }); }); diff --git a/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts b/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts index 22f2a990f26..fe3e6984b6f 100644 --- a/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts +++ b/app/components/UI/Predict/controllers/PredictController.getActivity.test.ts @@ -1,31 +1,5 @@ import { PredictController } from './PredictController'; -// Mock Engine AccountsController for selected address -jest.mock('../../../../core/Engine', () => ({ - context: { - AccountsController: { - getSelectedAccount: jest.fn(() => ({ address: '0xselected' })), - }, - AccountTreeController: { - getAccountsFromSelectedAccountGroup: jest.fn().mockReturnValue([ - { - address: '0xselected', - id: 'mock-account-id', - type: 'eip155:eoa', - options: {}, - metadata: { - name: 'Test Account', - importTime: Date.now(), - keyring: { type: 'HD Key Tree' }, - }, - scopes: ['eip155:1'], - methods: ['eth_sendTransaction'], - }, - ]), - }, - }, -})); - interface ActivityEntry { id: string; providerId: string; @@ -55,6 +29,11 @@ describe('PredictController.getActivity', () => { ) as MockPredictController; controller.providers = new Map(Object.entries(providers)); controller.update = jest.fn(); + ( + controller as unknown as { getSigner: () => { address: string } } + ).getSigner = jest.fn(() => ({ + address: '0xselected', + })); return controller; }; diff --git a/app/components/UI/Predict/controllers/PredictController.test.ts b/app/components/UI/Predict/controllers/PredictController.test.ts index 1a4a9300876..03a41eb07c6 100644 --- a/app/components/UI/Predict/controllers/PredictController.test.ts +++ b/app/components/UI/Predict/controllers/PredictController.test.ts @@ -7,14 +7,10 @@ import { type MessengerEvents, type MockAnyNamespace, } from '@metamask/messenger'; -import { - GasFeeEstimateLevel, - GasFeeEstimateType, -} from '@metamask/transaction-controller'; + import type { NetworkState } from '@metamask/network-controller'; import type { InternalAccount } from '@metamask/keyring-internal-api'; -import Engine from '../../../../core/Engine'; import DevLogger from '../../../../core/SDKConnect/utils/DevLogger'; import { addTransaction, @@ -46,71 +42,33 @@ jest.mock('../../../../util/transaction-controller', () => ({ addTransactionBatch: jest.fn(), })); -// Mock Engine -jest.mock('../../../../core/Engine', () => ({ - context: { - KeyringController: { - signTypedMessage: jest.fn(), +// Default mock values for messenger actions +const DEFAULT_REMOTE_FEATURE_FLAG_STATE = { + remoteFeatureFlags: { + predictFeeCollection: { + enabled: true, + collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7', + metamaskFee: 0.02, + providerFee: 0.02, + waiveList: [], }, - AccountsController: { - getSelectedAccount: jest.fn().mockReturnValue({ - id: 'mock-account-id', - address: '0x1234567890123456789012345678901234567890', - metadata: { name: 'Test Account' }, - }), + predictLiveSports: { + enabled: false, + leagues: [], }, - AccountTreeController: { - getAccountsFromSelectedAccountGroup: jest.fn(() => [ - { - id: 'mock-account-id', - address: '0x1234567890123456789012345678901234567890', - type: 'eip155:eoa', - name: 'Test Account', - metadata: { - lastSelected: 0, - }, - }, - ]), - }, - NetworkController: { - getState: jest.fn().mockReturnValue({ - selectedNetworkClientId: 'mainnet', - }), - findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'), - getNetworkClientById: jest.fn().mockReturnValue({ - blockTracker: { - checkForLatestBlock: jest.fn().mockResolvedValue(undefined), - }, - }), - }, - TransactionController: { - estimateGas: jest.fn(), - estimateGasFee: jest.fn(), - }, - AccountTrackerController: { - state: { - accountsByChainId: {}, - }, - }, - RemoteFeatureFlagController: { - state: { - remoteFeatureFlags: { - predictFeeCollection: { - enabled: true, - collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7', - metamaskFee: 0.02, - providerFee: 0.02, - waiveList: [], - }, - predictLiveSports: { - enabled: false, - leagues: [], - }, - }, - }, + predictMarketHighlights: { + enabled: false, + highlights: [], }, }, -})); + cacheTimestamp: Date.now(), +}; + +const DEFAULT_NETWORK_CLIENT = { + blockTracker: { + checkForLatestBlock: jest.fn().mockResolvedValue(undefined), + }, +}; // Mock DevLogger (default export) jest.mock('../../../../core/SDKConnect/utils/DevLogger', () => ({ @@ -266,6 +224,11 @@ describe('PredictController', () => { mocks?: { getSelectedAccount?: jest.MockedFunction<() => InternalAccount>; getNetworkState?: jest.MockedFunction<() => NetworkState>; + getRemoteFeatureFlagState?: jest.MockedFunction<() => any>; + findNetworkClientIdByChainId?: jest.MockedFunction< + (chainId: string) => string + >; + getNetworkClientById?: jest.MockedFunction<(clientId: string) => any>; }; } = {}, ): ReturnValue { @@ -284,6 +247,19 @@ describe('PredictController', () => { }), ); + rootMessenger.registerActionHandler( + 'AccountTreeController:getAccountsFromSelectedAccountGroup', + jest.fn().mockReturnValue([ + { + id: 'mock-account-id', + address: '0x1234567890123456789012345678901234567890', + type: 'eip155:eoa', + name: 'Test Account', + metadata: { lastSelected: 0 }, + }, + ]), + ); + rootMessenger.registerActionHandler( 'NetworkController:getState', mocks.getNetworkState ?? @@ -299,6 +275,34 @@ describe('PredictController', () => { }), ); + rootMessenger.registerActionHandler( + 'KeyringController:signTypedMessage', + jest.fn().mockResolvedValue('0xmocksignature'), + ); + + rootMessenger.registerActionHandler( + 'KeyringController:signPersonalMessage', + jest.fn().mockResolvedValue('0xmocksignature'), + ); + + rootMessenger.registerActionHandler( + 'NetworkController:findNetworkClientIdByChainId', + mocks.findNetworkClientIdByChainId ?? + jest.fn().mockReturnValue('polygon-mainnet'), + ); + + rootMessenger.registerActionHandler( + 'NetworkController:getNetworkClientById', + mocks.getNetworkClientById ?? + jest.fn().mockReturnValue(DEFAULT_NETWORK_CLIENT), + ); + + rootMessenger.registerActionHandler( + 'RemoteFeatureFlagController:getState', + mocks.getRemoteFeatureFlagState ?? + jest.fn().mockReturnValue(DEFAULT_REMOTE_FEATURE_FLAG_STATE), + ); + const messenger = new Messenger< 'PredictController', AllPredictControllerMessengerActions, @@ -312,14 +316,21 @@ describe('PredictController', () => { rootMessenger.delegate({ actions: [ 'AccountsController:getSelectedAccount', + 'AccountTreeController:getAccountsFromSelectedAccountGroup', 'NetworkController:getState', + 'NetworkController:findNetworkClientIdByChainId', + 'NetworkController:getNetworkClientById', 'TransactionController:estimateGas', + 'KeyringController:signTypedMessage', + 'KeyringController:signPersonalMessage', + 'RemoteFeatureFlagController:getState', ], events: [ 'TransactionController:transactionSubmitted', 'TransactionController:transactionConfirmed', 'TransactionController:transactionFailed', 'TransactionController:transactionRejected', + 'RemoteFeatureFlagController:stateChange', ], messenger, }); @@ -1794,32 +1805,28 @@ describe('PredictController', () => { outcomes: ['YES', 'NO'], }); - const setMarketHighlightsFlag = (flag: { + const createFlagState = (flag: { enabled: boolean; highlights: { category: string; markets: string[] }[]; - }) => { - ( - Engine.context.RemoteFeatureFlagController as any - ).state.remoteFeatureFlags.predictMarketHighlights = flag; - }; - - const clearMarketHighlightsFlag = () => { - delete (Engine.context.RemoteFeatureFlagController as any).state - .remoteFeatureFlags.predictMarketHighlights; - }; - - afterEach(() => { - clearMarketHighlightsFlag(); + }) => ({ + remoteFeatureFlags: { + predictFeeCollection: { + enabled: true, + collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7', + metamaskFee: 0.02, + providerFee: 0.02, + waiveList: [], + }, + predictLiveSports: { + enabled: false, + leagues: [], + }, + predictMarketHighlights: flag, + }, + cacheTimestamp: Date.now(), }); it('prepends highlighted markets when flag is enabled and offset is 0', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [ - { category: 'trending', markets: ['highlight-1', 'highlight-2'] }, - ], - }); - const regularMarkets = [ createMockMarket('regular-1'), createMockMarket('regular-2'), @@ -1829,134 +1836,180 @@ describe('PredictController', () => { createMockMarket('highlight-2'), ]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue( - highlightedMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue( + highlightedMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); - - expect(result).toHaveLength(4); - expect(result[0].id).toBe('highlight-1'); - expect(result[1].id).toBe('highlight-2'); - expect(result[2].id).toBe('regular-1'); - expect(result[3].id).toBe('regular-2'); - expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalledWith( - ['highlight-1', 'highlight-2'], - [], - ); - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); + + expect(result).toHaveLength(4); + expect(result[0].id).toBe('highlight-1'); + expect(result[1].id).toBe('highlight-2'); + expect(result[2].id).toBe('regular-1'); + expect(result[3].id).toBe('regular-2'); + expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalledWith( + ['highlight-1', 'highlight-2'], + [], + ); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { + category: 'trending', + markets: ['highlight-1', 'highlight-2'], + }, + ], + }), + ), + }, + }, + ); }); it('skips highlights when offset is greater than 0', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 10, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 10, + }); - expect(result).toHaveLength(1); - expect(result[0].id).toBe('regular-1'); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(result[0].id).toBe('regular-1'); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { category: 'trending', markets: ['highlight-1'] }, + ], + }), + ), + }, + }, + ); }); it('skips highlights when flag is disabled', async () => { - setMarketHighlightsFlag({ - enabled: false, - highlights: [{ category: 'trending', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(result[0].id).toBe('regular-1'); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(result[0].id).toBe('regular-1'); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: false, + highlights: [ + { category: 'trending', markets: ['highlight-1'] }, + ], + }), + ), + }, + }, + ); }); it('skips highlights when category is not provided', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + }); - expect(result).toHaveLength(1); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { category: 'trending', markets: ['highlight-1'] }, + ], + }), + ), + }, + }, + ); }); it('skips highlights when category has no configured highlights', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'crypto', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [{ category: 'crypto', markets: ['highlight-1'] }], + }), + ), + }, + }, + ); }); it('filters duplicates from regular results when highlighted market appears in both', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: ['duplicate-market'] }], - }); - const regularMarkets = [ createMockMarket('regular-1'), createMockMarket('duplicate-market'), @@ -1964,135 +2017,183 @@ describe('PredictController', () => { ]; const highlightedMarkets = [createMockMarket('duplicate-market')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue( - highlightedMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue( + highlightedMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(3); - expect(result[0].id).toBe('duplicate-market'); - expect(result[1].id).toBe('regular-1'); - expect(result[2].id).toBe('regular-2'); - }); + expect(result).toHaveLength(3); + expect(result[0].id).toBe('duplicate-market'); + expect(result[1].id).toBe('regular-1'); + expect(result[2].id).toBe('regular-2'); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { category: 'trending', markets: ['duplicate-market'] }, + ], + }), + ), + }, + }, + ); }); it('handles empty highlights array gracefully', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: [] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(result[0].id).toBe('regular-1'); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(result[0].id).toBe('regular-1'); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [{ category: 'trending', markets: [] }], + }), + ), + }, + }, + ); }); it('handles getMarketsByIds failure gracefully', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue([]); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue([]); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(result[0].id).toBe('regular-1'); - }); + expect(result).toHaveLength(1); + expect(result[0].id).toBe('regular-1'); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { category: 'trending', markets: ['highlight-1'] }, + ], + }), + ), + }, + }, + ); }); it('uses default flag when predictMarketHighlights is not in remote config', async () => { - clearMarketHighlightsFlag(); - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue({ + remoteFeatureFlags: { + predictFeeCollection: + DEFAULT_REMOTE_FEATURE_FLAG_STATE.remoteFeatureFlags + .predictFeeCollection, + predictLiveSports: + DEFAULT_REMOTE_FEATURE_FLAG_STATE.remoteFeatureFlags + .predictLiveSports, + // predictMarketHighlights intentionally omitted to test fallback + }, + cacheTimestamp: Date.now(), + }), + }, + }, + ); }); it('fetches highlights for first page when offset is undefined', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [{ category: 'trending', markets: ['highlight-1'] }], - }); - const regularMarkets = [createMockMarket('regular-1')]; const highlightedMarkets = [createMockMarket('highlight-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue( - highlightedMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue( + highlightedMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + }); - expect(result).toHaveLength(2); - expect(result[0].id).toBe('highlight-1'); - expect(result[1].id).toBe('regular-1'); - expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalled(); - }); + expect(result).toHaveLength(2); + expect(result[0].id).toBe('highlight-1'); + expect(result[1].id).toBe('regular-1'); + expect(mockPolymarketProvider.getMarketsByIds).toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { category: 'trending', markets: ['highlight-1'] }, + ], + }), + ), + }, + }, + ); }); it('preserves order of highlighted markets from config', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [ - { category: 'trending', markets: ['third', 'first', 'second'] }, - ], - }); - const regularMarkets = [createMockMarket('regular-1')]; const highlightedMarkets = [ createMockMarket('third'), @@ -2100,85 +2201,131 @@ describe('PredictController', () => { createMockMarket('second'), ]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue( - highlightedMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue( + highlightedMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result[0].id).toBe('third'); - expect(result[1].id).toBe('first'); - expect(result[2].id).toBe('second'); - }); + expect(result[0].id).toBe('third'); + expect(result[1].id).toBe('first'); + expect(result[2].id).toBe('second'); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { + category: 'trending', + markets: ['third', 'first', 'second'], + }, + ], + }), + ), + }, + }, + ); }); it('handles missing highlights array in flag gracefully', async () => { - ( - Engine.context.RemoteFeatureFlagController as any - ).state.remoteFeatureFlags.predictMarketHighlights = { enabled: true }; - const regularMarkets = [createMockMarket('regular-1')]; - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarkets as any, - ); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarkets as any, + ); - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); - expect(result).toHaveLength(1); - expect(result[0].id).toBe('regular-1'); - expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); - }); + expect(result).toHaveLength(1); + expect(result[0].id).toBe('regular-1'); + expect(mockPolymarketProvider.getMarketsByIds).not.toHaveBeenCalled(); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue({ + remoteFeatureFlags: { + predictFeeCollection: { + enabled: true, + collector: '0x100c7b833bbd604a77890783439bbb9d65e31de7', + metamaskFee: 0.02, + providerFee: 0.02, + waiveList: [], + }, + predictLiveSports: { + enabled: false, + leagues: [], + }, + predictMarketHighlights: { enabled: true }, + }, + cacheTimestamp: Date.now(), + }), + }, + }, + ); }); it('keeps market in regular results when highlight fetch fails for that market', async () => { - setMarketHighlightsFlag({ - enabled: true, - highlights: [ - { category: 'trending', markets: ['highlight-1', 'highlight-2'] }, - ], - }); - const regularMarketsIncludingFailedHighlight = [ createMockMarket('highlight-1'), createMockMarket('regular-1'), ]; const onlySuccessfullyFetchedHighlights = [ createMockMarket('highlight-2'), - ]; - - await withController(async ({ controller }) => { - mockPolymarketProvider.getMarkets.mockResolvedValue( - regularMarketsIncludingFailedHighlight as any, - ); - mockPolymarketProvider.getMarketsByIds.mockResolvedValue( - onlySuccessfullyFetchedHighlights as any, - ); + ]; - const result = await controller.getMarkets({ - providerId: 'polymarket', - category: 'trending', - offset: 0, - }); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getMarkets.mockResolvedValue( + regularMarketsIncludingFailedHighlight as any, + ); + mockPolymarketProvider.getMarketsByIds.mockResolvedValue( + onlySuccessfullyFetchedHighlights as any, + ); - expect(result).toHaveLength(3); - expect(result[0].id).toBe('highlight-2'); - expect(result[1].id).toBe('highlight-1'); - expect(result[2].id).toBe('regular-1'); - }); + const result = await controller.getMarkets({ + providerId: 'polymarket', + category: 'trending', + offset: 0, + }); + + expect(result).toHaveLength(3); + expect(result[0].id).toBe('highlight-2'); + expect(result[1].id).toBe('highlight-1'); + expect(result[2].id).toBe('regular-1'); + }, + { + mocks: { + getRemoteFeatureFlagState: jest.fn().mockReturnValue( + createFlagState({ + enabled: true, + highlights: [ + { + category: 'trending', + markets: ['highlight-1', 'highlight-2'], + }, + ], + }), + ), + }, + }, + ); }); }); @@ -2546,64 +2693,71 @@ describe('PredictController', () => { it('throws error when network client not found', async () => { // Arrange - await withController(async ({ controller }) => { - mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ - { - marketId: 'test-market', - providerId: 'polymarket', - outcomeId: 'test-outcome', - balance: '100', - }, - ]); - const MockedEngine = jest.requireMock('../../../../core/Engine'); - MockedEngine.context.NetworkController.findNetworkClientIdByChainId = - jest.fn().mockReturnValue(undefined); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ + { + marketId: 'test-market', + providerId: 'polymarket', + outcomeId: 'test-outcome', + balance: '100', + }, + ]); - mockPolymarketProvider.prepareClaim = jest - .fn() - .mockResolvedValue(mockClaim); - await controller.getPositions({ claimable: true }); + mockPolymarketProvider.prepareClaim = jest + .fn() + .mockResolvedValue(mockClaim); + await controller.getPositions({ claimable: true }); - // Act & Assert - await expect( - controller.claimWithConfirmation({ - providerId: 'polymarket', - }), - ).rejects.toThrow('Network client not found for chain ID'); - }); + // Act & Assert + await expect( + controller.claimWithConfirmation({ + providerId: 'polymarket', + }), + ).rejects.toThrow('Network client not found for chain ID'); + }, + { + mocks: { + findNetworkClientIdByChainId: jest.fn().mockReturnValue(undefined), + }, + }, + ); }); it('throws error when transaction batch returns no batchId', async () => { // Arrange - await withController(async ({ controller }) => { - mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ - { - marketId: 'test-market', - providerId: 'polymarket', - outcomeId: 'test-outcome', - balance: '100', - }, - ]); - mockPolymarketProvider.prepareClaim = jest - .fn() - .mockResolvedValue(mockClaim); - - const MockedEngine = jest.requireMock('../../../../core/Engine'); - MockedEngine.context.NetworkController.findNetworkClientIdByChainId = - jest.fn().mockReturnValue('mainnet'); + await withController( + async ({ controller }) => { + mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ + { + marketId: 'test-market', + providerId: 'polymarket', + outcomeId: 'test-outcome', + balance: '100', + }, + ]); + mockPolymarketProvider.prepareClaim = jest + .fn() + .mockResolvedValue(mockClaim); - (addTransactionBatch as jest.Mock).mockResolvedValue({}); - await controller.getPositions({ claimable: true }); + (addTransactionBatch as jest.Mock).mockResolvedValue({}); + await controller.getPositions({ claimable: true }); - // Act & Assert - await expect( - controller.claimWithConfirmation({ - providerId: 'polymarket', - }), - ).rejects.toThrow( - 'Failed to get batch ID from claim transaction submission', - ); - }); + // Act & Assert + await expect( + controller.claimWithConfirmation({ + providerId: 'polymarket', + }), + ).rejects.toThrow( + 'Failed to get batch ID from claim transaction submission', + ); + }, + { + mocks: { + findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'), + }, + }, + ); }); it('throws error when prepareClaim returns no transactions', async () => { @@ -2667,42 +2821,45 @@ describe('PredictController', () => { it('clears error state on successful claim', async () => { // Arrange const mockBatchId = 'claim-batch-1'; - await withController(async ({ controller }) => { - controller.updateStateForTesting((state) => { - state.lastError = 'Previous error'; - }); - - mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ - { - marketId: 'test-market', - providerId: 'polymarket', - outcomeId: 'test-outcome', - balance: '100', - }, - ]); - mockPolymarketProvider.prepareClaim = jest - .fn() - .mockResolvedValue(mockClaim); + await withController( + async ({ controller }) => { + controller.updateStateForTesting((state) => { + state.lastError = 'Previous error'; + }); - const MockedEngine = jest.requireMock('../../../../core/Engine'); - MockedEngine.context.NetworkController.findNetworkClientIdByChainId = - jest.fn().mockReturnValue('mainnet'); + mockPolymarketProvider.getPositions = jest.fn().mockResolvedValue([ + { + marketId: 'test-market', + providerId: 'polymarket', + outcomeId: 'test-outcome', + balance: '100', + }, + ]); + mockPolymarketProvider.prepareClaim = jest + .fn() + .mockResolvedValue(mockClaim); - (addTransactionBatch as jest.Mock).mockResolvedValue({ - batchId: mockBatchId, - }); - await controller.getPositions({ claimable: true }); + (addTransactionBatch as jest.Mock).mockResolvedValue({ + batchId: mockBatchId, + }); + await controller.getPositions({ claimable: true }); - // Act - const result = await controller.claimWithConfirmation({ - providerId: 'polymarket', - }); + // Act + const result = await controller.claimWithConfirmation({ + providerId: 'polymarket', + }); - // Assert - expect(result.batchId).toBe(mockBatchId); - expect(controller.state.lastError).toBeNull(); - expect(controller.state.lastUpdateTimestamp).toBeGreaterThan(0); - }); + // Assert + expect(result.batchId).toBe(mockBatchId); + expect(controller.state.lastError).toBeNull(); + expect(controller.state.lastUpdateTimestamp).toBeGreaterThan(0); + }, + { + mocks: { + findNetworkClientIdByChainId: jest.fn().mockReturnValue('mainnet'), + }, + }, + ); }); }); @@ -3009,37 +3166,6 @@ describe('PredictController', () => { batchId: mockBatchId, }); - Engine.context.AccountTrackerController.state.accountsByChainId = { - [mockChainId]: { - '0x1234567890123456789012345678901234567890': { - balance: '0x0', - }, - }, - }; - - Engine.context.TransactionController.estimateGas = jest - .fn() - .mockResolvedValue({ gas: '0x5208' }); - Engine.context.TransactionController.estimateGasFee = jest - .fn() - .mockResolvedValue({ - estimates: { - type: GasFeeEstimateType.FeeMarket, - [GasFeeEstimateLevel.Low]: { - maxFeePerGas: '0x3b9aca00', - maxPriorityFeePerGas: '0x1', - }, - [GasFeeEstimateLevel.Medium]: { - maxFeePerGas: '0x3b9aca00', - maxPriorityFeePerGas: '0x1', - }, - [GasFeeEstimateLevel.High]: { - maxFeePerGas: '0x3b9aca00', - maxPriorityFeePerGas: '0x1', - }, - }, - }); - await withController(async ({ controller }) => { // When calling depositWithConfirmation const result = await controller.depositWithConfirmation({ @@ -3064,7 +3190,7 @@ describe('PredictController', () => { expect(addTransactionBatch).toHaveBeenCalledWith({ from: '0x1234567890123456789012345678901234567890', origin: 'metamask', - networkClientId: 'mainnet', + networkClientId: 'polygon-mainnet', disableHook: true, disableSequential: true, disableUpgrade: true, @@ -3175,6 +3301,9 @@ describe('PredictController', () => { getNetworkState: jest.fn().mockReturnValue({ selectedNetworkClientId: mockNetworkClientId, }), + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue(mockNetworkClientId), }, }, ); @@ -4826,108 +4955,156 @@ describe('PredictController', () => { }); it('calls NetworkController.findNetworkClientIdByChainId with hex chain ID', async () => { - await withController(async ({ controller }) => { - // Arrange - const chainId = 137; + const mockFindNetworkClientIdByChainId = jest + .fn() + .mockReturnValue('polygon-mainnet'); + await withController( + async ({ controller }) => { + // Arrange + const chainId = 137; - // Act - // eslint-disable-next-line dot-notation - await controller['invalidateQueryCache'](chainId); + // Act + // eslint-disable-next-line dot-notation + await controller['invalidateQueryCache'](chainId); - // Assert - expect( - Engine.context.NetworkController.findNetworkClientIdByChainId, - ).toHaveBeenCalledWith('0x89'); - }); + // Assert + expect(mockFindNetworkClientIdByChainId).toHaveBeenCalledWith('0x89'); + }, + { + mocks: { + findNetworkClientIdByChainId: mockFindNetworkClientIdByChainId, + }, + }, + ); }); it('calls NetworkController.getNetworkClientById with network client ID', async () => { - await withController(async ({ controller }) => { - // Arrange - const chainId = 137; + const mockGetNetworkClientById = jest + .fn() + .mockReturnValue(DEFAULT_NETWORK_CLIENT); + await withController( + async ({ controller }) => { + // Arrange + const chainId = 137; - // Act - // eslint-disable-next-line dot-notation - await controller['invalidateQueryCache'](chainId); + // Act + // eslint-disable-next-line dot-notation + await controller['invalidateQueryCache'](chainId); - // Assert - expect( - Engine.context.NetworkController.getNetworkClientById, - ).toHaveBeenCalledWith('mainnet'); - }); + // Assert + expect(mockGetNetworkClientById).toHaveBeenCalledWith( + 'polygon-mainnet', + ); + }, + { + mocks: { + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue('polygon-mainnet'), + getNetworkClientById: mockGetNetworkClientById, + }, + }, + ); }); it('calls blockTracker.checkForLatestBlock to invalidate cache', async () => { const mockCheckForLatestBlock = jest.fn().mockResolvedValue(undefined); - // @ts-expect-error - Mocking Engine for test - Engine.context.NetworkController.getNetworkClientById.mockReturnValue({ + const mockGetNetworkClientById = jest.fn().mockReturnValue({ blockTracker: { checkForLatestBlock: mockCheckForLatestBlock, }, }); - await withController(async ({ controller }) => { - // Arrange - const chainId = 137; + await withController( + async ({ controller }) => { + // Arrange + const chainId = 137; - // Act - // eslint-disable-next-line dot-notation - await controller['invalidateQueryCache'](chainId); + // Act + // eslint-disable-next-line dot-notation + await controller['invalidateQueryCache'](chainId); - // Assert - expect(mockCheckForLatestBlock).toHaveBeenCalledWith(); - }); + // Assert + expect(mockCheckForLatestBlock).toHaveBeenCalledWith(); + }, + { + mocks: { + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue('polygon-mainnet'), + getNetworkClientById: mockGetNetworkClientById, + }, + }, + ); }); it('logs error when blockTracker.checkForLatestBlock fails', async () => { const mockError = new Error('Block tracker error'); const mockCheckForLatestBlock = jest.fn().mockRejectedValue(mockError); - // @ts-expect-error - Mocking Engine for test - Engine.context.NetworkController.getNetworkClientById.mockReturnValue({ + const mockGetNetworkClientById = jest.fn().mockReturnValue({ blockTracker: { checkForLatestBlock: mockCheckForLatestBlock, }, }); - await withController(async ({ controller }) => { - // Arrange - const chainId = 137; + await withController( + async ({ controller }) => { + // Arrange + const chainId = 137; - // Act - // eslint-disable-next-line dot-notation - await controller['invalidateQueryCache'](chainId); + // Act + // eslint-disable-next-line dot-notation + await controller['invalidateQueryCache'](chainId); - // Assert - expect(DevLogger.log).toHaveBeenCalledWith( - 'PredictController: Error invalidating query cache', - expect.objectContaining({ - error: 'Block tracker error', - timestamp: expect.any(String), - }), - ); - }); + // Assert + expect(DevLogger.log).toHaveBeenCalledWith( + 'PredictController: Error invalidating query cache', + expect.objectContaining({ + error: 'Block tracker error', + timestamp: expect.any(String), + }), + ); + }, + { + mocks: { + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue('polygon-mainnet'), + getNetworkClientById: mockGetNetworkClientById, + }, + }, + ); }); it('continues execution when invalidation fails', async () => { const mockError = new Error('Block tracker error'); const mockCheckForLatestBlock = jest.fn().mockRejectedValue(mockError); - // @ts-expect-error - Mocking Engine for test - Engine.context.NetworkController.getNetworkClientById.mockReturnValue({ + const mockGetNetworkClientById = jest.fn().mockReturnValue({ blockTracker: { checkForLatestBlock: mockCheckForLatestBlock, }, }); - await withController(async ({ controller }) => { - // Arrange - const chainId = 137; + await withController( + async ({ controller }) => { + // Arrange + const chainId = 137; - // Act & Assert - should not throw - await expect( - // eslint-disable-next-line dot-notation - controller['invalidateQueryCache'](chainId), - ).resolves.not.toThrow(); - }); + // Act & Assert - should not throw + await expect( + // eslint-disable-next-line dot-notation + controller['invalidateQueryCache'](chainId), + ).resolves.not.toThrow(); + }, + { + mocks: { + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue('polygon-mainnet'), + getNetworkClientById: mockGetNetworkClientById, + }, + }, + ); }); }); @@ -5371,24 +5548,33 @@ describe('PredictController', () => { it('calls invalidateQueryCache before fetching fresh balance', async () => { const mockCheckForLatestBlock = jest.fn().mockResolvedValue(undefined); - // @ts-expect-error - Mocking Engine for test - Engine.context.NetworkController.getNetworkClientById.mockReturnValue({ + const mockGetNetworkClientById = jest.fn().mockReturnValue({ blockTracker: { checkForLatestBlock: mockCheckForLatestBlock, }, }); mockPolymarketProvider.getBalance.mockResolvedValue(1000); - await withController(async ({ controller }) => { - // Act - await controller.getBalance({ - providerId: 'polymarket', - }); + await withController( + async ({ controller }) => { + // Act + await controller.getBalance({ + providerId: 'polymarket', + }); - // Assert - expect(mockCheckForLatestBlock).toHaveBeenCalled(); - expect(mockPolymarketProvider.getBalance).toHaveBeenCalled(); - }); + // Assert + expect(mockCheckForLatestBlock).toHaveBeenCalled(); + expect(mockPolymarketProvider.getBalance).toHaveBeenCalled(); + }, + { + mocks: { + findNetworkClientIdByChainId: jest + .fn() + .mockReturnValue('polygon-mainnet'), + getNetworkClientById: mockGetNetworkClientById, + }, + }, + ); }); it('fetches balance when validUntil equals current time', async () => { diff --git a/app/components/UI/Predict/controllers/PredictController.ts b/app/components/UI/Predict/controllers/PredictController.ts index 3ec84f8273d..4e9c5f23a3a 100644 --- a/app/components/UI/Predict/controllers/PredictController.ts +++ b/app/components/UI/Predict/controllers/PredictController.ts @@ -1,4 +1,6 @@ import { AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; +import { AccountTreeControllerGetAccountsFromSelectedAccountGroupAction } from '@metamask/account-tree-controller'; +import { isEvmAccountType } from '@metamask/keyring-api'; import { BaseController, ControllerGetStateAction, @@ -11,8 +13,14 @@ import { PersonalMessageParams, SignTypedDataVersion, TypedMessageParams, + KeyringControllerSignTypedMessageAction, + KeyringControllerSignPersonalMessageAction, } from '@metamask/keyring-controller'; -import { NetworkControllerGetStateAction } from '@metamask/network-controller'; +import { + NetworkControllerGetStateAction, + NetworkControllerFindNetworkClientIdByChainIdAction, + NetworkControllerGetNetworkClientByIdAction, +} from '@metamask/network-controller'; import { TransactionControllerEstimateGasAction, TransactionControllerTransactionConfirmedEvent, @@ -22,11 +30,14 @@ import { TransactionMeta, TransactionType, } from '@metamask/transaction-controller'; +import { + RemoteFeatureFlagControllerGetStateAction, + RemoteFeatureFlagControllerStateChangeEvent, +} from '@metamask/remote-feature-flag-controller'; import { Hex, hexToNumber, numberToHex } from '@metamask/utils'; import performance from 'react-native-performance'; import { MetaMetrics, MetaMetricsEvents } from '../../../../core/Analytics'; import { MetricsEventBuilder } from '../../../../core/Analytics/MetricsEventBuilder'; -import Engine from '../../../../core/Engine'; import DevLogger from '../../../../core/SDKConnect/utils/DevLogger'; import Logger, { type LoggerErrorOptions } from '../../../../util/Logger'; import { @@ -82,7 +93,6 @@ import { } from '../types'; import { ensureError } from '../utils/predictErrorHandler'; import { PREDICT_CONSTANTS, PREDICT_ERROR_CODES } from '../constants/errors'; -import { getEvmAccountFromSelectedAccountGroup } from '../utils/accounts'; import { GEO_BLOCKED_COUNTRIES } from '../constants/geoblock'; import { MATIC_CONTRACTS } from '../providers/polymarket/constants'; import { @@ -228,8 +238,14 @@ export type PredictControllerActions = */ type AllowedActions = | AccountsControllerGetSelectedAccountAction + | AccountTreeControllerGetAccountsFromSelectedAccountGroupAction | NetworkControllerGetStateAction - | TransactionControllerEstimateGasAction; + | NetworkControllerFindNetworkClientIdByChainIdAction + | NetworkControllerGetNetworkClientByIdAction + | TransactionControllerEstimateGasAction + | KeyringControllerSignTypedMessageAction + | KeyringControllerSignPersonalMessageAction + | RemoteFeatureFlagControllerGetStateAction; /** * External events the PredictController can subscribe to @@ -238,7 +254,8 @@ type AllowedEvents = | TransactionControllerTransactionSubmittedEvent | TransactionControllerTransactionConfirmedEvent | TransactionControllerTransactionFailedEvent - | TransactionControllerTransactionRejectedEvent; + | TransactionControllerTransactionRejectedEvent + | RemoteFeatureFlagControllerStateChangeEvent; /** * PredictController messenger constraints @@ -407,27 +424,42 @@ export class PredictController extends BaseController< * @private */ private getSigner(address?: string): Signer { - const { KeyringController } = Engine.context; - const selectedAddress = - address ?? getEvmAccountFromSelectedAccountGroup()?.address ?? '0x0'; + const selectedAddress = address ?? this.getEvmAccountAddress(); return { address: selectedAddress, signTypedMessage: ( _params: TypedMessageParams, _version: SignTypedDataVersion, - ) => KeyringController.signTypedMessage(_params, _version), + ) => + this.messenger.call( + 'KeyringController:signTypedMessage', + _params, + _version, + ), signPersonalMessage: (_params: PersonalMessageParams) => - KeyringController.signPersonalMessage(_params), + this.messenger.call('KeyringController:signPersonalMessage', _params), }; } + private getEvmAccountAddress(): string { + const accounts = this.messenger.call( + 'AccountTreeController:getAccountsFromSelectedAccountGroup', + ); + const evmAccount = accounts.find( + (account) => account && isEvmAccountType(account.type), + ); + return evmAccount?.address ?? '0x0'; + } + private async invalidateQueryCache(chainId: number) { - const { NetworkController } = Engine.context; - const networkClientId = NetworkController.findNetworkClientIdByChainId( + const networkClientId = this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', numberToHex(chainId), ); - const networkClient = - NetworkController.getNetworkClientById(networkClientId); + const networkClient = this.messenger.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ); try { await networkClient.blockTracker.checkForLatestBlock(); } catch (error) { @@ -474,9 +506,11 @@ export class PredictController extends BaseController< throw new Error('Provider not available'); } - const { RemoteFeatureFlagController } = Engine.context; + const remoteFeatureFlagState = this.messenger.call( + 'RemoteFeatureFlagController:getState', + ); const liveSportsFlag = - (RemoteFeatureFlagController.state.remoteFeatureFlags + (remoteFeatureFlagState.remoteFeatureFlags .predictLiveSports as unknown as PredictLiveSportsFlag | undefined) ?? DEFAULT_LIVE_SPORTS_FLAG; const liveSportsLeagues = liveSportsFlag.enabled @@ -484,7 +518,7 @@ export class PredictController extends BaseController< : []; const marketHighlightsFlag = - (RemoteFeatureFlagController.state.remoteFeatureFlags + (remoteFeatureFlagState.remoteFeatureFlags .predictMarketHighlights as unknown as | PredictMarketHighlightsFlag | undefined) ?? DEFAULT_MARKET_HIGHLIGHTS_FLAG; @@ -619,9 +653,11 @@ export class PredictController extends BaseController< throw new Error('Provider not available'); } - const { RemoteFeatureFlagController } = Engine.context; + const remoteFeatureFlagState = this.messenger.call( + 'RemoteFeatureFlagController:getState', + ); const liveSportsFlag = - (RemoteFeatureFlagController.state.remoteFeatureFlags + (remoteFeatureFlagState.remoteFeatureFlags .predictLiveSports as unknown as PredictLiveSportsFlag | undefined) ?? DEFAULT_LIVE_SPORTS_FLAG; const liveSportsLeagues = liveSportsFlag.enabled @@ -1427,9 +1463,11 @@ export class PredictController extends BaseController< throw new Error('Provider not available'); } - const { RemoteFeatureFlagController } = Engine.context; + const remoteFeatureFlagState = this.messenger.call( + 'RemoteFeatureFlagController:getState', + ); const feeCollection = - (RemoteFeatureFlagController.state.remoteFeatureFlags + (remoteFeatureFlagState.remoteFeatureFlags .predictFeeCollection as unknown as | PredictFeeCollection | undefined) ?? DEFAULT_FEE_COLLECTION_FLAG; @@ -1683,8 +1721,8 @@ export class PredictController extends BaseController< } // Find network client - can fail if chain is not supported - const { NetworkController } = Engine.context; - const networkClientId = NetworkController.findNetworkClientIdByChainId( + const networkClientId = this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', numberToHex(chainId), ); @@ -1977,9 +2015,10 @@ export class PredictController extends BaseController< providerId: params.providerId, }); - const { NetworkController } = Engine.context; - const networkClientId = - NetworkController.findNetworkClientIdByChainId(chainId); + const networkClientId = this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', + chainId, + ); if (!networkClientId) { throw new Error(`Network client not found for chain ID: ${chainId}`); @@ -2222,13 +2261,13 @@ export class PredictController extends BaseController< }; }); - const { NetworkController } = Engine.context; - const { batchId } = await addTransactionBatch({ from: signer.address as Hex, origin: ORIGIN_METAMASK, - networkClientId: - NetworkController.findNetworkClientIdByChainId(chainId), + networkClientId: this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', + chainId, + ), disableHook: true, disableSequential: true, requireApproval: true, @@ -2324,8 +2363,8 @@ export class PredictController extends BaseController< const chainId = this.state.withdrawTransaction.chainId; - const { NetworkController } = Engine.context; - const networkClientId = NetworkController.findNetworkClientIdByChainId( + const networkClientId = this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', numberToHex(chainId), ); diff --git a/app/components/UI/Ramp/hooks/useAnalytics.ts b/app/components/UI/Ramp/hooks/useAnalytics.ts index 90de87141f6..b5e35c48969 100644 --- a/app/components/UI/Ramp/hooks/useAnalytics.ts +++ b/app/components/UI/Ramp/hooks/useAnalytics.ts @@ -5,6 +5,7 @@ import { AnalyticsEvents as DepositEvents } from '../Deposit/types'; import { MetaMetricsEvents } from '../../../../core/Analytics'; import { analytics } from '../../../../util/analytics/analytics'; import { AnalyticsEventBuilder } from '../../../../util/analytics/AnalyticsEventBuilder'; +import type { AnalyticsUnfilteredProperties } from '../../../../util/analytics/analytics.types'; interface MergedRampEvents extends AggregatorEvents, DepositEvents {} @@ -14,7 +15,7 @@ export function trackEvent( ) { analytics.trackEvent( AnalyticsEventBuilder.createEventBuilder(MetaMetricsEvents[eventType]) - .addProperties({ ...params }) + .addProperties(params as AnalyticsUnfilteredProperties) .build(), ); } diff --git a/app/components/UI/Ramp/index.test.tsx b/app/components/UI/Ramp/index.test.tsx index b54e56368b8..ee26683de5c 100644 --- a/app/components/UI/Ramp/index.test.tsx +++ b/app/components/UI/Ramp/index.test.tsx @@ -14,6 +14,11 @@ import getAggregatorAnalyticsPayload from './Aggregator/utils/getAggregatorAnaly const mockNavigate = jest.fn(); +jest.mock('./hooks/useHydrateRampsController', () => ({ + __esModule: true, + default: jest.fn(), +})); + jest.mock('@react-navigation/native', () => { const actual = jest.requireActual('@react-navigation/native'); return { diff --git a/app/components/UI/Ramp/index.tsx b/app/components/UI/Ramp/index.tsx index fc175659349..bd0f817a3dd 100644 --- a/app/components/UI/Ramp/index.tsx +++ b/app/components/UI/Ramp/index.tsx @@ -34,6 +34,7 @@ import Routes from '../../../constants/navigation/Routes'; import getOrderAnalyticsPayload from './utils/getOrderAnalyticsPayload'; import { NativeRampsSdk } from '@consensys/native-ramps-sdk'; import useDetectGeolocation from './hooks/useDetectGeolocation'; +import useHydrateRampsController from './hooks/useHydrateRampsController'; import useRampsSmartRouting from './hooks/useRampsSmartRouting'; const POLLING_FREQUENCY = AppConstants.FIAT_ORDERS.POLLING_FREQUENCY; @@ -117,6 +118,7 @@ const styles = StyleSheet.create({ }); function FiatOrders() { + useHydrateRampsController(); useFetchRampNetworks(); useDetectGeolocation(); useRampsSmartRouting(); diff --git a/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx b/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx index 3a4f45d6965..5142d41c985 100644 --- a/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx +++ b/app/components/UI/Rewards/Views/RewardsDashboard.test.tsx @@ -53,6 +53,10 @@ jest.mock('../../../../selectors/rewards', () => ({ selectRewardsSubscriptionId: jest.fn(), })); +jest.mock('../../../../selectors/featureFlagController/rewards', () => ({ + selectSnapshotsRewardsEnabledFlag: jest.fn(), +})); + jest.mock( '../../../../selectors/multichainAccounts/accountTreeController', () => ({ @@ -69,6 +73,7 @@ import { } from '../../../../reducers/rewards/selectors'; import { selectRewardsSubscriptionId } from '../../../../selectors/rewards'; import { selectSelectedAccountGroup } from '../../../../selectors/multichainAccounts/accountTreeController'; +import { selectSnapshotsRewardsEnabledFlag } from '../../../../selectors/featureFlagController/rewards'; const mockSelectActiveTab = selectActiveTab as jest.MockedFunction< typeof selectActiveTab @@ -95,6 +100,10 @@ const mockSelectSelectedAccountGroup = selectSelectedAccountGroup as jest.MockedFunction< typeof selectSelectedAccountGroup >; +const mockSelectSnapshotsRewardsEnabledFlag = + selectSnapshotsRewardsEnabledFlag as jest.MockedFunction< + typeof selectSnapshotsRewardsEnabledFlag + >; // Mock theme jest.mock('../../../../util/theme', () => ({ @@ -168,7 +177,7 @@ jest.mock('../../../../../locales/i18n', () => ({ const translations: Record = { 'rewards.main_title': 'Rewards', 'rewards.tab_overview_title': 'Overview', - 'rewards.tab_levels_title': 'Levels', + 'rewards.tab_snapshots_title': 'Snapshots', 'rewards.tab_activity_title': 'Activity', 'rewards.not_implemented': 'Not implemented yet', }; @@ -235,16 +244,16 @@ jest.mock('../components/Tabs/RewardsOverview', () => ({ }, })); -jest.mock('../components/Tabs/RewardsLevels', () => ({ +jest.mock('../components/Tabs/RewardsSnapshots', () => ({ __esModule: true, - default: function MockRewardsLevels({ tabLabel }: { tabLabel: string }) { + default: function MockRewardsSnapshots({ tabLabel }: { tabLabel: string }) { const ReactActual = jest.requireActual('react'); const { View, Text } = jest.requireActual('react-native'); return ReactActual.createElement( View, - { testID: 'rewards-levels-tab' }, - ReactActual.createElement(Text, null, tabLabel || 'Levels'), + { testID: 'rewards-snapshots-tab' }, + ReactActual.createElement(Text, null, tabLabel || 'Snapshots'), ); }, })); @@ -359,6 +368,10 @@ jest.mock('../hooks/useRewardDashboardModals', () => ({ useRewardDashboardModals: jest.fn(), })); +jest.mock('../hooks/useBulkLinkState', () => ({ + useBulkLinkState: jest.fn(), +})); + jest.mock('../utils', () => ({ convertInternalAccountToCaipAccountId: jest.fn(), })); @@ -544,6 +557,7 @@ jest.spyOn(Alert, 'alert').mockImplementation(mockAlert); import { useRewardOptinSummary } from '../hooks/useRewardOptinSummary'; import { useLinkAccountGroup } from '../hooks/useLinkAccountGroup'; import { useRewardDashboardModals } from '../hooks/useRewardDashboardModals'; +import { useBulkLinkState } from '../hooks/useBulkLinkState'; import { convertInternalAccountToCaipAccountId } from '../utils'; import { InternalAccount } from '@metamask/keyring-internal-api'; import { AccountGroupType, AccountWalletType } from '@metamask/account-api'; @@ -558,6 +572,9 @@ const mockUseRewardDashboardModals = useRewardDashboardModals as jest.MockedFunction< typeof useRewardDashboardModals >; +const mockUseBulkLinkState = useBulkLinkState as jest.MockedFunction< + typeof useBulkLinkState +>; const mockConvertInternalAccountToCaipAccountId = convertInternalAccountToCaipAccountId as jest.MockedFunction< typeof convertInternalAccountToCaipAccountId @@ -571,6 +588,10 @@ describe('RewardsDashboard', () => { const mockShowNotSupportedModal = jest.fn(); const mockHasShownModal = jest.fn(); const mockResetSessionTracking = jest.fn(); + const mockResumeBulkLink = jest.fn(); + const mockStartBulkLink = jest.fn(); + const mockCancelBulkLink = jest.fn(); + const mockResetBulkLink = jest.fn(); const mockSelectedAccount = { id: 'account-1', @@ -610,6 +631,7 @@ describe('RewardsDashboard', () => { hideCurrentAccountNotOptedInBannerArray: [], selectedAccount: mockSelectedAccount, selectedAccountGroup: mockSelectedAccountGroup, + isSnapshotsEnabled: true, // Enable snapshots by default in tests }; const defaultHookValues = { @@ -637,6 +659,22 @@ describe('RewardsDashboard', () => { resetSessionTrackingForCurrentAccountGroup: jest.fn(), resetAllSessionTracking: jest.fn(), }, + useBulkLinkState: { + startBulkLink: mockStartBulkLink, + cancelBulkLink: mockCancelBulkLink, + resetBulkLink: mockResetBulkLink, + resumeBulkLink: mockResumeBulkLink, + isRunning: false, + wasInterrupted: false, + isCompleted: false, + hasFailures: false, + isFullySuccessful: false, + totalAccounts: 0, + linkedAccounts: 0, + failedAccounts: 0, + accountProgress: 0, + processedAccounts: 0, + }, }; beforeEach(() => { @@ -648,6 +686,10 @@ describe('RewardsDashboard', () => { mockShowNotSupportedModal.mockClear(); mockHasShownModal.mockClear(); mockResetSessionTracking.mockClear(); + mockResumeBulkLink.mockClear(); + mockStartBulkLink.mockClear(); + mockCancelBulkLink.mockClear(); + mockResetBulkLink.mockClear(); mockTrackEvent.mockClear(); mockCreateEventBuilder.mockClear(); mockBuild.mockClear(); @@ -682,6 +724,9 @@ describe('RewardsDashboard', () => { mockSelectSelectedAccountGroup.mockReturnValue( defaultSelectorValues.selectedAccountGroup, ); + mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue( + defaultSelectorValues.isSnapshotsEnabled, + ); // Setup hook mocks mockUseRewardOptinSummary.mockReturnValue( @@ -693,6 +738,7 @@ describe('RewardsDashboard', () => { mockUseRewardDashboardModals.mockReturnValue( defaultHookValues.useRewardDashboardModals, ); + mockUseBulkLinkState.mockReturnValue(defaultHookValues.useBulkLinkState); mockConvertInternalAccountToCaipAccountId.mockReturnValue('eip155:1:0x123'); // Setup default modal hook behavior - return false for all modal types by default @@ -711,6 +757,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); }); @@ -764,6 +812,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -796,6 +846,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -828,6 +880,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -858,6 +912,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -900,11 +956,11 @@ describe('RewardsDashboard', () => { it('should handle tab change when user selects different tab', () => { // Act const { getByTestId } = render(); - const levelsTab = getByTestId('tab-1'); - fireEvent.press(levelsTab); + const snapshotsTab = getByTestId('tab-1'); + fireEvent.press(snapshotsTab); // Assert - expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('levels')); + expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('snapshots')); }); it('should render all tab options', () => { @@ -926,14 +982,14 @@ describe('RewardsDashboard', () => { expect(getByTestId('rewards-overview-tab')).toBeTruthy(); }); - it('should switch to levels tab when levels tab is pressed', () => { + it('switches to snapshots tab when snapshots tab is pressed', () => { // Act const { getByTestId } = render(); - const levelsTab = getByTestId('tab-1'); - fireEvent.press(levelsTab); + const snapshotsTab = getByTestId('tab-1'); + fireEvent.press(snapshotsTab); // Assert - expect(getByTestId('rewards-levels-tab')).toBeTruthy(); + expect(getByTestId('rewards-snapshots-tab')).toBeTruthy(); }); it('should switch to activity tab when activity tab is pressed', () => { @@ -946,7 +1002,7 @@ describe('RewardsDashboard', () => { expect(getByTestId('rewards-activity-tab')).toBeTruthy(); }); - it('should not allow tab switching when user is not opted in', () => { + it('allows tab switching when user is not opted in', () => { // Arrange const futureDateObj = new Date(futureDate); mockSelectRewardsSubscriptionId.mockReturnValue(null); @@ -957,16 +1013,145 @@ describe('RewardsDashboard', () => { if (selector === selectRewardsSubscriptionId) return null; if (selector === selectSeasonId) return currentSeasonId; if (selector === selectSeasonEndDate) return futureDateObj; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); // Act const { getByTestId } = render(); - const levelsTab = getByTestId('tab-1'); - fireEvent.press(levelsTab); + const snapshotsTab = getByTestId('tab-1'); + fireEvent.press(snapshotsTab); - // Assert - should show levels tab (tab change occurred) - expect(getByTestId('rewards-levels-tab')).toBeTruthy(); + // Assert - tab change occurred + expect(getByTestId('rewards-snapshots-tab')).toBeTruthy(); + }); + }); + + describe('tabComponents when isSnapshotsEnabled is false', () => { + beforeEach(() => { + mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false); + mockUseSelector.mockImplementation((selector) => { + if (selector === selectActiveTab) + return defaultSelectorValues.activeTab; + if (selector === selectRewardsSubscriptionId) + return defaultSelectorValues.subscriptionId; + if (selector === selectSeasonId) return currentSeasonId; + if (selector === selectSeasonEndDate) + return defaultSelectorValues.seasonEndDate; + if (selector === selectHideUnlinkedAccountsBanner) + return defaultSelectorValues.hideUnlinkedAccountsBanner; + if (selector === selectHideCurrentAccountNotOptedInBannerArray) + return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; + if (selector === selectSelectedAccountGroup) + return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) return false; + return undefined; + }); + }); + + it('renders only overview and activity tabs when snapshots is disabled', () => { + // Act + const { getByTestId, queryByTestId } = render(); + + // Assert - verify only 2 tabs are rendered + expect(getByTestId('tab-headers')).toBeTruthy(); + expect(getByTestId('tab-0')).toBeTruthy(); + expect(getByTestId('tab-1')).toBeTruthy(); + expect(queryByTestId('tab-2')).toBeNull(); + }); + + it('does not render snapshots tab when snapshots is disabled', () => { + // Act + const { queryByTestId } = render(); + + // Assert - snapshots tab should not be visible by default + expect(queryByTestId('rewards-snapshots-tab')).toBeNull(); + }); + + it('renders overview tab as first tab when snapshots is disabled', () => { + // Act + const { getByTestId } = render(); + + // Assert + expect(getByTestId('rewards-overview-tab')).toBeTruthy(); + }); + + it('switches directly to activity tab at index 1 when snapshots is disabled', () => { + // Act + const { getByTestId } = render(); + const activityTab = getByTestId('tab-1'); + fireEvent.press(activityTab); + + // Assert - activity tab is now at index 1 instead of index 2 + expect(getByTestId('rewards-activity-tab')).toBeTruthy(); + }); + + it('dispatches setActiveTab with activity when tab-1 is pressed and snapshots is disabled', () => { + // Act + const { getByTestId } = render(); + const activityTab = getByTestId('tab-1'); + fireEvent.press(activityTab); + + // Assert - tab-1 should now be activity, not snapshots + expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('activity')); + }); + + it('resets activeTab to overview when snapshots tab becomes unavailable', () => { + // Arrange - activeTab is 'snapshots' but isSnapshotsEnabled is false + mockSelectActiveTab.mockReturnValue('snapshots'); + mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false); + mockUseSelector.mockImplementation((selector) => { + if (selector === selectActiveTab) return 'snapshots'; + if (selector === selectRewardsSubscriptionId) + return defaultSelectorValues.subscriptionId; + if (selector === selectSeasonId) return currentSeasonId; + if (selector === selectSeasonEndDate) + return defaultSelectorValues.seasonEndDate; + if (selector === selectHideUnlinkedAccountsBanner) + return defaultSelectorValues.hideUnlinkedAccountsBanner; + if (selector === selectHideCurrentAccountNotOptedInBannerArray) + return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; + if (selector === selectSelectedAccountGroup) + return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) return false; + return undefined; + }); + + // Act + render(); + + // Assert - should dispatch setActiveTab('overview') to reset the invalid tab + expect(mockDispatch).toHaveBeenCalledWith(setActiveTab('overview')); + }); + + it('does not reset activeTab when current tab is still available', () => { + // Arrange - activeTab is 'activity' and isSnapshotsEnabled is false + // activity tab should still be available + mockSelectActiveTab.mockReturnValue('activity'); + mockSelectSnapshotsRewardsEnabledFlag.mockReturnValue(false); + mockUseSelector.mockImplementation((selector) => { + if (selector === selectActiveTab) return 'activity'; + if (selector === selectRewardsSubscriptionId) + return defaultSelectorValues.subscriptionId; + if (selector === selectSeasonId) return currentSeasonId; + if (selector === selectSeasonEndDate) + return defaultSelectorValues.seasonEndDate; + if (selector === selectHideUnlinkedAccountsBanner) + return defaultSelectorValues.hideUnlinkedAccountsBanner; + if (selector === selectHideCurrentAccountNotOptedInBannerArray) + return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; + if (selector === selectSelectedAccountGroup) + return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) return false; + return undefined; + }); + + // Act + render(); + + // Assert - should NOT dispatch setActiveTab since 'activity' is still valid + expect(mockDispatch).not.toHaveBeenCalledWith(setActiveTab('overview')); }); }); @@ -989,6 +1174,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1020,6 +1207,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1051,6 +1240,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1079,6 +1270,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1192,6 +1385,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1251,6 +1446,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1286,6 +1483,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1393,6 +1592,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1467,6 +1668,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1497,6 +1700,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1577,6 +1782,8 @@ describe('RewardsDashboard', () => { return [{ accountGroupId: 'keyring:wallet1/1', hide: true }]; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1607,6 +1814,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1686,6 +1895,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1758,6 +1969,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1926,6 +2139,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -1987,6 +2202,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -2048,6 +2265,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); @@ -2136,7 +2355,7 @@ describe('RewardsDashboard', () => { expect(mockCreateEventBuilder).not.toHaveBeenCalled(); }); - it('should track tab viewed event when activeTab changes', () => { + it('tracks tab viewed event when activeTab changes', () => { // Arrange const { rerender } = render(); mockTrackEvent.mockClear(); @@ -2144,9 +2363,9 @@ describe('RewardsDashboard', () => { mockBuild.mockClear(); // Act - change active tab - mockSelectActiveTab.mockReturnValue('levels'); + mockSelectActiveTab.mockReturnValue('snapshots'); mockUseSelector.mockImplementation((selector) => { - if (selector === selectActiveTab) return 'levels'; + if (selector === selectActiveTab) return 'snapshots'; if (selector === selectRewardsSubscriptionId) return defaultSelectorValues.subscriptionId; if (selector === selectSeasonId) return currentSeasonId; @@ -2158,6 +2377,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); rerender(); @@ -2166,12 +2387,12 @@ describe('RewardsDashboard', () => { expect(mockCreateEventBuilder).toHaveBeenCalledWith( 'rewards_dashboard_tab_viewed', ); - expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'levels' }); + expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'snapshots' }); expect(mockBuild).toHaveBeenCalled(); expect(mockTrackEvent).toHaveBeenCalledWith({ event: 'mock-event' }); }); - it('should track tab viewed event for each tab change', () => { + it('tracks tab viewed event for each tab change', () => { // Arrange const { rerender } = render(); mockTrackEvent.mockClear(); @@ -2179,10 +2400,10 @@ describe('RewardsDashboard', () => { mockBuild.mockClear(); mockAddProperties.mockClear(); - // Act - change to levels tab - mockSelectActiveTab.mockReturnValue('levels'); + // Act - change to snapshots tab + mockSelectActiveTab.mockReturnValue('snapshots'); mockUseSelector.mockImplementation((selector) => { - if (selector === selectActiveTab) return 'levels'; + if (selector === selectActiveTab) return 'snapshots'; if (selector === selectRewardsSubscriptionId) return defaultSelectorValues.subscriptionId; if (selector === selectSeasonId) return currentSeasonId; @@ -2194,12 +2415,14 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); rerender(); - // Assert - levels tab - expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'levels' }); + // Assert - snapshots tab + expect(mockAddProperties).toHaveBeenCalledWith({ tab: 'snapshots' }); // Act - change to activity tab mockSelectActiveTab.mockReturnValue('activity'); @@ -2216,6 +2439,8 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); rerender(); @@ -2226,15 +2451,15 @@ describe('RewardsDashboard', () => { }); describe('TabsList ref functionality', () => { - it('should handle Redux state changes for activeTab without crashing', () => { + it('handles Redux state changes for activeTab without crashing', () => { // Arrange mockSelectActiveTab.mockReturnValue('overview'); const { rerender } = render(); - // Act - change activeTab in Redux to levels - mockSelectActiveTab.mockReturnValue('levels'); + // Act - change activeTab in Redux to snapshots + mockSelectActiveTab.mockReturnValue('snapshots'); mockUseSelector.mockImplementation((selector) => { - if (selector === selectActiveTab) return 'levels'; + if (selector === selectActiveTab) return 'snapshots'; if (selector === selectRewardsSubscriptionId) return defaultSelectorValues.subscriptionId; if (selector === selectSeasonId) return currentSeasonId; @@ -2246,21 +2471,23 @@ describe('RewardsDashboard', () => { return defaultSelectorValues.hideCurrentAccountNotOptedInBannerArray; if (selector === selectSelectedAccountGroup) return defaultSelectorValues.selectedAccountGroup; + if (selector === selectSnapshotsRewardsEnabledFlag) + return defaultSelectorValues.isSnapshotsEnabled; return undefined; }); - // Assert - should not crash when activeTab changes + // Assert - does not crash when activeTab changes expect(() => rerender()).not.toThrow(); }); }); describe('component lifecycle', () => { - it('should render without crashing', () => { + it('renders without crashing', () => { // Act & Assert expect(() => render()).not.toThrow(); }); - it('should cleanup properly when unmounted', () => { + it('cleans up properly when unmounted', () => { // Act const { unmount } = render(); @@ -2268,4 +2495,81 @@ describe('RewardsDashboard', () => { expect(() => unmount()).not.toThrow(); }); }); + + describe('bulk link auto-resume', () => { + it('calls resumeBulkLink when wasInterrupted is true and isRunning is false', () => { + // Arrange + mockUseBulkLinkState.mockReturnValue({ + ...defaultHookValues.useBulkLinkState, + wasInterrupted: true, + isRunning: false, + }); + + // Act + render(); + + // Assert + expect(mockResumeBulkLink).toHaveBeenCalled(); + }); + + it('does not call resumeBulkLink when wasInterrupted is false', () => { + // Arrange + mockUseBulkLinkState.mockReturnValue({ + ...defaultHookValues.useBulkLinkState, + wasInterrupted: false, + isRunning: false, + }); + + // Act + render(); + + // Assert + expect(mockResumeBulkLink).not.toHaveBeenCalled(); + }); + + it('does not call resumeBulkLink when isRunning is true', () => { + // Arrange + mockUseBulkLinkState.mockReturnValue({ + ...defaultHookValues.useBulkLinkState, + wasInterrupted: true, + isRunning: true, + }); + + // Act + render(); + + // Assert + expect(mockResumeBulkLink).not.toHaveBeenCalled(); + }); + + it('does not call resumeBulkLink when both wasInterrupted and isRunning are false', () => { + // Arrange + mockUseBulkLinkState.mockReturnValue({ + ...defaultHookValues.useBulkLinkState, + wasInterrupted: false, + isRunning: false, + }); + + // Act + render(); + + // Assert + expect(mockResumeBulkLink).not.toHaveBeenCalled(); + }); + + it('does not call resumeBulkLink when both wasInterrupted and isRunning are true', () => { + // Arrange + mockUseBulkLinkState.mockReturnValue({ + ...defaultHookValues.useBulkLinkState, + wasInterrupted: true, + isRunning: true, + }); + + // Act + render(); + + // Assert + expect(mockResumeBulkLink).not.toHaveBeenCalled(); + }); + }); }); diff --git a/app/components/UI/Rewards/Views/RewardsDashboard.tsx b/app/components/UI/Rewards/Views/RewardsDashboard.tsx index 47d242b53cc..d359ba230e8 100644 --- a/app/components/UI/Rewards/Views/RewardsDashboard.tsx +++ b/app/components/UI/Rewards/Views/RewardsDashboard.tsx @@ -34,13 +34,15 @@ import { } from '../../../../reducers/rewards/selectors'; import SeasonStatus from '../components/SeasonStatus/SeasonStatus'; import { selectRewardsSubscriptionId } from '../../../../selectors/rewards'; +import { selectSnapshotsRewardsEnabledFlag } from '../../../../selectors/featureFlagController/rewards'; import { useRewardOptinSummary } from '../hooks/useRewardOptinSummary'; import { useRewardDashboardModals, RewardsDashboardModalType, } from '../hooks/useRewardDashboardModals'; +import { useBulkLinkState } from '../hooks/useBulkLinkState'; import RewardsOverview from '../components/Tabs/RewardsOverview'; -import RewardsLevels from '../components/Tabs/RewardsLevels'; +import RewardsSnapshots from '../components/Tabs/RewardsSnapshots'; import RewardsActivity from '../components/Tabs/RewardsActivity'; import { TabsList } from '../../../../component-library/components-temp/Tabs'; import { TabsListRef } from '../../../../component-library/components-temp/Tabs/TabsList/TabsList.types'; @@ -65,6 +67,7 @@ const RewardsDashboard: React.FC = () => { ); const seasonId = useSelector(selectSeasonId); const seasonEndDate = useSelector(selectSeasonEndDate); + const isSnapshotsEnabled = useSelector(selectSnapshotsRewardsEnabledFlag); const hideCurrentAccountNotOptedInBannerMap = useSelector( selectHideCurrentAccountNotOptedInBannerArray, ); @@ -100,6 +103,9 @@ const RewardsDashboard: React.FC = () => { currentAccountGroupOptedInStatus, } = useRewardOptinSummary(); + // Use the bulk link state hook for resuming interrupted opt-in processes + const { wasInterrupted, isRunning, resumeBulkLink } = useBulkLinkState(); + const totalOptedInAccountsSelectedGroup = useMemo( () => optInBySelectedAccountGroup?.optedInAccounts?.length, [optInBySelectedAccountGroup], @@ -131,29 +137,48 @@ const RewardsDashboard: React.FC = () => { ); }, [colors, navigation]); - const tabOptions = useMemo( - () => [ + const tabOptions = useMemo(() => { + const options: { + value: 'overview' | 'snapshots' | 'activity'; + label: string; + }[] = [ { value: 'overview' as const, label: strings('rewards.tab_overview_title'), }, - { - value: 'levels' as const, - label: strings('rewards.tab_levels_title'), - }, - { - value: 'activity' as const, - label: strings('rewards.tab_activity_title'), - }, - ], - [], - ); + ]; + + if (isSnapshotsEnabled) { + options.push({ + value: 'snapshots' as const, + label: strings('rewards.tab_snapshots_title'), + }); + } + + options.push({ + value: 'activity' as const, + label: strings('rewards.tab_activity_title'), + }); + + return options; + }, [isSnapshotsEnabled]); const getActiveIndex = useCallback( () => tabOptions.findIndex((tab) => tab.value === activeTab), [tabOptions, activeTab], ); + // Reset activeTab to 'overview' if current tab becomes unavailable (e.g., snapshots disabled) + // This ensures Redux state stays in sync with the visible tab and analytics events are accurate + useEffect(() => { + const isCurrentTabAvailable = tabOptions.some( + (tab) => tab.value === activeTab, + ); + if (!isCurrentTabAvailable) { + dispatch(setActiveTab('overview')); + } + }, [tabOptions, activeTab, dispatch]); + // Sync TabsList with Redux state changes useEffect(() => { const activeIndex = tabOptions.findIndex((tab) => tab.value === activeTab); @@ -190,10 +215,48 @@ const RewardsDashboard: React.FC = () => { [getActiveIndex, handleTabChange], ); + const tabComponents = useMemo(() => { + const tabs: React.ReactElement[] = [ + , + ]; + + if (isSnapshotsEnabled) { + tabs.push( + , + ); + } + + tabs.push( + , + ); + + return tabs; + }, [isSnapshotsEnabled]); + const [showPreviousSeasonSummary, setShowPreviousSeasonSummary] = useState< boolean | null >(null); + // Auto-resume interrupted bulk link process when screen comes into focus. + // This handles the case where the app was closed during a bulk opt-in process. + // The saga is idempotent - it re-fetches opt-in status to skip already-linked accounts. + useFocusEffect( + useCallback(() => { + if (wasInterrupted && !isRunning) { + resumeBulkLink(); + } + }, [wasInterrupted, isRunning, resumeBulkLink]), + ); + // Evaluate showPreviousSeasonSummary when screen comes into focus useFocusEffect( useCallback(() => { @@ -328,23 +391,12 @@ const RewardsDashboard: React.FC = () => { ) : ( <> - + + + {/* Tab View */} - - - - - + {tabComponents} )} diff --git a/app/components/UI/Rewards/Views/RewardsView.constants.ts b/app/components/UI/Rewards/Views/RewardsView.constants.ts index f104c8c4a8a..51a21845453 100644 --- a/app/components/UI/Rewards/Views/RewardsView.constants.ts +++ b/app/components/UI/Rewards/Views/RewardsView.constants.ts @@ -61,4 +61,10 @@ export const REWARDS_VIEW_SELECTORS = { ACTIVITY_EVENT_ROW_DETAILS: 'activity-event-row-details', ACTIVITY_EVENT_ROW_DATE: 'activity-event-row-date', ACTIVITY_EVENT_ROW_BONUS: 'activity-event-row-bonus', + // Snapshots + TAB_CONTENT_SNAPSHOTS: 'rewards-view-tab-content-snapshots', + SNAPSHOTS_SECTION: 'rewards-view-snapshots-section', + SNAPSHOTS_ACTIVE_SECTION: 'rewards-view-snapshots-active-section', + SNAPSHOTS_UPCOMING_SECTION: 'rewards-view-snapshots-upcoming-section', + SNAPSHOTS_PREVIOUS_SECTION: 'rewards-view-snapshots-previous-section', } as const; diff --git a/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx b/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx index b8f13910c0c..cf7314b30bc 100644 --- a/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx +++ b/app/components/UI/Rewards/components/Onboarding/OnboardingNoActiveSeasonStep.tsx @@ -1,9 +1,10 @@ -import React, { useCallback } from 'react'; +import React, { useCallback, useState } from 'react'; import { Image, useWindowDimensions } from 'react-native'; import { useSelector } from 'react-redux'; import { useTailwind } from '@metamask/design-system-twrnc-preset'; import { useOptin } from '../../hooks/useOptIn'; import { Box, Text, TextVariant } from '@metamask/design-system-react-native'; +import Checkbox from '../../../../../component-library/components/Checkbox'; import step1Img from '../../../../../images/rewards/rewards-onboarding-step1.png'; import Step1BgImg from '../../../../../images/rewards/rewards-onboarding-step1-bg.svg'; import { strings } from '../../../../../../locales/i18n'; @@ -35,39 +36,62 @@ const OnboardingNoActiveSeasonStep: React.FC< const navigation = useNavigation(); const { width: screenWidth, height: screenHeight } = useWindowDimensions(); const { optin, optinError, optinLoading } = useOptin(); + const [bulkLink, setBulkLink] = useState(false); + + const handleBulkLinkToggle = useCallback(() => { + setBulkLink((prev) => !prev); + }, []); const handleNext = useCallback(() => { if (!canContinue()) { return; } - optin({}); - }, [optin, canContinue]); + optin({ bulkLink }); + }, [optin, canContinue, bulkLink]); const renderStepInfo = () => ( - - {/* Opt in error message */} - {optinError && ( - - )} + <> + + {/* Opt in error message */} + {optinError && ( + + )} - {/* Title and Description */} - - - {strings('rewards.onboarding.no_active_season.title')} - - - - {strings('rewards.onboarding.no_active_season.description')} - + {/* Title and Description */} + + + {strings('rewards.onboarding.no_active_season.title')} + + + + {strings('rewards.onboarding.no_active_season.description')} + + + + {/* Opt-in all accounts checkbox */} + + + {strings('rewards.onboarding.step4_bulk_link_checkbox')} + + } + /> - + ); const renderLegalDisclaimer = () => ( diff --git a/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx b/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx index da36b3870f4..5bb15817359 100644 --- a/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx +++ b/app/components/UI/Rewards/components/Onboarding/OnboardingStep.tsx @@ -140,7 +140,9 @@ const OnboardingStepComponent: React.FC = ({ - {renderStepInfo()} + + {renderStepInfo()} +