diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index bf36401e50..44fa0dff46 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -218,7 +218,7 @@ jobs: base_image: maxtext-unit-test-tpu:py312 cloud_runner: linux-x86-ct6e-180-4tpu pytest_marker: 'not cpu_only and not gpu_only and integration_test and not post_training' - pytest_addopts: '--ignore=tests/post_training' + pytest_addopts: '--ignore=tests/post_training --ignore=tests/integration/hlo_diff_test.py' xla_python_client_mem_fraction: 0.75 tf_force_gpu_allow_growth: false container_resource_option: "--privileged" diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 7817e2d678..a10803992c 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -70,6 +70,10 @@ on: description: 'If false, maxtext_sha must be provided for checkout' type: boolean default: false + is_update_hlo: + required: false + type: boolean + default: false permissions: contents: read @@ -167,13 +171,19 @@ jobs: else SPLIT_ARGS="" fi - $PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \ - -v \ - -m "${FINAL_PYTEST_MARKER}" \ - --durations=0 \ - $PYTEST_COV_ARGS \ - $SPLIT_ARGS \ - ${INPUTS_PYTEST_EXTRA_ARGS} + + # Setup substitution: If manually updating HLO, skip tests execution and run only the update script instead! + if [ "${INPUTS_IS_UPDATE_HLO}" == "true" ]; then + python3 tests/utils/update_hlo_references.py + else + $PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \ + -v \ + -m "${FINAL_PYTEST_MARKER}" \ + --durations=0 \ + $PYTEST_COV_ARGS \ + $SPLIT_ARGS \ + ${INPUTS_PYTEST_EXTRA_ARGS} + fi env: PYTHONPATH: "${{ github.workspace }}/src" @@ -185,6 +195,14 @@ jobs: INPUTS_WORKER_GROUP: ${{ inputs.worker_group }} INPUTS_PYTEST_EXTRA_ARGS: ${{ inputs.pytest_extra_args }} INPUTS_MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }} + INPUTS_IS_UPDATE_HLO: ${{ inputs.is_update_hlo }} + - name: Upload Reference HLO + if: ${{ inputs.is_update_hlo }} + uses: actions/upload-artifact@v4 + with: + name: reference-hlo + path: tests/utils/reference_hlo_*.txt + if-no-files-found: ignore - name: Upload results to Codecov if: ${{ !inputs.maxtext_installed }} # Skip code coverage upload for maxtext image testing uses: codecov/codecov-action@v5 diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 163c0d6763..c229a49c07 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -47,6 +47,10 @@ on: description: 'If false, maxtext_sha must be provided for checkout' type: boolean default: false + is_update_hlo: + required: false + type: boolean + default: false permissions: contents: read @@ -150,3 +154,4 @@ jobs: worker_group: ${{ matrix.worker_group }} total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }} maxtext_sha: ${{ inputs.maxtext_sha }} + is_update_hlo: ${{ inputs.is_update_hlo }} diff --git a/.github/workflows/update_reference_hlo.yml b/.github/workflows/update_reference_hlo.yml new file mode 100644 index 0000000000..d9a473ba42 --- /dev/null +++ b/.github/workflows/update_reference_hlo.yml @@ -0,0 +1,49 @@ +name: "Update HLO References (for hlo_diff_test.py)" + +on: + workflow_dispatch: +permissions: + contents: read + +jobs: + build-wheel: + uses: ./.github/workflows/build_package.yml + with: + device_type: tpu + device_name: v6e-4 + cloud_runner: linux-x86-n2-16-buildkit + + run-tests: + needs: build-wheel + uses: ./.github/workflows/run_tests_coordinator.yml + with: + flavor: tpu-integration + base_image: maxtext-unit-test-tpu:py312 + is_scheduled_run: false + maxtext_sha: ${{ github.sha }} + is_update_hlo: true + + commit-changes: + needs: run-tests # Wait for tests to finish + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.ref }} + + - name: Download Reference HLO + uses: actions/download-artifact@v4 + with: + name: reference-hlo + path: tests/utils/ + + - name: Commit and Push changes + run: | + git config --global user.name "github-actions[bot]" + git config --global user.email "github-actions[bot]@users.noreply.github.com" + git add tests/utils/reference_hlo_*.txt + git commit -m "Update reference HLO from CI artifact" + git push diff --git a/docs/development/hlo_diff_testing.md b/docs/development/hlo_diff_testing.md new file mode 100644 index 0000000000..f486f3e6de --- /dev/null +++ b/docs/development/hlo_diff_testing.md @@ -0,0 +1,90 @@ +# HLO Graph Diff Verification Testing + +This document provides context for the HLO Graph Diff tests, what HLO is, and how to manage reference baselines. + +## Related Files + +- **Test Logic**: `tests/integration/hlo_diff_test.py` +- **Reference Checkpoints baselines**: `tests/utils/reference_hlo_*.txt` +- **Update Helper script**: `tests/utils/update_hlo_references.py` +- **GitHub Action Trigger Workflow**: `.github/workflows/update_reference_hlo.yml` + +## What is HLO? + +**HLO (High-Level Optimizer)** is the intermediate representation used by XLA (Accelerated Linear Algebra) to capture the lowering compiler graph structures. + +An HLO module records: + +- The sequences of low-level math operations (dot products, convolutions, additions). +- Array tensor shapes and numerical precisions. +- Multipod TPU cluster partitioning array sharding mappings. + +## Purpose of HloDiffTest + +The primary purpose of the `TestHloDiff` validation checks is to ensure that **refactoring PRs are purely refactoring code** and not unintentionally impacting graph compiler lowering or performance. + +- **For pure refactors:** The HLO graph layout should remain *strictly identical*. Any detected deviation flags that execution boundaries or operation pipelines might have changed under the hood. +- **For dependency updates:** Changes to framework dependencies (like updating JAX or XLA versions) *are expected* to slightly alter compiled HLO output layouts, which makes baseline updates appropriate in those scenarios. + +______________________________________________________________________ + +## How the Test Works + +This test runs automatically as part of the [`tpu-integration`](https://github.com/AI-Hypercomputer/maxtext/actions/workflows/build_and_test_maxtext.yml) CI test suite on every Pull Request. + +When the test method executes, it performs the following sequence of actions: + +1. **Triggers Compilation**: It runs the model training lifecycle compilation-only phase (invoking `train_compile.main()`) without actually allocating hardware compute nodes or running optimization passes. +2. **Dumps HLO modules**: Instructs the XLA compiler back-end to capture optimizer operations lowering structure graphs and dump them to text files. +3. **Strict comparison matches**: Compares the structural lines of the generated representation graph directly against baseline `.txt` copies stored under `tests/utils/`. + +______________________________________________________________________ + +## Updating HLO reference files + +When intended architectures transformations alter graph lowering, reference file baselines require updates. + +> [!IMPORTANT]\ +> While running the update script locally is not the end of the world, **relying on local execution can cause remote CI tests to fail.** +> The PR verification pipelines run the tests in a strictly locked GitHub Actions environment. The smallest discrepancies in local library installations will introduce slight backend lowering graph deviations. If your local execution leads to a remote CI check failure, rely on the GitHub Action trigger described below to generate environment-matching baselines. + +### Method 1: Run the manual GitHub Action Workflow (Highly Recommended) + +Triggering the CI workflow guarantees execution runs within the correct environment isolation scope. + +#### Option A: Using the GitHub UI + +1. Go to the Actions tab in the repository browser. +2. Find the manual workflow: `Update HLO References (for hlo_diff_test.py)`. +3. Run it targeting your PR workspace branch. It compiles the graph layout and commits the baseline update files back to the branch automatically. + +#### Option B: Using the GitHub CLI (`gh`) + +Alternatively, you can trigger the remote workflow via terminal CLI execution: + +```bash +gh workflow run update_reference_hlo.yml --ref +``` + +> [!NOTE] +> A successful run of the manual update workflow will add a new commit to your Pull Request branch. Once complete, you must: +> +> 1. Pull the new commit from remote. +> 2. Squash the commits in your branch once again to keep your PR history clean. +> 3. Push the squashed commit to remote. +> 4. Retry the `tpu-integration` workflow to verify tests pass on your PR. + +### Method 2: Local Execution + +If you need to test or update baselines manually during development: + +```bash +source .venv/bin/activate +pytest tests/integration/hlo_diff_test.py -v +``` + +Or to force update the local baselines: + +```bash +python3 tests/utils/update_hlo_references.py +``` diff --git a/tests/integration/hlo_diff_test.py b/tests/integration/hlo_diff_test.py new file mode 100644 index 0000000000..5cbbc2c1b5 --- /dev/null +++ b/tests/integration/hlo_diff_test.py @@ -0,0 +1,188 @@ +# Copyright 2026 Google LLC +"""Tests for HLO Graph Diff Verification. + +Integrates validation checks system that detects unintended compiler graph transformations +or model graph deviations without breaking isolated PR constraints validations. + +This is part 1 of the automated compiler validation pipeline: +- Ensures the PR isn't changing compiler performance and model graph generation unintentionally. +- Checks current runtime dumps against base reference checkpoints (in tests/utils/reference_hlo_*.txt). +""" + +import os +import unittest +import shutil +import jax +import pytest +from maxtext.trainers.pre_train import train_compile +from absl import flags +import difflib +import re + +pytestmark = [pytest.mark.integration_test] + +# Maximum number of filtered lines from HLO to compare or store in the reference file +MAX_LINES = 2000 + +# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata" +SECTION_CONTENT_REGEX = re.compile(r'^\d+ (?:"[^"]*"|\{[^\}]*\})$') + + +# Filters and normalizes an HLO line to ignore environment-specific details. +def filter_line(line): + # Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata" + if SECTION_CONTENT_REGEX.match(line): + return None + return line + + +@pytest.mark.tpu_backend +class TestHloDiff: + """Tests for HLO Graph Diff Verification.""" + + def setup_method(self): + # Disable cache to ensure compilation occurs every time + jax.config.update("jax_enable_compilation_cache", False) + + def check_files_equal_ignoring_sections(self, file_path1, file_path2): + """Asserts that two text files are identical, ignoring specific sections.""" + + def get_filtered_lines(file_path): + with open(file_path, "r", encoding="utf-8") as f: + lines = [] + for line in f: + line = filter_line(line) + if line is None: + continue + lines.append(line) + return lines + + lines1 = get_filtered_lines(file_path1)[:MAX_LINES] + lines2 = get_filtered_lines(file_path2)[:MAX_LINES] + + if lines1 == lines2: + return True + + print("\n" + "=" * 20 + " HLO Diff " + "=" * 20) + for line in difflib.unified_diff(lines1, lines2, fromfile=file_path1, tofile=file_path2): + print(line, end="") + print("=" * 50 + "\n") + return False + + @pytest.mark.parametrize( + "test_id, config_file, overrides", + [ + ( + "deepseek3", + "src/maxtext/configs/models/deepseek3-test.yml", + { + "compile_topology": "v6e-4", + "base_num_decoder_layers": 4, + "per_device_batch_size": 1, + "max_target_length": 128, + }, + ), + ( + "llama3_8b", + "src/maxtext/configs/models/llama3-8b.yml", + { + "compile_topology": "v6e-4", + "base_num_decoder_layers": 4, + "per_device_batch_size": 1, + "max_target_length": 128, + }, + ), + ( + "qwen3_1.7b", + "src/maxtext/configs/models/qwen3-1.7b.yml", + { + "compile_topology": "v6e-4", + "base_num_decoder_layers": 4, + "per_device_batch_size": 1, + "max_target_length": 128, + }, + ), + ], + ) + def test_hlo_diff(self, test_id, config_file, overrides): + """Test HLO diff for parameterized configurations.""" + local_landing_dir = os.path.join(os.path.dirname(__file__), f"hlo_diff_dump_{test_id}") + + # Clean up before run + if os.path.exists(local_landing_dir): + shutil.rmtree(local_landing_dir) + os.makedirs(local_landing_dir, exist_ok=True) + + try: + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + config_path = os.path.join(base_dir, config_file) + + # Arguments for train_compile + test_args = [ + None, + config_path, + "dataset_type=synthetic", + "override_model_config=true", + "compile_topology_num_slices=1", + ] + + for key, value in overrides.items(): + test_args.append(f"{key}={value}") + + test_args.append( + f'compile_xla_flags="--xla_dump_to={local_landing_dir} ' + f'--xla_dump_hlo_as_text=True --xla_dump_hlo_module_re=jit_train_step"' + ) + + print(f"Running train_compile for {test_id} with args:", test_args) + + # Initialize flags if not already done to avoid ParseCommandLine error + if not flags.FLAGS.is_parsed(): + flags.FLAGS(["dummy_program", "--grain_train_files=dummy"], known_only=True) + + train_compile.main(tuple(test_args)) + + print(f"Files in landing dir for {test_id}:", os.listdir(local_landing_dir)) + + # Locate the dumped HLO file + files = os.listdir(local_landing_dir) + matches = [f for f in files if f.endswith(".after_optimizations.txt")] + dumped_hlo = matches[0] if matches else None + + assert dumped_hlo, f"Dumped HLO file not found for {test_id}!" + + dumped_hlo_path = os.path.join(local_landing_dir, dumped_hlo) + + reference_hlo_path = os.path.join(os.path.dirname(__file__), f"../utils/reference_hlo_{test_id}.txt") + + if not os.path.exists(reference_hlo_path): + print(f"Reference file not found. Creating it at {reference_hlo_path}") + with ( + open(dumped_hlo_path, "r", encoding="utf-8") as f_in, + open(reference_hlo_path, "w", encoding="utf-8") as f_out, + ): + count = 0 + for line in f_in: + line = filter_line(line) + if line is None: + continue + f_out.write(line) + count += 1 + if count >= MAX_LINES: + break + print(f"Reference file created for {test_id}. Please commit it.") + else: + assert self.check_files_equal_ignoring_sections(dumped_hlo_path, reference_hlo_path), ( + f"HLO deviation detected in {test_id}! If this is intended, please run the 'Update HLO References' workflow " + f"from Github Actions against your branch to update the reference HLO. " + f"For more details, see docs/development/hlo_diff_testing.md." + ) + print(f"HLO comparison successful against {reference_hlo_path}") + + finally: + if os.path.exists(local_landing_dir): + shutil.rmtree(local_landing_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/reference_hlo_deepseek3.txt b/tests/utils/reference_hlo_deepseek3.txt new file mode 100644 index 0000000000..a1dd390424 --- /dev/null +++ b/tests/utils/reference_hlo_deepseek3.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias), {39}: (39, {}, may-alias), {40}: (40, {}, may-alias), {41}: (41, {}, may-alias), {42}: (42, {}, may-alias), {43}: (43, {}, may-alias), {44}: (44, {}, may-alias), {45}: (45, {}, may-alias), {46}: (46, {}, may-alias), {47}: (47, {}, may-alias), {48}: (48, {}, may-alias), {49}: (49, {}, may-alias), {50}: (50, {}, may-alias), {51}: (51, {}, may-alias), {52}: (52, {}, may-alias), {53}: (53, {}, may-alias), {54}: (54, {}, may-alias), {55}: (55, {}, may-alias), {56}: (56, {}, may-alias), {57}: (57, {}, may-alias), {58}: (58, {}, may-alias), {59}: (59, {}, may-alias), {60}: (60, {}, may-alias), {61}: (61, {}, may-alias), {62}: (62, {}, may-alias), {63}: (63, {}, may-alias), {64}: (64, {}, may-alias), {65}: (65, {}, may-alias), {66}: (66, {}, may-alias), {67}: (67, {}, may-alias), {68}: (68, {}, may-alias), {69}: (69, {}, may-alias), {70}: (70, {}, may-alias), {71}: (71, {}, may-alias), {72}: (72, {}, may-alias), {73}: (73, {}, may-alias), {74}: (74, {}, may-alias), {75}: (75, {}, may-alias), {76}: (76, {}, may-alias), {77}: (77, {}, may-alias), {78}: (78, {}, may-alias), {79}: (79, {}, may-alias), {80}: (80, {}, may-alias), {81}: (81, {}, may-alias), {82}: (82, {}, may-alias), {83}: (83, {}, may-alias), {84}: (84, {}, may-alias), {85}: (85, {}, may-alias), {86}: (86, {}, may-alias), {87}: (87, {}, may-alias), {88}: (88, {}, may-alias), {89}: (89, {}, may-alias), {90}: (90, {}, may-alias), {91}: (91, {}, may-alias), {92}: (92, {}, may-alias), {93}: (93, {}, may-alias), {94}: (94, {}, may-alias), {95}: (95, {}, may-alias), {96}: (96, {}, may-alias), {97}: (97, {}, may-alias), {98}: (98, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=5*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=10*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=15*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=20*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=25*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=30*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[512]{0:T(512)}, /*index=35*/f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, /*index=40*/f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, /*index=45*/f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, /*index=50*/f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, /*index=55*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, /*index=60*/f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, /*index=65*/f32[129280,512]{1,0:T(8,128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=70*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=75*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=80*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=85*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=90*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=95*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, /*index=100*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=5*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=10*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=15*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=20*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=25*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=30*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[512]{0:T(512)}, /*index=35*/f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, /*index=40*/f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, /*index=45*/f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, /*index=50*/f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, /*index=55*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, /*index=60*/f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, /*index=65*/f32[129280,512]{1,0:T(8,128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=70*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=75*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=80*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=85*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=90*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=95*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, /*index=100*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=105*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%region_46.56 (top_k.25: bf16[], top_k.26: bf16[], top_k.27: s32[], top_k.28: s32[]) -> pred[] { + %constant.1536 = s32[]{:T(128)} constant(0) + %constant.1537 = s32[]{:T(128)} constant(2147483647) + %top_k.25 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} + %top_k.26 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} + %top_k.27 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} + %top_k.28 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} + %convert.393 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.393), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1536), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1537, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.127 = s32[]{:T(128)S(6)} select(%compare.144, %xor.40, %bitcast-convert.39), metadata={op_name="select.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %convert.394 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.394), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1536), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1537, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.128 = s32[]{:T(128)S(6)} select(%compare.145, %xor.41, %bitcast-convert.40), metadata={op_name="select.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %compare.146 = pred[]{:T(512)S(6)} compare(%select.127, %select.128), direction=GT, metadata={op_name="compare.0"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.147 = pred[]{:T(512)S(6)} compare(%select.128, %select.127), direction=GT, metadata={op_name="compare.117"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.148 = pred[]{:T(512)S(6)} compare(%compare.146, %compare.147), direction=EQ, metadata={op_name="compare.118"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.149 = pred[]{:T(512)S(6)} compare(%top_k.27, %top_k.28), direction=LT, metadata={op_name="compare.119"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.129 = pred[]{:T(512)} select(%compare.148, %compare.149, %compare.146), metadata={op_name="select.113"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_47.57 (sort.64: s32[], sort.65: s32[], sort.66: s32[], sort.67: s32[]) -> pred[] { + %sort.64 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.65 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.66 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.67 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %lt_to.32 = pred[]{:T(512)S(6)} compare(%sort.64, %sort.65), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.33 = pred[]{:T(512)S(6)} compare(%sort.65, %sort.64), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.150 = pred[]{:T(512)S(6)} compare(%lt_to.32, %lt_to.33), direction=EQ, metadata={op_name="compare.120"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.151 = pred[]{:T(512)S(6)} compare(%sort.66, %sort.67), direction=LT, metadata={op_name="compare.121"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.130 = pred[]{:T(512)} select(%compare.150, %compare.151, %lt_to.32), metadata={op_name="select.114"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.58 (sort.68: s32[], sort.69: s32[], sort.70: s32[], sort.71: s32[]) -> pred[] { + %sort.68 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.69 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.70 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.71 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %lt_to.34 = pred[]{:T(512)S(6)} compare(%sort.68, %sort.69), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.35 = pred[]{:T(512)S(6)} compare(%sort.69, %sort.68), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.152 = pred[]{:T(512)S(6)} compare(%lt_to.34, %lt_to.35), direction=EQ, metadata={op_name="compare.122"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.153 = pred[]{:T(512)S(6)} compare(%sort.70, %sort.71), direction=LT, metadata={op_name="compare.123"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.131 = pred[]{:T(512)} select(%compare.152, %compare.153, %lt_to.34), metadata={op_name="select.115"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_67.80 (sort.78: s32[], sort.79: s32[], sort.80: s32[], sort.81: s32[]) -> pred[] { + %sort.78 = s32[]{:T(128)} parameter(0), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.79 = s32[]{:T(128)} parameter(1), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.80 = s32[]{:T(128)} parameter(2), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.81 = s32[]{:T(128)} parameter(3), metadata={op_name="sort_activations/jit(argsort)/sort"} + %lt_to.37 = pred[]{:T(512)S(6)} compare(%sort.78, %sort.79), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.38 = pred[]{:T(512)S(6)} compare(%sort.79, %sort.78), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.156 = pred[]{:T(512)S(6)} compare(%lt_to.37, %lt_to.38), direction=EQ, metadata={op_name="compare.124"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.157 = pred[]{:T(512)S(6)} compare(%sort.80, %sort.81), direction=LT, metadata={op_name="compare.125"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.134 = pred[]{:T(512)} select(%compare.156, %compare.157, %lt_to.37), metadata={op_name="select.116"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_119.141 (reduce_sum.157: bf16[], reduce_sum.158: bf16[]) -> bf16[] { + %reduce_sum.157 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + %reduce_sum.158 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + ROOT %reduce_sum.159 = bf16[]{:T(256)} add(%reduce_sum.157, %reduce_sum.158), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=1244}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_107.126 (psum.6: bf16[], psum.9: bf16[]) -> bf16[] { + %psum.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.9 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1407 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=1207}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_108.127 (psum.10: bf16[], psum.11: bf16[]) -> bf16[] { + %psum.10 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.11 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1408 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=1207}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_109.128 (psum.14: bf16[], psum.15: bf16[]) -> bf16[] { + %psum.14 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.15 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1409 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=1207}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_62.73 (reduce-window.111: s32[], reduce-window.112: s32[]) -> s32[] { + %reduce-window.111 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.35"} + %reduce-window.112 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.35"} + ROOT %reduce_window_sum.108 = s32[]{:T(128)} add(%reduce-window.111, %reduce-window.112), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_64.75 (reduce-window.113: s32[], reduce-window.114: s32[]) -> s32[] { + %reduce-window.113 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.36"} + %reduce-window.114 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.36"} + ROOT %reduce_window_sum.109 = s32[]{:T(128)} add(%reduce-window.113, %reduce-window.114), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_65.76 (reduce-window.115: s32[], reduce-window.116: s32[]) -> s32[] { + %reduce-window.115 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.63"} + %reduce-window.116 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.63"} + ROOT %reduce_window_sum.110 = s32[]{:T(128)} add(%reduce-window.115, %reduce-window.116), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_68.81.clone (reduce-window.333: s32[], reduce-window.334: s32[]) -> s32[] { + %reduce-window.333 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.38"} + %reduce-window.334 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.38"} + ROOT %reduce_window_sum.261 = s32[]{:T(128)} add(%reduce-window.333, %reduce-window.334), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_69.82.clone (reduce-window.335: s32[], reduce-window.336: s32[]) -> s32[] { + %reduce-window.335 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.64"} + %reduce-window.336 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.64"} + ROOT %reduce_window_sum.262 = s32[]{:T(128)} add(%reduce-window.335, %reduce-window.336), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_71.84.clone (reduce-window.337: s32[], reduce-window.338: s32[]) -> s32[] { + %reduce-window.337 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.40"} + %reduce-window.338 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.40"} + ROOT %reduce_window_sum.263 = s32[]{:T(128)} add(%reduce-window.337, %reduce-window.338), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_72.85.clone (reduce-window.339: s32[], reduce-window.340: s32[]) -> s32[] { + %reduce-window.339 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.65"} + %reduce-window.340 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.65"} + ROOT %reduce_window_sum.264 = s32[]{:T(128)} add(%reduce-window.339, %reduce-window.340), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_74.87.clone (reduce-window.341: s32[], reduce-window.342: s32[]) -> s32[] { + %reduce-window.341 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.42"} + %reduce-window.342 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.42"} + ROOT %reduce_window_sum.265 = s32[]{:T(128)} add(%reduce-window.341, %reduce-window.342), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_75.88.clone (reduce-window.343: s32[], reduce-window.344: s32[]) -> s32[] { + %reduce-window.343 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.66"} + %reduce-window.344 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.66"} + ROOT %reduce_window_sum.266 = s32[]{:T(128)} add(%reduce-window.343, %reduce-window.344), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_80.96.clone (reduce-window.345: s32[], reduce-window.346: s32[]) -> s32[] { + %reduce-window.345 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.44"} + %reduce-window.346 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.44"} + ROOT %reduce_window_sum.267 = s32[]{:T(128)} add(%reduce-window.345, %reduce-window.346), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_82.98.clone (reduce-window.347: s32[], reduce-window.348: s32[]) -> s32[] { + %reduce-window.347 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.45"} + %reduce-window.348 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.45"} + ROOT %reduce_window_sum.268 = s32[]{:T(128)} add(%reduce-window.347, %reduce-window.348), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_83.99.clone (reduce-window.349: s32[], reduce-window.350: s32[]) -> s32[] { + %reduce-window.349 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.67"} + %reduce-window.350 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.67"} + ROOT %reduce_window_sum.269 = s32[]{:T(128)} add(%reduce-window.349, %reduce-window.350), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_94.112 (reduce-window.174: s32[], reduce-window.175: s32[]) -> s32[] { + %reduce-window.174 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.49"} + %reduce-window.175 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.49"} + ROOT %reduce_window_sum.138 = s32[]{:T(128)} add(%reduce-window.174, %reduce-window.175), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_95.113 (reduce-window.179: s32[], reduce-window.180: s32[]) -> s32[] { + %reduce-window.179 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.69"} + %reduce-window.180 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.69"} + ROOT %reduce_window_sum.139 = s32[]{:T(128)} add(%reduce-window.179, %reduce-window.180), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_97.115 (reduce-window.184: s32[], reduce-window.185: s32[]) -> s32[] { + %reduce-window.184 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.51"} + %reduce-window.185 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.51"} + ROOT %reduce_window_sum.140 = s32[]{:T(128)} add(%reduce-window.184, %reduce-window.185), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_98.116 (reduce-window.189: s32[], reduce-window.190: s32[]) -> s32[] { + %reduce-window.189 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.70"} + %reduce-window.190 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.70"} + ROOT %reduce_window_sum.141 = s32[]{:T(128)} add(%reduce-window.189, %reduce-window.190), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_103.121 (reduce-window.194: s32[], reduce-window.195: s32[]) -> s32[] { + %reduce-window.194 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.53"} + %reduce-window.195 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.53"} + ROOT %reduce_window_sum.142 = s32[]{:T(128)} add(%reduce-window.194, %reduce-window.195), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_105.123 (reduce-window.199: s32[], reduce-window.200: s32[]) -> s32[] { + %reduce-window.199 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.54"} + %reduce-window.200 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.54"} + ROOT %reduce_window_sum.143 = s32[]{:T(128)} add(%reduce-window.199, %reduce-window.200), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_106.124 (reduce-window.204: s32[], reduce-window.205: s32[]) -> s32[] { + %reduce-window.204 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.71"} + %reduce-window.205 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.71"} + ROOT %reduce_window_sum.144 = s32[]{:T(128)} add(%reduce-window.204, %reduce-window.205), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.5 (param_0.17: bf16[129280,512], param_1.108: s32[1024]) -> bf16[512,512] { + %param_0.17 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.108 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.13 = s32[1024]{0:T(1024)} custom-call(%param_1.108), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=104} + %slice.1060 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=104} + %reshape.3813 = s32[4,128]{1,0:T(4,128)} reshape(%slice.1060), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=104} + %transpose.849 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3813), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=104} + %gather.213 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.849), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=104} + %transpose.848 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.213), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=104} + ROOT %reshape.3812 = bf16[512,512]{1,0:T(8,128)(2,1)} reshape(%transpose.848), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=104} +} + +%fused_computation.6 (param_0.20: f32[163840,32], param_1.110: s32[1024]) -> f32[512,32] { + %param_0.20 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.110 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.15 = s32[1024]{0:T(1024)} custom-call(%param_1.110), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %slice.1062 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %reshape.3821 = s32[4,128]{1,0:T(4,128)} reshape(%slice.1062), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=180} + %transpose.855 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3821), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=180} + %gather.215 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.855), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %transpose.854 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.215), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + ROOT %reshape.3820 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.854), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} +} + +%fused_computation.7 (param_0.23: f32[163840,32], param_1.112: s32[1024]) -> f32[512,32] { + %param_0.23 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.112 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.17 = s32[1024]{0:T(1024)} custom-call(%param_1.112), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %slice.1064 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %reshape.3829 = s32[4,128]{1,0:T(4,128)} reshape(%slice.1064), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=180} + %transpose.861 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3829), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=180} + %gather.217 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.861), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + %transpose.860 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.217), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} + ROOT %reshape.3828 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.860), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=180} +} + +%fused_computation.8 (param_0.26: f32[163840,32], param_1.120: s32[1024]) -> f32[512,32] { + %param_0.26 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.120 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.25 = s32[1024]{0:T(1024)} custom-call(%param_1.120), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %slice.1072 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %reshape.3837 = s32[4,128]{1,0:T(4,128)} reshape(%slice.1072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=981} + %transpose.867 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3837), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=981} + %gather.219 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.867), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %transpose.866 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.219), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + ROOT %reshape.3836 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.866), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} +} + +%fused_computation.9 (param_0.29: f32[163840,32], param_1.122: s32[1024]) -> f32[512,32] { + %param_0.29 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.122 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.27 = s32[1024]{0:T(1024)} custom-call(%param_1.122), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %slice.1074 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %reshape.3845 = s32[4,128]{1,0:T(4,128)} reshape(%slice.1074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=981} + %transpose.873 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3845), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=981} + %gather.221 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.873), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + %transpose.872 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.221), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} + ROOT %reshape.3844 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.872), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=981} +} + +%fused_computation.10 (param_0.32: bf16[4096,512], param_1.126: s32[4096]) -> bf16[4096,512] { + %param_0.32 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.126 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.31 = s32[4096]{0:T(1024)} custom-call(%param_1.126), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %slice.1078 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %reshape.3853 = s32[4096]{0:T(1024)} reshape(%slice.1078), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %transpose.879 = s32[4096]{0:T(1024)} transpose(%reshape.3853), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %gather.223 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.879), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %transpose.878 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.223), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + ROOT %reshape.3852 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.878), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} +} + +%fused_computation.11 (param_0.35: bf16[4096,512], param_1.128: s32[4096]) -> bf16[4096,512] { + %param_0.35 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.128 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.33 = s32[4096]{0:T(1024)} custom-call(%param_1.128), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %slice.1080 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %reshape.3861 = s32[4096]{0:T(1024)} reshape(%slice.1080), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1733} + %transpose.885 = s32[4096]{0:T(1024)} transpose(%reshape.3861), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1733} + %gather.225 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.885), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %transpose.884 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.225), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + ROOT %reshape.3860 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.884), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} +} + +%fused_computation.12 (param_0.38: bf16[4096,512], param_1.130: s32[4096]) -> bf16[4096,512] { + %param_0.38 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.130 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.35 = s32[4096]{0:T(1024)} custom-call(%param_1.130), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %slice.1082 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %reshape.3869 = s32[4096]{0:T(1024)} reshape(%slice.1082), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %transpose.891 = s32[4096]{0:T(1024)} transpose(%reshape.3869), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %gather.227 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.891), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + %transpose.890 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.227), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} + ROOT %reshape.3868 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.890), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1710} +} + +%fused_computation.13 (param_0.41: bf16[4096,512], param_1.132: s32[4096]) -> bf16[4096,512] { + %param_0.41 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.132 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.37 = s32[4096]{0:T(1024)} custom-call(%param_1.132), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %slice.1084 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %reshape.3877 = s32[4096]{0:T(1024)} reshape(%slice.1084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %transpose.897 = s32[4096]{0:T(1024)} transpose(%reshape.3877), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1710} + %gather.229 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.897), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + %transpose.896 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.229), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} + ROOT %reshape.3876 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.896), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1733} +} + +%fused_computation.15 (param_0.47: s32[256], param_1.124: s32[1024]) -> s32[263] { + %param_0.47 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.124 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.29 = s32[1024]{0:T(1024)} custom-call(%param_1.124), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1956} + %slice.1076 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1956} + %reshape.3908 = s32[263]{0:T(512)} reshape(%slice.1076), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1956} + %transpose.913 = s32[263]{0:T(512)} transpose(%reshape.3908), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1956} + %gather.234 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.913), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1956} + %transpose.912 = s32[263]{0:T(512)} transpose(%gather.234), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1956} + ROOT %reshape.3907 = s32[263]{0:T(512)S(1)} reshape(%transpose.912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1956} +} + +%fused_computation.16 (param_0.50: s32[256], param_1.134: s32[1024]) -> s32[263] { + %param_0.50 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.134 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.39 = s32[1024]{0:T(1024)} custom-call(%param_1.134), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=1992} + %slice.1086 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=1992} + %reshape.3931 = s32[263]{0:T(512)} reshape(%slice.1086), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1992} + %transpose.923 = s32[263]{0:T(512)} transpose(%reshape.3931), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1992} + %gather.237 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.923), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=1992} + %transpose.922 = s32[263]{0:T(512)} transpose(%gather.237), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=1992} + ROOT %reshape.3930 = s32[263]{0:T(512)S(1)} reshape(%transpose.922), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=1992} +} + +%region_173.198.clone (scatter-add.94: bf16[], scatter-add.96: bf16[]) -> bf16[] { + %scatter-add.94 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.96 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.1875 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=1554}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.21 (param_0.55: bf16[129280,512], param_1.65: s32[512], param_2.24: bf16[512,512]) -> bf16[129280,512] { + %param_0.55 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.65 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.3985 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=104} + %transpose.956 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3985), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=104} + %param_2.24 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} parameter(2) + %reshape.3986 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=116} + %transpose.957 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3986), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=116} + ROOT %scatter.93 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.956, %transpose.957), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=104} +} + +%region_12.18 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { + %constant.1497 = s32[]{:T(128)} constant(0) + %constant.1498 = s32[]{:T(128)} constant(2147483647) + %top_k.0 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} + %top_k.6 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} + %top_k.7 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} + %top_k.8 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} + %convert.385 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.385), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1497), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1498, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.118 = s32[]{:T(128)S(6)} select(%compare.128, %xor.36, %bitcast-convert.35), metadata={op_name="select.14"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %convert.386 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.386), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1497), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1498, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.119 = s32[]{:T(128)S(6)} select(%compare.129, %xor.37, %bitcast-convert.36), metadata={op_name="select.15"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %compare.130 = pred[]{:T(512)S(6)} compare(%select.118, %select.119), direction=GT, metadata={op_name="compare.1"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.131 = pred[]{:T(512)S(6)} compare(%select.119, %select.118), direction=GT, metadata={op_name="compare.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.132 = pred[]{:T(512)S(6)} compare(%compare.130, %compare.131), direction=EQ, metadata={op_name="compare.109"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.133 = pred[]{:T(512)S(6)} compare(%top_k.7, %top_k.8), direction=LT, metadata={op_name="compare.110"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.120 = pred[]{:T(512)} select(%compare.132, %compare.133, %compare.130), metadata={op_name="select.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_15.21.clone.1 (reduce-window.252: s32[], reduce-window.253: s32[]) -> s32[] { + %reduce-window.252 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.20"} + %reduce-window.253 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.20"} + ROOT %reduce_window_sum.210 = s32[]{:T(128)} add(%reduce-window.252, %reduce-window.253), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_16.22.clone.1 (reduce-window.254: s32[], reduce-window.255: s32[]) -> s32[] { + %reduce-window.254 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.56"} + %reduce-window.255 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.56"} + ROOT %reduce_window_sum.211 = s32[]{:T(128)} add(%reduce-window.254, %reduce-window.255), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_18.24.clone.1 (reduce-window.256: s32[], reduce-window.257: s32[]) -> s32[] { + %reduce-window.256 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.22"} + %reduce-window.257 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.22"} + ROOT %reduce_window_sum.212 = s32[]{:T(128)} add(%reduce-window.256, %reduce-window.257), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_19.25.clone.1 (reduce-window.258: s32[], reduce-window.259: s32[]) -> s32[] { + %reduce-window.258 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.57"} + %reduce-window.259 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.57"} + ROOT %reduce_window_sum.213 = s32[]{:T(128)} add(%reduce-window.258, %reduce-window.259), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_21.27.clone.1 (reduce-window.260: s32[], reduce-window.261: s32[]) -> s32[] { + %reduce-window.260 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.24"} + %reduce-window.261 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.24"} + ROOT %reduce_window_sum.214 = s32[]{:T(128)} add(%reduce-window.260, %reduce-window.261), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_22.28.clone.1 (reduce-window.262: s32[], reduce-window.263: s32[]) -> s32[] { + %reduce-window.262 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.58"} + %reduce-window.263 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.58"} + ROOT %reduce_window_sum.215 = s32[]{:T(128)} add(%reduce-window.262, %reduce-window.263), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.4.clone (param_0.68: s32[256], param_1.114: s32[1024]) -> s32[263] { + %param_0.68 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.114 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.19 = s32[1024]{0:T(1024)} custom-call(%param_1.114), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1872} + %slice.1066 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1872} + %reshape.4129 = s32[263]{0:T(512)} reshape(%slice.1066), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1872} + %transpose.1039 = s32[263]{0:T(512)} transpose(%reshape.4129), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=1872} + %gather.239 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.1039), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1872} + %transpose.1038 = s32[263]{0:T(512)} transpose(%gather.239), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1872} + ROOT %reshape.4128 = s32[263]{0:T(512)S(1)} reshape(%transpose.1038), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=1872} +} + +%region_27.34.clone.1 (reduce-window.264: s32[], reduce-window.265: s32[]) -> s32[] { + %reduce-window.264 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.26"} + %reduce-window.265 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.26"} + ROOT %reduce_window_sum.216 = s32[]{:T(128)} add(%reduce-window.264, %reduce-window.265), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_29.36.clone.1 (reduce-window.266: s32[], reduce-window.267: s32[]) -> s32[] { + %reduce-window.266 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.27"} + %reduce-window.267 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.27"} + ROOT %reduce_window_sum.217 = s32[]{:T(128)} add(%reduce-window.266, %reduce-window.267), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_30.37.clone.1 (reduce-window.268: s32[], reduce-window.269: s32[]) -> s32[] { + %reduce-window.268 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.59"} + %reduce-window.269 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.59"} + ROOT %reduce_window_sum.218 = s32[]{:T(128)} add(%reduce-window.268, %reduce-window.269), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_13.19 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { + %sort.46 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.47 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %sort.44 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.45 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.48 = s32[]{:T(128)} parameter(4), metadata={op_name="jit(argsort)/sort"} + %sort.49 = s32[]{:T(128)} parameter(5), metadata={op_name="jit(argsort)/sort"} + %lt_to.27 = pred[]{:T(512)S(6)} compare(%sort.44, %sort.45), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.28 = pred[]{:T(512)S(6)} compare(%sort.45, %sort.44), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.134 = pred[]{:T(512)S(6)} compare(%lt_to.27, %lt_to.28), direction=EQ, metadata={op_name="compare.111"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.135 = pred[]{:T(512)S(6)} compare(%sort.48, %sort.49), direction=LT, metadata={op_name="compare.112"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.121 = pred[]{:T(512)} select(%compare.134, %compare.135, %lt_to.27), metadata={op_name="select.109"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.2.clone (param_0.71: bf16[4096,512], param_1.116: s32[4096]) -> bf16[4096,512] { + %param_0.71 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.116 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.21 = s32[4096]{0:T(1024)} custom-call(%param_1.116), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1659} + %slice.1068 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1659} + %reshape.4152 = s32[4096]{0:T(1024)} reshape(%slice.1068), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1659} + %transpose.1045 = s32[4096]{0:T(1024)} transpose(%reshape.4152), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1659} + %gather.240 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.1045), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1659} + %transpose.1044 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.240), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1659} + ROOT %reshape.4151 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1659} +} + +%region_31.39 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { + %sort.52 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.53 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %sort.50 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.51 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.54 = s32[]{:T(128)} parameter(4), metadata={op_name="jit(argsort)/sort"} + %sort.55 = s32[]{:T(128)} parameter(5), metadata={op_name="jit(argsort)/sort"} + %lt_to.30 = pred[]{:T(512)S(6)} compare(%sort.50, %sort.51), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.31 = pred[]{:T(512)S(6)} compare(%sort.51, %sort.50), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.142 = pred[]{:T(512)S(6)} compare(%lt_to.30, %lt_to.31), direction=EQ, metadata={op_name="compare.113"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.143 = pred[]{:T(512)S(6)} compare(%sort.54, %sort.55), direction=LT, metadata={op_name="compare.114"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.126 = pred[]{:T(512)} select(%compare.142, %compare.143, %lt_to.30), metadata={op_name="select.110"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.3.clone (param_0.72: bf16[4096,512], param_1.118: s32[4096]) -> bf16[4096,512] { + %param_0.72 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.118 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.23 = s32[4096]{0:T(1024)} custom-call(%param_1.118), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1687} + %slice.1070 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1687} + %reshape.4154 = s32[4096]{0:T(1024)} reshape(%slice.1070), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1687} + %transpose.1047 = s32[4096]{0:T(1024)} transpose(%reshape.4154), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=1687} + %gather.241 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.1047), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1687} + %transpose.1046 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.241), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1687} + ROOT %reshape.4153 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1046), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=1687} +} + +%compare (name: s32[], name.1: s32[], name.2: bf16[], name.3: bf16[]) -> pred[] { + %name.2 = bf16[] parameter(2) + %name.3 = bf16[] parameter(3) + %name = s32[] parameter(0) + %name.1 = s32[] parameter(1) + ROOT %compare.393 = pred[] compare(%name, %name.1), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_154.179 (reduce_sum.431: f32[], reduce_sum.254: f32[]) -> f32[] { + %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.488 (param_0.4566: f32[3,1536,128,192]) -> f32[] { + %param_0.4566 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) + %bitcast.672 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4566), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.672, %bitcast.672), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6397 = f32[]{:T(128)} constant(0) + ROOT %reduce.669 = f32[]{:T(128)} reduce(%square.564, %constant.6397), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%fused_computation.489 (param_0.1427: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { + %param_0.1427 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %copy.1550 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1427), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + ROOT %bitcast.673 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1550), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} +} + +%region_221.246 (reduce_sum.893: f32[], reduce_sum.603: f32[]) -> f32[] { + %reduce_sum.893 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.603 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.604 = f32[]{:T(128)} add(%reduce_sum.893, %reduce_sum.603), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_187.212 (reduce_sum.655: f32[], reduce_sum.449: f32[]) -> f32[] { + %reduce_sum.655 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.449 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.655, %reduce_sum.449), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.490 (param_0.4536: f32[1536,3,128,192], param_1.5952: f32[], param_2.5469: f32[], param_3.4115: f32[], param_4.3388: f32[1536,3,128,192], param_5.3053: f32[], param_6.2563: f32[3,1536,128,192], param_7.2484: pred[], param_8.1878: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { + %param_0.4536 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %param_3.4115 = f32[]{:T(128)S(6)} parameter(3) + %mul.4804.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.4115), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2484 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4669.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.2484), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2563 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) + %bitcast.1374.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.2563), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3053 = f32[]{:T(128)} parameter(5) + %div.2846.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.3053), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2845.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1374.clone.1, %div.2846.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4668.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.4669.clone.1, %bitcast.1374.clone.1, %div.2845.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1540.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} compare(%select_n.4668.clone.1, %select_n.4668.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6367 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2552.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6367), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4667.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%ne.1540.clone.1, %broadcast_in_dim.2552.clone.1, %select_n.4668.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6059.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2299.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6059.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2298.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} compare(%select_n.4667.clone.1, %eq.2299.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6058.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2551.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6058.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4666.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%eq.2298.clone.1, %broadcast_in_dim.2551.clone.1, %select_n.4667.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6057.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2297.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6057.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2296.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} compare(%select_n.4666.clone.1, %eq.2297.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6056.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2550.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6056.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4665.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%eq.2296.clone.1, %broadcast_in_dim.2550.clone.1, %select_n.4666.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6052.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5797.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6052.clone.1), dimensions={}, metadata={op_name="broadcast.334"} + %mul.4810.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.4665.clone.1, %broadcast.5797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1878 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(8) + %constant.6060.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4811.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6060.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4809.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.1878, %mul.4811.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3632.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4810.clone.1, %mul.4809.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5469 = f32[]{:T(128)S(6)} parameter(2) + %div.2842.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.5469), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.4665.clone.1, %select_n.4665.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6055.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4808.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6055.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4806.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.4808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3388 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(4) + %constant.6054.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4807.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6054.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4805.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.3388, %mul.4807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3631.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4806.clone.1, %mul.4805.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5952 = f32[]{:T(128)S(6)} parameter(1) + %div.2841.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5952), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2840.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3631.clone.1, %div.2841.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2840.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6053.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3630.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.6053.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3629.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3630.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1294.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2842.clone.1, %add.3629.clone.1), metadata={op_name="multiply.290"} + %div.2839.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3632.clone.1, %multiply.1294.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4803.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4536, %broadcast.5797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3628.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2839.clone.1, %mul.4803.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4802.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.4804.clone.1, %add.3628.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3627.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4536, %mul.4802.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3627.clone.1, %add.3627.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.670 = f32[]{:T(128)} reduce(%square.565, %constant.6367), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.671.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.6367), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.671 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.670, %add.3627.clone.1, %add.3631.clone.1, %add.3632.clone.1, %reduce.671.clone.1) +} + +%region_160.185 (reduce_sum.473: f32[], reduce_sum.293: f32[]) -> f32[] { + %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_158.183 (reduce_sum.459: f32[], reduce_sum.460: f32[]) -> f32[] { + %reduce_sum.459 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.460 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.461 = f32[]{:T(128)} add(%reduce_sum.459, %reduce_sum.460), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.531 (param_0.4562: bf16[256,512,512], param_1.5974: bf16[256,512,512]) -> (f32[], f32[]) { + %param_0.4562 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1696 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4562), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.695 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1696), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.695, %bitcast.695), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6393 = f32[]{:T(128)} constant(0) + %reduce.672 = f32[]{:T(128)} reduce(%square.570, %constant.6393), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + %param_1.5974 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) + %broadcast_in_dim.1704.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5974), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.703.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1704.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.703.clone.1, %bitcast.703.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %reduce.674.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.6393), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + ROOT %tuple.778 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.672, %reduce.674.clone.1) +} + +%region_159.184 (reduce_sum.466: f32[], reduce_sum.279: f32[]) -> f32[] { + %reduce_sum.466 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.279 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.286 = f32[]{:T(128)} add(%reduce_sum.466, %reduce_sum.279), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.533 (param_0.4561: bf16[256,512,512]) -> f32[] { + %param_0.4561 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1700 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.699 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1700), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.699, %bitcast.699), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6392 = f32[]{:T(128)} constant(0) + ROOT %reduce.673 = f32[]{:T(128)} reduce(%square.573, %constant.6392), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%region_227.252 (reduce_sum.935: f32[], reduce_sum.631: f32[]) -> f32[] { + %reduce_sum.935 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.631 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.632 = f32[]{:T(128)} add(%reduce_sum.935, %reduce_sum.631), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_193.218 (reduce_sum.697: f32[], reduce_sum.471: f32[]) -> f32[] { + %reduce_sum.697 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.697, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.551 (param_0.4530: f32[], param_1.5946: f32[256,1,512,512], param_2.5463: f32[], param_3.4109: f32[256,1,512,512], param_4.3382: f32[], param_5.3047: bf16[256,512,512], param_6.2557: pred[], param_7.2478: f32[], param_8.1872: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.1872 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1359.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.1872), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %param_7.2478 = f32[]{:T(128)S(6)} parameter(7) + %mul.4753.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.2478), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_6.2557 = pred[]{:T(512)S(6)} parameter(6) + %select_n.4624.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.2557), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_5.3047 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.2523.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.3047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.1361.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.2523.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %param_4.3382 = f32[]{:T(128)} parameter(4) + %div.2804.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.3382), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2803.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1361.clone.1, %div.2804.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4623.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.4624.clone.1, %bitcast.1361.clone.1, %div.2803.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1531.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4623.clone.1, %select_n.4623.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6361 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2522.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6361), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4622.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%ne.1531.clone.1, %broadcast_in_dim.2522.clone.1, %select_n.4623.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5994.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2263.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5994.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2262.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4622.clone.1, %eq.2263.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5993.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2521.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5993.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4621.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2262.clone.1, %broadcast_in_dim.2521.clone.1, %select_n.4622.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5992.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2261.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5992.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2260.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4621.clone.1, %eq.2261.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5991.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2520.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5991.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4620.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2260.clone.1, %broadcast_in_dim.2520.clone.1, %select_n.4621.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5990.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5777.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5990.clone.1), dimensions={}, metadata={op_name="broadcast.3475"} + %mul.4755.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4620.clone.1, %broadcast.5777.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_3.4109 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1360.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.4109), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.5989.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.5776.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5989.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4754.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1360.clone.1, %broadcast.5776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3597.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4755.clone.1, %mul.4754.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5463 = f32[]{:T(128)S(6)} parameter(2) + %div.2802.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.5463), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4620.clone.1, %select_n.4620.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.5988.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.5779.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5988.clone.1), dimensions={}, metadata={op_name="broadcast.3483"} + %mul.4757.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.5779.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_1.5946 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1362.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5946), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.5987.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.5778.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5987.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4756.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1362.clone.1, %broadcast.5778.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3598.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4757.clone.1, %mul.4756.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_0.4530 = f32[]{:T(128)S(6)} parameter(0) + %div.2801.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4530), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2800.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3598.clone.1, %div.2801.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2800.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.5995.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.5775.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5995.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3596.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.5775.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1288.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2802.clone.1, %add.3596.clone.1), metadata={op_name="multiply.296"} + %div.2799.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3597.clone.1, %multiply.1288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4752.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1359.clone.1, %broadcast.5777.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3595.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2799.clone.1, %mul.4752.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4751.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4753.clone.1, %add.3595.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3594.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1359.clone.1, %mul.4751.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3594.clone.1, %add.3594.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.675 = f32[]{:T(128)} reduce(%square.577, %constant.6361), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %bitcast.856.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3598.clone.1) + %bitcast.829.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3597.clone.1) + %reduce.684.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.6361), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.681 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.675, %add.3594.clone.1, %bitcast.856.clone.1, %bitcast.829.clone.1, %reduce.684.clone.1) +} + +%region_226.251 (reduce_sum.928: f32[], reduce_sum.625: f32[]) -> f32[] { + %reduce_sum.928 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.625 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.626 = f32[]{:T(128)} add(%reduce_sum.928, %reduce_sum.625), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_192.217 (reduce_sum.690: f32[], reduce_sum.465: f32[]) -> f32[] { + %reduce_sum.690 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.470 = f32[]{:T(128)} add(%reduce_sum.690, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.552 (param_0.4531: f32[], param_1.5947: f32[256,1,512,512], param_2.5464: f32[], param_3.4110: f32[256,1,512,512], param_4.3383: f32[], param_5.3048: bf16[256,512,512], param_6.2558: pred[], param_7.2479: f32[], param_8.1873: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.1873 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1363.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.1873), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %param_7.2479 = f32[]{:T(128)S(6)} parameter(7) + %mul.4760.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.2479), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_6.2558 = pred[]{:T(512)S(6)} parameter(6) + %select_n.4629.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.2558), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_5.3048 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.2527.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.3048), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.1365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.2527.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %param_4.3383 = f32[]{:T(128)} parameter(4) + %div.2810.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.3383), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2809.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1365.clone.1, %div.2810.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4628.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.4629.clone.1, %bitcast.1365.clone.1, %div.2809.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1532.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4628.clone.1, %select_n.4628.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6362 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2526.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6362), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4627.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%ne.1532.clone.1, %broadcast_in_dim.2526.clone.1, %select_n.4628.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6003.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2267.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6003.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2266.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4627.clone.1, %eq.2267.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6002.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2525.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6002.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4626.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2266.clone.1, %broadcast_in_dim.2525.clone.1, %select_n.4627.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6001.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2265.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6001.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2264.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4626.clone.1, %eq.2265.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6000.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2524.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6000.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4625.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2264.clone.1, %broadcast_in_dim.2524.clone.1, %select_n.4626.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5999.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5782.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5999.clone.1), dimensions={}, metadata={op_name="broadcast.3475"} + %mul.4762.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4625.clone.1, %broadcast.5782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_3.4110 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1364.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.4110), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.5998.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.5781.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5998.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4761.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1364.clone.1, %broadcast.5781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3602.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4762.clone.1, %mul.4761.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5464 = f32[]{:T(128)S(6)} parameter(2) + %div.2808.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.5464), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4625.clone.1, %select_n.4625.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.5997.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.5784.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5997.clone.1), dimensions={}, metadata={op_name="broadcast.3483"} + %mul.4764.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.5784.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_1.5947 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1366.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5947), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.5996.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.5783.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.5996.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4763.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1366.clone.1, %broadcast.5783.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3603.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4764.clone.1, %mul.4763.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_0.4531 = f32[]{:T(128)S(6)} parameter(0) + %div.2807.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4531), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2806.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3603.clone.1, %div.2807.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2806.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6004.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.5780.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6004.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3601.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.5780.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1289.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2808.clone.1, %add.3601.clone.1), metadata={op_name="multiply.295"} + %div.2805.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3602.clone.1, %multiply.1289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4759.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1363.clone.1, %broadcast.5782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3600.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2805.clone.1, %mul.4759.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4758.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4760.clone.1, %add.3600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3599.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1363.clone.1, %mul.4758.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3599.clone.1, %add.3599.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.676 = f32[]{:T(128)} reduce(%square.578, %constant.6362), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %bitcast.847.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3603.clone.1) + %bitcast.820.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3602.clone.1) + %reduce.685.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.6362), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.680 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.676, %add.3599.clone.1, %bitcast.847.clone.1, %bitcast.820.clone.1, %reduce.685.clone.1) +} + +%region_225.250 (reduce_sum.921: f32[], reduce_sum.619: f32[]) -> f32[] { + %reduce_sum.921 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.619 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.624 = f32[]{:T(128)} add(%reduce_sum.921, %reduce_sum.619), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_191.216 (reduce_sum.683: f32[], reduce_sum.463: f32[]) -> f32[] { + %reduce_sum.683 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.463 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.683, %reduce_sum.463), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.553 (param_0.4532: f32[], param_1.5948: f32[256,1,512,512], param_2.5465: f32[], param_3.4111: f32[256,1,512,512], param_4.3384: f32[], param_5.3049: bf16[256,512,512], param_6.2559: pred[], param_7.2480: f32[], param_8.1874: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.1874 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1367.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.1874), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %param_7.2480 = f32[]{:T(128)S(6)} parameter(7) + %mul.4767.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.2480), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_6.2559 = pred[]{:T(512)S(6)} parameter(6) + %select_n.4634.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.2559), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_5.3049 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.2531.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.3049), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.1369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.2531.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %param_4.3384 = f32[]{:T(128)} parameter(4) + %div.2816.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.3384), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2815.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1369.clone.1, %div.2816.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4633.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.4634.clone.1, %bitcast.1369.clone.1, %div.2815.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1533.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4633.clone.1, %select_n.4633.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6363 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6363), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4632.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%ne.1533.clone.1, %broadcast_in_dim.2530.clone.1, %select_n.4633.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6012.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2271.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6012.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2270.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4632.clone.1, %eq.2271.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6011.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6011.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4631.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2270.clone.1, %broadcast_in_dim.2529.clone.1, %select_n.4632.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6010.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2269.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6010.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2268.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} compare(%select_n.4631.clone.1, %eq.2269.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6009.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6009.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4630.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%eq.2268.clone.1, %broadcast_in_dim.2528.clone.1, %select_n.4631.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6008.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5787.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6008.clone.1), dimensions={}, metadata={op_name="broadcast.3475"} + %mul.4769.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4630.clone.1, %broadcast.5787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_3.4111 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1368.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.4111), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.6007.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.5786.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6007.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4768.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1368.clone.1, %broadcast.5786.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3607.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4769.clone.1, %mul.4768.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5465 = f32[]{:T(128)S(6)} parameter(2) + %div.2814.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.5465), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.4630.clone.1, %select_n.4630.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6006.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.5789.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6006.clone.1), dimensions={}, metadata={op_name="broadcast.3483"} + %mul.4771.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.5789.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_1.5948 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1370.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5948), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.6005.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.5788.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6005.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4770.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1370.clone.1, %broadcast.5788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3608.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4771.clone.1, %mul.4770.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_0.4532 = f32[]{:T(128)S(6)} parameter(0) + %div.2813.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4532), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2812.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3608.clone.1, %div.2813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2812.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6013.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.5785.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.6013.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3606.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.5785.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1290.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2814.clone.1, %add.3606.clone.1), metadata={op_name="multiply.294"} + %div.2811.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3607.clone.1, %multiply.1290.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4766.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1367.clone.1, %broadcast.5787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3605.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2811.clone.1, %mul.4766.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4765.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4767.clone.1, %add.3605.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3604.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1367.clone.1, %mul.4765.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3604.clone.1, %add.3604.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.677 = f32[]{:T(128)} reduce(%square.579, %constant.6363), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %bitcast.838.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3608.clone.1) + %bitcast.811.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3607.clone.1) + %reduce.686.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.6363), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.679 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.677, %add.3604.clone.1, %bitcast.838.clone.1, %bitcast.811.clone.1, %reduce.686.clone.1) +} + +%region_155.180 (reduce_sum.438: f32[], reduce_sum.259: f32[]) -> f32[] { + %reduce_sum.438 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.259 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.260 = f32[]{:T(128)} add(%reduce_sum.438, %reduce_sum.259), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.565.clone.clone.clone (param_0.4475: bf16[4,128,129280], param_1.5880: s32[4,128], param_2.5396: f32[4,128], param_3.4077: f32[4,128], param_4.3355: bf16[4,128], param_5.3025: f32[4,128]) -> bf16[4,128,129280] { + %param_5.3025 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.4980 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.3025), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=90} + %param_3.4077 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.4979 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.4077), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=90} + %param_0.4475 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.3215 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4475), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=834} + %param_4.3355 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.926 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.3355), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %sub.925 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3215, %sub.926), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.925), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=846} + %mul.4978 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4979, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=90} + %param_2.5396 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.2969 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.5396), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=90} + %div.2968 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.4978, %div.2969), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=90} + %param_1.5880 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.2413 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5880), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %eq.2412 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %eq.2411 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.2413, %eq.2412), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %convert_element_type.3214 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.2411), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=850} + %sub.924 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2968, %convert_element_type.3214), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=90} + %mul.4977 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4980, %sub.924), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=90} + ROOT %convert_element_type.3213 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.4977), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=834} +} + +%fused_computation.1005.clone.clone (param_0.4476: f32[4,128], param_1.5881: bf16[4,128,512], param_2.5398: bf16[512]) -> bf16[4,128,512] { + %param_2.5398 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) + %dot_general.832 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.5398), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=99} + %param_1.5881 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3217 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.5881), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=813} + %param_0.4476 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.4982 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4476), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=818} + %mul.4981 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3217, %mul.4982), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=818} + %convert_element_type.3216 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.4981), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=819} + ROOT %dot_general.831 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.832, %convert_element_type.3216), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=99} +} + +%fused_computation.554 (param_0.4565: bf16[4,128,129280], param_1.5976: s32[4,128], param_2.5490: f32[4,128], param_3.4133: f32[4,128], param_4.3404: bf16[4,128], param_5.3067: f32[4,128], param_6.2577: f32[4,128], param_7.2498: bf16[4,128,512], param_8.1891: bf16[512]) -> (f32[], bf16[512,129280,1]) { + %param_6.2577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.2498 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) + %param_8.1891 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(8) + %fusion.577.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.2577, %param_7.2498, %param_8.1891), kind=kLoop, calls=%fused_computation.1005.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=99} + %param_0.4565 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.5976 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.5490 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.4133 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.3404 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %param_5.3067 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4565, %param_1.5976, %param_2.5490, %param_3.4133, %param_4.3404, /*index=5*/%param_5.3067), kind=kLoop, calls=%fused_computation.565.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=834} + %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.577.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=832} + %bitcast.770 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=832} + %convert_element_type.2699 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.770), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=822} + %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2699, %convert_element_type.2699), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6396 = f32[]{:T(128)} constant(0) + %reduce.678 = f32[]{:T(128)} reduce(%square.581, %constant.6396), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + ROOT %tuple.768 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.678, %convolution.141.clone.1) +} + +%region_174.199 (reduce_sum.564: f32[], reduce_sum.387: f32[]) -> f32[] { + %reduce_sum.564 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.564, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.555 (param_0.4549: bf16[129280,512]) -> f32[] { + %param_0.4549 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2701 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4549), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=103} + %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2701, %convert_element_type.2701), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6380 = f32[]{:T(128)} constant(0) + ROOT %reduce.679 = f32[]{:T(128)} reduce(%square.583, %constant.6380), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%region_240.265 (reduce_sum.1026: f32[], reduce_sum.689: f32[]) -> f32[] { + %reduce_sum.1026 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.689 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.694 = f32[]{:T(128)} add(%reduce_sum.1026, %reduce_sum.689), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_206.231 (reduce_sum.788: f32[], reduce_sum.533: f32[]) -> f32[] { + %reduce_sum.788 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.533 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.788, %reduce_sum.533), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.556 (param_0.4517: f32[129280,512], param_1.5933: f32[], param_2.5450: f32[], param_3.4096: f32[], param_4.3369: f32[129280,512], param_5.3034: f32[], param_6.2544: bf16[129280,512], param_7.2465: pred[], param_8.1859: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { + %param_0.4517 = f32[129280,512]{1,0:T(8,128)} parameter(0) + %param_3.4096 = f32[]{:T(128)S(6)} parameter(3) + %mul.4641.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.4096), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2465 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4519.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.2465), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2544 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.3158.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.2544), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=103} + %param_5.3034 = f32[]{:T(128)} parameter(5) + %div.2710.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.3034), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2709.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3158.clone.1, %div.2710.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4518.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.4519.clone.1, %convert_element_type.3158.clone.1, %div.2709.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1510.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} compare(%select_n.4518.clone.1, %select_n.4518.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6348 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2444.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.6348), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4517.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%ne.1510.clone.1, %broadcast_in_dim.2444.clone.1, %select_n.4518.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5829.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2179.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5829.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2177.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} compare(%select_n.4517.clone.1, %eq.2179.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5828.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2443.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5828.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4516.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%eq.2177.clone.1, %broadcast_in_dim.2443.clone.1, %select_n.4517.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5827.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2178.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5827.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2176.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} compare(%select_n.4516.clone.1, %eq.2178.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5826.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2442.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5826.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4515.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%eq.2176.clone.1, %broadcast_in_dim.2442.clone.1, %select_n.4516.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5822.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5727.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5822.clone.1), dimensions={}, metadata={op_name="broadcast.318"} + %mul.4647.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.4515.clone.1, %broadcast.5727.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1859 = f32[129280,512]{1,0:T(8,128)} parameter(8) + %constant.5830.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4648.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5830.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4646.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.1859, %mul.4648.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3527.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4647.clone.1, %mul.4646.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5450 = f32[]{:T(128)S(6)} parameter(2) + %div.2706.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.5450), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.4515.clone.1, %select_n.4515.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.5825.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4645.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5825.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4643.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.4645.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3369 = f32[129280,512]{1,0:T(8,128)} parameter(4) + %constant.5824.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4644.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5824.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4642.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.3369, %mul.4644.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3526.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4643.clone.1, %mul.4642.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5933 = f32[]{:T(128)S(6)} parameter(1) + %div.2705.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5933), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2704.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3526.clone.1, %div.2705.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2704.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.5823.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3525.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.5823.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3524.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3525.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1275.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2706.clone.1, %add.3524.clone.1), metadata={op_name="multiply.309"} + %div.2703.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3527.clone.1, %multiply.1275.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4640.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4517, %broadcast.5727.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3523.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2703.clone.1, %mul.4640.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4639.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.4641.clone.1, %add.3523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3522.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4517, %mul.4639.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3522.clone.1, %add.3522.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.680 = f32[]{:T(128)} reduce(%square.584, %constant.6348), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.687.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.6348), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.682 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.680, %add.3522.clone.1, %add.3526.clone.1, %add.3527.clone.1, %reduce.687.clone.1) +} + +%region_222.247 (reduce_sum.900: f32[], reduce_sum.605: f32[]) -> f32[] { + %reduce_sum.900 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.605 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.610 = f32[]{:T(128)} add(%reduce_sum.900, %reduce_sum.605), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_188.213 (reduce_sum.662: f32[], reduce_sum.451: f32[]) -> f32[] { + %reduce_sum.662 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.451 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.455 = f32[]{:T(128)} add(%reduce_sum.662, %reduce_sum.451), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.557 (param_0.4535: f32[512,129280], param_1.5951: f32[], param_2.5468: f32[], param_3.4114: f32[], param_4.3387: f32[512,129280], param_5.3052: f32[], param_6.2562: bf16[512,129280,1], param_7.2483: pred[], param_8.1877: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { + %param_0.4535 = f32[512,129280]{1,0:T(8,128)} parameter(0) + %param_3.4114 = f32[]{:T(128)S(6)} parameter(3) + %mul.4794.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.4114), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2483 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4659.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.2483), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2562 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.1372.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.2562), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=832} + %convert_element_type.3160.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1372.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=822} + %param_5.3052 = f32[]{:T(128)} parameter(5) + %div.2838.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.3052), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2837.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3160.clone.1, %div.2838.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4658.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.4659.clone.1, %convert_element_type.3160.clone.1, %div.2837.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1538.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} compare(%select_n.4658.clone.1, %select_n.4658.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6366 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2546.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6366), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4657.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%ne.1538.clone.1, %broadcast_in_dim.2546.clone.1, %select_n.4658.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6045.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2291.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6045.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2289.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} compare(%select_n.4657.clone.1, %eq.2291.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6044.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2545.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6044.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4656.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%eq.2289.clone.1, %broadcast_in_dim.2545.clone.1, %select_n.4657.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6043.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2290.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6043.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2288.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} compare(%select_n.4656.clone.1, %eq.2290.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6042.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2544.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6042.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4655.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%eq.2288.clone.1, %broadcast_in_dim.2544.clone.1, %select_n.4656.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6038.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5795.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6038.clone.1), dimensions={}, metadata={op_name="broadcast.333"} + %mul.4800.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.4655.clone.1, %broadcast.5795.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1877 = f32[512,129280]{1,0:T(8,128)} parameter(8) + %constant.6046.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4801.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6046.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4799.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.1877, %mul.4801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3626.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4800.clone.1, %mul.4799.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5468 = f32[]{:T(128)S(6)} parameter(2) + %div.2834.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.5468), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.4655.clone.1, %select_n.4655.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6041.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4798.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6041.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4796.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.4798.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3387 = f32[512,129280]{1,0:T(8,128)} parameter(4) + %constant.6040.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4797.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6040.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4795.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.3387, %mul.4797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3625.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4796.clone.1, %mul.4795.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5951 = f32[]{:T(128)S(6)} parameter(1) + %div.2833.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5951), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2832.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3625.clone.1, %div.2833.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2832.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6039.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3624.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.6039.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3623.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3624.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1293.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2834.clone.1, %add.3623.clone.1), metadata={op_name="multiply.291"} + %div.2831.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3626.clone.1, %multiply.1293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4793.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4535, %broadcast.5795.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3622.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2831.clone.1, %mul.4793.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4792.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.4794.clone.1, %add.3622.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3621.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4535, %mul.4792.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3621.clone.1, %add.3621.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.681 = f32[]{:T(128)} reduce(%square.585, %constant.6366), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.688.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.6366), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.683 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.681, %add.3621.clone.1, %add.3625.clone.1, %add.3626.clone.1, %reduce.688.clone.1) +} + +%region_207.232 (reduce_sum.795: f32[], reduce_sum.535: f32[]) -> f32[] { + %reduce_sum.795 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.535 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.540 = f32[]{:T(128)} add(%reduce_sum.795, %reduce_sum.535), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=1613}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.558 (param_0.4586: bf16[4,128,129280], param_1.5990: f32[4,128], param_2.5500: s32[4,128], param_3.4141: bf16[4,128]) -> f32[4,128] { + %param_2.5500 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.1371 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.5500), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %eq.1268 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %eq.1267 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.1371, %eq.1268), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=850} + %param_0.4586 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2718 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4586), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=834} + %param_3.4141 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.787 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.4141), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %sub.778 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2718, %sub.787), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %param_1.5990 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.785 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5990), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=1611} + %sub.774 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.778, %sub.785), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=1611} + %constant.6420 = f32[]{:T(128)} constant(0) + %broadcast.5280 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.6420), dimensions={}, metadata={op_name="broadcast.518"} + %mul.3701 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.1267, %sub.774, %broadcast.5280), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=1612} + ROOT %reduce.682 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.3701, %constant.6420), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=1613} +} + +%region_37.47 (reduce_sum.76: f32[], reduce_sum.80: f32[]) -> f32[] { + %reduce_sum.76 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.80 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.83 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.80), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=847}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.569 (param_0.4587: bf16[4,128,129280], param_1.5991: bf16[4,128]) -> f32[4,128] { + %param_0.4587 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2724 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4587), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=834} + %param_1.5991 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.788 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5991), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %sub.784 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2724, %sub.788), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=845} + %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.784), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=846} + %constant.6421 = f32[]{:T(128)} constant(0) + ROOT %reduce.683 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.6421), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=847} +} + +%region_152.177 (reduce_sum.417: f32[], reduce_sum.244: f32[]) -> f32[] { + %reduce_sum.417 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.244 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.251 = f32[]{:T(128)} add(%reduce_sum.417, %reduce_sum.244), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.577 (param_0.4568: f32[3,512,128,256]) -> f32[] { + %param_0.4568 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.734 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4568), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.734, %bitcast.734), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6399 = f32[]{:T(128)} constant(0) + ROOT %reduce.689 = f32[]{:T(128)} reduce(%square.588, %constant.6399), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%fused_computation.578 (param_0.1646: f32[512,3,128,256]) -> bf16[3,512,128,256] { + %param_0.1646 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %copy.1551 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1646), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + ROOT %bitcast.735 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1551), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} +} + +%region_219.244 (reduce_sum.879: f32[], reduce_sum.591: f32[]) -> f32[] { + %reduce_sum.879 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.591 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.596 = f32[]{:T(128)} add(%reduce_sum.879, %reduce_sum.591), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_185.210 (reduce_sum.641: f32[], reduce_sum.437: f32[]) -> f32[] { + %reduce_sum.641 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.437 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.641, %reduce_sum.437), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.579 (param_0.4538: f32[512,3,128,256], param_1.5954: f32[], param_2.5471: f32[], param_3.4117: f32[], param_4.3390: f32[512,3,128,256], param_5.3055: f32[], param_6.2565: f32[3,512,128,256], param_7.2486: pred[], param_8.1880: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { + %param_0.4538 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %param_3.4117 = f32[]{:T(128)S(6)} parameter(3) + %mul.4824.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.4117), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2486 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4689.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.2486), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2565 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.1378.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.2565), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3055 = f32[]{:T(128)} parameter(5) + %div.2862.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.3055), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2861.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1378.clone.1, %div.2862.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4688.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.4689.clone.1, %bitcast.1378.clone.1, %div.2861.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1544.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4688.clone.1, %select_n.4688.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6369 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2564.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6369), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4687.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%ne.1544.clone.1, %broadcast_in_dim.2564.clone.1, %select_n.4688.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6087.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2315.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6087.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2314.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4687.clone.1, %eq.2315.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6086.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2563.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6086.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4686.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%eq.2314.clone.1, %broadcast_in_dim.2563.clone.1, %select_n.4687.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6085.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2313.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6085.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2312.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4686.clone.1, %eq.2313.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6084.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2562.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6084.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4685.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%eq.2312.clone.1, %broadcast_in_dim.2562.clone.1, %select_n.4686.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6080.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5801.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6080.clone.1), dimensions={}, metadata={op_name="broadcast.336"} + %mul.4830.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.4685.clone.1, %broadcast.5801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1880 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(8) + %constant.6088.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4831.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6088.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4829.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.1880, %mul.4831.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3644.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4830.clone.1, %mul.4829.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5471 = f32[]{:T(128)S(6)} parameter(2) + %div.2858.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.5471), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.4685.clone.1, %select_n.4685.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6083.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4828.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6083.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4826.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.4828.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3390 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(4) + %constant.6082.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4827.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6082.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4825.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.3390, %mul.4827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3643.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4826.clone.1, %mul.4825.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5954 = f32[]{:T(128)S(6)} parameter(1) + %div.2857.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5954), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2856.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3643.clone.1, %div.2857.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2856.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6081.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3642.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.6081.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3641.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3642.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1296.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2858.clone.1, %add.3641.clone.1), metadata={op_name="multiply.288"} + %div.2855.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3644.clone.1, %multiply.1296.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4823.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4538, %broadcast.5801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3640.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2855.clone.1, %mul.4823.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4822.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.4824.clone.1, %add.3640.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3639.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4538, %mul.4822.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3639.clone.1, %add.3639.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.690 = f32[]{:T(128)} reduce(%square.589, %constant.6369), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.691.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.6369), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.678 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.690, %add.3639.clone.1, %add.3643.clone.1, %add.3644.clone.1, %reduce.691.clone.1) +} + +%region_172.197 (reduce_sum.557: f32[], reduce_sum.381: f32[]) -> f32[] { + %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.386 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.846.clone.clone (param_0.4502: f32[4,128], param_1.5925: bf16[4,128,1536], param_2.5432: bf16[1536]) -> bf16[4,128,1536,1] { + %param_2.5432 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.852 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.5432), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + %param_1.5925 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3239 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.5925), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=528} + %param_0.4502 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5028 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4502), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=537} + %mul.5027 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3239, %mul.5028), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=537} + %convert_element_type.3238 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.5027), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=538} + %dot_general.851 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.852, %convert_element_type.3238), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + ROOT %bitcast.1466 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dot_general.851), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} +} + +%bitcast_fusion.12 (bitcast_input.12: bf16[4,128,128,192]) -> bf16[4,128,128,192] { + %bitcast_input.12 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.1488 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) +} + +%fused_computation.591 (param_0.4550: bf16[4,128,128,192], param_1.5965: f32[4,128], param_2.5482: bf16[4,128,1536], param_3.4128: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { + %param_1.5965 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.5482 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.4128 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.460.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5965, %param_2.5482, %param_3.4128), kind=kLoop, calls=%fused_computation.846.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + %param_0.4550 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + %fusion.777 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4550), kind=kLoop, calls=%bitcast_fusion.12 + %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.460.clone.1, %fusion.777), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=937} + %bitcast.801 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=937} + %broadcast_in_dim.1711 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.801), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.739 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1711), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.739, %bitcast.739), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6381 = f32[]{:T(128)} constant(0) + %reduce.692 = f32[]{:T(128)} reduce(%square.592, %constant.6381), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + ROOT %tuple.777 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.692, %convolution.146.clone.1) +} + +%region_239.264 (reduce_sum.1019: f32[], reduce_sum.687: f32[]) -> f32[] { + %reduce_sum.1019 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.687 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.688 = f32[]{:T(128)} add(%reduce_sum.1019, %reduce_sum.687), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_205.230 (reduce_sum.781: f32[], reduce_sum.527: f32[]) -> f32[] { + %reduce_sum.781 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.527 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.781, %reduce_sum.527), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.596 (param_0.4518: f32[], param_1.5934: f32[], param_2.5451: f32[], param_3.4097: f32[1536,1,128,192], param_4.3370: f32[1536,1,128,192], param_5.3035: f32[], param_6.2545: bf16[1536,128,192,1], param_7.2466: pred[], param_8.1860: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { + %param_3.4097 = f32[1536,1,128,192]{2,3,1,0:T(8,128)} parameter(3) + %copy.1673.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} copy(%param_3.4097), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + %param_2.5451 = f32[]{:T(128)S(6)} parameter(2) + %mul.4651.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%param_2.5451), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2466 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4529.clone.1 = pred[1536,1,128,192]{2,0,3,1:T(8,128)(4,1)} broadcast(%param_7.2466), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2545 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} parameter(6) + %bitcast.1337.clone.1 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%param_6.2545), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=937} + %broadcast_in_dim.2449.clone.1 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.1337.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.1336.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.2449.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %param_5.3035 = f32[]{:T(128)} parameter(5) + %div.2718.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%param_5.3035), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2717.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} divide(%bitcast.1336.clone.1, %div.2718.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4528.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} select(%select_n.4529.clone.1, %bitcast.1336.clone.1, %div.2717.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1512.clone.1 = pred[1536,1,128,192]{2,0,3,1:T(8,128)(4,1)} compare(%select_n.4528.clone.1, %select_n.4528.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6349 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2452.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.6349), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4527.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} select(%ne.1512.clone.1, %broadcast_in_dim.2452.clone.1, %select_n.4528.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5843.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2187.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5843.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2185.clone.1 = pred[1536,1,128,192]{2,0,3,1:T(8,128)(4,1)} compare(%select_n.4527.clone.1, %eq.2187.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5842.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2451.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5842.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4526.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} select(%eq.2185.clone.1, %broadcast_in_dim.2451.clone.1, %select_n.4527.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5841.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2186.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5841.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2184.clone.1 = pred[1536,1,128,192]{2,0,3,1:T(8,128)(4,1)} compare(%select_n.4526.clone.1, %eq.2186.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5840.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2450.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5840.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4525.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} select(%eq.2184.clone.1, %broadcast_in_dim.2450.clone.1, %select_n.4526.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5836.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5729.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5836.clone.1), dimensions={}, metadata={op_name="broadcast.3699"} + %mul.4657.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%select_n.4525.clone.1, %broadcast.5729.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1860 = f32[1536,1,128,192]{2,3,1,0:T(8,128)} parameter(8) + %copy.1675.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} copy(%param_8.1860), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + %constant.5844.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4658.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5844.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4656.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%copy.1675.clone.1, %mul.4658.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3533.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} add(%mul.4657.clone.1, %mul.4656.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_1.5934 = f32[]{:T(128)S(6)} parameter(1) + %div.2714.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%param_1.5934), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.381.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%select_n.4525.clone.1, %select_n.4525.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.5839.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4655.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5839.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4653.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%integer_pow.381.clone.1, %mul.4655.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3370 = f32[1536,1,128,192]{2,3,1,0:T(8,128)} parameter(4) + %copy.1674.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} copy(%param_4.3370), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + %constant.5838.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4654.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5838.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4652.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%copy.1674.clone.1, %mul.4654.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3532.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} add(%mul.4653.clone.1, %mul.4652.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_0.4518 = f32[]{:T(128)S(6)} parameter(0) + %div.2713.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%param_0.4518), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2712.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} divide(%add.3532.clone.1, %div.2713.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.139.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} sqrt(%div.2712.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.5837.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3531.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} broadcast(%constant.5837.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3530.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} add(%sqrt.139.clone.1, %add.3531.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1276.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%div.2714.clone.1, %add.3530.clone.1), metadata={op_name="multiply.308"} + %div.2711.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} divide(%add.3533.clone.1, %multiply.1276.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4650.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%copy.1673.clone.1, %broadcast.5729.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3529.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} add(%div.2711.clone.1, %mul.4650.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4649.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%mul.4651.clone.1, %add.3529.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3528.clone.1 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} add(%copy.1673.clone.1, %mul.4649.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.593 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%add.3528.clone.1, %add.3528.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.693 = f32[]{:T(128)} reduce(%square.593, %constant.6349), dimensions={0,1,2,3}, to_apply=%region_239.264, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.694.clone.1 = f32[]{:T(128)} reduce(%integer_pow.381.clone.1, %constant.6349), dimensions={0,1,2,3}, to_apply=%region_205.230, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.691 = (f32[]{:T(128)}, f32[1536,1,128,192]{2,0,3,1:T(8,128)}, f32[1536,1,128,192]{2,0,3,1:T(8,128)}, f32[1536,1,128,192]{2,0,3,1:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.693, %add.3528.clone.1, %add.3532.clone.1, %add.3533.clone.1, %reduce.694.clone.1) +} + +%fused_computation.600 (param_0.1709: f32[256,1,512,512]) -> bf16[256,512,512] { + %param_0.1709 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.760 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_0.1709), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %convert_element_type.2738 = bf16[256,1,512,512]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.760), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=717} + ROOT %bitcast.758 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} bitcast(%convert_element_type.2738), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=717} +} + +%fused_computation.601 (param_0.1713: f32[256,1,512,512]) -> bf16[256,512,512] { + %param_0.1713 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.763 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_0.1713), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %convert_element_type.2740 = bf16[256,1,512,512]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.763), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=722} + ROOT %bitcast.761 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} bitcast(%convert_element_type.2740), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=722} +} + +%fused_computation.602 (param_0.1717: f32[256,1,512,512]) -> bf16[256,512,512] { + %param_0.1717 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.766 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_0.1717), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %convert_element_type.2742 = bf16[256,1,512,512]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.766), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=727} + ROOT %bitcast.764 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} bitcast(%convert_element_type.2742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=727} +} + +%region_145.170 (reduce_sum.368: f32[], reduce_sum.198: f32[]) -> f32[] { + %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.198 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.204 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.198), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.610 (param_0.4574: f32[3,18432,512]) -> f32[] { + %param_0.4574 = f32[3,18432,512]{2,1,0:T(8,128)} parameter(0) + %bitcast.778 = f32[18432,3,512]{2,0,1:T(8,128)} bitcast(%param_0.4574), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.596 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%bitcast.778, %bitcast.778), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6405 = f32[]{:T(128)} constant(0) + ROOT %reduce.695 = f32[]{:T(128)} reduce(%square.596, %constant.6405), dimensions={0,1,2}, to_apply=%region_145.170, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%region_144.169 (reduce_sum.361: f32[], reduce_sum.193: f32[]) -> f32[] { + %reduce_sum.361 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.193 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.197 = f32[]{:T(128)} add(%reduce_sum.361, %reduce_sum.193), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_143.168 (reduce_sum.354: f32[], reduce_sum.188: f32[]) -> f32[] { + %reduce_sum.354 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.188 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.192 = f32[]{:T(128)} add(%reduce_sum.354, %reduce_sum.188), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.612 (param_0.4575: f32[3,512,18432], param_1.5978: f32[3,512,18432]) -> (f32[], f32[]) { + %param_0.4575 = f32[3,512,18432]{2,1,0:T(8,128)} parameter(0) + %bitcast.782 = f32[512,3,18432]{2,0,1:T(8,128)} bitcast(%param_0.4575), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.599 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%bitcast.782, %bitcast.782), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6406 = f32[]{:T(128)} constant(0) + %reduce.696 = f32[]{:T(128)} reduce(%square.599, %constant.6406), dimensions={0,1,2}, to_apply=%region_144.169, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + %param_1.5978 = f32[3,512,18432]{2,1,0:T(8,128)} parameter(1) + %bitcast.786.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} bitcast(%param_1.5978), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.602.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%bitcast.786.clone.1, %bitcast.786.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %reduce.697.clone.1 = f32[]{:T(128)} reduce(%square.602.clone.1, %constant.6406), dimensions={0,1,2}, to_apply=%region_143.168, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + ROOT %tuple.779 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.696, %reduce.697.clone.1) +} + +%region_212.237 (reduce_sum.830: f32[], reduce_sum.561: f32[]) -> f32[] { + %reduce_sum.830 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.561 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.562 = f32[]{:T(128)} add(%reduce_sum.830, %reduce_sum.561), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_178.203 (reduce_sum.592: f32[], reduce_sum.407: f32[]) -> f32[] { + %reduce_sum.592 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.407 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.408 = f32[]{:T(128)} add(%reduce_sum.592, %reduce_sum.407), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.615 (param_0.4545: f32[18432,3,512], param_1.5961: f32[], param_2.5478: f32[], param_3.4124: f32[], param_4.3397: f32[18432,3,512], param_5.3062: f32[], param_6.2572: f32[3,18432,512], param_7.2493: pred[], param_8.1887: f32[18432,3,512]) -> (f32[], f32[18432,3,512], f32[18432,3,512], f32[18432,3,512], f32[]) { + %param_0.4545 = f32[18432,3,512]{2,0,1:T(8,128)} parameter(0) + %param_3.4124 = f32[]{:T(128)S(6)} parameter(3) + %mul.4885.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%param_3.4124), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2493 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4759.clone.1 = pred[18432,3,512]{2,0,1:T(8,128)(4,1)} broadcast(%param_7.2493), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2572 = f32[3,18432,512]{2,1,0:T(8,128)} parameter(6) + %bitcast.1392.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} bitcast(%param_6.2572), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3062 = f32[]{:T(128)} parameter(5) + %div.2918.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%param_5.3062), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2917.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} divide(%bitcast.1392.clone.1, %div.2918.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4758.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} select(%select_n.4759.clone.1, %bitcast.1392.clone.1, %div.2917.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1558.clone.1 = pred[18432,3,512]{2,0,1:T(8,128)(4,1)} compare(%select_n.4758.clone.1, %select_n.4758.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6376 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2606.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6376), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4757.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} select(%ne.1558.clone.1, %broadcast_in_dim.2606.clone.1, %select_n.4758.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6185.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2371.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6185.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2370.clone.1 = pred[18432,3,512]{2,0,1:T(8,128)(4,1)} compare(%select_n.4757.clone.1, %eq.2371.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6184.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2605.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6184.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4756.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} select(%eq.2370.clone.1, %broadcast_in_dim.2605.clone.1, %select_n.4757.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6183.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2369.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6183.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2368.clone.1 = pred[18432,3,512]{2,0,1:T(8,128)(4,1)} compare(%select_n.4756.clone.1, %eq.2369.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6182.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2604.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6182.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4755.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} select(%eq.2368.clone.1, %broadcast_in_dim.2604.clone.1, %select_n.4756.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6178.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5827.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6178.clone.1), dimensions={}, metadata={op_name="broadcast.342"} + %mul.4891.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%select_n.4755.clone.1, %broadcast.5827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1887 = f32[18432,3,512]{2,0,1:T(8,128)} parameter(8) + %constant.6186.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4892.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6186.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4890.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%param_8.1887, %mul.4892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3683.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} add(%mul.4891.clone.1, %mul.4890.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5478 = f32[]{:T(128)S(6)} parameter(2) + %div.2914.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%param_2.5478), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.408.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%select_n.4755.clone.1, %select_n.4755.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6181.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4889.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6181.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4887.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%integer_pow.408.clone.1, %mul.4889.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3397 = f32[18432,3,512]{2,0,1:T(8,128)} parameter(4) + %constant.6180.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4888.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6180.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4886.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%param_4.3397, %mul.4888.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3682.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} add(%mul.4887.clone.1, %mul.4886.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5961 = f32[]{:T(128)S(6)} parameter(1) + %div.2913.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%param_1.5961), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2912.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} divide(%add.3682.clone.1, %div.2913.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.166.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} sqrt(%div.2912.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6179.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3681.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} broadcast(%constant.6179.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3680.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} add(%sqrt.166.clone.1, %add.3681.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1303.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%div.2914.clone.1, %add.3680.clone.1), metadata={op_name="multiply.281"} + %div.2911.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} divide(%add.3683.clone.1, %multiply.1303.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4884.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%param_0.4545, %broadcast.5827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3679.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} add(%div.2911.clone.1, %mul.4884.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4883.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%mul.4885.clone.1, %add.3679.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3678.clone.1 = f32[18432,3,512]{2,0,1:T(8,128)} add(%param_0.4545, %mul.4883.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.603 = f32[18432,3,512]{2,0,1:T(8,128)} multiply(%add.3678.clone.1, %add.3678.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.698 = f32[]{:T(128)} reduce(%square.603, %constant.6376), dimensions={0,1,2}, to_apply=%region_212.237, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.701.clone.1 = f32[]{:T(128)} reduce(%integer_pow.408.clone.1, %constant.6376), dimensions={0,1,2}, to_apply=%region_178.203, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.684 = (f32[]{:T(128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.698, %add.3678.clone.1, %add.3682.clone.1, %add.3683.clone.1, %reduce.701.clone.1) +} + +%region_211.236 (reduce_sum.823: f32[], reduce_sum.555: f32[]) -> f32[] { + %reduce_sum.823 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.555 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.556 = f32[]{:T(128)} add(%reduce_sum.823, %reduce_sum.555), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_177.202 (reduce_sum.585: f32[], reduce_sum.401: f32[]) -> f32[] { + %reduce_sum.585 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.401 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.402 = f32[]{:T(128)} add(%reduce_sum.585, %reduce_sum.401), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.616 (param_0.4546: f32[512,3,18432], param_1.5962: f32[], param_2.5479: f32[], param_3.4125: f32[], param_4.3398: f32[512,3,18432], param_5.3063: f32[], param_6.2573: f32[3,512,18432], param_7.2494: pred[], param_8.1888: f32[512,3,18432]) -> (f32[], f32[512,3,18432], f32[512,3,18432], f32[512,3,18432], f32[]) { + %param_0.4546 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(0) + %param_3.4125 = f32[]{:T(128)S(6)} parameter(3) + %mul.4895.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_3.4125), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2494 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4769.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} broadcast(%param_7.2494), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2573 = f32[3,512,18432]{2,1,0:T(8,128)} parameter(6) + %bitcast.1394.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} bitcast(%param_6.2573), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3063 = f32[]{:T(128)} parameter(5) + %div.2926.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_5.3063), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2925.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%bitcast.1394.clone.1, %div.2926.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4768.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%select_n.4769.clone.1, %bitcast.1394.clone.1, %div.2925.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1560.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4768.clone.1, %select_n.4768.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6377 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2612.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6377), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4767.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%ne.1560.clone.1, %broadcast_in_dim.2612.clone.1, %select_n.4768.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6199.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2379.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6199.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2378.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4767.clone.1, %eq.2379.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6198.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2611.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6198.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4766.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%eq.2378.clone.1, %broadcast_in_dim.2611.clone.1, %select_n.4767.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6197.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2377.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6197.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2376.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4766.clone.1, %eq.2377.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6196.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2610.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6196.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4765.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%eq.2376.clone.1, %broadcast_in_dim.2610.clone.1, %select_n.4766.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6192.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5833.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6192.clone.1), dimensions={}, metadata={op_name="broadcast.344"} + %mul.4899.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%select_n.4765.clone.1, %broadcast.5833.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1888 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(8) + %constant.6200.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.5832.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6200.clone.1), dimensions={}, metadata={op_name="broadcast.343"} + %mul.4898.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_8.1888, %broadcast.5832.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3688.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%mul.4899.clone.1, %mul.4898.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5479 = f32[]{:T(128)S(6)} parameter(2) + %div.2922.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_2.5479), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.409.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%select_n.4765.clone.1, %select_n.4765.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6195.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.5831.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6195.clone.1), dimensions={}, metadata={op_name="broadcast.317"} + %mul.4897.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%integer_pow.409.clone.1, %broadcast.5831.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3398 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(4) + %constant.6194.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.5830.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6194.clone.1), dimensions={}, metadata={op_name="broadcast.316"} + %mul.4896.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_4.3398, %broadcast.5830.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3687.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%mul.4897.clone.1, %mul.4896.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5962 = f32[]{:T(128)S(6)} parameter(1) + %div.2921.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_1.5962), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2920.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%add.3687.clone.1, %div.2921.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.167.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} sqrt(%div.2920.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6193.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.5828.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6193.clone.1), dimensions={}, metadata={op_name="broadcast.307"} + %add.3686.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%sqrt.167.clone.1, %broadcast.5828.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1304.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%div.2922.clone.1, %add.3686.clone.1), metadata={op_name="multiply.280"} + %div.2919.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%add.3688.clone.1, %multiply.1304.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4894.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_0.4546, %broadcast.5833.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3685.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%div.2919.clone.1, %mul.4894.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4893.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%mul.4895.clone.1, %add.3685.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3684.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%param_0.4546, %mul.4893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.604 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%add.3684.clone.1, %add.3684.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.699 = f32[]{:T(128)} reduce(%square.604, %constant.6377), dimensions={0,1,2}, to_apply=%region_211.236, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.702.clone.1 = f32[]{:T(128)} reduce(%integer_pow.409.clone.1, %constant.6377), dimensions={0,1,2}, to_apply=%region_177.202, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.685 = (f32[]{:T(128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.699, %add.3684.clone.1, %add.3687.clone.1, %add.3688.clone.1, %reduce.702.clone.1) +} + +%region_210.235 (reduce_sum.816: f32[], reduce_sum.549: f32[]) -> f32[] { + %reduce_sum.816 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.549 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.554 = f32[]{:T(128)} add(%reduce_sum.816, %reduce_sum.549), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_176.201 (reduce_sum.578: f32[], reduce_sum.395: f32[]) -> f32[] { + %reduce_sum.578 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.395 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.400 = f32[]{:T(128)} add(%reduce_sum.578, %reduce_sum.395), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.617 (param_0.4547: f32[512,3,18432], param_1.5963: f32[], param_2.5480: f32[], param_3.4126: f32[], param_4.3399: f32[512,3,18432], param_5.3064: f32[], param_6.2574: f32[3,512,18432], param_7.2495: pred[], param_8.1889: f32[512,3,18432]) -> (f32[], f32[512,3,18432], f32[512,3,18432], f32[512,3,18432], f32[]) { + %param_0.4547 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(0) + %param_3.4126 = f32[]{:T(128)S(6)} parameter(3) + %mul.4902.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_3.4126), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2495 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4779.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} broadcast(%param_7.2495), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2574 = f32[3,512,18432]{2,1,0:T(8,128)} parameter(6) + %bitcast.1396.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} bitcast(%param_6.2574), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3064 = f32[]{:T(128)} parameter(5) + %div.2934.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_5.3064), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2933.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%bitcast.1396.clone.1, %div.2934.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4778.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%select_n.4779.clone.1, %bitcast.1396.clone.1, %div.2933.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1562.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4778.clone.1, %select_n.4778.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6378 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2618.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6378), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4777.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%ne.1562.clone.1, %broadcast_in_dim.2618.clone.1, %select_n.4778.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6213.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2387.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6213.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2386.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4777.clone.1, %eq.2387.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6212.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2617.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6212.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4776.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%eq.2386.clone.1, %broadcast_in_dim.2617.clone.1, %select_n.4777.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6211.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2385.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6211.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2384.clone.1 = pred[512,3,18432]{2,0,1:T(8,128)(4,1)} compare(%select_n.4776.clone.1, %eq.2385.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6210.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2616.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6210.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4775.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} select(%eq.2384.clone.1, %broadcast_in_dim.2616.clone.1, %select_n.4776.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6206.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5839.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6206.clone.1), dimensions={}, metadata={op_name="broadcast.344"} + %mul.4906.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%select_n.4775.clone.1, %broadcast.5839.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1889 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(8) + %constant.6214.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.5838.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6214.clone.1), dimensions={}, metadata={op_name="broadcast.343"} + %mul.4905.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_8.1889, %broadcast.5838.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3693.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%mul.4906.clone.1, %mul.4905.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5480 = f32[]{:T(128)S(6)} parameter(2) + %div.2930.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_2.5480), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.410.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%select_n.4775.clone.1, %select_n.4775.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6209.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.5837.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6209.clone.1), dimensions={}, metadata={op_name="broadcast.317"} + %mul.4904.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%integer_pow.410.clone.1, %broadcast.5837.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3399 = f32[512,3,18432]{2,0,1:T(8,128)} parameter(4) + %constant.6208.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.5836.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6208.clone.1), dimensions={}, metadata={op_name="broadcast.316"} + %mul.4903.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_4.3399, %broadcast.5836.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3692.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%mul.4904.clone.1, %mul.4903.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5963 = f32[]{:T(128)S(6)} parameter(1) + %div.2929.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%param_1.5963), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2928.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%add.3692.clone.1, %div.2929.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.168.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} sqrt(%div.2928.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6207.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.5834.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} broadcast(%constant.6207.clone.1), dimensions={}, metadata={op_name="broadcast.307"} + %add.3691.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%sqrt.168.clone.1, %broadcast.5834.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1305.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%div.2930.clone.1, %add.3691.clone.1), metadata={op_name="multiply.279"} + %div.2927.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} divide(%add.3693.clone.1, %multiply.1305.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4901.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%param_0.4547, %broadcast.5839.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3690.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%div.2927.clone.1, %mul.4901.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4900.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%mul.4902.clone.1, %add.3690.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3689.clone.1 = f32[512,3,18432]{2,0,1:T(8,128)} add(%param_0.4547, %mul.4900.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.605 = f32[512,3,18432]{2,0,1:T(8,128)} multiply(%add.3689.clone.1, %add.3689.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.700 = f32[]{:T(128)} reduce(%square.605, %constant.6378), dimensions={0,1,2}, to_apply=%region_210.235, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.703.clone.1 = f32[]{:T(128)} reduce(%integer_pow.410.clone.1, %constant.6378), dimensions={0,1,2}, to_apply=%region_176.201, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.686 = (f32[]{:T(128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.700, %add.3689.clone.1, %add.3692.clone.1, %add.3693.clone.1, %reduce.703.clone.1) +} + +%region_149.174 (reduce_sum.396: f32[], reduce_sum.225: f32[]) -> f32[] { + %reduce_sum.396 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.225 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.229 = f32[]{:T(128)} add(%reduce_sum.396, %reduce_sum.225), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.631 (param_0.4571: f32[3,128,128,512]) -> f32[] { + %param_0.4571 = f32[3,128,128,512]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.790 = f32[128,3,128,512]{3,2,1,0:T(8,128)} bitcast(%param_0.4571), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %square.608 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%bitcast.790, %bitcast.790), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6402 = f32[]{:T(128)} constant(0) + ROOT %reduce.704 = f32[]{:T(128)} reduce(%square.608, %constant.6402), dimensions={0,1,2,3}, to_apply=%region_149.174, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} +} + +%fused_computation.632 (param_0.1791: f32[128,3,128,512]) -> bf16[3,128,128,512] { + %param_0.1791 = f32[128,3,128,512]{3,2,1,0:T(8,128)} parameter(0) + %copy.1559 = bf16[128,3,128,512]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1791), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.791 = bf16[3,128,128,512]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.1559), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} +} + +%region_216.241 (reduce_sum.858: f32[], reduce_sum.577: f32[]) -> f32[] { + %reduce_sum.858 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.577 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.582 = f32[]{:T(128)} add(%reduce_sum.858, %reduce_sum.577), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_182.207 (reduce_sum.620: f32[], reduce_sum.423: f32[]) -> f32[] { + %reduce_sum.620 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.428 = f32[]{:T(128)} add(%reduce_sum.620, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.633 (param_0.4541: f32[128,3,128,512], param_1.5957: f32[], param_2.5474: f32[], param_3.4120: f32[], param_4.3393: f32[128,3,128,512], param_5.3058: f32[], param_6.2568: f32[3,128,128,512], param_7.2489: pred[], param_8.1883: f32[128,3,128,512]) -> (f32[], f32[128,3,128,512], f32[128,3,128,512], f32[128,3,128,512], f32[]) { + %param_0.4541 = f32[128,3,128,512]{3,2,1,0:T(8,128)} parameter(0) + %param_3.4120 = f32[]{:T(128)S(6)} parameter(3) + %mul.4854.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%param_3.4120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2489 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4719.clone.1 = pred[128,3,128,512]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.2489), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2568 = f32[3,128,128,512]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.1384.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} bitcast(%param_6.2568), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=120} + %param_5.3058 = f32[]{:T(128)} parameter(5) + %div.2886.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%param_5.3058), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2885.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} divide(%bitcast.1384.clone.1, %div.2886.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4718.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} select(%select_n.4719.clone.1, %bitcast.1384.clone.1, %div.2885.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1550.clone.1 = pred[128,3,128,512]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4718.clone.1, %select_n.4718.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6372 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2582.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6372), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4717.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} select(%ne.1550.clone.1, %broadcast_in_dim.2582.clone.1, %select_n.4718.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6129.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2339.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6129.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2338.clone.1 = pred[128,3,128,512]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4717.clone.1, %eq.2339.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6128.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2581.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6128.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4716.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} select(%eq.2338.clone.1, %broadcast_in_dim.2581.clone.1, %select_n.4717.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6127.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2337.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6127.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2336.clone.1 = pred[128,3,128,512]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.4716.clone.1, %eq.2337.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.6126.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2580.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6126.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4715.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} select(%eq.2336.clone.1, %broadcast_in_dim.2580.clone.1, %select_n.4716.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.6122.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5807.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6122.clone.1), dimensions={}, metadata={op_name="broadcast.339"} + %mul.4860.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%select_n.4715.clone.1, %broadcast.5807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1883 = f32[128,3,128,512]{3,2,1,0:T(8,128)} parameter(8) + %constant.6130.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4861.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6130.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4859.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%param_8.1883, %mul.4861.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3662.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} add(%mul.4860.clone.1, %mul.4859.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_2.5474 = f32[]{:T(128)S(6)} parameter(2) + %div.2882.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%param_2.5474), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.404.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%select_n.4715.clone.1, %select_n.4715.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.6125.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4858.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6125.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4856.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%integer_pow.404.clone.1, %mul.4858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3393 = f32[128,3,128,512]{3,2,1,0:T(8,128)} parameter(4) + %constant.6124.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4857.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6124.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4855.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%param_4.3393, %mul.4857.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3661.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} add(%mul.4856.clone.1, %mul.4855.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_1.5957 = f32[]{:T(128)S(6)} parameter(1) + %div.2881.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%param_1.5957), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2880.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} divide(%add.3661.clone.1, %div.2881.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.162.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} sqrt(%div.2880.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.6123.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3660.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} broadcast(%constant.6123.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3659.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} add(%sqrt.162.clone.1, %add.3660.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1299.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%div.2882.clone.1, %add.3659.clone.1), metadata={op_name="multiply.285"} + %div.2879.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} divide(%add.3662.clone.1, %multiply.1299.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4853.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%param_0.4541, %broadcast.5807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3658.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} add(%div.2879.clone.1, %mul.4853.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4852.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%mul.4854.clone.1, %add.3658.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3657.clone.1 = f32[128,3,128,512]{3,2,1,0:T(8,128)} add(%param_0.4541, %mul.4852.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.609 = f32[128,3,128,512]{3,2,1,0:T(8,128)} multiply(%add.3657.clone.1, %add.3657.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.705 = f32[]{:T(128)} reduce(%square.609, %constant.6372), dimensions={0,1,2,3}, to_apply=%region_216.241, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.706.clone.1 = f32[]{:T(128)} reduce(%integer_pow.404.clone.1, %constant.6372), dimensions={0,1,2,3}, to_apply=%region_182.207, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.690 = (f32[]{:T(128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.705, %add.3657.clone.1, %add.3661.clone.1, %add.3662.clone.1, %reduce.706.clone.1) +} + +%fused_computation.642 (param_0.4596: f32[32]) -> (f32[163840,32], f32[163840,32]) { + %mul.3805 = f32[163840,32]{1,0:T(8,128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/dense_layers/mul" stack_frame_id=155} + %param_0.4596 = f32[32]{0:T(128)S(1)} parameter(0) + %mul.3855 = f32[163840,32]{1,0:T(8,128)} broadcast(%param_0.4596), dimensions={1}, metadata={op_name="jit(train_step)/dense_layers/mul" stack_frame_id=155} + %mul.3804 = f32[163840,32]{1,0:T(8,128)} multiply(%mul.3805, %mul.3855), metadata={op_name="jit(train_step)/dense_layers/mul" stack_frame_id=155} + %constant.6433 = f32[]{:T(128)} constant(0) + %convert_element_type.2752 = f32[163840,32]{1,0:T(8,128)} broadcast(%constant.6433), dimensions={}, metadata={op_name="jit(train_step)/dense_layers/convert_element_type" stack_frame_id=78} + %exp.482 = pred[163840,32]{1,0:T(8,128)(4,1)} compare(%mul.3804, %convert_element_type.2752), direction=EQ, metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %mul.3785 = f32[163840,32]{1,0:T(8,128)} multiply(%mul.3804, %convert_element_type.2752), metadata={op_name="jit(train_step)/dense_layers/mul" stack_frame_id=78} + %exp.475 = f32[163840,32]{1,0:T(8,128)} exponential(%mul.3785), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %constant.5437 = f32[]{:T(128)} constant(inf) + %broadcast.5329 = f32[163840,32]{1,0:T(8,128)} broadcast(%constant.5437), dimensions={}, metadata={op_name="broadcast.365"} + %exp.474 = pred[163840,32]{1,0:T(8,128)(4,1)} compare(%exp.475, %broadcast.5329), direction=EQ, metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.462 = f32[163840,32]{1,0:T(8,128)} sine(%mul.3804), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.461 = f32[163840,32]{1,0:T(8,128)} multiply(%exp.475, %exp.462), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.460 = f32[163840,32]{1,0:T(8,128)} multiply(%exp.461, %exp.475), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.459 = f32[163840,32]{1,0:T(8,128)} select(%exp.474, %exp.460, %exp.461), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.455 = f32[163840,32]{1,0:T(8,128)} select(%exp.482, %convert_element_type.2752, %exp.459), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.463.clone.1 = f32[163840,32]{1,0:T(8,128)} cosine(%mul.3804), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.454.clone.1 = f32[163840,32]{1,0:T(8,128)} multiply(%exp.475, %exp.463.clone.1), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.453.clone.1 = f32[163840,32]{1,0:T(8,128)} multiply(%exp.454.clone.1, %exp.475), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + %exp.452.clone.1 = f32[163840,32]{1,0:T(8,128)} select(%exp.474, %exp.453.clone.1, %exp.454.clone.1), metadata={op_name="jit(train_step)/dense_layers/exp" stack_frame_id=175} + ROOT %tuple.755 = (f32[163840,32]{1,0:T(8,128)}, f32[163840,32]{1,0:T(8,128)}) tuple(%exp.455, %exp.452.clone.1) +} + +%fused_computation.650 (param_0.1853: bf16[1536,128,192]) -> bf16[1536,128,192,1] { + %param_0.1853 = bf16[1536,128,192]{1,2,0:T(8,128)(2,1)} parameter(0) + %copy.1567 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} copy(%param_0.1853), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=541} + ROOT %bitcast.805 = bf16[1536,128,192,1]{1,0,2,3:T(8,128)(2,1)} bitcast(%copy.1567), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=541} +} + +%fused_computation.841 (param_0.2654: f32[4,128], param_1.3265: bf16[4,128,1536], param_2.2626: bf16[1536]) -> bf16[4,128,1536,1] { + %param_2.2626 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.719 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.2626), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + %param_1.3265 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.2863 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.3265), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=528} + %param_0.2654 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.4131 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.2654), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=537} + %mul.4130 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.2863, %mul.4131), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=537} + %convert_element_type.2862 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.4130), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=538} + %dot_general.711 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.719, %convert_element_type.2862), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + ROOT %bitcast.1074 = bf16[4,128,1536,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.711), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} +} + +%fused_computation.649 (param_0.2414: bf16[1536,128,192], param_1.3264: bf16[4,128,1536], param_2.2625: f32[4,128], param_3.1842: bf16[1536]) -> bf16[4,128,128,192] { + %param_2.2625 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_1.3264 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_3.1842 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.457 = bf16[4,128,1536,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.2625, %param_1.3264, %param_3.1842), kind=kLoop, calls=%fused_computation.841, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=499} + %param_0.2414 = bf16[1536,128,192]{1,2,0:T(8,128)(2,1)} parameter(0) + %fusion.374 = bf16[1536,128,192,1]{1,0,2,3:T(8,128)(2,1)} fusion(%param_0.2414), kind=kLoop, calls=%fused_computation.650, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=541} + ROOT %convolution.133 = bf16[4,128,128,192]{2,1,3,0:T(8,128)(2,1)S(1)} convolution(%fusion.457, %fusion.374), window={size=1x192 pad=0_0x191_191 rhs_reversal=0x1}, dim_labels=0bf1_io10->0bf1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/dot_general" stack_frame_id=543} +} + +%fused_computation.651 (param_0.1844: f32[1536,1,128,192]) -> bf16[1536,128,192] { + %param_0.1844 = f32[1536,1,128,192]{2,3,1,0:T(8,128)} parameter(0) + %convert_element_type.2743 = bf16[1536,1,128,192]{2,3,1,0:T(8,128)(2,1)} convert(%param_0.1844), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=541} + ROOT %bitcast.806 = bf16[1536,128,192]{1,2,0:T(8,128)(2,1)} bitcast(%convert_element_type.2743), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=541} +} + +%region_170.195 (reduce_sum.543: f32[], reduce_sum.373: f32[]) -> f32[] { + %reduce_sum.543 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.373 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.374 = f32[]{:T(128)} add(%reduce_sum.543, %reduce_sum.373), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.685.clone.clone.clone (param_0.4495: bf16[4,128,128,128], param_1.5914: bf16[4,128,128,128]) -> bf16[4,128,128,256] { + %param_1.5914 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.6323 = bf16[]{:T(256)} constant(-inf) + %pad.441 = bf16[4,128,128,256]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.5914, %constant.6323), padding=0_0x0_0x0_0x0_128, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/concatenate" stack_frame_id=1039} + %param_0.4495 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)} parameter(0) + %pad.440 = bf16[4,128,128,256]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.4495, %constant.6323), padding=0_0x0_0x0_0x128_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/concatenate" stack_frame_id=1039} + ROOT %maximum.99 = bf16[4,128,128,256]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.441, %pad.440), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/concatenate" stack_frame_id=1039} +} + +%fused_computation.976.clone.clone (param_0.4496: f32[4,128], param_1.5915: bf16[4,128,576], param_2.5421: bf16[512]) -> bf16[4,128,512,1] { + %param_2.5421 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) + %dot_general.840 = bf16[4,128,512]{1,2,0:T(8,128)(2,1)} broadcast(%param_2.5421), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=584} + %param_1.5915 = bf16[4,128,576]{1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %split.456 = bf16[4,128,512]{1,2,0:T(8,128)(2,1)} slice(%param_1.5915), slice={[0:4], [0:128], [0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/split" stack_frame_id=590} + %convert_element_type.3225 = f32[4,128,512]{1,2,0:T(8,128)} convert(%split.456), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=591} + %param_0.4496 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5012 = f32[4,128,512]{1,2,0:T(8,128)} broadcast(%param_0.4496), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=600} + %mul.5011 = f32[4,128,512]{1,2,0:T(8,128)} multiply(%convert_element_type.3225, %mul.5012), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=600} + %convert_element_type.3224 = bf16[4,128,512]{1,2,0:T(8,128)(2,1)} convert(%mul.5011), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=601} + %dot_general.839 = bf16[4,128,512]{1,2,0:T(8,128)(2,1)} multiply(%dot_general.840, %convert_element_type.3224), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=584} + ROOT %bitcast.1462 = bf16[4,128,512,1]{1,2,3,0:T(8,128)(2,1)} bitcast(%dot_general.839), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=584} +} + +%fused_computation.655 (param_0.4552: bf16[4,128,128,128], param_1.5967: bf16[4,128,128,128], param_2.5484: f32[4,128], param_3.4130: bf16[4,128,576], param_4.3402: bf16[512]) -> (f32[], bf16[512,128,256,1]) { + %param_2.5484 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.4130 = bf16[4,128,576]{1,2,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.3402 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(4) + %fusion.560.clone.1 = bf16[4,128,512,1]{1,2,3,0:T(8,128)(2,1)} fusion(%param_2.5484, %param_3.4130, %param_4.3402), kind=kLoop, calls=%fused_computation.976.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=584} + %param_0.4552 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)} parameter(0) + %param_1.5967 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %pad_maximum_fusion.19.clone.1 = bf16[4,128,128,256]{3,1,2,0:T(8,128)(2,1)} fusion(%param_0.4552, %param_1.5967), kind=kLoop, calls=%fused_computation.685.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/concatenate" stack_frame_id=1039} + %convolution.136.clone.1 = bf16[512,128,256,1]{2,0,1,3:T(8,128)(2,1)} convolution(%fusion.560.clone.1, %pad_maximum_fusion.19.clone.1), window={size=128x4 pad=127_127x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1i0o->b0f1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=1038} + %bitcast.878 = bf16[512,128,256]{2,0,1:T(8,128)(2,1)} bitcast(%convolution.136.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=1038} + %broadcast_in_dim.1796 = f32[512,128,256]{2,0,1:T(8,128)} convert(%bitcast.878), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.810 = f32[512,1,128,256]{3,0,2,1:T(8,128)} bitcast(%broadcast_in_dim.1796), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %square.612 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%bitcast.810, %bitcast.810), metadata={op_name="jit(train_step)/square" stack_frame_id=860} + %constant.6383 = f32[]{:T(128)} constant(0) + %reduce.707 = f32[]{:T(128)} reduce(%square.612, %constant.6383), dimensions={0,1,2,3}, to_apply=%region_170.195, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=862} + ROOT %tuple.775 = (f32[]{:T(128)}, bf16[512,128,256,1]{2,0,1,3:T(8,128)(2,1)}) tuple(%reduce.707, %convolution.136.clone.1) +} + +%region_237.262 (reduce_sum.1005: f32[], reduce_sum.675: f32[]) -> f32[] { + %reduce_sum.1005 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.675 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.680 = f32[]{:T(128)} add(%reduce_sum.1005, %reduce_sum.675), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_203.228 (reduce_sum.767: f32[], reduce_sum.519: f32[]) -> f32[] { + %reduce_sum.767 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.519 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.520 = f32[]{:T(128)} add(%reduce_sum.767, %reduce_sum.519), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.660 (param_0.4520: f32[], param_1.5936: f32[], param_2.5453: f32[], param_3.4099: f32[512,1,128,256], param_4.3372: f32[512,1,128,256], param_5.3037: f32[], param_6.2547: bf16[512,128,256,1], param_7.2468: pred[], param_8.1862: f32[512,1,128,256]) -> (f32[], f32[512,1,128,256], f32[512,1,128,256], f32[512,1,128,256], f32[]) { + %param_3.4099 = f32[512,1,128,256]{3,2,1,0:T(8,128)S(1)} parameter(3) + %copy.1679.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} copy(%param_3.4099), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + %param_2.5453 = f32[]{:T(128)S(6)} parameter(2) + %mul.4671.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%param_2.5453), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.2468 = pred[]{:T(512)S(6)} parameter(7) + %select_n.4549.clone.1 = pred[512,1,128,256]{3,0,2,1:T(8,128)(4,1)} broadcast(%param_7.2468), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %param_6.2547 = bf16[512,128,256,1]{2,0,1,3:T(8,128)(2,1)} parameter(6) + %bitcast.1345.clone.1 = bf16[512,128,256]{2,0,1:T(8,128)(2,1)} bitcast(%param_6.2547), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=1038} + %broadcast_in_dim.2465.clone.1 = f32[512,128,256]{2,0,1:T(8,128)} convert(%bitcast.1345.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=480} + %bitcast.1344.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} bitcast(%broadcast_in_dim.2465.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=427} + %param_5.3037 = f32[]{:T(128)} parameter(5) + %div.2734.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%param_5.3037), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %div.2733.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} divide(%bitcast.1344.clone.1, %div.2734.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1564} + %select_n.4548.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} select(%select_n.4549.clone.1, %bitcast.1344.clone.1, %div.2733.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=1563} + %ne.1516.clone.1 = pred[512,1,128,256]{3,0,2,1:T(8,128)(4,1)} compare(%select_n.4548.clone.1, %select_n.4548.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=1566} + %constant.6351 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.2468.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.6351), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4547.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} select(%ne.1516.clone.1, %broadcast_in_dim.2468.clone.1, %select_n.4548.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5871.clone.1 = f32[]{:T(128)} constant(inf) + %eq.2203.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5871.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2201.clone.1 = pred[512,1,128,256]{3,0,2,1:T(8,128)(4,1)} compare(%select_n.4547.clone.1, %eq.2203.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5870.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.2467.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5870.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4546.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} select(%eq.2201.clone.1, %broadcast_in_dim.2467.clone.1, %select_n.4547.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5869.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.2202.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5869.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %eq.2200.clone.1 = pred[512,1,128,256]{3,0,2,1:T(8,128)(4,1)} compare(%select_n.4546.clone.1, %eq.2202.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=1566} + %constant.5868.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.2466.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5868.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=1566} + %select_n.4545.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} select(%eq.2200.clone.1, %broadcast_in_dim.2466.clone.1, %select_n.4546.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=1566} + %constant.5864.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.5733.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5864.clone.1), dimensions={}, metadata={op_name="broadcast.3667"} + %mul.4677.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%select_n.4545.clone.1, %broadcast.5733.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1569} + %param_8.1862 = f32[512,1,128,256]{3,2,1,0:T(8,128)} parameter(8) + %copy.1681.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} copy(%param_8.1862), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + %constant.5872.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4678.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5872.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %mul.4676.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%copy.1681.clone.1, %mul.4678.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1570} + %add.3545.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} add(%mul.4677.clone.1, %mul.4676.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1569} + %param_1.5936 = f32[]{:T(128)S(6)} parameter(1) + %div.2730.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%param_1.5936), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1580} + %integer_pow.383.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%select_n.4545.clone.1, %select_n.4545.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=1584} + %constant.5867.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4675.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5867.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %mul.4673.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%integer_pow.383.clone.1, %mul.4675.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1587} + %param_4.3372 = f32[512,1,128,256]{3,2,1,0:T(8,128)} parameter(4) + %copy.1680.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} copy(%param_4.3372), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + %constant.5866.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4674.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5866.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %mul.4672.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%copy.1680.clone.1, %mul.4674.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1588} + %add.3544.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} add(%mul.4673.clone.1, %mul.4672.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1587} + %param_0.4520 = f32[]{:T(128)S(6)} parameter(0) + %div.2729.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%param_0.4520), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %div.2728.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} divide(%add.3544.clone.1, %div.2729.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1589} + %sqrt.141.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} sqrt(%div.2728.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=1591} + %constant.5865.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3543.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} broadcast(%constant.5865.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %add.3542.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} add(%sqrt.141.clone.1, %add.3543.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1591} + %multiply.1278.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%div.2730.clone.1, %add.3542.clone.1), metadata={op_name="multiply.306"} + %div.2727.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} divide(%add.3545.clone.1, %multiply.1278.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=1592} + %mul.4670.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%copy.1679.clone.1, %broadcast.5733.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=1594} + %add.3541.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} add(%div.2727.clone.1, %mul.4670.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1595} + %mul.4669.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%mul.4671.clone.1, %add.3541.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.3540.clone.1 = f32[512,1,128,256]{3,0,2,1:T(8,128)} add(%copy.1679.clone.1, %mul.4669.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=1599} + %square.613 = f32[512,1,128,256]{3,0,2,1:T(8,128)} multiply(%add.3540.clone.1, %add.3540.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=1630} + %reduce.708 = f32[]{:T(128)} reduce(%square.613, %constant.6351), dimensions={0,1,2,3}, to_apply=%region_237.262, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1631} + %reduce.709.clone.1 = f32[]{:T(128)} reduce(%integer_pow.383.clone.1, %constant.6351), dimensions={0,1,2,3}, to_apply=%region_203.228, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=1607} + ROOT %tuple.697 = (f32[]{:T(128)}, f32[512,1,128,256]{3,0,2,1:T(8,128)}, f32[512,1,128,256]{3,0,2,1:T(8,128)}, f32[512,1,128,256]{3,0,2,1:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.708, %add.3540.clone.1, %add.3544.clone.1, %add.3545.clone.1, %reduce.709.clone.1) +} + +%fused_computation.677 (param_0.1989: bf16[4,128,128,128], param_1.5311: bf16[4,128,128,64], param_2.5340: bf16[4,128,128,64]) -> (bf16[4,128,128,192], bf16[4,128,128,192]) { + %param_0.1989 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)} parameter(0) + %constant.5349 = bf16[]{:T(256)} constant(-inf) + %pad.295 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1989, %constant.5349), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/concatenate" stack_frame_id=633} + %param_1.5311 = bf16[4,128,128,64]{3,1,2,0:T(8,128)(2,1)} parameter(1) + %pad.313 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.5311, %constant.5349), padding=0_0x0_0x0_0x128_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/concatenate" stack_frame_id=1056} + %maximum.53 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.295, %pad.313), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/concatenate" stack_frame_id=1056} + %bitcast.872 = bf16[4,128,128,192]{3,2,1,0:T(8,128)(2,1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/transpose" stack_frame_id=1060} + %param_2.5340 = bf16[4,128,128,64]{3,1,2,0:T(8,128)(2,1)} parameter(2) + %pad.310.clone.1 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} pad(%param_2.5340, %constant.5349), padding=0_0x0_0x0_0x128_0, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/concatenate" stack_frame_id=633} + %maximum.51.clone.1 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.295, %pad.310.clone.1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/concatenate" stack_frame_id=633} + %bitcast.871.clone.1 = bf16[4,128,128,192]{3,2,1,0:T(8,128)(2,1)} bitcast(%maximum.51.clone.1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/transpose" stack_frame_id=637} + ROOT %tuple.703 = (bf16[4,128,128,192]{3,2,1,0:T(8,128)(2,1)}, bf16[4,128,128,192]{3,2,1,0:T(8,128)(2,1)}) tuple(%bitcast.872, %bitcast.871.clone.1) +} + +%fused_computation.683 (param_0.2150: bf16[4,128,128,128], param_1.5316: bf16[4,128,128,32], param_2.4576: bf16[4,128,128,32], param_3.4018: bf16[4,128,128,32], param_4.3282: bf16[4,128,128,32]) -> (bf16[4,128,128,192], bf16[4,128,128,192]) { + %param_0.2150 = bf16[4,128,128,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %constant.5357 = bf16[]{:T(256)} constant(-inf) + %pad.307 = bf16[4,128,128,192]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.2150, %constant.5357), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/concatenate" stack_frame_id=561} diff --git a/tests/utils/reference_hlo_llama3_8b.txt b/tests/utils/reference_hlo_llama3_8b.txt new file mode 100644 index 0000000000..32cc3f89e7 --- /dev/null +++ b/tests/utils/reference_hlo_llama3_8b.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=5*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=10*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[4096]{0:T(1024)}, /*index=15*/f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, /*index=20*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, /*index=25*/f32[128256,4096]{1,0:T(8,128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=30*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, /*index=40*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=5*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=10*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[4096]{0:T(1024)}, /*index=15*/f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, /*index=20*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, /*index=25*/f32[128256,4096]{1,0:T(8,128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=30*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, /*index=40*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=45*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%fused_computation (param_0.2: bf16[128256,4096], param_1.7: s32[1024]) -> bf16[512,4096] { + %param_0.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %reshape.342 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %transpose.326 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.342), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.326), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %transpose.325 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + ROOT %reshape.341 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.325), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} +} + +%region_33.38.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { + %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.476 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=580}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.1 (param_0.3: bf16[128256,4096], param_1.5: s32[512], param_2.4: bf16[512,4096]) -> bf16[128256,4096] { + %param_0.3 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.349 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %transpose.331 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.349), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %param_2.4 = bf16[512,4096]{1,0:T(8,128)(2,1)S(1)} parameter(2) + %reshape.350 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=106} + %transpose.332 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.350), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=106} + ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.331, %transpose.332), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=94} +} + +%region_32.37 (reduce_sum.190: f32[], reduce_sum.191: f32[]) -> f32[] { + %reduce_sum.190 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.191 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.192 = f32[]{:T(128)} add(%reduce_sum.190, %reduce_sum.191), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.301.clone.clone.clone (param_0.1252: bf16[4,128,128256], param_1.1645: s32[4,128], param_2.1587: f32[4,128], param_3.1245: f32[4,128], param_4.928: bf16[4,128], param_5.780: f32[4,128]) -> bf16[4,128,128256] { + %param_5.780 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.1613 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.780), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_3.1245 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1612 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.1245), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_0.1252 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1062 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1252), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %param_4.928 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.928), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1062, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %exp.62 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=361} + %mul.1611 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1612, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_2.1587 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.895 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1587), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %div.894 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1611, %div.895), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %param_1.1645 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.649 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1645), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.648 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.647 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.649, %eq.648), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %convert_element_type.1061 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.647), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=365} + %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.894, %convert_element_type.1061), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=80} + %mul.1610 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1613, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + ROOT %convert_element_type.1060 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1610), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} +} + +%fused_computation.343.clone.clone (param_0.1253: f32[4,128], param_1.1646: bf16[4,128,4096], param_2.1589: bf16[4096]) -> bf16[4,128,4096] { + %param_2.1589 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.387 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1589), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_1.1646 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1064 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1646), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %param_0.1253 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1615 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1253), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %mul.1614 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1064, %mul.1615), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %convert_element_type.1063 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1614), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=334} + ROOT %dot_general.386 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.387, %convert_element_type.1063), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} +} + +%fused_computation.234 (param_0.1272: bf16[4,128,128256], param_1.1661: s32[4,128], param_2.1613: f32[4,128], param_3.1261: f32[4,128], param_4.943: bf16[4,128], param_5.795: f32[4,128], param_6.641: f32[4,128], param_7.620: bf16[4,128,4096], param_8.398: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { + %param_6.641 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.620 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) + %param_8.398 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(8) + %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.641, %param_7.620, %param_8.398), kind=kLoop, calls=%fused_computation.343.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_0.1272 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1661 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1613 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1261 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.943 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %param_5.795 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1272, %param_1.1661, %param_2.1613, %param_3.1261, %param_4.943, /*index=5*/%param_5.795), kind=kLoop, calls=%fused_computation.301.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %convolution.88.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.239.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=347} + %bitcast.306 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%convolution.88.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=347} + %convert_element_type.923 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.923, %convert_element_type.923), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1454 = f32[]{:T(128)} constant(0) + %reduce.118 = f32[]{:T(128)} reduce(%square.157, %constant.1454), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.118, %convolution.88.clone.1) +} + +%region_34.39 (reduce_sum.196: f32[], reduce_sum.197: f32[]) -> f32[] { + %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.197 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.198 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.197), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.235 (param_0.1271: bf16[128256,4096]) -> f32[] { + %param_0.1271 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.925 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1271), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.925, %convert_element_type.925), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1453 = f32[]{:T(128)} constant(0) + ROOT %reduce.119 = f32[]{:T(128)} reduce(%square.159, %constant.1453), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} +} + +%region_60.65 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { + %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.259: f32[], reduce_sum.260: f32[]) -> f32[] { + %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.260 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.261 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.260), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.236 (param_0.1259: f32[128256,4096], param_1.1649: f32[], param_2.1601: f32[], param_3.1249: f32[], param_4.931: f32[128256,4096], param_5.783: f32[], param_6.629: bf16[128256,4096], param_7.608: pred[], param_8.386: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { + %param_0.1259 = f32[128256,4096]{1,0:T(8,128)} parameter(0) + %param_3.1249 = f32[]{:T(128)S(6)} parameter(3) + %mul.1482.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.1249), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.608 = pred[]{:T(512)S(6)} parameter(7) + %select_n.788.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} broadcast(%param_7.608), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.629 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1035.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.629), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %param_5.783 = f32[]{:T(128)} parameter(5) + %div.797.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_5.783), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.796.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1035.clone.1, %div.797.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.787.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.788.clone.1, %convert_element_type.1035.clone.1, %div.796.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.184.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} compare(%select_n.787.clone.1, %select_n.787.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1441 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.620.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1441), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.786.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%ne.184.clone.1, %broadcast_in_dim.620.clone.1, %select_n.787.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1266.clone.1 = f32[]{:T(128)} constant(inf) + %eq.549.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1266.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.547.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} compare(%select_n.786.clone.1, %eq.549.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1265.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.619.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1265.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.785.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%eq.547.clone.1, %broadcast_in_dim.619.clone.1, %select_n.786.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1264.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.548.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1264.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.546.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} compare(%select_n.785.clone.1, %eq.548.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1263.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.618.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1263.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.784.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%eq.546.clone.1, %broadcast_in_dim.618.clone.1, %select_n.785.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1259.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.554.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1259.clone.1), dimensions={}, metadata={op_name="broadcast.61"} + %mul.1488.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.784.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.386 = f32[128256,4096]{1,0:T(8,128)} parameter(8) + %constant.1267.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1489.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1267.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1487.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.386, %mul.1489.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.776.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1488.clone.1, %mul.1487.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1601 = f32[]{:T(128)S(6)} parameter(2) + %div.793.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1601), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.60.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.784.clone.1, %select_n.784.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1262.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1486.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1262.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1484.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.931 = f32[128256,4096]{1,0:T(8,128)} parameter(4) + %constant.1261.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1485.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1261.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1483.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.931, %mul.1485.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.775.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1484.clone.1, %mul.1483.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1649 = f32[]{:T(128)S(6)} parameter(1) + %div.792.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1649), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.791.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.775.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.58.clone.1 = f32[128256,4096]{1,0:T(8,128)} sqrt(%div.791.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1260.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.774.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.1260.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.773.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.256.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.793.clone.1, %add.773.clone.1), metadata={op_name="multiply.42"} + %div.790.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.776.clone.1, %multiply.256.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1481.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1259, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.772.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.790.clone.1, %mul.1481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1480.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1482.clone.1, %add.772.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.771.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1259, %mul.1480.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.771.clone.1, %add.771.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.120 = f32[]{:T(128)} reduce(%square.160, %constant.1441), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.122.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.1441), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.120, %add.771.clone.1, %add.775.clone.1, %add.776.clone.1, %reduce.122.clone.1) +} + +%region_59.64 (reduce_sum.331: f32[], reduce_sum.332: f32[]) -> f32[] { + %reduce_sum.331 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.332 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.323 = f32[]{:T(128)} add(%reduce_sum.331, %reduce_sum.332), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.253: f32[], reduce_sum.254: f32[]) -> f32[] { + %reduce_sum.253 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.255 = f32[]{:T(128)} add(%reduce_sum.253, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.237 (param_0.1260: f32[4096,128256], param_1.1650: f32[], param_2.1602: f32[], param_3.1250: f32[], param_4.932: f32[4096,128256], param_5.784: f32[], param_6.630: bf16[4096,128256,1], param_7.609: pred[], param_8.387: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { + %param_0.1260 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %param_3.1250 = f32[]{:T(128)S(6)} parameter(3) + %mul.1492.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.1250), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.609 = pred[]{:T(512)S(6)} parameter(7) + %select_n.798.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} broadcast(%param_7.609), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.630 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.630), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=347} + %convert_element_type.1037.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + %param_5.784 = f32[]{:T(128)} parameter(5) + %div.805.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_5.784), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.804.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1037.clone.1, %div.805.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.797.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.798.clone.1, %convert_element_type.1037.clone.1, %div.804.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.186.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} compare(%select_n.797.clone.1, %select_n.797.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1442 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.626.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1442), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.796.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%ne.186.clone.1, %broadcast_in_dim.626.clone.1, %select_n.797.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1280.clone.1 = f32[]{:T(128)} constant(inf) + %eq.557.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1280.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.555.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} compare(%select_n.796.clone.1, %eq.557.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1279.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.625.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1279.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.795.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%eq.555.clone.1, %broadcast_in_dim.625.clone.1, %select_n.796.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1278.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.556.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1278.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.554.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} compare(%select_n.795.clone.1, %eq.556.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1277.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.624.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1277.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.794.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%eq.554.clone.1, %broadcast_in_dim.624.clone.1, %select_n.795.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1273.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.556.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1273.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %mul.1498.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.794.clone.1, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.387 = f32[4096,128256]{1,0:T(8,128)} parameter(8) + %constant.1281.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1499.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1281.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1497.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.387, %mul.1499.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.782.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1498.clone.1, %mul.1497.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1602 = f32[]{:T(128)S(6)} parameter(2) + %div.801.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1602), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.61.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.794.clone.1, %select_n.794.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1276.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1496.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1276.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1494.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1496.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.932 = f32[4096,128256]{1,0:T(8,128)} parameter(4) + %constant.1275.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1495.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1275.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1493.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.932, %mul.1495.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.781.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1494.clone.1, %mul.1493.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1650 = f32[]{:T(128)S(6)} parameter(1) + %div.800.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1650), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.799.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.781.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.59.clone.1 = f32[4096,128256]{1,0:T(8,128)} sqrt(%div.799.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1274.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.780.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.1274.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.779.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.257.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.801.clone.1, %add.779.clone.1), metadata={op_name="multiply.41"} + %div.798.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.782.clone.1, %multiply.257.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1491.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1260, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.778.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.798.clone.1, %mul.1491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1490.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1492.clone.1, %add.778.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.777.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1260, %mul.1490.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.777.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.121 = f32[]{:T(128)} reduce(%square.161, %constant.1442), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.123.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.1442), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.121, %add.777.clone.1, %add.781.clone.1, %add.782.clone.1, %reduce.123.clone.1) +} + +%region_25.30 (reduce_sum.154: f32[], reduce_sum.155: f32[]) -> f32[] { + %reduce_sum.154 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.155 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.156 = f32[]{:T(128)} add(%reduce_sum.154, %reduce_sum.155), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.254 (param_0.1277: f32[4,14336,4096]) -> f32[] { + %param_0.1277 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) + %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1277), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.164 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%bitcast.314, %bitcast.314), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1459 = f32[]{:T(128)} constant(0) + ROOT %reduce.124 = f32[]{:T(128)} reduce(%square.164, %constant.1459), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} +} + +%region_24.29 (reduce_sum.148: f32[], reduce_sum.149: f32[]) -> f32[] { + %reduce_sum.148 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.149 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.150 = f32[]{:T(128)} add(%reduce_sum.148, %reduce_sum.149), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_23.28 (reduce_sum.142: f32[], reduce_sum.143: f32[]) -> f32[] { + %reduce_sum.142 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.143 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.147 = f32[]{:T(128)} add(%reduce_sum.142, %reduce_sum.143), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.256 (param_0.1278: f32[4,4096,14336], param_1.1664: f32[4,4096,14336]) -> (f32[], f32[]) { + %param_0.1278 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) + %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1278), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.167 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.318, %bitcast.318), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1460 = f32[]{:T(128)} constant(0) + %reduce.125 = f32[]{:T(128)} reduce(%square.167, %constant.1460), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + %param_1.1664 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) + %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1664), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.170.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.322.clone.1, %bitcast.322.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %reduce.126.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1460), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.125, %reduce.126.clone.1) +} + +%fused_computation.259 (param_0.739: f32[14336,4,4096]) -> bf16[4,14336,4096] { + %param_0.739 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.739), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.323 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} bitcast(%copy.234), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%fused_computation.260 (param_0.741: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.741 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.741), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.324 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.235), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%fused_computation.261 (param_0.743: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.743 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.743), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.325 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.236), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_52.57 (reduce_sum.289: f32[], reduce_sum.290: f32[]) -> f32[] { + %reduce_sum.289 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.290 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.289, %reduce_sum.290), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_38.43 (reduce_sum.217: f32[], reduce_sum.218: f32[]) -> f32[] { + %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.218 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.219 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.218), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.262 (param_0.1267: f32[14336,4,4096], param_1.1657: f32[], param_2.1609: f32[], param_3.1257: f32[], param_4.939: f32[14336,4,4096], param_5.791: f32[], param_6.637: f32[4,14336,4096], param_7.616: pred[], param_8.394: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { + %param_0.1267 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %param_3.1257 = f32[]{:T(128)S(6)} parameter(3) + %mul.1550.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.1257), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.616 = pred[]{:T(512)S(6)} parameter(7) + %select_n.868.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.616), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.637 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) + %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.637), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.791 = f32[]{:T(128)} parameter(5) + %div.861.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_5.791), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.860.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%bitcast.423.clone.1, %div.861.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.867.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%select_n.868.clone.1, %bitcast.423.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.200.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} compare(%select_n.867.clone.1, %select_n.867.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1449 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.668.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1449), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.866.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%ne.200.clone.1, %broadcast_in_dim.668.clone.1, %select_n.867.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1378.clone.1 = f32[]{:T(128)} constant(inf) + %eq.613.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1378.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.612.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} compare(%select_n.866.clone.1, %eq.613.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1377.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.667.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1377.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.865.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%eq.612.clone.1, %broadcast_in_dim.667.clone.1, %select_n.866.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1376.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.611.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1376.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.610.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} compare(%select_n.865.clone.1, %eq.611.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1375.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.666.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1375.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.864.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%eq.610.clone.1, %broadcast_in_dim.666.clone.1, %select_n.865.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1371.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.586.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1371.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1556.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.864.clone.1, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.394 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(8) + %constant.1379.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1557.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1379.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1555.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.394, %mul.1557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.820.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1556.clone.1, %mul.1555.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1609 = f32[]{:T(128)S(6)} parameter(2) + %div.857.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1609), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.68.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.864.clone.1, %select_n.864.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1374.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1554.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1374.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1552.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.939 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) + %constant.1373.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1553.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1373.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1551.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.939, %mul.1553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.819.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1552.clone.1, %mul.1551.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1657 = f32[]{:T(128)S(6)} parameter(1) + %div.856.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1657), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.855.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.819.clone.1, %div.856.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.66.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} sqrt(%div.855.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1372.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.818.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.1372.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.817.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.264.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.857.clone.1, %add.817.clone.1), metadata={op_name="multiply.34"} + %div.854.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.820.clone.1, %multiply.264.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1549.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1267, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.816.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.854.clone.1, %mul.1549.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1548.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1550.clone.1, %add.816.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.815.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1267, %mul.1548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.815.clone.1, %add.815.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.127 = f32[]{:T(128)} reduce(%square.171, %constant.1449), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1449), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.127, %add.815.clone.1, %add.819.clone.1, %add.820.clone.1, %reduce.130.clone.1) +} + +%region_51.56 (reduce_sum.283: f32[], reduce_sum.287: f32[]) -> f32[] { + %reduce_sum.283 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.288 = f32[]{:T(128)} add(%reduce_sum.283, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_37.42 (reduce_sum.211: f32[], reduce_sum.212: f32[]) -> f32[] { + %reduce_sum.211 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.212 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.213 = f32[]{:T(128)} add(%reduce_sum.211, %reduce_sum.212), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.263 (param_0.1268: f32[4096,4,14336], param_1.1658: f32[], param_2.1610: f32[], param_3.1258: f32[], param_4.940: f32[4096,4,14336], param_5.792: f32[], param_6.638: f32[4,4096,14336], param_7.617: pred[], param_8.395: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1268 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.1258 = f32[]{:T(128)S(6)} parameter(3) + %mul.1560.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.1258), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.617 = pred[]{:T(512)S(6)} parameter(7) + %select_n.878.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.617), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.638 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.638), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.792 = f32[]{:T(128)} parameter(5) + %div.869.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.792), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.868.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.425.clone.1, %div.869.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.877.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.878.clone.1, %bitcast.425.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.202.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.877.clone.1, %select_n.877.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1450 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.674.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1450), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.876.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%ne.202.clone.1, %broadcast_in_dim.674.clone.1, %select_n.877.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1392.clone.1 = f32[]{:T(128)} constant(inf) + %eq.621.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1392.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.620.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.876.clone.1, %eq.621.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1391.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.673.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1391.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.875.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%eq.620.clone.1, %broadcast_in_dim.673.clone.1, %select_n.876.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1390.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.619.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1390.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.618.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.875.clone.1, %eq.619.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1389.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.672.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1389.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.874.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%eq.618.clone.1, %broadcast_in_dim.672.clone.1, %select_n.875.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1385.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.592.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1385.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1564.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.874.clone.1, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.395 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) + %constant.1393.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.591.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1393.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1563.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.395, %broadcast.591.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.825.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1564.clone.1, %mul.1563.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1610 = f32[]{:T(128)S(6)} parameter(2) + %div.865.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1610), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.69.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.874.clone.1, %select_n.874.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1388.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.590.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1388.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1562.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.590.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.940 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.1387.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.589.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1387.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1561.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.940, %broadcast.589.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.824.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1562.clone.1, %mul.1561.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1658 = f32[]{:T(128)S(6)} parameter(1) + %div.864.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1658), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.863.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.824.clone.1, %div.864.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.67.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.863.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1386.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.587.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1386.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.823.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.587.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.265.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.865.clone.1, %add.823.clone.1), metadata={op_name="multiply.33"} + %div.862.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.825.clone.1, %multiply.265.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1268, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.822.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.862.clone.1, %mul.1559.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1560.clone.1, %add.822.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.821.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1268, %mul.1558.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.821.clone.1, %add.821.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.128 = f32[]{:T(128)} reduce(%square.172, %constant.1450), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1450), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.821.clone.1, %add.824.clone.1, %add.825.clone.1, %reduce.131.clone.1) +} + +%region_50.55 (reduce_sum.280: f32[], reduce_sum.281: f32[]) -> f32[] { + %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.281 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.282 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.281), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_36.41 (reduce_sum.205: f32[], reduce_sum.206: f32[]) -> f32[] { + %reduce_sum.205 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.206 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.205, %reduce_sum.206), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.264 (param_0.1269: f32[4096,4,14336], param_1.1659: f32[], param_2.1611: f32[], param_3.1259: f32[], param_4.941: f32[4096,4,14336], param_5.793: f32[], param_6.639: f32[4,4096,14336], param_7.618: pred[], param_8.396: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1269 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.1259 = f32[]{:T(128)S(6)} parameter(3) + %mul.1567.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.1259), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.618 = pred[]{:T(512)S(6)} parameter(7) + %select_n.888.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.618), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.639 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.639), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.793 = f32[]{:T(128)} parameter(5) + %div.877.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.793), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.876.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.427.clone.1, %div.877.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.887.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.888.clone.1, %bitcast.427.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.204.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.887.clone.1, %select_n.887.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1451 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.680.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1451), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.886.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%ne.204.clone.1, %broadcast_in_dim.680.clone.1, %select_n.887.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1406.clone.1 = f32[]{:T(128)} constant(inf) + %eq.629.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1406.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.628.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.886.clone.1, %eq.629.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1405.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.679.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1405.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.885.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%eq.628.clone.1, %broadcast_in_dim.679.clone.1, %select_n.886.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1404.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.627.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1404.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.626.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} compare(%select_n.885.clone.1, %eq.627.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1403.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.678.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1403.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.884.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%eq.626.clone.1, %broadcast_in_dim.678.clone.1, %select_n.885.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1399.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.598.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1399.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1571.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.884.clone.1, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.396 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) + %constant.1407.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.597.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1407.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1570.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.396, %broadcast.597.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.830.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1571.clone.1, %mul.1570.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1611 = f32[]{:T(128)S(6)} parameter(2) + %div.873.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1611), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.70.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.884.clone.1, %select_n.884.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1402.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.596.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1402.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1569.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.596.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.941 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.1401.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.595.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1401.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1568.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.941, %broadcast.595.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.829.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1569.clone.1, %mul.1568.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1659 = f32[]{:T(128)S(6)} parameter(1) + %div.872.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1659), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.871.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.829.clone.1, %div.872.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.68.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.871.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1400.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.593.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.1400.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.828.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.593.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.266.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.873.clone.1, %add.828.clone.1), metadata={op_name="multiply.32"} + %div.870.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.830.clone.1, %multiply.266.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1566.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1269, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.827.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.870.clone.1, %mul.1566.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1565.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1567.clone.1, %add.827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.826.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1269, %mul.1565.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.826.clone.1, %add.826.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.129 = f32[]{:T(128)} reduce(%square.173, %constant.1451), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.132.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1451), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.826.clone.1, %add.829.clone.1, %add.830.clone.1, %reduce.132.clone.1) +} + +%region_30.35 (reduce_sum.178: f32[], reduce_sum.182: f32[]) -> f32[] { + %reduce_sum.178 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.182 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.183 = f32[]{:T(128)} add(%reduce_sum.178, %reduce_sum.182), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.288 (param_0.1273: f32[4,4096,32,128]) -> f32[] { + %param_0.1273 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1273), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.176 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%bitcast.329, %bitcast.329), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1455 = f32[]{:T(128)} constant(0) + ROOT %reduce.133 = f32[]{:T(128)} reduce(%square.176, %constant.1455), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} +} + +%region_29.34 (reduce_sum.175: f32[], reduce_sum.176: f32[]) -> f32[] { + %reduce_sum.175 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.176 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.177 = f32[]{:T(128)} add(%reduce_sum.175, %reduce_sum.176), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.290 (param_0.1274: f32[4,32,128,4096]) -> f32[] { + %param_0.1274 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1274), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.179 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%bitcast.333, %bitcast.333), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1456 = f32[]{:T(128)} constant(0) + ROOT %reduce.134 = f32[]{:T(128)} reduce(%square.179, %constant.1456), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} +} + +%fused_computation.291 (param_0.826: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { + %param_0.826 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.826), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.334 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.237), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_57.62 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { + %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_43.48 (reduce_sum.241: f32[], reduce_sum.245: f32[]) -> f32[] { + %reduce_sum.241 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.246 = f32[]{:T(128)} add(%reduce_sum.241, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.292 (param_0.1262: f32[4096,4,32,128], param_1.1652: f32[], param_2.1604: f32[], param_3.1252: f32[], param_4.934: f32[4096,4,32,128], param_5.786: f32[], param_6.632: f32[4,4096,32,128], param_7.611: pred[], param_8.389: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { + %param_0.1262 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1252 = f32[]{:T(128)S(6)} parameter(3) + %mul.1509.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1252), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.611 = pred[]{:T(512)S(6)} parameter(7) + %select_n.818.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.611), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.632 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.632), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.786 = f32[]{:T(128)} parameter(5) + %div.821.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_5.786), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.820.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%bitcast.413.clone.1, %div.821.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.817.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%select_n.818.clone.1, %bitcast.413.clone.1, %div.820.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.190.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.817.clone.1, %select_n.817.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1444 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.638.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1444), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.816.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%ne.190.clone.1, %broadcast_in_dim.638.clone.1, %select_n.817.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1308.clone.1 = f32[]{:T(128)} constant(inf) + %eq.573.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1308.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.572.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.816.clone.1, %eq.573.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1307.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.637.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1307.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.815.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%eq.572.clone.1, %broadcast_in_dim.637.clone.1, %select_n.816.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1306.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.571.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1306.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.570.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.815.clone.1, %eq.571.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1305.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.636.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1305.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.814.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%eq.570.clone.1, %broadcast_in_dim.636.clone.1, %select_n.815.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1301.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.564.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1301.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %mul.1515.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.814.clone.1, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.389 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1309.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1516.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1309.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1514.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.389, %mul.1516.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.793.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1515.clone.1, %mul.1514.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1604 = f32[]{:T(128)S(6)} parameter(2) + %div.817.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1604), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.63.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.814.clone.1, %select_n.814.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1304.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1513.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1304.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1511.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1513.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.934 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1303.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1512.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1303.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1510.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.934, %mul.1512.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.792.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1511.clone.1, %mul.1510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1652 = f32[]{:T(128)S(6)} parameter(1) + %div.816.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1652), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.815.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.792.clone.1, %div.816.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.61.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} sqrt(%div.815.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1302.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.791.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.1302.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.790.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.259.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.817.clone.1, %add.790.clone.1), metadata={op_name="multiply.39"} + %div.814.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.793.clone.1, %multiply.259.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1508.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1262, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.789.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.814.clone.1, %mul.1508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1507.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1509.clone.1, %add.789.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.788.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1262, %mul.1507.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.788.clone.1, %add.788.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.135 = f32[]{:T(128)} reduce(%square.180, %constant.1444), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.1444), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.135, %add.788.clone.1, %add.792.clone.1, %add.793.clone.1, %reduce.139.clone.1) +} + +%region_56.61 (reduce_sum.310: f32[], reduce_sum.311: f32[]) -> f32[] { + %reduce_sum.310 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.311 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.310, %reduce_sum.311), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_42.47 (reduce_sum.238: f32[], reduce_sum.239: f32[]) -> f32[] { + %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.239 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.240 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.239), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.293 (param_0.1263: f32[32,4,128,4096], param_1.1653: f32[], param_2.1605: f32[], param_3.1253: f32[], param_4.935: f32[32,4,128,4096], param_5.787: f32[], param_6.633: f32[4,32,128,4096], param_7.612: pred[], param_8.390: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { + %param_0.1263 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1253 = f32[]{:T(128)S(6)} parameter(3) + %mul.1519.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.1253), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.612 = pred[]{:T(512)S(6)} parameter(7) + %select_n.828.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.612), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.633 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.787 = f32[]{:T(128)} parameter(5) + %div.829.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_5.787), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.828.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%bitcast.415.clone.1, %div.829.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.827.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%select_n.828.clone.1, %bitcast.415.clone.1, %div.828.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.192.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.827.clone.1, %select_n.827.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1445 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.644.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1445), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.826.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%ne.192.clone.1, %broadcast_in_dim.644.clone.1, %select_n.827.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1322.clone.1 = f32[]{:T(128)} constant(inf) + %eq.581.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1322.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.580.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.826.clone.1, %eq.581.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1321.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.643.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1321.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.825.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%eq.580.clone.1, %broadcast_in_dim.643.clone.1, %select_n.826.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1320.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.579.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1320.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.578.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.825.clone.1, %eq.579.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1319.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.642.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1319.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.824.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%eq.578.clone.1, %broadcast_in_dim.642.clone.1, %select_n.825.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1315.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.566.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1315.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %mul.1525.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.824.clone.1, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.390 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(8) + %constant.1323.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1526.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1323.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1524.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.390, %mul.1526.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.799.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1525.clone.1, %mul.1524.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1605 = f32[]{:T(128)S(6)} parameter(2) + %div.825.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1605), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.64.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.824.clone.1, %select_n.824.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1318.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1523.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1318.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1521.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.935 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) + %constant.1317.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1522.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1317.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1520.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.935, %mul.1522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.798.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1521.clone.1, %mul.1520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1653 = f32[]{:T(128)S(6)} parameter(1) + %div.824.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1653), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.823.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.798.clone.1, %div.824.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.62.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} sqrt(%div.823.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1316.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.797.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.1316.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.796.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.260.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.825.clone.1, %add.796.clone.1), metadata={op_name="multiply.38"} + %div.822.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.799.clone.1, %multiply.260.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1518.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1263, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.795.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.822.clone.1, %mul.1518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1517.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1519.clone.1, %add.795.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.794.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1263, %mul.1517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.794.clone.1, %add.794.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.136 = f32[]{:T(128)} reduce(%square.181, %constant.1445), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.1445), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.136, %add.794.clone.1, %add.798.clone.1, %add.799.clone.1, %reduce.140.clone.1) +} + +%region_47.52 (reduce_sum.262: f32[], reduce_sum.266: f32[]) -> f32[] { + %reduce_sum.262 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.267 = f32[]{:T(128)} add(%reduce_sum.262, %reduce_sum.266), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=639}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.300 (param_0.1282: bf16[4,128,128256], param_1.1668: f32[4,128], param_2.1616: s32[4,128], param_3.1263: bf16[4,128]) -> f32[4,128] { + %param_2.1616 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.276 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1616), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.263 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.262 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.276, %eq.263), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %param_0.1282 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.962 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1282), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %param_3.1263 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.1263), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.962, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %param_1.1668 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1668), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=637} + %sub.60 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=637} + %constant.1465 = f32[]{:T(128)} constant(0) + %broadcast.511 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1465), dimensions={}, metadata={op_name="broadcast.83"} + %mul.1373 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.262, %sub.60, %broadcast.511), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=638} + ROOT %reduce.137 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1373, %constant.1465), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=639} +} + +%region_7.10 (reduce_sum.93: f32[], reduce_sum.94: f32[]) -> f32[] { + %reduce_sum.93 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.94 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.95 = f32[]{:T(128)} add(%reduce_sum.93, %reduce_sum.94), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=362}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.305 (param_0.1283: bf16[4,128,128256], param_1.1669: bf16[4,128]) -> f32[4,128] { + %param_0.1283 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.968 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1283), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %param_1.1669 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1669), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.968, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %exp.54 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=361} + %constant.1466 = f32[]{:T(128)} constant(0) + ROOT %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1466), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=362} +} + +%region_31.36 (reduce_sum.184: f32[], reduce_sum.185: f32[]) -> f32[] { + %reduce_sum.184 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.185 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.189 = f32[]{:T(128)} add(%reduce_sum.184, %reduce_sum.185), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_28.33 (reduce_sum.169: f32[], reduce_sum.170: f32[]) -> f32[] { + %reduce_sum.169 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.170 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.171 = f32[]{:T(128)} add(%reduce_sum.169, %reduce_sum.170), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.317 (param_0.1275: f32[4,4096,8,128], param_1.1662: f32[4,4096,8,128]) -> (f32[], f32[]) { + %param_0.1275 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.338 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1275), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.184 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.338, %bitcast.338), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1457 = f32[]{:T(128)} constant(0) + %reduce.141 = f32[]{:T(128)} reduce(%square.184, %constant.1457), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + %param_1.1662 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.342.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1662), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.187.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.342.clone.1, %bitcast.342.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %reduce.142.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1457), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.141, %reduce.142.clone.1) +} + +%fused_computation.320 (param_0.907: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.907 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.907), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.343 = bf16[4,4096,8,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%copy.238), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_58.63 (reduce_sum.324: f32[], reduce_sum.325: f32[]) -> f32[] { + %reduce_sum.324 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.325 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.322 = f32[]{:T(128)} add(%reduce_sum.324, %reduce_sum.325), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_44.49 (reduce_sum.247: f32[], reduce_sum.248: f32[]) -> f32[] { + %reduce_sum.247 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.248 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.247, %reduce_sum.248), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.321 (param_0.1261: f32[4096,4,8,128], param_1.1651: f32[], param_2.1603: f32[], param_3.1251: f32[], param_4.933: f32[4096,4,8,128], param_5.785: f32[], param_6.631: f32[4,4096,8,128], param_7.610: pred[], param_8.388: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1261 = f32[4096,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %param_3.1251 = f32[]{:T(128)S(6)} parameter(3) + %mul.1502.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1251), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.610 = pred[]{:T(512)S(6)} parameter(7) + %select_n.808.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.610), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.631 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.631), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.785 = f32[]{:T(128)} parameter(5) + %div.813.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.785), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.812.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.411.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.807.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.808.clone.1, %bitcast.411.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.188.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.807.clone.1, %select_n.807.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1443 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.632.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1443), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.806.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%ne.188.clone.1, %broadcast_in_dim.632.clone.1, %select_n.807.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1294.clone.1 = f32[]{:T(128)} constant(inf) + %eq.565.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1294.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.564.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.806.clone.1, %eq.565.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1293.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.631.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1293.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.805.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%eq.564.clone.1, %broadcast_in_dim.631.clone.1, %select_n.806.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1292.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.563.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1292.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.562.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.805.clone.1, %eq.563.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1291.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.630.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1291.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.804.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%eq.562.clone.1, %broadcast_in_dim.630.clone.1, %select_n.805.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1287.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.562.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1287.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1506.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.804.clone.1, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.388 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1295.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.561.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1295.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1505.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.388, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1506.clone.1, %mul.1505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1603 = f32[]{:T(128)S(6)} parameter(2) + %div.809.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1603), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.62.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.804.clone.1, %select_n.804.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1290.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.560.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1290.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1504.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.560.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.933 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1289.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.559.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1289.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1503.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.933, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1504.clone.1, %mul.1503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1651 = f32[]{:T(128)S(6)} parameter(1) + %div.808.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1651), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.807.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.786.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.60.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.807.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1288.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.557.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1288.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.785.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.258.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.809.clone.1, %add.785.clone.1), metadata={op_name="multiply.40"} + %div.806.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.787.clone.1, %multiply.258.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1501.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1261, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.784.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.806.clone.1, %mul.1501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1500.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1502.clone.1, %add.784.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.783.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)S(1)} add(%param_0.1261, %mul.1500.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.783.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.143 = f32[]{:T(128)} reduce(%square.188, %constant.1443), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.145.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.1443), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)S(1)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.143, %add.783.clone.1, %add.786.clone.1, %add.787.clone.1, %reduce.145.clone.1) +} + +%region_55.60 (reduce_sum.304: f32[], reduce_sum.308: f32[]) -> f32[] { + %reduce_sum.304 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.309 = f32[]{:T(128)} add(%reduce_sum.304, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_41.46 (reduce_sum.232: f32[], reduce_sum.233: f32[]) -> f32[] { + %reduce_sum.232 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.233 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.234 = f32[]{:T(128)} add(%reduce_sum.232, %reduce_sum.233), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.322 (param_0.1264: f32[4096,4,8,128], param_1.1654: f32[], param_2.1606: f32[], param_3.1254: f32[], param_4.936: f32[4096,4,8,128], param_5.788: f32[], param_6.634: f32[4,4096,8,128], param_7.613: pred[], param_8.391: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1264 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1254 = f32[]{:T(128)S(6)} parameter(3) + %mul.1529.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1254), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.613 = pred[]{:T(512)S(6)} parameter(7) + %select_n.838.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.613), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.634 = f32[4,4096,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.634), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.788 = f32[]{:T(128)} parameter(5) + %div.837.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.788), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.836.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.417.clone.1, %div.837.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.837.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.838.clone.1, %bitcast.417.clone.1, %div.836.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.194.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.837.clone.1, %select_n.837.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1446 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.650.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1446), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.836.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%ne.194.clone.1, %broadcast_in_dim.650.clone.1, %select_n.837.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1336.clone.1 = f32[]{:T(128)} constant(inf) + %eq.589.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1336.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.588.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.836.clone.1, %eq.589.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1335.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.649.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1335.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.835.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%eq.588.clone.1, %broadcast_in_dim.649.clone.1, %select_n.836.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1334.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.587.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1334.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.586.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.835.clone.1, %eq.587.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1333.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.648.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1333.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.834.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%eq.586.clone.1, %broadcast_in_dim.648.clone.1, %select_n.835.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1329.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.572.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1329.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.834.clone.1, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.391 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1337.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.571.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1337.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.391, %broadcast.571.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.804.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1533.clone.1, %mul.1532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1606 = f32[]{:T(128)S(6)} parameter(2) + %div.833.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1606), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.65.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.834.clone.1, %select_n.834.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1332.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.570.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1332.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.570.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.936 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1331.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.569.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1331.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.936, %broadcast.569.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.803.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1531.clone.1, %mul.1530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1654 = f32[]{:T(128)S(6)} parameter(1) + %div.832.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1654), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.831.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.803.clone.1, %div.832.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.63.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.831.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1330.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.567.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1330.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.802.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.567.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.833.clone.1, %add.802.clone.1), metadata={op_name="multiply.37"} + %div.830.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.804.clone.1, %multiply.261.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1264, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.801.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.830.clone.1, %mul.1528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1527.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1529.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.800.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1264, %mul.1527.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.800.clone.1, %add.800.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.144 = f32[]{:T(128)} reduce(%square.189, %constant.1446), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.146.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1446), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.144, %add.800.clone.1, %add.803.clone.1, %add.804.clone.1, %reduce.146.clone.1) +} + +%fused_computation.338 (param_0.982: bf16[4,128,4096], param_1.1097: f32[4,128], param_2.891: f32[4,128], param_3.630: bf16[4,128,4096], param_4.441: bf16[4096]) -> bf16[4,128,4096] { + %param_3.630 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.441 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %dot_general.375 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.441), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %dot_general.365 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.630, %dot_general.375), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %convert_element_type.985 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=334} + %param_2.891 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.1423 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.891), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %mul.1415 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.985, %mul.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %param_0.982 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.996 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.982), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %param_1.1097 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.1422 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1097), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=329} + %mul.1421 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.996, %mul.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=329} + %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1415, %mul.1421), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=329} + ROOT %convert_element_type.983 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} +} + +%region_5.8 (reduce_sum.87: f32[], reduce_sum.88: f32[]) -> f32[] { + %reduce_sum.87 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.88 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.92 = f32[]{:T(128)} add(%reduce_sum.87, %reduce_sum.88), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=330}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.339 (param_0.1284: bf16[4,128,4096]) -> f32[4,128] { + %param_0.1284 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.987 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1284), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.987, %convert_element_type.987), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=329} + %constant.1467 = f32[]{:T(128)} constant(0) + ROOT %reduce.147 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1467), dimensions={2}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=330} +} + +%region_10.13 (reduce_sum.102: f32[], reduce_sum.106: f32[]) -> f32[] { + %reduce_sum.102 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.106 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.107 = f32[]{:T(128)} add(%reduce_sum.102, %reduce_sum.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=333}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.341 (param_0.1279: bf16[4,128,4096], param_1.1665: bf16[4,128,4096], param_2.1614: bf16[4096]) -> f32[4,128] { + %param_0.1279 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.994 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1279), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %param_1.1665 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1614 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.374 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1614), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %dot_general.364 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1665, %dot_general.374), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %convert_element_type.993 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=334} + %mul.1419 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.994, %convert_element_type.993), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %constant.1461 = f32[]{:T(128)} constant(0) + ROOT %reduce.148 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1419, %constant.1461), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=333} +} + +%region_8.11 (dot_general.182: bf16[], dot_general.183: bf16[]) -> bf16[] { + %dot_general.182 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} + %dot_general.183 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} + ROOT %add.168 = bf16[]{:T(256)} add(%dot_general.182, %dot_general.183), metadata={op_name="add.54"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.250.clone.clone (param_0.1248: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1248 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1051 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1248), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1051), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} +} + +%fused_computation.301.clone.1.clone.clone (param_0.1249: bf16[4,128,128256], param_1.1641: s32[4,128], param_2.1582: f32[4,128], param_3.1242: f32[4,128], param_4.925: bf16[4,128], param_5.777: f32[4,128]) -> bf16[4,128,128256] { + %param_5.777 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.1603 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.777), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_3.1242 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1602 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.1242), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_0.1249 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1054 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1249), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %param_4.925 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.925), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1054, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=360} + %exp.60 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=361} + %mul.1601 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1602, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_2.1582 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.891 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1582), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %div.890 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1601, %div.891), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %param_1.1641 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.643 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1641), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.642 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %eq.641 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.643, %eq.642), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=365} + %convert_element_type.1053 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.641), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=365} + %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.890, %convert_element_type.1053), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=80} + %mul.1600 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1603, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + ROOT %convert_element_type.1052 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} +} + +%fused_computation.342 (param_0.1247: f32[4,128], param_1.1640: bf16[4,128,4096], param_2.1583: f32[4096,128256], param_3.1243: bf16[4,128,128256], param_4.926: s32[4,128], param_5.778: f32[4,128], param_6.626: f32[4,128], param_7.605: bf16[4,128], param_8.384: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { + %param_3.1243 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.926 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.778 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.626 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.605 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %param_8.384 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.1243, %param_4.926, %param_5.778, %param_6.626, %param_7.605, /*index=5*/%param_8.384), kind=kLoop, calls=%fused_computation.301.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=349} + %param_2.1583 = f32[4096,128256]{1,0:T(8,128)} parameter(2) + %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1583), kind=kLoop, calls=%fused_computation.250.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + %convolution.86.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.219.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=347} + %param_1.1640 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1006 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1640), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %param_0.1247 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1434 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1247), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %mul.1433 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1006, %mul.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %convert_element_type.1005 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=334} + %multiply.252 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convolution.86.clone.1, %convert_element_type.1005), metadata={op_name="multiply.206"} + %constant.1106 = bf16[]{:T(256)} constant(0) + %reduce.149 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%multiply.252, %constant.1106), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.149, %convolution.86.clone.1) +} + +%fused_computation.350 (param_0.1032: f32[64], param_1.1158: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1158 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.675 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1158), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=156} + %param_0.1032 = f32[64]{0:T(128)S(1)} parameter(0) + %div.673 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1032), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=156} + %div.672 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.675, %div.673), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=156} + %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.672), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=166} + %convert_element_type.1014 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=166} + %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.672), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=164} + %convert_element_type.1013.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=164} + ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1014, %convert_element_type.1013.clone.1) +} + +%fused_computation.357 (param_0.1029: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1029 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1098 = bf16[]{:T(256)} constant(-inf) + %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1029, %constant.1098), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=165} + %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1029, %constant.1098), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=165} + ROOT %maximum.34 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.38, %pad.37), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=165} +} + +%fused_computation.358 (param_0.1031: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1031 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1097 = bf16[]{:T(256)} constant(-inf) + %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1031, %constant.1097), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=167} + %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1031, %constant.1097), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=167} + ROOT %maximum.35 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.40, %pad.39), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=167} +} + +%region_27.32 (reduce_sum.163: f32[], reduce_sum.164: f32[]) -> f32[] { + %reduce_sum.163 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.164 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.163, %reduce_sum.164), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_26.31 (reduce_sum.157: f32[], reduce_sum.161: f32[]) -> f32[] { + %reduce_sum.157 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.161 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.162 = f32[]{:T(128)} add(%reduce_sum.157, %reduce_sum.161), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.362 (param_0.1276: f32[4,4096], param_1.1663: f32[4,4096]) -> (f32[], f32[]) { + %param_0.1276 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.195 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1458 = f32[]{:T(128)} constant(0) + %reduce.150 = f32[]{:T(128)} reduce(%square.195, %constant.1458), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + %param_1.1663 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) + %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1663), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.198.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %reduce.151.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1458), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.150, %reduce.151.clone.1) +} + +%region_54.59 (reduce_sum.301: f32[], reduce_sum.302: f32[]) -> f32[] { + %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.302 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.303 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.302), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_40.45 (reduce_sum.226: f32[], reduce_sum.227: f32[]) -> f32[] { + %reduce_sum.226 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.227 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.226, %reduce_sum.227), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.365 (param_0.1265: f32[4096,4], param_1.1655: f32[], param_2.1607: f32[], param_3.1255: f32[], param_4.937: f32[4096,4], param_5.789: f32[], param_6.635: f32[4,4096], param_7.614: pred[], param_8.392: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1265 = f32[4096,4]{0,1:T(4,128)} parameter(0) + %param_3.1255 = f32[]{:T(128)S(6)} parameter(3) + %mul.1536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.1255), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.614 = pred[]{:T(512)S(6)} parameter(7) + %select_n.848.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.614), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.635 = f32[4,4096]{1,0:T(4,128)} parameter(6) + %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.635), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.789 = f32[]{:T(128)} parameter(5) + %div.845.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.789), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.844.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.419.clone.1, %div.845.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.847.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.848.clone.1, %bitcast.419.clone.1, %div.844.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.196.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.847.clone.1, %select_n.847.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1447 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.656.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1447), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.846.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%ne.196.clone.1, %broadcast_in_dim.656.clone.1, %select_n.847.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1350.clone.1 = f32[]{:T(128)} constant(inf) + %eq.597.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1350.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.596.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.846.clone.1, %eq.597.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1349.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.655.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1349.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.845.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%eq.596.clone.1, %broadcast_in_dim.655.clone.1, %select_n.846.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1348.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.595.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1348.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.594.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.845.clone.1, %eq.595.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1347.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.654.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1347.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.844.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%eq.594.clone.1, %broadcast_in_dim.654.clone.1, %select_n.845.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1343.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.578.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1343.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1540.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.844.clone.1, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.392 = f32[4096,4]{0,1:T(4,128)} parameter(8) + %constant.1351.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.577.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1351.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1539.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.392, %broadcast.577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.809.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%mul.1540.clone.1, %mul.1539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1607 = f32[]{:T(128)S(6)} parameter(2) + %div.841.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1607), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.66.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.844.clone.1, %select_n.844.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1346.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.576.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1346.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1538.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.576.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.937 = f32[4096,4]{0,1:T(4,128)} parameter(4) + %constant.1345.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.575.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1345.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1537.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.937, %broadcast.575.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.808.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%mul.1538.clone.1, %mul.1537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1655 = f32[]{:T(128)S(6)} parameter(1) + %div.840.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1655), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.839.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.808.clone.1, %div.840.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.64.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.839.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1344.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.573.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1344.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.807.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.262.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.841.clone.1, %add.807.clone.1), metadata={op_name="multiply.36"} + %div.838.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.809.clone.1, %multiply.262.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1535.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1265, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.806.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.838.clone.1, %mul.1535.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1534.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1536.clone.1, %add.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.805.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%param_0.1265, %mul.1534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.805.clone.1, %add.805.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.152 = f32[]{:T(128)} reduce(%square.199, %constant.1447), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.154.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1447), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.152, %add.805.clone.1, %add.808.clone.1, %add.809.clone.1, %reduce.154.clone.1) +} + +%region_53.58 (reduce_sum.295: f32[], reduce_sum.296: f32[]) -> f32[] { + %reduce_sum.295 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.296 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.297 = f32[]{:T(128)} add(%reduce_sum.295, %reduce_sum.296), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_39.44 (reduce_sum.220: f32[], reduce_sum.224: f32[]) -> f32[] { + %reduce_sum.220 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.225 = f32[]{:T(128)} add(%reduce_sum.220, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.366 (param_0.1266: f32[4096,4], param_1.1656: f32[], param_2.1608: f32[], param_3.1256: f32[], param_4.938: f32[4096,4], param_5.790: f32[], param_6.636: f32[4,4096], param_7.615: pred[], param_8.393: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1266 = f32[4096,4]{0,1:T(4,128)} parameter(0) + %param_3.1256 = f32[]{:T(128)S(6)} parameter(3) + %mul.1543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.1256), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.615 = pred[]{:T(512)S(6)} parameter(7) + %select_n.858.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.615), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.636 = f32[4,4096]{1,0:T(4,128)} parameter(6) + %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.790 = f32[]{:T(128)} parameter(5) + %div.853.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.790), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.852.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.421.clone.1, %div.853.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.857.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.858.clone.1, %bitcast.421.clone.1, %div.852.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.198.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.857.clone.1, %select_n.857.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1448 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.662.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1448), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.856.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%ne.198.clone.1, %broadcast_in_dim.662.clone.1, %select_n.857.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1364.clone.1 = f32[]{:T(128)} constant(inf) + %eq.605.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1364.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.604.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.856.clone.1, %eq.605.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1363.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.661.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1363.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.855.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%eq.604.clone.1, %broadcast_in_dim.661.clone.1, %select_n.856.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1362.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.603.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1362.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.602.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} compare(%select_n.855.clone.1, %eq.603.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1361.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.660.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1361.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.854.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%eq.602.clone.1, %broadcast_in_dim.660.clone.1, %select_n.855.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1357.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.584.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1357.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1547.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.854.clone.1, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.393 = f32[4096,4]{0,1:T(4,128)} parameter(8) + %constant.1365.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.583.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1365.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1546.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.393, %broadcast.583.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.814.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%mul.1547.clone.1, %mul.1546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1608 = f32[]{:T(128)S(6)} parameter(2) + %div.849.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1608), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.67.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.854.clone.1, %select_n.854.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1360.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.582.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1360.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1545.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.582.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.938 = f32[4096,4]{0,1:T(4,128)} parameter(4) + %constant.1359.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.581.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1359.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1544.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.938, %broadcast.581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.813.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%mul.1545.clone.1, %mul.1544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1656 = f32[]{:T(128)S(6)} parameter(1) + %div.848.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1656), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.847.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.813.clone.1, %div.848.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.65.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.847.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1358.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.579.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.1358.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.812.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.263.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.849.clone.1, %add.812.clone.1), metadata={op_name="multiply.35"} + %div.846.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.814.clone.1, %multiply.263.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1542.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1266, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.811.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.846.clone.1, %mul.1542.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1541.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1543.clone.1, %add.811.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.810.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%param_0.1266, %mul.1541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.810.clone.1, %add.810.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.153 = f32[]{:T(128)} reduce(%square.200, %constant.1448), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.155.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1448), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.153, %add.810.clone.1, %add.813.clone.1, %add.814.clone.1, %reduce.155.clone.1) +} + +%region_9.12 (reduce_sum.99: f32[], reduce_sum.100: f32[]) -> f32[] { + %reduce_sum.100 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.99 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.101 = f32[]{:T(128)} add(%reduce_sum.99, %reduce_sum.100), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.380 (param_0.1280: bf16[4096]) -> f32[] { + %param_0.1280 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1018 = f32[4096]{0:T(1024)} convert(%param_0.1280), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=88} + %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1018, %convert_element_type.1018), metadata={op_name="jit(train_step)/square" stack_frame_id=375} + %constant.1462 = f32[]{:T(128)} constant(0) + ROOT %reduce.156 = f32[]{:T(128)} reduce(%square.203, %constant.1462), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=377} +} + +%region_49.54 (reduce_sum.274: f32[], reduce_sum.275: f32[]) -> f32[] { + %reduce_sum.274 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.275 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.276 = f32[]{:T(128)} add(%reduce_sum.274, %reduce_sum.275), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_35.40 (reduce_sum.199: f32[], reduce_sum.203: f32[]) -> f32[] { + %reduce_sum.199 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.204 = f32[]{:T(128)} add(%reduce_sum.199, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.381 (param_0.1270: f32[4096], param_1.1660: f32[], param_2.1612: f32[], param_3.1260: f32[], param_4.942: f32[4096], param_5.794: f32[], param_6.640: bf16[4096], param_7.619: pred[], param_8.397: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { + %param_0.1270 = f32[4096]{0:T(1024)S(1)} parameter(0) + %param_3.1260 = f32[]{:T(128)S(6)} parameter(3) + %mul.1574.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.1260), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.619 = pred[]{:T(512)S(6)} parameter(7) + %select_n.898.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} broadcast(%param_7.619), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %param_6.640 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1039.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.640), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=88} + %param_5.794 = f32[]{:T(128)} parameter(5) + %div.885.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_5.794), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %div.884.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1039.clone.1, %div.885.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=590} + %select_n.897.clone.1 = f32[4096]{0:T(1024)} select(%select_n.898.clone.1, %convert_element_type.1039.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=589} + %ne.206.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} compare(%select_n.897.clone.1, %select_n.897.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=592} + %constant.1452 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.686.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1452), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.896.clone.1 = f32[4096]{0:T(1024)} select(%ne.206.clone.1, %broadcast_in_dim.686.clone.1, %select_n.897.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1420.clone.1 = f32[]{:T(128)} constant(inf) + %eq.637.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1420.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.635.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} compare(%select_n.896.clone.1, %eq.637.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1419.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.685.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1419.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.895.clone.1 = f32[4096]{0:T(1024)} select(%eq.635.clone.1, %broadcast_in_dim.685.clone.1, %select_n.896.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1418.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.636.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1418.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %eq.634.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} compare(%select_n.895.clone.1, %eq.636.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=592} + %constant.1417.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.684.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1417.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=592} + %select_n.894.clone.1 = f32[4096]{0:T(1024)} select(%eq.634.clone.1, %broadcast_in_dim.684.clone.1, %select_n.895.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=592} + %constant.1413.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.600.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1413.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.1580.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.894.clone.1, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=595} + %param_8.397 = f32[4096]{0:T(1024)S(1)} parameter(8) + %constant.1421.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1581.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1421.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %mul.1579.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.397, %mul.1581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=596} + %add.836.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1580.clone.1, %mul.1579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=595} + %param_2.1612 = f32[]{:T(128)S(6)} parameter(2) + %div.881.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1612), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=606} + %integer_pow.71.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.894.clone.1, %select_n.894.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=610} + %constant.1416.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1578.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1416.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %mul.1576.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=613} + %param_4.942 = f32[4096]{0:T(1024)S(1)} parameter(4) + %constant.1415.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1577.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1415.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %mul.1575.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.942, %mul.1577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=614} + %add.835.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1576.clone.1, %mul.1575.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=613} + %param_1.1660 = f32[]{:T(128)S(6)} parameter(1) + %div.880.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1660), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %div.879.clone.1 = f32[4096]{0:T(1024)} divide(%add.835.clone.1, %div.880.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=615} + %sqrt.69.clone.1 = f32[4096]{0:T(1024)} sqrt(%div.879.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=617} + %constant.1414.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.834.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.1414.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %add.833.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.834.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=617} + %multiply.267.clone.1 = f32[4096]{0:T(1024)} multiply(%div.881.clone.1, %add.833.clone.1), metadata={op_name="multiply.31"} + %div.878.clone.1 = f32[4096]{0:T(1024)} divide(%add.836.clone.1, %multiply.267.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=618} + %mul.1573.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1270, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=620} + %add.832.clone.1 = f32[4096]{0:T(1024)} add(%div.878.clone.1, %mul.1573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=621} + %mul.1572.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1574.clone.1, %add.832.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.831.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1270, %mul.1572.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=625} + %square.204 = f32[4096]{0:T(1024)} multiply(%add.831.clone.1, %add.831.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=656} + %reduce.157 = f32[]{:T(128)} reduce(%square.204, %constant.1452), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=657} + %reduce.158.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1452), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=633} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.157, %add.831.clone.1, %add.835.clone.1, %add.836.clone.1, %reduce.158.clone.1) +} + +%fused_computation.387 (param_0.1117: s32[512]) -> s32[1024] { + %constant.929 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %broadcast.539 = s32[1024]{0:T(1024)} broadcast(%constant.929), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %param_0.1117 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.930 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.1117, %constant.930), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %constant.928 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %broadcast.538 = s32[1024]{0:T(1024)} broadcast(%constant.928), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.539, %pad.41, %broadcast.538), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} +} + +%fused_computation.388 (param_0.1116: s32[4,128]) -> s32[512] { + %param_0.1116 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.1120 = s32[]{:T(128)} constant(0) + %broadcast.546 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1120), dimensions={}, metadata={op_name="broadcast.81"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1116, %broadcast.546), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=94} + %constant.1107 = s32[]{:T(128)} constant(128256) + %add.760 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1107), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=94} + %add.748 = s32[4,128]{1,0:T(4,128)} add(%param_0.1116, %add.760), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=94} + %select_n.628 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.748, %param_0.1116), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=94} + ROOT %bitcast.376 = s32[512]{0:T(512)S(1)} bitcast(%select_n.628), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} +} + +%region_61.66 (reduce_sum.345: f32[], reduce_sum.346: f32[]) -> f32[] { + %reduce_sum.345 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.346 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.330 = f32[]{:T(128)} add(%reduce_sum.345, %reduce_sum.346), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=666}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.268: f32[], reduce_sum.269: f32[]) -> f32[] { + %reduce_sum.268 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.269 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.268, %reduce_sum.269), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=68}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.389 (param_0.1281: bf16[4,128], param_1.1667: f32[4,128], param_2.1615: f32[4,128], param_3.1262: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.1262 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.1427.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.601.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1427.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %ne.207.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.1262, %broadcast.601.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=64} + %param_1.1667 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1667), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=363} + %param_0.1281 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1281), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=359} + %add.762 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=363} + %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %add.762), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=641} + %constant.1464 = f32[]{:T(128)} constant(0) + %broadcast.543 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1464), dimensions={}, metadata={op_name="broadcast.32"} + %mul.1473 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=642} + %mul.1465 = f32[4,128]{1,0:T(4,128)} select(%ne.207.clone.1, %mul.1473, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=665} + %reduce.159 = f32[]{:T(128)} reduce(%mul.1465, %constant.1464), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=666} + %param_2.1615 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1615), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=640} + %add.749.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1473), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=643} + %mul.1466.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.207.clone.1, %add.749.clone.1, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=69} + %reduce.160.clone.1 = f32[]{:T(128)} reduce(%mul.1466.clone.1, %constant.1464), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=68} + %mul.1471.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %broadcast.543), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %constant.1123.clone.1 = f32[]{:T(128)} constant(1) + %add.757.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=80} + %add.750.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1471.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=80} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.159, %reduce.160.clone.1, %ne.207.clone.1, %add.750.clone.1) +} + +%fused_computation.392 (param_0.1140: f32[4,128], param_1.1333: f32[4,128]) -> f32[4,128] { + %param_0.1140 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1333 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1101 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.549 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1101), dimensions={}, metadata={op_name="broadcast.264"} + %div.728 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1333, %broadcast.549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=330} + %constant.1099 = f32[]{:T(128)} constant(1e-05) + %add.770 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1099), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=331} + %add.769 = f32[4,128]{1,0:T(4,128)} add(%div.728, %add.770), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=331} + %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=332} + %div.721 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=332} + %constant.1096 = f32[]{:T(128)} constant(-0.5) + %mul.1477 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1096), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=332} + %mul.1470 = f32[4,128]{1,0:T(4,128)} multiply(%div.721, %mul.1477), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=332} + %mul.1469 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1140, %mul.1470), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=332} + %constant.1095 = f32[]{:T(128)} constant(0.00048828125) + %mul.1476 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1095), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=329} + ROOT %mul.1468 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1469, %mul.1476), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=329} +} + +%region_0.1 (reduce_sum.67: s32[], reduce_sum.71: s32[]) -> s32[] { + %reduce_sum.67 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.71 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.72 = s32[]{:T(128)} add(%reduce_sum.67, %reduce_sum.71), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=65}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.396 (param_0.1157: pred[4,128]) -> s32[] { + %param_0.1157 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1031 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1157), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=65} + %constant.1121 = s32[]{:T(128)} constant(0) + ROOT %reduce.161 = s32[]{:T(128)} reduce(%convert_element_type.1031, %constant.1121), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=65} +} + +%fused_computation.397 (param_0.1142: f32[4,128]) -> f32[4,128] { + %param_0.1142 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1102 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.541 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1102), dimensions={}, metadata={op_name="broadcast.264"} + %div.726 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1142, %broadcast.541), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=330} + %constant.1100 = f32[]{:T(128)} constant(1e-05) + %add.759 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1100), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=331} + %add.756 = f32[4,128]{1,0:T(4,128)} add(%div.726, %add.759), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=331} + ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=332} +} + +%fused_computation.398 (param_0.1143: pred[4,128], param_1.1666: f32[]) -> f32[4,128] { + %param_0.1143 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1666 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.524 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1666), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=68} + %constant.1463 = f32[]{:T(128)} constant(0) + %broadcast.545 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1463), dimensions={}, metadata={op_name="broadcast.32"} + ROOT %mul.1478 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1143, %broadcast_in_dim.524, %broadcast.545), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=69} +} + +%fused_computation.400 () -> f32[64] { + %constant.1105 = f32[]{:T(128)} constant(500000) + %broadcast.552 = f32[64]{0:T(128)} broadcast(%constant.1105), dimensions={}, metadata={op_name="broadcast.255"} + %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=161} + %constant.1104 = s32[]{:T(128)} constant(2) + %broadcast.551 = s32[64]{0:T(128)} broadcast(%constant.1104), dimensions={}, metadata={op_name="broadcast.256"} + %mul.1479 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.551), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=162} + %convert_element_type.1032 = f32[64]{0:T(128)} convert(%mul.1479), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=162} + %constant.1103 = f32[]{:T(128)} constant(0.0078125) + %broadcast.550 = f32[64]{0:T(128)} broadcast(%constant.1103), dimensions={}, metadata={op_name="broadcast.257"} + %div.729 = f32[64]{0:T(128)} multiply(%convert_element_type.1032, %broadcast.550), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=162} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.552, %div.729), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=163} +} + +%fused_computation.401 (param_0.1155: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1155 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1033 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1155), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=156} + %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1033), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=385} + ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1033) +} + +%fused_computation.405 (param_0.1256: f32[4096,4]) -> bf16[4,4096] { + %param_0.1256 = f32[4096,4]{0,1:T(4,128)} parameter(0) + %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1256), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.106 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.451) +} + +%fused_computation.406 (param_0.1257: f32[4096,4]) -> bf16[4,4096] { + %param_0.1257 = f32[4096,4]{0,1:T(4,128)} parameter(0) + %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1257), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.108 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.452) +} + +%region_6.9 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { + %reduce_max.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + %reduce_max.8 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=359}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.252.clone.clone (param_0.1243: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1243 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1044 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1243), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} +} + +%fused_computation.344.clone.clone (param_0.1244: f32[4,128], param_1.1637: bf16[4,128,4096], param_2.1578: bf16[4096]) -> bf16[4,128,4096] { + %param_2.1578 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.383 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1578), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_1.1637 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1046 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1637), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=328} + %param_0.1244 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1595 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1244), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %mul.1594 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1046, %mul.1595), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=333} + %convert_element_type.1045 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1594), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=334} + ROOT %dot_general.382 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.383, %convert_element_type.1045), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} +} + +%fused_computation.407 (param_0.1258: f32[4096,128256], param_1.1648: f32[4,128], param_2.1600: bf16[4,128,4096], param_3.1248: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { + %param_1.1648 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1600 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.1248 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1648, %param_2.1600, %param_3.1248), kind=kLoop, calls=%fused_computation.344.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_0.1258 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1258), kind=kLoop, calls=%fused_computation.252.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=337} + %convolution.87.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convolution(%fusion.240.clone.1, %fusion.221.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=347} + %constant.1440 = bf16[]{:T(256)} constant(-inf) + %reduce.162 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.1440), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=359} + ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.162, %convolution.87.clone.1) +} + +%fused_computation.408 (param_0.1255: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.1255 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.110 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) +} + +%convert_element_type.525.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %lhs.1 = bf16[] parameter(0) + %rhs.1 = bf16[] parameter(1) + ROOT %add.624 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.121.clone.clone (param_0.1395: bf16[4,4096], param_1.1756: s32[]) -> bf16[4096] { + %param_0.1395 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1756 = s32[]{:T(128)S(6)} parameter(1) + %constant.1564 = s32[]{:T(128)} constant(0) + %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1395, %param_1.1756, %constant.1564), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[false,false]},"used_scoped_memory_configs":[]} + %constant.1565 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=391} + ROOT %reduce.174 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1565), dimensions={0}, to_apply=%convert_element_type.525.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=398} +} + +%region_12.14 (reduce_sum.108: f32[], reduce_sum.109: f32[]) -> f32[] { + %reduce_sum.108 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.109 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.113 = f32[]{:T(128)} add(%reduce_sum.108, %reduce_sum.109), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=407}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.58.clone.clone (param_0.1396: bf16[4,4,128,4096], param_1.1757: s32[]) -> f32[4,128] { + %param_0.1396 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1757 = s32[]{:T(128)S(6)} parameter(1) + %constant.1566 = s32[]{:T(128)} constant(0) + %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1396, %param_1.1757, %constant.1566, %constant.1566, %constant.1566), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.548 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.317), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=173} + %convert_element_type.1111 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=400} + %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1111, %convert_element_type.1111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=401} + %constant.1567 = f32[]{:T(128)} constant(0) + ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1567), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=402} +} + +%fused_computation.143.clone.1.clone (param_0.1397: f32[4,128]) -> f32[4,128] { + %param_0.1397 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1569 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1569), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %div.914 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1397, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=402} + %constant.1568 = f32[]{:T(128)} constant(1e-05) + %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1568), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %add.858 = f32[4,128]{1,0:T(4,128)} add(%div.914, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=408} + ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.858), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=409} +} + +%fused_computation.24.clone.1.clone.clone (param_0.1411: bf16[4,4096,32,128], param_1.1767: s32[]) -> bf16[4096,32,128,1] { + %param_0.1411 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1767 = s32[]{:T(128)S(6)} parameter(1) + %constant.1582 = s32[]{:T(128)} constant(0) + %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1411, %param_1.1767, %constant.1582, %constant.1582, %constant.1582), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.559 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.323), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=421} +} + +%fused_computation.91.clone.clone (param_0.1412: f32[4,128], param_1.1768: bf16[4,4,128,4096], param_2.1677: s32[], param_3.1307: bf16[4096]) -> bf16[4,128,4096,1] { + %param_3.1307 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %dot_general.428 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.1307), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_1.1768 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1677 = s32[]{:T(128)S(6)} parameter(2) + %constant.1583 = s32[]{:T(128)} constant(0) + %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1768, %param_2.1677, %constant.1583, %constant.1583, %constant.1583), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.561 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.324), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=173} + %convert_element_type.1119 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=400} + %param_0.1412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1709 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %mul.1708 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1119, %mul.1709), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %convert_element_type.1118 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1708), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=411} + %dot_general.427 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.428, %convert_element_type.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} +} + +%fused_computation.36.clone.clone (param_0.1413: bf16[4,4096,32,128], param_1.1769: s32[], param_2.1678: f32[4,128], param_3.1308: bf16[4,4,128,4096], param_4.971: bf16[4096]) -> bf16[4,128,32,128] { + %param_2.1678 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1308 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1769 = s32[]{:T(128)S(6)} parameter(1) + %param_4.971 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1678, %param_3.1308, %param_1.1769, %param_4.971), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_0.1413 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1413, %param_1.1769), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=421} + ROOT %convolution.113 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.343, %fusion.342), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=423} +} + +%fused_computation.70.clone.clone (param_0.1414: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { + %param_0.1414 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1414), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=433} + %neg.129 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=434} + %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1414), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=433} + ROOT %tuple.187 = (bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) +} + +%fused_computation.145.clone.clone () -> f32[64] { + %constant.1572 = f32[]{:T(128)} constant(500000) + %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1572), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=425} + %constant.1571 = s32[]{:T(128)} constant(2) + %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1571), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %mul.1699 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=426} + %convert_element_type.1112 = f32[64]{0:T(128)} convert(%mul.1699), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=426} + %constant.1570 = f32[]{:T(128)} constant(0.0078125) + %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1570), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %div.915 = f32[64]{0:T(128)} multiply(%convert_element_type.1112, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=426} + ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.84, %div.915), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=427} +} + +%fused_computation.117.clone.clone (param_0.1398: f32[64], param_1.1758: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1758 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.918 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1758), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=385} + %param_0.1398 = f32[64]{0:T(128)S(1)} parameter(0) + %div.917 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1398), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=385} + %div.916 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.918, %div.917), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=385} + %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=428} + %convert_element_type.1113 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=428} + %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=436} + %convert_element_type.829.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=436} + ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1113, %convert_element_type.829.clone.3) +} + +%fused_computation.120.clone.clone (param_0.1405: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1405 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1578 = bf16[]{:T(256)} constant(-inf) + %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1405, %constant.1578), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=437} + %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1405, %constant.1578), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=437} + %maximum.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.61, %pad.60), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=437} + ROOT %bitcast.554 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=432} +} + +%fused_computation.119.clone.clone (param_0.1399: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1399 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1573 = bf16[]{:T(256)} constant(-inf) + %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1399, %constant.1573), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=429} + %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1399, %constant.1573), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=429} + %maximum.44 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.59, %pad.58), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=429} + ROOT %bitcast.549 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=431} +} + +%fused_computation.73.clone.clone (param_0.1415: bf16[4,128,32,64], param_1.1770: bf16[4,128,32,64], param_2.1679: bf16[4,128,32,128], param_3.1309: bf16[4,128,128], param_4.972: bf16[4,128,128]) -> bf16[4,32,128,128] { + %param_2.1679 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.972 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.1713 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.972), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=431} + %mul.1711 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1679, %mul.1713), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=431} + %param_1.1770 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1584 = bf16[]{:T(256)} constant(-inf) + %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1770, %constant.1584), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=435} + %param_0.1415 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1415, %constant.1584), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=435} + %maximum.47 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.65, %pad.64), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=435} + %param_3.1309 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.1712 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.1309), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=432} + %mul.1710 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.1712), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=432} + %add.860 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1711, %mul.1710), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=438} + ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.860), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=446} +} + +%fused_computation.90.clone.clone (param_0.1407: f32[4,128], param_1.1764: bf16[4,4,128,4096], param_2.1674: s32[], param_3.1304: bf16[4096]) -> bf16[4,128,4096,1] { + %param_3.1304 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %dot_general.426 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.1304), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_1.1764 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1674 = s32[]{:T(128)S(6)} parameter(2) + %constant.1580 = s32[]{:T(128)} constant(0) + %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1764, %param_2.1674, %constant.1580, %constant.1580, %constant.1580), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.557 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=173} + %convert_element_type.1117 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=400} + %param_0.1407 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1703 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1407), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %mul.1702 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1117, %mul.1703), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %convert_element_type.1116 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1702), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=411} + %dot_general.425 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.426, %convert_element_type.1116), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} +} + +%fused_computation.64.clone.1.clone.clone (param_0.1406: bf16[4,4096,8,128], param_1.1763: s32[]) -> bf16[4096,8,128,1] { + %param_0.1406 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1763 = s32[]{:T(128)S(6)} parameter(1) + %constant.1579 = s32[]{:T(128)} constant(0) + %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1406, %param_1.1763, %constant.1579, %constant.1579, %constant.1579), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.555 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.321), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=460} +} + +%fused_computation.89.clone.clone (param_0.1408: bf16[4,4096,8,128], param_1.1765: s32[], param_2.1675: f32[4,128], param_3.1305: bf16[4,4,128,4096], param_4.969: bf16[4096]) -> bf16[4,128,8,128] { + %param_2.1675 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1305 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1765 = s32[]{:T(128)S(6)} parameter(1) + %param_4.969 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1675, %param_3.1305, %param_1.1765, %param_4.969), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_0.1408 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1408, %param_1.1765), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=460} + ROOT %convolution.112 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.340, %fusion.341), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=462} +} + +%fused_computation.106.clone.clone (param_0.1409: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1409 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1409), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=469} + %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=470} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1409), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=469} + ROOT %tuple.186 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.109.clone.clone (param_0.1410: bf16[4,128,8,64], param_1.1766: bf16[4,128,8,64], param_2.1676: bf16[4,128,8,128], param_3.1306: bf16[4,128,128], param_4.970: bf16[4,128,128]) -> bf16[4,8,128,128] { + %param_2.1676 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.970 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.1707 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.970), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=467} + %mul.1705 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1676, %mul.1707), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=467} + %param_1.1766 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1581 = bf16[]{:T(256)} constant(-inf) + %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1766, %constant.1581), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=471} + %param_0.1410 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1410, %constant.1581), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=471} + %maximum.46 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.63, %pad.62), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=471} + %param_3.1306 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.1706 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.1306), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=468} + %mul.1704 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.1706), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=468} + %add.859 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1705, %mul.1704), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=472} + ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.859), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=476} +} + +%fused_computation.135.clone.clone (param_0.1401: bf16[4,4096,8,128], param_1.1760: s32[]) -> bf16[1,4096,8,128] { + %param_0.1401 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1760 = s32[]{:T(128)S(6)} parameter(1) + %constant.1576 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1401, %param_1.1760, %constant.1576, %constant.1576, %constant.1576), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} +} + +%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1402: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { + %param_0.1402 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173} + ROOT %bitcast.550 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=484} +} + +%fused_computation.88.clone.clone.clone.clone (param_0.1403: f32[4,128], param_1.1761: bf16[4,4,128,4096], param_2.1672: s32[], param_3.1302: bf16[4096]) -> bf16[4,128,4096,1] { + %param_3.1302 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %dot_general.424 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.1302), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_1.1761 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1672 = s32[]{:T(128)S(6)} parameter(2) + %constant.1577 = s32[]{:T(128)} constant(0) + %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1761, %param_2.1672, %constant.1577, %constant.1577, %constant.1577), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.552 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.320), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=173} + %convert_element_type.1115 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=400} + %param_0.1403 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1701 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1403), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %mul.1700 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1115, %mul.1701), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=410} + %convert_element_type.1114 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1700), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=411} + %dot_general.423 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.424, %convert_element_type.1114), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} +} + +%fused_computation.114.clone.clone (param_0.1404: bf16[1,4096,8,128], param_1.1762: f32[4,128], param_2.1673: bf16[4,4,128,4096], param_3.1303: s32[], param_4.968: bf16[4096]) -> bf16[4,8,128,128] { + %param_1.1762 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1673 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.1303 = s32[]{:T(128)S(6)} parameter(3) + %param_4.968 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1762, %param_2.1673, %param_3.1303, %param_4.968), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=399} + %param_0.1404 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1404), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=484} + %convolution.111 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.339, %fusion.338), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=486} + ROOT %bitcast.553 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=490} +} + +%fused_computation.402.clone.clone (param_0.1439: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { + %param_0.1439 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1439), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=673} + %bitcast.262.clone.3 = f32[4,32,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=673} + ROOT %tuple.192 = (f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)}, f32[4,32,128]{2,1,0:T(8,128)S(1)}) tuple(%slice.11, %bitcast.262.clone.3) +} + +%region_13.16 (reduce_sum.120: f32[], reduce_sum.121: f32[]) -> f32[] { + %reduce_sum.120 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.121 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.122 = f32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=530}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1416: bf16[4,32,128,4096], param_1.1771: s32[]) -> bf16[32,128,4096,1] { + %param_0.1416 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1771 = s32[]{:T(128)S(6)} parameter(1) + %constant.1585 = s32[]{:T(128)} constant(0) + %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1416, %param_1.1771, %constant.1585, %constant.1585, %constant.1585), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.563 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=518} +} + +%fused_computation.80.clone.clone.clone.clone.clone.clone (param_0.1417: bf16[4,32,128,128]) -> bf16[4,128,32,128] { + %param_0.1417 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1417), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=511} +} + +%fused_computation.61.clone.clone (param_0.1418: bf16[4,32,128,4096], param_1.1772: s32[], param_2.1680: bf16[4,32,128,128], param_3.1310: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { + %param_3.1310 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1772 = s32[]{:T(128)S(6)} parameter(1) + %constant.453.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.1310, %param_1.1772, %constant.453.clone.1.clone.3, %constant.453.clone.1.clone.3, %constant.453.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.207.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.208.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=173} + %param_2.1680 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1680), kind=kLoop, calls=%fused_computation.80.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=511} + %param_0.1418 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1418, %param_1.1772), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=518} + %convolution.62.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.83.clone.3, %fusion.82.clone.3), window={size=1x32}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=520} + %bitcast.182.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=520} + %add.635.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=524} + %convert_element_type.1120 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.635.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=525} + %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1120, %convert_element_type.1120), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=526} + %constant.1586 = f32[]{:T(128)} constant(0) + %reduce.177 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1586), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=527} + ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %add.635.clone.3) +} + +%convert_element_type.523.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %lhs = bf16[] parameter(0) + %rhs = bf16[] parameter(1) + ROOT %add.623 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.122.clone.clone (param_0.1400: bf16[4,4096], param_1.1759: s32[]) -> bf16[4096] { + %param_0.1400 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1759 = s32[]{:T(128)S(6)} parameter(1) + %constant.1574 = s32[]{:T(128)} constant(0) + %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1400, %param_1.1759, %constant.1574), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[false,false]},"used_scoped_memory_configs":[]} + %constant.1575 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=391} + ROOT %reduce.176 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1575), dimensions={0}, to_apply=%convert_element_type.523.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=391} +} + +%fused_computation.12.clone.clone.clone (param_0.1419: bf16[4,14336,4096], param_1.1773: s32[]) -> bf16[14336,4096,1] { + %param_0.1419 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1773 = s32[]{:T(128)S(6)} parameter(1) + %constant.1587 = s32[]{:T(128)} constant(0) + %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1419, %param_1.1773, %constant.1587, %constant.1587), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.566 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.326), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=548} +} + +%bitcast_fusion.3.clone.clone (bitcast_input.12: bf16[4,128,4096]) -> bf16[4,128,4096] { + %bitcast_input.12 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.565 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.12) +} + +%fused_computation.13.clone.clone (param_0.1420: bf16[4,128,4096], param_1.1774: bf16[4,14336,4096], param_2.1681: s32[]) -> bf16[14336,4,128] { + %param_1.1774 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1681 = s32[]{:T(128)S(6)} parameter(2) + %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1774, %param_2.1681), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=548} + %param_0.1420 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1420), kind=kLoop, calls=%bitcast_fusion.3.clone.clone + ROOT %convolution.114 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.344, %fusion.345), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=553} +} + +%fused_computation.144.clone.1.clone (param_0.1421: f32[4,128]) -> f32[4,128] { + %param_0.1421 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1589 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1589), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %div.919 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1421, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=527} + %constant.1588 = f32[]{:T(128)} constant(1e-05) + %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1588), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=173} + %add.861 = f32[4,128]{1,0:T(4,128)} add(%div.919, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=531} + ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=532} +} + +%fused_computation.11.clone.1.clone.clone (param_0.1425: bf16[4,4096,14336], param_1.1778: s32[]) -> bf16[4096,14336,1] { + %param_0.1425 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1778 = s32[]{:T(128)S(6)} parameter(1) + %constant.1591 = s32[]{:T(128)} constant(0) + %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1425, %param_1.1778, %constant.1591, %constant.1591), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.568 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=542} +} + +%fused_computation.96.clone.2.clone.clone (param_0.1426: f32[4,128], param_1.1779: bf16[4,128,4096], param_2.1684: bf16[4096]) -> bf16[4,128,4096] { + %param_2.1684 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.432 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1684), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=392} + %param_1.1779 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1124 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1779), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=525} + %param_0.1426 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1717 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1426), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=533} + %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1124, %mul.1717), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=533} + %convert_element_type.1123 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1716), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=534} + ROOT %dot_general.431 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.432, %convert_element_type.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=392} +} + +%fused_computation.23.clone.clone (param_0.1427: bf16[4,4096,14336], param_1.1780: s32[], param_2.1685: f32[4,128], param_3.1312: bf16[4,128,4096], param_4.974: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1685 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1312 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.974 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1685, %param_3.1312, %param_4.974), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=392} + %param_0.1427 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1780 = s32[]{:T(128)S(6)} parameter(1) + %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1427, %param_1.1780), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=542} + ROOT %convolution.116 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.349, %fusion.348), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=544} +} + +%fused_computation.14.clone.1.clone.clone (param_0.1428: bf16[4,4096,14336], param_1.1781: s32[]) -> bf16[4096,14336,1] { + %param_0.1428 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1781 = s32[]{:T(128)S(6)} parameter(1) + %constant.1592 = s32[]{:T(128)} constant(0) + %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1428, %param_1.1781, %constant.1592, %constant.1592), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=173}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.569 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.329), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=542} +} + +%fused_computation.39.clone.1.clone.clone (param_0.1429: bf16[14336,4,128], param_1.1782: bf16[4,128,14336]) -> bf16[4,128,14336] { + %param_1.1782 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1593 = bf16[]{:T(256)} constant(1) + %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1593), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=545} + %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1782), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=545} + %exp.69 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=545} + %add.862 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=545} + %div.920 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.862), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=545} + %mul.1719 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1782, %div.920), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=545} + %param_0.1429 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) + %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=553} + ROOT %mul.1718 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1719, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=557} +} + +%fused_computation.21.clone.clone (param_0.1430: bf16[4,4096,14336], param_1.1783: s32[], param_2.1686: bf16[14336,4,128], param_3.1313: bf16[4,128,14336]) -> bf16[4,128,4096] { + %param_2.1686 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %param_3.1313 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1686, %param_3.1313), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=557} + %param_0.1430 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1783 = s32[]{:T(128)S(6)} parameter(1) + %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1430, %param_1.1783), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=542} + ROOT %convolution.117 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%bitcast_multiply_fusion.15, %fusion.350), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=544} diff --git a/tests/utils/reference_hlo_qwen3_1.7b.txt b/tests/utils/reference_hlo_qwen3_1.7b.txt new file mode 100644 index 0000000000..a3c75315d3 --- /dev/null +++ b/tests/utils/reference_hlo_qwen3_1.7b.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias), {39}: (39, {}, may-alias), {40}: (40, {}, may-alias), {41}: (41, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, /*index=5*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, /*index=10*/f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, /*index=15*/f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, /*index=20*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, /*index=25*/f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, /*index=30*/f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=40*/f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, /*index=45*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, /*index=5*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, /*index=10*/f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, /*index=15*/f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, /*index=20*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, /*index=25*/f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, /*index=30*/f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=40*/f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=45*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=50*/f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%fused_computation (param_0.2: bf16[151936,2048], param_1.7: s32[1024]) -> bf16[512,2048] { + %param_0.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %reshape.444 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %transpose.461 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.444), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.461), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %transpose.460 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + ROOT %reshape.443 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.460), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} +} + +%region_42.47.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { + %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.584 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=609}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.1 (param_0.3: bf16[151936,2048], param_1.5: s32[512], param_2.4: bf16[512,2048]) -> bf16[151936,2048] { + %param_0.3 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.451 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %transpose.466 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.451), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} + %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} parameter(2) + %reshape.452 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=106} + %transpose.467 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.452), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=106} + ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.466, %transpose.467), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=94} +} + +%region_71.76 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { + %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_56.61 (reduce_sum.386: f32[], reduce_sum.387: f32[]) -> f32[] { + %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.386 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.386, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.280 (param_0.1545: f32[151936,2048], param_1.1982: f32[], param_2.1885: f32[], param_3.1395: f32[], param_4.982: f32[151936,2048], param_5.792: f32[], param_6.626: bf16[151936,2048], param_7.565: bf16[151936,2048,1], param_8.350: pred[], param_9.236: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { + %param_0.1545 = f32[151936,2048]{1,0:T(8,128)} parameter(0) + %param_3.1395 = f32[]{:T(128)S(6)} parameter(3) + %mul.1926.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.1395), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_8.350 = pred[]{:T(512)S(6)} parameter(8) + %select_n.856.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} broadcast(%param_8.350), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_7.565 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) + %bitcast.464.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.565), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=357} + %convert_element_type.1415.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.464.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=601} + %param_6.626 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1414.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.626), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %add_any.197.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1415.clone.1, %convert_element_type.1414.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=93} + %param_5.792 = f32[]{:T(128)} parameter(5) + %div.938.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.792), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.937.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.197.clone.1, %div.938.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.856.clone.1, %add_any.197.clone.1, %div.937.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.198.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} compare(%select_n.855.clone.1, %select_n.855.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1723 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.684.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1723), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%ne.198.clone.1, %broadcast_in_dim.684.clone.1, %select_n.855.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1518.clone.1 = f32[]{:T(128)} constant(inf) + %eq.603.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1518.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.602.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} compare(%select_n.854.clone.1, %eq.603.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1517.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.683.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1517.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%eq.602.clone.1, %broadcast_in_dim.683.clone.1, %select_n.854.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1516.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.601.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1516.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.600.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} compare(%select_n.853.clone.1, %eq.601.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1515.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.682.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1515.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.852.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%eq.600.clone.1, %broadcast_in_dim.682.clone.1, %select_n.853.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1511.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.844.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1511.clone.1), dimensions={}, metadata={op_name="broadcast.74"} + %mul.1932.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.852.clone.1, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_9.236 = f32[151936,2048]{1,0:T(8,128)} parameter(9) + %constant.1519.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1933.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1519.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %mul.1931.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.236, %mul.1933.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.941.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1932.clone.1, %mul.1931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1885 = f32[]{:T(128)S(6)} parameter(2) + %div.934.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1885), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.65.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.852.clone.1, %select_n.852.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1514.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1930.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1514.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %mul.1928.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.1930.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.982 = f32[151936,2048]{1,0:T(8,128)} parameter(4) + %constant.1513.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1929.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1513.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %mul.1927.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.982, %mul.1929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.940.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1928.clone.1, %mul.1927.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1982 = f32[]{:T(128)S(6)} parameter(1) + %div.933.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1982), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.932.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.940.clone.1, %div.933.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.62.clone.1 = f32[151936,2048]{1,0:T(8,128)} sqrt(%div.932.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1512.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.939.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1512.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %add.938.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.426.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.934.clone.1, %add.938.clone.1), metadata={op_name="multiply.61"} + %div.931.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.941.clone.1, %multiply.426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1925.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1545, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.937.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.931.clone.1, %mul.1925.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1924.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.1926.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.936.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1545, %mul.1924.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.936.clone.1, %add.936.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.176 = f32[]{:T(128)} reduce(%square.214, %constant.1723), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.178.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1723), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.144 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.176, %add.936.clone.1, %add.940.clone.1, %add.941.clone.1, %reduce.178.clone.1) +} + +%region_43.48 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { + %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.319 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.391.clone.clone (param_0.1532: f32[4,128], param_1.1975: bf16[4,128,2048], param_2.1861: bf16[2048]) -> bf16[4,128,2048] { + %param_2.1861 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.480 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1861), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_1.1975 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1457 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1975), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %param_0.1532 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2083 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1532), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %mul.2082 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1457, %mul.2083), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %convert_element_type.1456 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2082), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=350} + ROOT %dot_general.479 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.480, %convert_element_type.1456), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} +} + +%fused_computation.301.clone.clone.clone (param_0.1533: bf16[4,128,151936], param_1.1976: s32[4,128], param_2.1862: f32[4,128], param_3.1388: f32[4,128], param_4.972: bf16[4,128], param_5.770: f32[4,128]) -> bf16[4,128,151936] { + %param_5.770 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2087 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.770), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_3.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2086 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_0.1533 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1460 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1533), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_4.972 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.94 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.972), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %sub.93 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1460, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=371} + %mul.2085 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2086, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_2.1862 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.1044 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1862), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %div.1043 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2085, %div.1044), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %param_1.1976 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.711 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1976), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.710 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.709 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.711, %eq.710), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %convert_element_type.1459 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.709), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=375} + %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.1043, %convert_element_type.1459), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=80} + %mul.2084 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2087, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + ROOT %convert_element_type.1458 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} +} + +%fused_computation.284 (param_0.1558: bf16[151936,2048], param_1.1995: f32[4,128], param_2.1898: bf16[4,128,2048], param_3.1408: bf16[2048], param_4.995: bf16[4,128,151936], param_5.805: s32[4,128], param_6.639: f32[4,128], param_7.578: f32[4,128], param_8.363: bf16[4,128], param_9.237: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { + %param_4.995 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) + %param_5.805 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.639 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.578 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) + %param_8.363 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(8) + %param_9.237 = f32[4,128]{1,0:T(4,128)S(1)} parameter(9) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.995, %param_5.805, %param_6.639, %param_7.578, %param_8.363, /*index=5*/%param_9.237), kind=kLoop, calls=%fused_computation.301.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_1.1995 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1898 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.1408 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.269.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1995, %param_2.1898, %param_3.1408), kind=kLoop, calls=%fused_computation.391.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.1.clone.1, %fusion.269.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=357} + %bitcast.333 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=357} + %convert_element_type.1323 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.333), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=601} + %param_0.1558 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1322 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1558), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %add_any.184 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1323, %convert_element_type.1322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=93} + %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.184, %add_any.184), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1736 = f32[]{:T(128)} constant(0) + %reduce.177 = f32[]{:T(128)} reduce(%square.215, %constant.1736), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + ROOT %tuple.166 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.177, %convolution.86.clone.1) +} + +%region_57.62 (reduce_sum.389: f32[], reduce_sum.393: f32[]) -> f32[] { + %reduce_sum.393 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.393), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=668}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.300 (param_0.1569: bf16[4,128,151936], param_1.2003: f32[4,128], param_2.1901: s32[4,128], param_3.1410: bf16[4,128]) -> f32[4,128] { + %param_2.1901 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.228 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1901), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.213 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.212 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.228, %eq.213), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %param_0.1569 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1328 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1569), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_3.1410 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.1410), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1328, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %param_1.2003 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.2003), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=666} + %sub.60 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=666} + %constant.1748 = f32[]{:T(128)} constant(0) + %broadcast.769 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1748), dimensions={}, metadata={op_name="broadcast.109"} + %mul.1765 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.212, %sub.60, %broadcast.769), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=667} + ROOT %reduce.179 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1765, %constant.1748), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=668} +} + +%region_9.12 (reduce_sum.186: f32[], reduce_sum.190: f32[]) -> f32[] { + %reduce_sum.190 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.186 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.191 = f32[]{:T(128)} add(%reduce_sum.186, %reduce_sum.190), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=372}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.305 (param_0.1570: bf16[4,128,151936], param_1.2004: bf16[4,128]) -> f32[4,128] { + %param_0.1570 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1334 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1570), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_1.2004 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.2004), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1334, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %exp.54 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=371} + %constant.1749 = f32[]{:T(128)} constant(0) + ROOT %reduce.180 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1749), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=372} +} + +%region_33.38 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { + %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.274 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.310 (param_0.1564: f32[4,6144,2048]) -> f32[] { + %param_0.1564 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) + %bitcast.341 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.341, %bitcast.341), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1742 = f32[]{:T(128)} constant(0) + ROOT %reduce.181 = f32[]{:T(128)} reduce(%square.218, %constant.1742), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} +} + +%region_32.37 (reduce_sum.263: f32[], reduce_sum.267: f32[]) -> f32[] { + %reduce_sum.267 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.263 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.268 = f32[]{:T(128)} add(%reduce_sum.263, %reduce_sum.267), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_31.36 (reduce_sum.260: f32[], reduce_sum.261: f32[]) -> f32[] { + %reduce_sum.261 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.260 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.260, %reduce_sum.261), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.312 (param_0.1565: f32[4,2048,6144], param_1.1999: f32[4,2048,6144]) -> (f32[], f32[]) { + %param_0.1565 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) + %bitcast.345 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1565), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.345, %bitcast.345), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1743 = f32[]{:T(128)} constant(0) + %reduce.182 = f32[]{:T(128)} reduce(%square.221, %constant.1743), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + %param_1.1999 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) + %bitcast.349.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1999), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.349.clone.1, %bitcast.349.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %reduce.183.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1743), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + ROOT %tuple.167 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.182, %reduce.183.clone.1) +} + +%fused_computation.315 (param_0.939: f32[6144,4,2048]) -> bf16[4,6144,2048] { + %param_0.939 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %copy.190 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.939), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.350 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%fused_computation.316 (param_0.941: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.941 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.191 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.941), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.351 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%fused_computation.317 (param_0.943: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.943 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.192 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.943), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.352 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.192), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_62.67 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { + %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.421 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_47.52 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { + %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.340 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.318 (param_0.1554: f32[6144,4,2048], param_1.1991: f32[], param_2.1894: f32[], param_3.1404: f32[], param_4.991: f32[6144,4,2048], param_5.801: f32[], param_6.635: f32[4,6144,2048], param_7.574: pred[], param_8.359: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { + %param_0.1554 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %param_3.1404 = f32[]{:T(128)S(6)} parameter(3) + %mul.1998.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.1404), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.574 = pred[]{:T(512)S(6)} parameter(7) + %select_n.946.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.574), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.635 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) + %bitcast.482.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.635), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.801 = f32[]{:T(128)} parameter(5) + %div.1010.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.801), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.1009.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.482.clone.1, %div.1010.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.945.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.946.clone.1, %bitcast.482.clone.1, %div.1009.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.216.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} compare(%select_n.945.clone.1, %select_n.945.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1732 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.738.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1732), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.944.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%ne.216.clone.1, %broadcast_in_dim.738.clone.1, %select_n.945.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1644.clone.1 = f32[]{:T(128)} constant(inf) + %eq.675.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1644.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.674.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} compare(%select_n.944.clone.1, %eq.675.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1643.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.737.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1643.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.943.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%eq.674.clone.1, %broadcast_in_dim.737.clone.1, %select_n.944.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1642.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.673.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1642.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.672.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} compare(%select_n.943.clone.1, %eq.673.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1641.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.736.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1641.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.942.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%eq.672.clone.1, %broadcast_in_dim.736.clone.1, %select_n.943.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1637.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.886.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1637.clone.1), dimensions={}, metadata={op_name="broadcast.83"} + %mul.2004.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.942.clone.1, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.359 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(8) + %constant.1645.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2005.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1645.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %mul.2003.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.359, %mul.2005.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.989.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2004.clone.1, %mul.2003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1894 = f32[]{:T(128)S(6)} parameter(2) + %div.1006.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1894), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.74.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.942.clone.1, %select_n.942.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1640.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2002.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1640.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %mul.2000.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2002.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.991 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) + %constant.1639.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2001.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1639.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %mul.1999.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.991, %mul.2001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.988.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2000.clone.1, %mul.1999.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1991 = f32[]{:T(128)S(6)} parameter(1) + %div.1005.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1991), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.1004.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.988.clone.1, %div.1005.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.71.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} sqrt(%div.1004.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1638.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.987.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1638.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %add.986.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.435.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.1006.clone.1, %add.986.clone.1), metadata={op_name="multiply.52"} + %div.1003.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.989.clone.1, %multiply.435.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1997.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1554, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.985.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.1003.clone.1, %mul.1997.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1996.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.1998.clone.1, %add.985.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.984.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1554, %mul.1996.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.984.clone.1, %add.984.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.184 = f32[]{:T(128)} reduce(%square.225, %constant.1732), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.187.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1732), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.184, %add.984.clone.1, %add.988.clone.1, %add.989.clone.1, %reduce.187.clone.1) +} + +%region_61.66 (reduce_sum.410: f32[], reduce_sum.414: f32[]) -> f32[] { + %reduce_sum.414 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.414), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.332: f32[], reduce_sum.333: f32[]) -> f32[] { + %reduce_sum.333 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.332 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.332, %reduce_sum.333), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.319 (param_0.1555: f32[2048,4,6144], param_1.1992: f32[], param_2.1895: f32[], param_3.1405: f32[], param_4.992: f32[2048,4,6144], param_5.802: f32[], param_6.636: f32[4,2048,6144], param_7.575: pred[], param_8.360: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1555 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.1405 = f32[]{:T(128)S(6)} parameter(3) + %mul.2008.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.1405), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.575 = pred[]{:T(512)S(6)} parameter(7) + %select_n.956.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.575), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.636 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.484.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.802 = f32[]{:T(128)} parameter(5) + %div.1018.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.802), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.1017.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.484.clone.1, %div.1018.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.955.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.956.clone.1, %bitcast.484.clone.1, %div.1017.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.218.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.955.clone.1, %select_n.955.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1733 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.744.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1733), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.954.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%ne.218.clone.1, %broadcast_in_dim.744.clone.1, %select_n.955.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1658.clone.1 = f32[]{:T(128)} constant(inf) + %eq.683.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1658.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.682.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.954.clone.1, %eq.683.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1657.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.743.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1657.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.953.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%eq.682.clone.1, %broadcast_in_dim.743.clone.1, %select_n.954.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1656.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.681.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1656.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.680.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.953.clone.1, %eq.681.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1655.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.742.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1655.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.952.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%eq.680.clone.1, %broadcast_in_dim.742.clone.1, %select_n.953.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1651.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.892.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1651.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2012.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.952.clone.1, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.360 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) + %constant.1659.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.891.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1659.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2011.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.360, %broadcast.891.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.994.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2012.clone.1, %mul.2011.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1895 = f32[]{:T(128)S(6)} parameter(2) + %div.1014.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1895), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.75.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.952.clone.1, %select_n.952.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1654.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.890.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1654.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2010.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.890.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.992 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1653.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.889.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1653.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2009.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.992, %broadcast.889.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.993.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2010.clone.1, %mul.2009.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1992 = f32[]{:T(128)S(6)} parameter(1) + %div.1013.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1992), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.1012.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.993.clone.1, %div.1013.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.72.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.1012.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1652.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.887.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1652.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.992.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.887.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.436.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.1014.clone.1, %add.992.clone.1), metadata={op_name="multiply.51"} + %div.1011.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.994.clone.1, %multiply.436.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.2007.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1555, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.991.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.1011.clone.1, %mul.2007.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.2006.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2008.clone.1, %add.991.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.990.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1555, %mul.2006.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.990.clone.1, %add.990.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.185 = f32[]{:T(128)} reduce(%square.226, %constant.1733), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.188.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1733), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.146 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.185, %add.990.clone.1, %add.993.clone.1, %add.994.clone.1, %reduce.188.clone.1) +} + +%region_60.65 (reduce_sum.407: f32[], reduce_sum.408: f32[]) -> f32[] { + %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.407 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.407, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.326: f32[], reduce_sum.330: f32[]) -> f32[] { + %reduce_sum.330 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.326 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.331 = f32[]{:T(128)} add(%reduce_sum.326, %reduce_sum.330), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.320 (param_0.1556: f32[2048,4,6144], param_1.1993: f32[], param_2.1896: f32[], param_3.1406: f32[], param_4.993: f32[2048,4,6144], param_5.803: f32[], param_6.637: f32[4,2048,6144], param_7.576: pred[], param_8.361: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1556 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.1406 = f32[]{:T(128)S(6)} parameter(3) + %mul.2015.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.1406), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.576 = pred[]{:T(512)S(6)} parameter(7) + %select_n.966.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.576), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.637 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.486.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.637), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.803 = f32[]{:T(128)} parameter(5) + %div.1026.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.803), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.1025.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.486.clone.1, %div.1026.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.965.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.966.clone.1, %bitcast.486.clone.1, %div.1025.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.220.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.965.clone.1, %select_n.965.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1734 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.750.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1734), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.964.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%ne.220.clone.1, %broadcast_in_dim.750.clone.1, %select_n.965.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1672.clone.1 = f32[]{:T(128)} constant(inf) + %eq.691.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1672.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.690.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.964.clone.1, %eq.691.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1671.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.749.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1671.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.963.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%eq.690.clone.1, %broadcast_in_dim.749.clone.1, %select_n.964.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1670.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.689.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1670.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.688.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} compare(%select_n.963.clone.1, %eq.689.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1669.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.748.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1669.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.962.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%eq.688.clone.1, %broadcast_in_dim.748.clone.1, %select_n.963.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1665.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.898.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1665.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2019.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.962.clone.1, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.361 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) + %constant.1673.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.897.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1673.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2018.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.361, %broadcast.897.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.999.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2019.clone.1, %mul.2018.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1896 = f32[]{:T(128)S(6)} parameter(2) + %div.1022.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1896), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.76.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.962.clone.1, %select_n.962.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1668.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.896.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1668.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2017.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.993 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1667.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.895.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1667.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2016.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.993, %broadcast.895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.998.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2017.clone.1, %mul.2016.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1993 = f32[]{:T(128)S(6)} parameter(1) + %div.1021.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1993), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.1020.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.998.clone.1, %div.1021.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.73.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.1020.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1666.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.893.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1666.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.997.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.437.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.1022.clone.1, %add.997.clone.1), metadata={op_name="multiply.50"} + %div.1019.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.999.clone.1, %multiply.437.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.2014.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1556, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.996.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.1019.clone.1, %mul.2014.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.2013.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2015.clone.1, %add.996.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.995.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1556, %mul.2013.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.995.clone.1, %add.995.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.186 = f32[]{:T(128)} reduce(%square.227, %constant.1734), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.189.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1734), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.186, %add.995.clone.1, %add.998.clone.1, %add.999.clone.1, %reduce.189.clone.1) +} + +%region_39.44 (reduce_sum.302: f32[], reduce_sum.303: f32[]) -> f32[] { + %reduce_sum.303 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.302 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.304 = f32[]{:T(128)} add(%reduce_sum.302, %reduce_sum.303), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.348 (param_0.1559: f32[4,2048,16,128]) -> f32[] { + %param_0.1559 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.356 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1559), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.356, %bitcast.356), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1737 = f32[]{:T(128)} constant(0) + ROOT %reduce.190 = f32[]{:T(128)} reduce(%square.230, %constant.1737), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} +} + +%region_38.43 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { + %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.350 (param_0.1560: f32[4,16,128,2048]) -> f32[] { + %param_0.1560 = f32[4,16,128,2048]{3,2,0,1:T(8,128)S(1)} parameter(0) + %bitcast.360 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1560), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.360, %bitcast.360), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1738 = f32[]{:T(128)} constant(0) + ROOT %reduce.191 = f32[]{:T(128)} reduce(%square.233, %constant.1738), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} +} + +%fused_computation.351 (param_0.1033: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { + %param_0.1033 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %copy.193 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1033), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.361 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.193), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_68.73 (reduce_sum.449: f32[], reduce_sum.450: f32[]) -> f32[] { + %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.449 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.449, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_53.58 (reduce_sum.368: f32[], reduce_sum.372: f32[]) -> f32[] { + %reduce_sum.372 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.372), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.352 (param_0.1548: f32[2048,4,16,128], param_1.1985: f32[], param_2.1888: f32[], param_3.1398: f32[], param_4.985: f32[2048,4,16,128], param_5.795: f32[], param_6.629: f32[4,2048,16,128], param_7.568: pred[], param_8.353: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { + %param_0.1548 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1398 = f32[]{:T(128)S(6)} parameter(3) + %mul.1950.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1398), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.568 = pred[]{:T(512)S(6)} parameter(7) + %select_n.886.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.568), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.629 = f32[4,2048,16,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.470.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.629), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.795 = f32[]{:T(128)} parameter(5) + %div.962.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.795), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.961.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.470.clone.1, %div.962.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.885.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.886.clone.1, %bitcast.470.clone.1, %div.961.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.204.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.885.clone.1, %select_n.885.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1726 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.702.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1726), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%ne.204.clone.1, %broadcast_in_dim.702.clone.1, %select_n.885.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1560.clone.1 = f32[]{:T(128)} constant(inf) + %eq.627.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1560.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.626.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.884.clone.1, %eq.627.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1559.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.701.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1559.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%eq.626.clone.1, %broadcast_in_dim.701.clone.1, %select_n.884.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1558.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.625.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1558.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.624.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.883.clone.1, %eq.625.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1557.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.700.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1557.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.882.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%eq.624.clone.1, %broadcast_in_dim.700.clone.1, %select_n.883.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1553.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.858.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1553.clone.1), dimensions={}, metadata={op_name="broadcast.75"} + %mul.1956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.882.clone.1, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.353 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1561.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1561.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %mul.1955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.353, %mul.1957.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1956.clone.1, %mul.1955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1888 = f32[]{:T(128)S(6)} parameter(2) + %div.958.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1888), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.68.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.882.clone.1, %select_n.882.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1556.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1556.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %mul.1952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.1954.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.985 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1555.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1555.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %mul.1951.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.985, %mul.1953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1952.clone.1, %mul.1951.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1985 = f32[]{:T(128)S(6)} parameter(1) + %div.957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1985), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.956.clone.1, %div.957.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.65.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} sqrt(%div.956.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1554.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1554.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %add.954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.429.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.958.clone.1, %add.954.clone.1), metadata={op_name="multiply.58"} + %div.955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.957.clone.1, %multiply.429.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1949.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1548, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.955.clone.1, %mul.1949.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1948.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.1950.clone.1, %add.953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1548, %mul.1948.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.952.clone.1, %add.952.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.192 = f32[]{:T(128)} reduce(%square.234, %constant.1726), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.194.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1726), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.192, %add.952.clone.1, %add.956.clone.1, %add.957.clone.1, %reduce.194.clone.1) +} + +%region_67.72 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { + %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_52.57 (reduce_sum.365: f32[], reduce_sum.366: f32[]) -> f32[] { + %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.365 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.365, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.353 (param_0.1549: f32[16,4,128,2048], param_1.1986: f32[], param_2.1889: f32[], param_3.1399: f32[], param_4.986: f32[16,4,128,2048], param_5.796: f32[], param_6.630: f32[4,16,128,2048], param_7.569: pred[], param_8.354: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { + %param_0.1549 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1399 = f32[]{:T(128)S(6)} parameter(3) + %mul.1960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.1399), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.569 = pred[]{:T(512)S(6)} parameter(7) + %select_n.896.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.569), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.630 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.472.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.630), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.796 = f32[]{:T(128)} parameter(5) + %div.970.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.796), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.969.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.472.clone.1, %div.970.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.895.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.896.clone.1, %bitcast.472.clone.1, %div.969.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.206.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.895.clone.1, %select_n.895.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1727 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.708.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1727), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.894.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%ne.206.clone.1, %broadcast_in_dim.708.clone.1, %select_n.895.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1574.clone.1 = f32[]{:T(128)} constant(inf) + %eq.635.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1574.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.634.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.894.clone.1, %eq.635.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1573.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.707.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1573.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.893.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%eq.634.clone.1, %broadcast_in_dim.707.clone.1, %select_n.894.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1572.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.633.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1572.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.632.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.893.clone.1, %eq.633.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1571.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.706.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1571.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%eq.632.clone.1, %broadcast_in_dim.706.clone.1, %select_n.893.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1567.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.860.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1567.clone.1), dimensions={}, metadata={op_name="broadcast.76"} + %mul.1966.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.892.clone.1, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.354 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %constant.1575.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1967.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1575.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %mul.1965.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.354, %mul.1967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1966.clone.1, %mul.1965.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1889 = f32[]{:T(128)S(6)} parameter(2) + %div.966.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1889), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.892.clone.1, %select_n.892.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1570.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1964.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1570.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %mul.1962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.1964.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.986 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %constant.1569.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1569.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %mul.1961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.986, %mul.1963.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1962.clone.1, %mul.1961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1986 = f32[]{:T(128)S(6)} parameter(1) + %div.965.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1986), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.964.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.962.clone.1, %div.965.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.964.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1568.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1568.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %add.960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.430.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.966.clone.1, %add.960.clone.1), metadata={op_name="multiply.57"} + %div.963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.963.clone.1, %multiply.430.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1549, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.963.clone.1, %mul.1959.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.1960.clone.1, %add.959.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1549, %mul.1958.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.958.clone.1, %add.958.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.193 = f32[]{:T(128)} reduce(%square.235, %constant.1727), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.195.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1727), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.193, %add.958.clone.1, %add.962.clone.1, %add.963.clone.1, %reduce.195.clone.1) +} + +%region_41.46 (reduce_sum.311: f32[], reduce_sum.312: f32[]) -> f32[] { + %reduce_sum.312 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.311 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.311, %reduce_sum.312), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_36.41 (reduce_sum.284: f32[], reduce_sum.288: f32[]) -> f32[] { + %reduce_sum.288 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.288), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.365 (param_0.1562: f32[4,2048,8,128], param_1.1997: f32[4,2048,8,128]) -> (f32[], f32[]) { + %param_0.1562 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) + %bitcast.365 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1562), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.365, %bitcast.365), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1740 = f32[]{:T(128)} constant(0) + %reduce.196 = f32[]{:T(128)} reduce(%square.238, %constant.1740), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + %param_1.1997 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(1) + %bitcast.369.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.369.clone.1, %bitcast.369.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %reduce.197.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1740), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.196, %reduce.197.clone.1) +} + +%fused_computation.368 (param_0.1071: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1071 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %copy.194 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1071), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.370 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.194), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} +} + +%region_70.75 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { + %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.463 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_55.60 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { + %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.369 (param_0.1546: f32[2048,4,8,128], param_1.1983: f32[], param_2.1886: f32[], param_3.1396: f32[], param_4.983: f32[2048,4,8,128], param_5.793: f32[], param_6.627: f32[4,2048,8,128], param_7.566: pred[], param_8.351: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1546 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.1396 = f32[]{:T(128)S(6)} parameter(3) + %mul.1936.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1396), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.566 = pred[]{:T(512)S(6)} parameter(7) + %select_n.866.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.566), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.627 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.466.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.627), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.793 = f32[]{:T(128)} parameter(5) + %div.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.793), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.466.clone.1, %div.946.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.865.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.866.clone.1, %bitcast.466.clone.1, %div.945.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.200.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.865.clone.1, %select_n.865.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1724 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.690.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1724), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%ne.200.clone.1, %broadcast_in_dim.690.clone.1, %select_n.865.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1532.clone.1 = f32[]{:T(128)} constant(inf) + %eq.611.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1532.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.610.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.864.clone.1, %eq.611.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1531.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.689.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1531.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%eq.610.clone.1, %broadcast_in_dim.689.clone.1, %select_n.864.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1530.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.609.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1530.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.608.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.863.clone.1, %eq.609.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1529.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.688.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1529.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%eq.608.clone.1, %broadcast_in_dim.688.clone.1, %select_n.863.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1525.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.850.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1525.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.1940.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.862.clone.1, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.351 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1533.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.849.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1533.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.1939.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.351, %broadcast.849.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1940.clone.1, %mul.1939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1886 = f32[]{:T(128)S(6)} parameter(2) + %div.942.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1886), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.862.clone.1, %select_n.862.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1528.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.848.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1528.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1938.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.848.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.983 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1527.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.847.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1527.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1937.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.983, %broadcast.847.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1938.clone.1, %mul.1937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1983 = f32[]{:T(128)S(6)} parameter(1) + %div.941.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1983), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.940.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.945.clone.1, %div.941.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.940.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1526.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.845.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1526.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.944.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.427.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.942.clone.1, %add.944.clone.1), metadata={op_name="multiply.60"} + %div.939.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.946.clone.1, %multiply.427.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1935.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1546, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.943.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.939.clone.1, %mul.1935.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1934.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1936.clone.1, %add.943.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.942.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1546, %mul.1934.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.942.clone.1, %add.942.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.198 = f32[]{:T(128)} reduce(%square.242, %constant.1724), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.200.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1724), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.150 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.198, %add.942.clone.1, %add.945.clone.1, %add.946.clone.1, %reduce.200.clone.1) +} + +%region_65.70 (reduce_sum.431: f32[], reduce_sum.435: f32[]) -> f32[] { + %reduce_sum.435 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.435), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_50.55 (reduce_sum.353: f32[], reduce_sum.354: f32[]) -> f32[] { + %reduce_sum.354 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.353 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.358 = f32[]{:T(128)} add(%reduce_sum.353, %reduce_sum.354), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.370 (param_0.1551: f32[2048,4,8,128], param_1.1988: f32[], param_2.1891: f32[], param_3.1401: f32[], param_4.988: f32[2048,4,8,128], param_5.798: f32[], param_6.632: f32[4,2048,8,128], param_7.571: pred[], param_8.356: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1551 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %param_3.1401 = f32[]{:T(128)S(6)} parameter(3) + %mul.1977.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.1401), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.571 = pred[]{:T(512)S(6)} parameter(7) + %select_n.916.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.571), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.632 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.476.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.632), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.798 = f32[]{:T(128)} parameter(5) + %div.986.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.798), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.985.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.476.clone.1, %div.986.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.915.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.916.clone.1, %bitcast.476.clone.1, %div.985.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.210.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.915.clone.1, %select_n.915.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1729 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.720.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1729), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.914.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%ne.210.clone.1, %broadcast_in_dim.720.clone.1, %select_n.915.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1602.clone.1 = f32[]{:T(128)} constant(inf) + %eq.651.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1602.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.650.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.914.clone.1, %eq.651.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1601.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.719.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1601.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.913.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%eq.650.clone.1, %broadcast_in_dim.719.clone.1, %select_n.914.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1600.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.649.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1600.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.648.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} compare(%select_n.913.clone.1, %eq.649.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1599.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.718.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1599.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.912.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%eq.648.clone.1, %broadcast_in_dim.718.clone.1, %select_n.913.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1595.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.872.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1595.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.1981.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.912.clone.1, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.356 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1603.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.871.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1603.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.1980.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.356, %broadcast.871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.973.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1981.clone.1, %mul.1980.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1891 = f32[]{:T(128)S(6)} parameter(2) + %div.982.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1891), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.912.clone.1, %select_n.912.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1598.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.870.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1598.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1979.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.870.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.988 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1597.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.869.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1597.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1978.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.988, %broadcast.869.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.972.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1979.clone.1, %mul.1978.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1988 = f32[]{:T(128)S(6)} parameter(1) + %div.981.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1988), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.980.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.972.clone.1, %div.981.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.980.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1596.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1596.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.971.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.432.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.982.clone.1, %add.971.clone.1), metadata={op_name="multiply.55"} + %div.979.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.973.clone.1, %multiply.432.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1976.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1551, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.970.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.979.clone.1, %mul.1976.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1975.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1977.clone.1, %add.970.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.969.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} add(%param_0.1551, %mul.1975.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.969.clone.1, %add.969.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.199 = f32[]{:T(128)} reduce(%square.243, %constant.1729), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.201.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1729), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.199, %add.969.clone.1, %add.972.clone.1, %add.973.clone.1, %reduce.201.clone.1) +} + +%fused_computation.386 (param_0.1153: bf16[4,128,2048], param_1.1210: f32[4,128], param_2.925: f32[4,128], param_3.576: bf16[4,128,2048], param_4.403: bf16[2048]) -> bf16[4,128,2048] { + %param_3.576 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.403 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %dot_general.448 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.403), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %dot_general.438 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.576, %dot_general.448), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %convert_element_type.1351 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.438), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=350} + %param_2.925 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.1851 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.925), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %mul.1843 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1351, %mul.1851), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %param_0.1153 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1362 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1153), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %param_1.1210 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.1850 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1210), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=345} + %mul.1849 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1362, %mul.1850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=345} + %add_any.187 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.1843, %mul.1849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=345} + ROOT %convert_element_type.1349 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.187), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} +} + +%region_7.10 (reduce_sum.171: f32[], reduce_sum.184: f32[]) -> f32[] { + %reduce_sum.184 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.171 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.185 = f32[]{:T(128)} add(%reduce_sum.171, %reduce_sum.184), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=346}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.387 (param_0.1571: bf16[4,128,2048]) -> f32[4,128] { + %param_0.1571 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1353 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1571), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1353, %convert_element_type.1353), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=345} + %constant.1750 = f32[]{:T(128)} constant(0) + ROOT %reduce.202 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1750), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=346} +} + +%region_12.15 (reduce_sum.198: f32[], reduce_sum.199: f32[]) -> f32[] { + %reduce_sum.199 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.198 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.200 = f32[]{:T(128)} add(%reduce_sum.198, %reduce_sum.199), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=349}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.389 (param_0.1566: bf16[4,128,2048], param_1.2000: bf16[4,128,2048], param_2.1899: bf16[2048]) -> f32[4,128] { + %param_0.1566 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1360 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1566), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %param_1.2000 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1899 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.447 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1899), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %dot_general.437 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.2000, %dot_general.447), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %convert_element_type.1359 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.437), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=350} + %mul.1847 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1360, %convert_element_type.1359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %constant.1744 = f32[]{:T(128)} constant(0) + ROOT %reduce.203 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1847, %constant.1744), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=349} +} + +%region_10.13 (dot_general.190: bf16[], dot_general.191: bf16[]) -> bf16[] { + %dot_general.191 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} + %dot_general.190 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} + ROOT %add.419 = bf16[]{:T(256)} add(%dot_general.190, %dot_general.191), metadata={op_name="add.82"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.288.clone.clone (param_0.1528: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1528 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.528 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1528), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} +} + +%fused_computation.301.clone.1.clone.clone (param_0.1529: bf16[4,128,151936], param_1.1972: s32[4,128], param_2.1856: f32[4,128], param_3.1383: f32[4,128], param_4.968: bf16[4,128], param_5.766: f32[4,128]) -> bf16[4,128,151936] { + %param_5.766 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2075 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.766), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_3.1383 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2074 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.1383), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_0.1529 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1450 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1529), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_4.968 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.88 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.968), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %sub.87 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1450, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=370} + %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=371} + %mul.2073 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2074, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %param_2.1856 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.1040 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1856), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %div.1039 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2073, %div.1040), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=80} + %param_1.1972 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.705 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1972), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.704 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %eq.703 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.705, %eq.704), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=375} + %convert_element_type.1449 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.703), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=375} + %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.1039, %convert_element_type.1449), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=80} + %mul.2072 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2075, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + ROOT %convert_element_type.1448 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} +} + +%fused_computation.390 (param_0.1527: f32[4,128], param_1.1971: bf16[4,128,2048], param_2.1857: bf16[151936,2048], param_3.1384: bf16[4,128,151936], param_4.969: s32[4,128], param_5.767: f32[4,128], param_6.608: f32[4,128], param_7.563: bf16[4,128], param_8.348: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { + %param_3.1384 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.969 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.767 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.608 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.563 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %param_8.348 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.1384, %param_4.969, %param_5.767, %param_6.608, %param_7.563, /*index=5*/%param_8.348), kind=kLoop, calls=%fused_computation.301.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=359} + %param_2.1857 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) + %fusion.251.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1857), kind=kLoop, calls=%fused_computation.288.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.251.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=357} + %param_1.1971 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1372 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1971), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %param_0.1527 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1862 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1527), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %mul.1861 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1372, %mul.1862), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %convert_element_type.1371 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.1861), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=350} + %multiply.420 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convolution.84.clone.1, %convert_element_type.1371), metadata={op_name="multiply.362"} + %constant.1274 = bf16[]{:T(256)} constant(0) + %reduce.204 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%multiply.420, %constant.1274), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + ROOT %tuple.165 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.204, %convolution.84.clone.1) +} + +%fused_computation.398 (param_0.1197: f32[64], param_1.1275: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1275 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.756 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1275), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=160} + %param_0.1197 = f32[64]{0:T(128)S(1)} parameter(0) + %div.754 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1197), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=160} + %div.753 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.756, %div.754), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=160} + %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.753), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=170} + %convert_element_type.1380 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=170} + %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.753), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=168} + %convert_element_type.1379.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=168} + ROOT %tuple.158 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1380, %convert_element_type.1379.clone.1) +} + +%fused_computation.399 (param_0.1194: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1194 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1266 = bf16[]{:T(256)} constant(-inf) + %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1194, %constant.1266), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=169} + %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1194, %constant.1266), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=169} + ROOT %maximum.42 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.46, %pad.45), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=169} +} + +%fused_computation.400 (param_0.1196: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1196 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1265 = bf16[]{:T(256)} constant(-inf) + %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1196, %constant.1265), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=171} + %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1196, %constant.1265), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=171} + ROOT %maximum.43 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.48, %pad.47), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=171} +} + +%region_35.40 (reduce_sum.281: f32[], reduce_sum.282: f32[]) -> f32[] { + %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.281 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.281, %reduce_sum.282), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_34.39 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { + %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.410 (param_0.1563: f32[4,2048], param_1.1998: f32[4,2048]) -> (f32[], f32[]) { + %param_0.1563 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.398 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1563), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.398, %bitcast.398), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1741 = f32[]{:T(128)} constant(0) + %reduce.205 = f32[]{:T(128)} reduce(%square.249, %constant.1741), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + %param_1.1998 = f32[4,2048]{1,0:T(4,128)} parameter(1) + %bitcast.402.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1998), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.402.clone.1, %bitcast.402.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %reduce.206.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1741), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.205, %reduce.206.clone.1) +} + +%region_64.69 (reduce_sum.428: f32[], reduce_sum.429: f32[]) -> f32[] { + %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.428 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.428, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_49.54 (reduce_sum.347: f32[], reduce_sum.351: f32[]) -> f32[] { + %reduce_sum.351 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.347 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.352 = f32[]{:T(128)} add(%reduce_sum.347, %reduce_sum.351), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.413 (param_0.1552: f32[2048,4], param_1.1989: f32[], param_2.1892: f32[], param_3.1402: f32[], param_4.989: f32[2048,4], param_5.799: f32[], param_6.633: f32[4,2048], param_7.572: pred[], param_8.357: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1552 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.1402 = f32[]{:T(128)S(6)} parameter(3) + %mul.1984.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.1402), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.572 = pred[]{:T(512)S(6)} parameter(7) + %select_n.926.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.572), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.633 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.478.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.799 = f32[]{:T(128)} parameter(5) + %div.994.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.799), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.993.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.478.clone.1, %div.994.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.925.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.926.clone.1, %bitcast.478.clone.1, %div.993.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.212.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.925.clone.1, %select_n.925.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1730 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.726.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1730), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.924.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%ne.212.clone.1, %broadcast_in_dim.726.clone.1, %select_n.925.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1616.clone.1 = f32[]{:T(128)} constant(inf) + %eq.659.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1616.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.658.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.924.clone.1, %eq.659.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1615.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.725.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1615.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.923.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%eq.658.clone.1, %broadcast_in_dim.725.clone.1, %select_n.924.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1614.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.657.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1614.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.656.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.923.clone.1, %eq.657.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1613.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.724.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1613.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.922.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%eq.656.clone.1, %broadcast_in_dim.724.clone.1, %select_n.923.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1609.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.878.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1609.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.1988.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.922.clone.1, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.357 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1617.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.877.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1617.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.1987.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.357, %broadcast.877.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.978.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1988.clone.1, %mul.1987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1892 = f32[]{:T(128)S(6)} parameter(2) + %div.990.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1892), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.72.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.922.clone.1, %select_n.922.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1612.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.876.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1612.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1986.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.876.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.989 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1611.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.875.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1611.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1985.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.989, %broadcast.875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.977.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1986.clone.1, %mul.1985.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1989 = f32[]{:T(128)S(6)} parameter(1) + %div.989.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1989), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.988.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.977.clone.1, %div.989.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.69.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.988.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1610.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.873.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1610.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.976.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.433.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.990.clone.1, %add.976.clone.1), metadata={op_name="multiply.54"} + %div.987.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.978.clone.1, %multiply.433.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1983.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1552, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.975.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.987.clone.1, %mul.1983.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1982.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1984.clone.1, %add.975.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.974.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1552, %mul.1982.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.974.clone.1, %add.974.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.207 = f32[]{:T(128)} reduce(%square.253, %constant.1730), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.209.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1730), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.207, %add.974.clone.1, %add.977.clone.1, %add.978.clone.1, %reduce.209.clone.1) +} + +%region_63.68 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { + %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.344: f32[], reduce_sum.345: f32[]) -> f32[] { + %reduce_sum.345 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.344 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.346 = f32[]{:T(128)} add(%reduce_sum.344, %reduce_sum.345), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.414 (param_0.1553: f32[2048,4], param_1.1990: f32[], param_2.1893: f32[], param_3.1403: f32[], param_4.990: f32[2048,4], param_5.800: f32[], param_6.634: f32[4,2048], param_7.573: pred[], param_8.358: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1553 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.1403 = f32[]{:T(128)S(6)} parameter(3) + %mul.1991.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.1403), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.573 = pred[]{:T(512)S(6)} parameter(7) + %select_n.936.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.573), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.634 = f32[4,2048]{1,0:T(4,128)} parameter(6) + %bitcast.480.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.634), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.800 = f32[]{:T(128)} parameter(5) + %div.1002.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.800), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.1001.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.480.clone.1, %div.1002.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.935.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.936.clone.1, %bitcast.480.clone.1, %div.1001.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.214.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.935.clone.1, %select_n.935.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1731 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.732.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1731), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.934.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%ne.214.clone.1, %broadcast_in_dim.732.clone.1, %select_n.935.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1630.clone.1 = f32[]{:T(128)} constant(inf) + %eq.667.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1630.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.666.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.934.clone.1, %eq.667.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1629.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.731.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1629.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.933.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%eq.666.clone.1, %broadcast_in_dim.731.clone.1, %select_n.934.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1628.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.665.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1628.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.664.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} compare(%select_n.933.clone.1, %eq.665.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1627.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.730.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1627.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.932.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%eq.664.clone.1, %broadcast_in_dim.730.clone.1, %select_n.933.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1623.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.884.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1623.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.1995.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.932.clone.1, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.358 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1631.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.883.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1631.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.1994.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.358, %broadcast.883.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.983.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1995.clone.1, %mul.1994.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1893 = f32[]{:T(128)S(6)} parameter(2) + %div.998.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1893), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.73.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.932.clone.1, %select_n.932.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1626.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.882.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1626.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1993.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.882.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.990 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1625.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.881.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1625.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1992.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.990, %broadcast.881.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.982.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1993.clone.1, %mul.1992.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1990 = f32[]{:T(128)S(6)} parameter(1) + %div.997.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1990), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.996.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.982.clone.1, %div.997.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.70.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.996.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1624.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.879.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1624.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.981.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.434.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.998.clone.1, %add.981.clone.1), metadata={op_name="multiply.53"} + %div.995.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.983.clone.1, %multiply.434.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1990.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1553, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.980.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.995.clone.1, %mul.1990.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1989.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1991.clone.1, %add.980.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.979.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1553, %mul.1989.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.979.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.208 = f32[]{:T(128)} reduce(%square.254, %constant.1731), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.210.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1731), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.208, %add.979.clone.1, %add.982.clone.1, %add.983.clone.1, %reduce.210.clone.1) +} + +%region_11.14 (reduce_sum.192: f32[], reduce_sum.193: f32[]) -> f32[] { + %reduce_sum.193 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.192 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.197 = f32[]{:T(128)} add(%reduce_sum.192, %reduce_sum.193), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.428 (param_0.1567: bf16[2048]) -> f32[] { + %param_0.1567 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1384 = f32[2048]{0:T(1024)} convert(%param_0.1567), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=88} + %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1384, %convert_element_type.1384), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1745 = f32[]{:T(128)} constant(0) + ROOT %reduce.211 = f32[]{:T(128)} reduce(%square.257, %constant.1745), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} +} + +%region_59.64 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { + %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_44.49 (reduce_sum.323: f32[], reduce_sum.324: f32[]) -> f32[] { + %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.323 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.325 = f32[]{:T(128)} add(%reduce_sum.323, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.429 (param_0.1557: f32[2048], param_1.1994: f32[], param_2.1897: f32[], param_3.1407: f32[], param_4.994: f32[2048], param_5.804: f32[], param_6.638: bf16[2048], param_7.577: pred[], param_8.362: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { + %param_0.1557 = f32[2048]{0:T(1024)S(1)} parameter(0) + %param_3.1407 = f32[]{:T(128)S(6)} parameter(3) + %mul.2022.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.1407), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.577 = pred[]{:T(512)S(6)} parameter(7) + %select_n.976.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.577), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.638 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1417.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.638), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=88} + %param_5.804 = f32[]{:T(128)} parameter(5) + %div.1034.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.804), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.1033.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1417.clone.1, %div.1034.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.975.clone.1 = f32[2048]{0:T(1024)} select(%select_n.976.clone.1, %convert_element_type.1417.clone.1, %div.1033.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.222.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} compare(%select_n.975.clone.1, %select_n.975.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1735 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.756.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1735), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.974.clone.1 = f32[2048]{0:T(1024)} select(%ne.222.clone.1, %broadcast_in_dim.756.clone.1, %select_n.975.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1686.clone.1 = f32[]{:T(128)} constant(inf) + %eq.699.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1686.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.697.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} compare(%select_n.974.clone.1, %eq.699.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1685.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.755.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1685.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.973.clone.1 = f32[2048]{0:T(1024)} select(%eq.697.clone.1, %broadcast_in_dim.755.clone.1, %select_n.974.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1684.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.698.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1684.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.696.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} compare(%select_n.973.clone.1, %eq.698.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1683.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.754.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1683.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.972.clone.1 = f32[2048]{0:T(1024)} select(%eq.696.clone.1, %broadcast_in_dim.754.clone.1, %select_n.973.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1679.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.900.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1679.clone.1), dimensions={}, metadata={op_name="broadcast.86"} + %mul.2028.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.972.clone.1, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.362 = f32[2048]{0:T(1024)S(1)} parameter(8) + %constant.1687.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2029.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1687.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %mul.2027.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.362, %mul.2029.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.1005.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2028.clone.1, %mul.2027.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1897 = f32[]{:T(128)S(6)} parameter(2) + %div.1030.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1897), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.77.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.972.clone.1, %select_n.972.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1682.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2026.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1682.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %mul.2024.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2026.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.994 = f32[2048]{0:T(1024)S(1)} parameter(4) + %constant.1681.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2025.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1681.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %mul.2023.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.994, %mul.2025.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.1004.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2024.clone.1, %mul.2023.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1994 = f32[]{:T(128)S(6)} parameter(1) + %div.1029.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1994), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.1028.clone.1 = f32[2048]{0:T(1024)} divide(%add.1004.clone.1, %div.1029.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.74.clone.1 = f32[2048]{0:T(1024)} sqrt(%div.1028.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1680.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.1003.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1680.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %add.1002.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.1003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.438.clone.1 = f32[2048]{0:T(1024)} multiply(%div.1030.clone.1, %add.1002.clone.1), metadata={op_name="multiply.49"} + %div.1027.clone.1 = f32[2048]{0:T(1024)} divide(%add.1005.clone.1, %multiply.438.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.2021.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1557, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.1001.clone.1 = f32[2048]{0:T(1024)} add(%div.1027.clone.1, %mul.2021.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.2020.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2022.clone.1, %add.1001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.1000.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1557, %mul.2020.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.258 = f32[2048]{0:T(1024)} multiply(%add.1000.clone.1, %add.1000.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.212 = f32[]{:T(128)} reduce(%square.258, %constant.1735), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.213.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1735), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.156 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.212, %add.1000.clone.1, %add.1004.clone.1, %add.1005.clone.1, %reduce.213.clone.1) +} + +%fused_computation.441 (param_0.1314: s32[512]) -> s32[1024] { + %constant.1104 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %broadcast.815 = s32[1024]{0:T(1024)} broadcast(%constant.1104), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %param_0.1314 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.1105 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1314, %constant.1105), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %constant.1103 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + %broadcast.814 = s32[1024]{0:T(1024)} broadcast(%constant.1103), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.815, %pad.49, %broadcast.814), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=94} +} + +%fused_computation.444 (param_0.1309: s32[4,128]) -> s32[512] { + %param_0.1309 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.1289 = s32[]{:T(128)} constant(0) + %broadcast.834 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1289), dimensions={}, metadata={op_name="broadcast.95"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1309, %broadcast.834), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=94} + %constant.1275 = s32[]{:T(128)} constant(151936) + %add.925 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1275), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=94} + %add.903 = s32[4,128]{1,0:T(4,128)} add(%param_0.1309, %add.925), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=94} + %select_n.584 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.903, %param_0.1309), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=94} + ROOT %bitcast.403 = s32[512]{0:T(512)S(1)} bitcast(%select_n.584), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=94} +} + +%region_40.45 (reduce_sum.305: f32[], reduce_sum.309: f32[]) -> f32[] { + %reduce_sum.309 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.305 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.310 = f32[]{:T(128)} add(%reduce_sum.305, %reduce_sum.309), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_37.42 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { + %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.295 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.446 (param_0.1561: f32[4,128], param_1.1996: f32[4,128]) -> (f32[], f32[]) { + %param_0.1561 = f32[4,128]{1,0:T(4,128)} parameter(0) + %bitcast.407 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.407, %bitcast.407), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %constant.1739 = f32[]{:T(128)} constant(0) + %reduce.214 = f32[]{:T(128)} reduce(%square.261, %constant.1739), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + %param_1.1996 = f32[4,128]{1,0:T(4,128)} parameter(1) + %bitcast.411.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.411.clone.1, %bitcast.411.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=385} + %reduce.215.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1739), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=387} + ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.214, %reduce.215.clone.1) +} + +%region_72.77 (reduce_sum.470: f32[], reduce_sum.471: f32[]) -> f32[] { + %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.470 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.470, %reduce_sum.471), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=695}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_58.63 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { + %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.400 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=68}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.449 (param_0.1568: bf16[4,128], param_1.2002: f32[4,128], param_2.1900: f32[4,128], param_3.1409: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.1409 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.1693.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.901.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1693.clone.1), dimensions={}, metadata={op_name="broadcast.95"} + %ne.223.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.1409, %broadcast.901.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=64} + %param_1.2002 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.2002), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=373} + %param_0.1568 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1568), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=369} + %add.927 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=373} + %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %add.927), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=670} + %constant.1747 = f32[]{:T(128)} constant(0) + %broadcast.831 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1747), dimensions={}, metadata={op_name="broadcast.99"} + %mul.1913 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=671} + %mul.1893 = f32[4,128]{1,0:T(4,128)} select(%ne.223.clone.1, %mul.1913, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=694} + %reduce.216 = f32[]{:T(128)} reduce(%mul.1893, %constant.1747), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=695} + %param_2.1900 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1900), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=669} + %add.904.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1913), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=672} + %mul.1894.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.223.clone.1, %add.904.clone.1, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=69} + %reduce.219.clone.1 = f32[]{:T(128)} reduce(%mul.1894.clone.1, %constant.1747), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=68} + %mul.1911.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %broadcast.831), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=80} + %constant.1292.clone.1 = f32[]{:T(128)} constant(1) + %add.922.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1292.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=80} + %add.915.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1911.clone.1, %add.922.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=80} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.216, %reduce.219.clone.1, %ne.223.clone.1, %add.915.clone.1) +} + +%region_69.74 (reduce_sum.452: f32[], reduce_sum.456: f32[]) -> f32[] { + %reduce_sum.456 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.456), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_54.59 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { + %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.379 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.450 (param_0.1547: f32[128,4], param_1.1984: f32[], param_2.1887: f32[], param_3.1397: f32[], param_4.984: f32[128,4], param_5.794: f32[], param_6.628: f32[4,128], param_7.567: pred[], param_8.352: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1547 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.1397 = f32[]{:T(128)S(6)} parameter(3) + %mul.1943.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.1397), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.567 = pred[]{:T(512)S(6)} parameter(7) + %select_n.876.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.567), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.628 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.468.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.628), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.794 = f32[]{:T(128)} parameter(5) + %div.954.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.794), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.953.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.468.clone.1, %div.954.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.875.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.876.clone.1, %bitcast.468.clone.1, %div.953.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.202.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.875.clone.1, %select_n.875.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1725 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.696.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1725), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.874.clone.1 = f32[128,4]{0,1:T(4,128)} select(%ne.202.clone.1, %broadcast_in_dim.696.clone.1, %select_n.875.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1546.clone.1 = f32[]{:T(128)} constant(inf) + %eq.619.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1546.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.618.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.874.clone.1, %eq.619.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1545.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.695.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1545.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.873.clone.1 = f32[128,4]{0,1:T(4,128)} select(%eq.618.clone.1, %broadcast_in_dim.695.clone.1, %select_n.874.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1544.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.617.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1544.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.616.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.873.clone.1, %eq.617.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1543.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.694.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1543.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.872.clone.1 = f32[128,4]{0,1:T(4,128)} select(%eq.616.clone.1, %broadcast_in_dim.694.clone.1, %select_n.873.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1539.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.856.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1539.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.1947.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.872.clone.1, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.352 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1547.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.855.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1547.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.1946.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.352, %broadcast.855.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.951.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1947.clone.1, %mul.1946.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1887 = f32[]{:T(128)S(6)} parameter(2) + %div.950.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1887), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.67.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.872.clone.1, %select_n.872.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1542.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.854.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1542.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1945.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.854.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.984 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1541.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.853.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1541.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1944.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.984, %broadcast.853.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.950.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1945.clone.1, %mul.1944.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1984 = f32[]{:T(128)S(6)} parameter(1) + %div.949.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1984), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.948.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.950.clone.1, %div.949.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.64.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.948.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1540.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.851.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1540.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.949.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.851.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.428.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.950.clone.1, %add.949.clone.1), metadata={op_name="multiply.59"} + %div.947.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.951.clone.1, %multiply.428.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1942.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1547, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.948.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.947.clone.1, %mul.1942.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1941.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1943.clone.1, %add.948.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.947.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1547, %mul.1941.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.947.clone.1, %add.947.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.217 = f32[]{:T(128)} reduce(%square.265, %constant.1725), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.221.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1725), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.159 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.217, %add.947.clone.1, %add.950.clone.1, %add.951.clone.1, %reduce.221.clone.1) +} + +%region_66.71 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { + %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_51.56 (reduce_sum.359: f32[], reduce_sum.360: f32[]) -> f32[] { + %reduce_sum.360 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.359 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.361 = f32[]{:T(128)} add(%reduce_sum.359, %reduce_sum.360), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.451 (param_0.1550: f32[128,4], param_1.1987: f32[], param_2.1890: f32[], param_3.1400: f32[], param_4.987: f32[128,4], param_5.797: f32[], param_6.631: f32[4,128], param_7.570: pred[], param_8.355: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1550 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.1400 = f32[]{:T(128)S(6)} parameter(3) + %mul.1970.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.1400), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %param_7.570 = pred[]{:T(512)S(6)} parameter(7) + %select_n.906.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.570), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %param_6.631 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.474.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.631), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + %param_5.797 = f32[]{:T(128)} parameter(5) + %div.978.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.797), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %div.977.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.474.clone.1, %div.978.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=619} + %select_n.905.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.906.clone.1, %bitcast.474.clone.1, %div.977.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=618} + %ne.208.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.905.clone.1, %select_n.905.clone.1), direction=NE, metadata={op_name="jit(train_step)/jit(nan_to_num)/ne" stack_frame_id=621} + %constant.1728 = f32[]{:T(128)} constant(0) + %broadcast_in_dim.714.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1728), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.904.clone.1 = f32[128,4]{0,1:T(4,128)} select(%ne.208.clone.1, %broadcast_in_dim.714.clone.1, %select_n.905.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1588.clone.1 = f32[]{:T(128)} constant(inf) + %eq.643.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1588.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.642.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.904.clone.1, %eq.643.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1587.clone.1 = f32[]{:T(128)} constant(3.40282347e+38) + %broadcast_in_dim.713.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1587.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.903.clone.1 = f32[128,4]{0,1:T(4,128)} select(%eq.642.clone.1, %broadcast_in_dim.713.clone.1, %select_n.904.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1586.clone.1 = f32[]{:T(128)} constant(-inf) + %eq.641.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1586.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %eq.640.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} compare(%select_n.903.clone.1, %eq.641.clone.1), direction=EQ, metadata={op_name="jit(train_step)/jit(nan_to_num)/eq" stack_frame_id=621} + %constant.1585.clone.1 = f32[]{:T(128)} constant(-3.40282347e+38) + %broadcast_in_dim.712.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1585.clone.1), dimensions={}, metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/broadcast_in_dim" stack_frame_id=621} + %select_n.902.clone.1 = f32[128,4]{0,1:T(4,128)} select(%eq.640.clone.1, %broadcast_in_dim.712.clone.1, %select_n.903.clone.1), metadata={op_name="jit(train_step)/jit(nan_to_num)/jit(_where)/select_n" stack_frame_id=621} + %constant.1581.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.866.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1581.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.1974.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.902.clone.1, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=624} + %param_8.355 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1589.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.865.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1589.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.1973.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.355, %broadcast.865.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=625} + %add.968.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1974.clone.1, %mul.1973.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=624} + %param_2.1890 = f32[]{:T(128)S(6)} parameter(2) + %div.974.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1890), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=635} + %integer_pow.70.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.902.clone.1, %select_n.902.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=639} + %constant.1584.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.864.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1584.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1972.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.864.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=642} + %param_4.987 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1583.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.863.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1583.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1971.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.987, %broadcast.863.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=643} + %add.967.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1972.clone.1, %mul.1971.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=642} + %param_1.1987 = f32[]{:T(128)S(6)} parameter(1) + %div.973.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1987), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %div.972.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.967.clone.1, %div.973.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=644} + %sqrt.67.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.972.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=646} + %constant.1582.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.861.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1582.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.966.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.861.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=646} + %multiply.431.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.974.clone.1, %add.966.clone.1), metadata={op_name="multiply.56"} + %div.971.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.968.clone.1, %multiply.431.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=647} + %mul.1969.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1550, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=649} + %add.965.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.971.clone.1, %mul.1969.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=650} + %mul.1968.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1970.clone.1, %add.965.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=62} + %add.964.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1550, %mul.1968.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=654} + %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.964.clone.1, %add.964.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=685} + %reduce.218 = f32[]{:T(128)} reduce(%square.266, %constant.1728), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=686} + %reduce.222.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1728), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=662} + ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.218, %add.964.clone.1, %add.967.clone.1, %add.968.clone.1, %reduce.222.clone.1) +} + +%fused_computation.460 (param_0.1375: f32[4,128], param_1.1540: f32[4,128]) -> f32[4,128] { + %param_0.1375 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1540 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1269 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.837 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1269), dimensions={}, metadata={op_name="broadcast.399"} + %div.833 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1540, %broadcast.837), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=346} + %constant.1267 = f32[]{:T(128)} constant(1e-06) + %add.935 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1267), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=347} + %add.934 = f32[4,128]{1,0:T(4,128)} add(%div.833, %add.935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=347} + %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=348} + %div.820 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=348} + %constant.1264 = f32[]{:T(128)} constant(-0.5) + %mul.1919 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1264), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=348} + %mul.1910 = f32[4,128]{1,0:T(4,128)} multiply(%div.820, %mul.1919), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=348} + %mul.1909 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1375, %mul.1910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=348} + %constant.1263 = f32[]{:T(128)} constant(0.0009765625) + %mul.1918 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1263), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=345} + ROOT %mul.1908 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1909, %mul.1918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=345} +} + +%region_0.1 (reduce_sum.137: s32[], reduce_sum.138: s32[]) -> s32[] { + %reduce_sum.138 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.137 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.139 = s32[]{:T(128)} add(%reduce_sum.137, %reduce_sum.138), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=65}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.464 (param_0.1394: pred[4,128]) -> s32[] { + %param_0.1394 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1397 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1394), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=65} + %constant.1290 = s32[]{:T(128)} constant(0) + ROOT %reduce.220 = s32[]{:T(128)} reduce(%convert_element_type.1397, %constant.1290), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=65} +} + +%fused_computation.467 (param_0.1377: f32[4,128]) -> f32[4,128] { + %param_0.1377 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1270 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.829 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1270), dimensions={}, metadata={op_name="broadcast.399"} + %div.825 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1377, %broadcast.829), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=346} + %constant.1268 = f32[]{:T(128)} constant(1e-06) + %add.924 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1268), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=347} + %add.921 = f32[4,128]{1,0:T(4,128)} add(%div.825, %add.924), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=347} + ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.921), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=348} +} + +%fused_computation.468 (param_0.1378: pred[4,128], param_1.2001: f32[]) -> f32[4,128] { + %param_0.1378 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.2001 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.534 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.2001), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=68} + %constant.1746 = f32[]{:T(128)} constant(0) + %broadcast.833 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1746), dimensions={}, metadata={op_name="broadcast.99"} + ROOT %mul.1920 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1378, %broadcast_in_dim.534, %broadcast.833), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=69} +} + +%fused_computation.470 () -> f32[64] { + %constant.1273 = f32[]{:T(128)} constant(1e+06) + %broadcast.840 = f32[64]{0:T(128)} broadcast(%constant.1273), dimensions={}, metadata={op_name="broadcast.390"} + %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=165} + %constant.1272 = s32[]{:T(128)} constant(2) + %broadcast.839 = s32[64]{0:T(128)} broadcast(%constant.1272), dimensions={}, metadata={op_name="broadcast.391"} + %mul.1921 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.839), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=166} + %convert_element_type.1398 = f32[64]{0:T(128)} convert(%mul.1921), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=166} + %constant.1271 = f32[]{:T(128)} constant(0.0078125) + %broadcast.838 = f32[64]{0:T(128)} broadcast(%constant.1271), dimensions={}, metadata={op_name="broadcast.392"} + %div.834 = f32[64]{0:T(128)} multiply(%convert_element_type.1398, %broadcast.838), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=166} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.840, %div.834), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=167} +} + +%fused_computation.471 (param_0.1392: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1392 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1399 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1392), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=160} + %bitcast.412 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=396} + ROOT %tuple.162 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.412, %convert_element_type.1399) +} + +%fused_computation.474 (param_0.1537: f32[2048,4]) -> bf16[4,2048] { + %param_0.1537 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.531 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1537), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.145 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.531) +} + +%fused_computation.475 (param_0.1536: f32[2048,4]) -> bf16[4,2048] { + %param_0.1536 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.530 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1536), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.147 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.530) +} + +%fused_computation.476 (param_0.1538: f32[128,4]) -> bf16[4,128] { + %param_0.1538 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.532 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1538), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.149 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.532) +} + +%fused_computation.477 (param_0.1539: f32[128,4]) -> bf16[4,128] { + %param_0.1539 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.533 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1539), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.151 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.533) +} + +%region_8.11 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { + %reduce_max.8 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + %reduce_max.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=369}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.290.clone.clone (param_0.1523: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1523 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.526 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1523), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} +} + +%fused_computation.392.clone.clone (param_0.1524: f32[4,128], param_1.1968: bf16[4,128,2048], param_2.1852: bf16[2048]) -> bf16[4,128,2048] { + %param_2.1852 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %dot_general.476 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1852), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_1.1968 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1444 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1968), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=344} + %param_0.1524 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2067 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1524), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %mul.2066 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1444, %mul.2067), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=349} + %convert_element_type.1443 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2066), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=350} + ROOT %dot_general.475 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.476, %convert_element_type.1443), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} +} + +%fused_computation.478 (param_0.1540: bf16[151936,2048], param_1.1977: f32[4,128], param_2.1876: bf16[4,128,2048], param_3.1390: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { + %param_1.1977 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1876 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.1390 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.270.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1977, %param_2.1876, %param_3.1390), kind=kLoop, calls=%fused_computation.392.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=89} + %param_0.1540 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %fusion.253.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1540), kind=kLoop, calls=%fused_computation.290.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=93} + %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.270.clone.1, %fusion.253.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=357} + %constant.1718 = bf16[]{:T(256)} constant(-inf) + %reduce.223 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1718), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=369} + ROOT %tuple.164 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.223, %convolution.85.clone.1) +} + +%fused_computation.479 (param_0.1535: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1535 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %bitcast.529 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1535), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=110} + ROOT %convert.153 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.529) +} + +%convert_element_type.767.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %rhs.1 = bf16[] parameter(1) + %lhs.1 = bf16[] parameter(0) + ROOT %add.755 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.155.clone.clone (param_0.1711: bf16[4,2048], param_1.2113: s32[]) -> bf16[2048] { + %param_0.1711 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.2113 = s32[]{:T(128)S(6)} parameter(1) + %constant.1884 = s32[]{:T(128)} constant(0) + %dynamic_slice.388 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1711, %param_1.2113, %constant.1884), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[false,false]},"used_scoped_memory_configs":[]} + %constant.1885 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=404} + ROOT %reduce.244 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.388, %constant.1885), dimensions={0}, to_apply=%convert_element_type.767.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=412} +} + +%region_14.16 (reduce_sum.204: f32[], reduce_sum.205: f32[]) -> f32[] { + %reduce_sum.205 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.204 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.206 = f32[]{:T(128)} add(%reduce_sum.204, %reduce_sum.205), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=422}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.58.clone.clone (param_0.1712: bf16[4,4,128,2048], param_1.2114: s32[]) -> f32[4,128] { + %param_0.1712 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.2114 = s32[]{:T(128)S(6)} parameter(1) + %constant.1886 = s32[]{:T(128)} constant(0) + %dynamic_slice.389 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1712, %param_1.2114, %constant.1886, %constant.1886, %constant.1886), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.633 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.389), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=177} + %convert_element_type.1570 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=414} + %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1570, %convert_element_type.1570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=415} + %constant.1887 = f32[]{:T(128)} constant(0) + ROOT %reduce.245 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1887), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=416} +} + +%fused_computation.179.clone.1.clone (param_0.1713: f32[4,128]) -> f32[4,128] { + %param_0.1713 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1889 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1889), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %div.1077 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1713, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=416} + %constant.1888 = f32[]{:T(128)} constant(1e-06) + %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1888), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %add.1039 = f32[4,128]{1,0:T(4,128)} add(%div.1077, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=423} + ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1039), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=424} +} + +%region_15.17 (reduce_sum.207: f32[], reduce_sum.211: f32[]) -> f32[] { + %reduce_sum.211 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.207 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.212 = f32[]{:T(128)} add(%reduce_sum.207, %reduce_sum.211), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=443}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1727: bf16[4,2048,16,128], param_1.2124: s32[]) -> bf16[2048,16,128,1] { + %param_0.1727 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.2124 = s32[]{:T(128)S(6)} parameter(1) + %constant.1900 = s32[]{:T(128)} constant(0) + %dynamic_slice.395 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1727, %param_1.2124, %constant.1900, %constant.1900, %constant.1900), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.644 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=433} +} + +%fused_computation.114.clone.clone.clone.clone (param_0.1728: f32[4,128], param_1.2125: bf16[4,4,128,2048], param_2.1976: s32[], param_3.1459: bf16[2048]) -> bf16[4,128,2048,1] { + %param_3.1459 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %dot_general.571 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.1459), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + %param_1.2125 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1976 = s32[]{:T(128)S(6)} parameter(2) + %constant.1901 = s32[]{:T(128)} constant(0) + %dynamic_slice.396 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.2125, %param_2.1976, %constant.1901, %constant.1901, %constant.1901), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.646 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.396), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=177} + %convert_element_type.1581 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.646), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=414} + %param_0.1728 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2256 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1728), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=425} + %mul.2255 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1581, %mul.2256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=425} + %convert_element_type.1580 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2255), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=426} + %dot_general.570 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.571, %convert_element_type.1580), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + ROOT %bitcast.645 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} +} + +%fused_computation.61.clone.clone (param_0.1729: bf16[4,2048,16,128], param_1.2126: s32[], param_2.1977: f32[4,128], param_3.1460: bf16[4,4,128,2048], param_4.1030: bf16[2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { + %param_2.1977 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1460 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.2126 = s32[]{:T(128)S(6)} parameter(1) + %param_4.1030 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.74.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1977, %param_3.1460, %param_1.2126, %param_4.1030), kind=kLoop, calls=%fused_computation.114.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + %param_0.1729 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1729, %param_1.2126), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=433} + %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.74.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=435} + %convert_element_type.1582 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=436} + %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1582, %convert_element_type.1582), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=437} + %constant.1902 = f32[]{:T(128)} constant(0) + %reduce.247 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1902), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=438} + ROOT %tuple.208 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.247, %convolution.44.clone.3) +} + +%fused_computation.151.clone.1.clone (param_0.1730: f32[4,128,16]) -> f32[4,128,16] { + %param_0.1730 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1903 = f32[]{:T(128)} constant(0.0078125) + %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1903), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %div.1079 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1730, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=438} + %constant.1904 = f32[]{:T(128)} constant(1e-06) + %add.1044 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1904), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=444} + %add.1043 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1079, %add.1044), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=444} + ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1043), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=445} +} + +%fused_computation.182.clone.clone (param_0.1726: bf16[4,128], param_1.2123: s32[]) -> bf16[128] { + %param_0.1726 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.2123 = s32[]{:T(128)S(6)} parameter(1) + %constant.1899 = s32[]{:T(128)} constant(0) + %dynamic_slice.394 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1726, %param_1.2123, %constant.1899), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[false,false]},"used_scoped_memory_configs":[]} + ROOT %bitcast.643 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.394), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=408} +} + +%fused_computation.121.clone.1.clone (param_0.1731: f32[4,128,16], param_1.2127: bf16[4,128,16,128], param_2.1978: bf16[128]) -> bf16[4,128,16,128] { + %param_2.1978 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %dot_general.573 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1978), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=409} + %param_1.2127 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1584 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.2127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=436} + %param_0.1731 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2258 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1731), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=446} + %mul.2257 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1584, %mul.2258), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=446} + %convert_element_type.1583 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2257), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=447} + ROOT %dot_general.572 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.573, %convert_element_type.1583), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=409} +} + +%fused_computation.90.clone.clone (param_0.1732: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { + %param_0.1732 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1732), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=457} + %neg.129 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=458} + %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1732), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=457} + ROOT %tuple.209 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) +} + +%fused_computation.187.clone.clone () -> f32[64] { + %constant.1878 = f32[]{:T(128)} constant(1e+06) + %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1878), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=449} + %constant.1877 = s32[]{:T(128)} constant(2) + %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1877), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %mul.2242 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=450} + %convert_element_type.1568 = f32[64]{0:T(128)} convert(%mul.2242), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=450} + %constant.1879 = f32[]{:T(128)} constant(0.0078125) + %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1879), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %div.1073 = f32[64]{0:T(128)} multiply(%convert_element_type.1568, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=450} + ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.104, %div.1073), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=451} +} + +%fused_computation.143.clone.clone (param_0.1706: f32[64], param_1.2109: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.2109 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.1076 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.2109), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=396} + %param_0.1706 = f32[64]{0:T(128)S(1)} parameter(0) + %div.1075 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1706), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=396} + %div.1074 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.1076, %div.1075), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=396} + %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.1074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=452} + %convert_element_type.1569 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=452} + %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.1074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=460} + %convert_element_type.1189.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=460} + ROOT %tuple.205 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1569, %convert_element_type.1189.clone.3) +} + +%fused_computation.146.clone.1.clone (param_0.1707: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1707 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1880 = bf16[]{:T(256)} constant(-inf) + %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1707, %constant.1880), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=461} + %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1707, %constant.1880), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=461} + %maximum.53 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.69, %pad.68), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=461} + ROOT %bitcast.630 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=456} +} + +%fused_computation.145.clone.1.clone (param_0.1722: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1722 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1897 = bf16[]{:T(256)} constant(-inf) + %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1722, %constant.1897), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=453} + %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1722, %constant.1897), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=453} + %maximum.54 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.71, %pad.70), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=453} + ROOT %bitcast.641 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=455} +} + +%fused_computation.94.clone.clone (param_0.1733: bf16[4,128,16,64], param_1.2128: bf16[4,128,16,64], param_2.1979: bf16[4,128,128], param_3.1461: bf16[4,128,128], param_4.1031: f32[4,128,16], param_5.823: bf16[4,128,16,128], param_6.652: bf16[128]) -> bf16[4,16,128,128] { + %param_6.652 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) + %dot_general.575 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.652), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=409} + %param_5.823 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) + %convert_element_type.1586 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_5.823), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=436} + %param_4.1031 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(4) + %mul.2265 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_4.1031), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=446} + %mul.2264 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1586, %mul.2265), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=446} + %convert_element_type.1585 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=447} + %dot_general.574 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.575, %convert_element_type.1585), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=409} + %param_3.1461 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2263 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.1461), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=455} + %mul.2261 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.574, %mul.2263), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=455} + %param_1.2128 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1905 = bf16[]{:T(256)} constant(-inf) + %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.2128, %constant.1905), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=459} + %param_0.1733 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1733, %constant.1905), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=459} + %maximum.56 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.75, %pad.74), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=459} + %param_2.1979 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2262 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1979), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=456} + %mul.2260 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2262), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=456} + %add.1045 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2261, %mul.2260), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=462} + %constant.1906 = bf16[]{:T(256)} constant(0.08838) + %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1906), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %mul.2259 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1045, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=463} + ROOT %bitcast.647 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2259), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=471} +} + +%region_16.18 (reduce_sum.213: f32[], reduce_sum.214: f32[]) -> f32[] { + %reduce_sum.214 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.213 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.218 = f32[]{:T(128)} add(%reduce_sum.213, %reduce_sum.214), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=497}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.69.clone.1.clone.clone.clone.clone (param_0.1718: bf16[4,2048,8,128], param_1.2118: s32[]) -> bf16[2048,8,128,1] { + %param_0.1718 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.2118 = s32[]{:T(128)S(6)} parameter(1) + %constant.1892 = s32[]{:T(128)} constant(0) + %dynamic_slice.392 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1718, %param_1.2118, %constant.1892, %constant.1892, %constant.1892), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + ROOT %bitcast.638 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=489} +} + +%fused_computation.113.clone.clone.clone.clone (param_0.1719: f32[4,128], param_1.2119: bf16[4,4,128,2048], param_2.1972: s32[], param_3.1456: bf16[2048]) -> bf16[4,128,2048,1] { + %param_3.1456 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %dot_general.565 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.1456), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + %param_1.2119 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1972 = s32[]{:T(128)S(6)} parameter(2) + %constant.1893 = s32[]{:T(128)} constant(0) + %dynamic_slice.393 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.2119, %param_2.1972, %constant.1893, %constant.1893, %constant.1893), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,true,true,true]},"used_scoped_memory_configs":[]} + %bitcast.640 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.393), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=177} + %convert_element_type.1574 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.640), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=414} + %param_0.1719 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2246 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1719), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=425} + %mul.2245 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1574, %mul.2246), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=425} + %convert_element_type.1573 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2245), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=426} + %dot_general.564 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.565, %convert_element_type.1573), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + ROOT %bitcast.639 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} +} + +%fused_computation.84.clone.clone (param_0.1720: bf16[4,2048,8,128], param_1.2120: s32[], param_2.1973: f32[4,128], param_3.1457: bf16[4,4,128,2048], param_4.1028: bf16[2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { + %param_2.1973 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.1457 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.2120 = s32[]{:T(128)S(6)} parameter(1) + %param_4.1028 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.73.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1973, %param_3.1457, %param_1.2120, %param_4.1028), kind=kLoop, calls=%fused_computation.113.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=413} + %param_0.1720 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.87.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1720, %param_1.2120), kind=kLoop, calls=%fused_computation.69.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=489} + %convolution.50.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.73.clone.3, %fusion.87.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=491} + %convert_element_type.1575 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.50.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=492} + %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1575, %convert_element_type.1575), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=493} + %constant.1894 = f32[]{:T(128)} constant(0) + %reduce.246 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1894), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=494} + ROOT %tuple.206 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.246, %convolution.50.clone.3) +} + +%fused_computation.154.clone.1.clone (param_0.1721: f32[4,128,8]) -> f32[4,128,8] { + %param_0.1721 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1895 = f32[]{:T(128)} constant(0.0078125) + %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1895), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=177} + %div.1078 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1721, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=494} + %constant.1896 = f32[]{:T(128)} constant(1e-06) + %add.1041 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1896), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=498} + %add.1040 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1078, %add.1041), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=498} + ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1040), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=499} +} + +%fused_computation.184.clone.clone (param_0.1705: bf16[4,128], param_1.2108: s32[]) -> bf16[128] { + %param_0.1705 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.2108 = s32[]{:T(128)S(6)} parameter(1) + %constant.1876 = s32[]{:T(128)} constant(0) + %dynamic_slice.385 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1705, %param_1.2108, %constant.1876), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=177}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[false,false]},"used_scoped_memory_configs":[]} + ROOT %bitcast.629 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=484} +} + +%fused_computation.139.clone.1.clone (param_0.1723: f32[4,128,8], param_1.2121: bf16[4,128,8,128], param_2.1974: bf16[128]) -> bf16[4,128,8,128] { + %param_2.1974 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %dot_general.567 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1974), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=485} + %param_1.2121 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1577 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.2121), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=492} + %param_0.1723 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2248 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1723), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=500} + %mul.2247 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1577, %mul.2248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=500} + %convert_element_type.1576 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2247), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=501} + ROOT %dot_general.566 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.567, %convert_element_type.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=485} +} + +%fused_computation.126.clone.clone (param_0.1724: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1724 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1724), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=508} + %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=509} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1724), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=508} + ROOT %tuple.207 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.129.clone.clone (param_0.1725: bf16[4,128,8,64], param_1.2122: bf16[4,128,8,64], param_2.1975: bf16[4,128,128], param_3.1458: bf16[4,128,128], param_4.1029: f32[4,128,8], param_5.822: bf16[4,128,8,128], param_6.651: bf16[128]) -> bf16[4,8,128,128] { + %param_6.651 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) + %dot_general.569 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.651), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=485} + %param_5.822 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) + %convert_element_type.1579 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_5.822), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=492} + %param_4.1029 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(4) + %mul.2254 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_4.1029), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=500} + %mul.2253 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1579, %mul.2254), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=500} + %convert_element_type.1578 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2253), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=501} + %dot_general.568 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.569, %convert_element_type.1578), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=485} + %param_3.1458 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2252 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.1458), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=506} + %mul.2250 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.568, %mul.2252), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=506} + %param_1.2122 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1898 = bf16[]{:T(256)} constant(-inf) diff --git a/tests/utils/update_hlo_references.py b/tests/utils/update_hlo_references.py new file mode 100644 index 0000000000..ba7fc2b3f2 --- /dev/null +++ b/tests/utils/update_hlo_references.py @@ -0,0 +1,47 @@ +# Copyright 2026 Google LLC +"""Helper script to dynamically recreate reference files for HLO validations checks rules. + +This tool dynamically removes existing reference HLO files and executes the test suite +in order to recreate and validate them. The secure CI workflow `.github/workflows/update_reference_hlo.yml` +uses this script from an isolated test runner environment to bridge dynamic artifact +extractions setup logic and commit auto updates to PR. +""" + +import os +import subprocess +import glob + + +def main(): + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + test_dir = os.path.join(base_dir, "tests/integration/") + test_file = os.path.join(test_dir, "hlo_diff_test.py") + + reference_pattern = os.path.join(base_dir, "tests/utils/reference_hlo_*.txt") + existing_files = glob.glob(reference_pattern) + + if existing_files: + for reference_file in existing_files: + print(f"Removing existing reference file: {reference_file}") + os.remove(reference_file) + else: + print(f"No existing reference files found matching {reference_pattern}.") + + print(f"Running test suite to generate new references: {test_file}") + + env = os.environ.copy() + env["PYTHONPATH"] = os.path.join(base_dir, "src/") + + result = subprocess.run(["pytest", test_file, "-v"], env=env, capture_output=True, text=True, check=False) + + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + if result.returncode == 0: + print("Reference files updated successfully.") + else: + print(f"Failed to update reference files. Test exited with code {result.returncode}") + + +if __name__ == "__main__": + main()