|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +"""Tests for HLO Graph Diff Verification. |
| 3 | +
|
| 4 | +Integrates validation checks system that detects unintended compiler graph transformations |
| 5 | +or model graph deviations without breaking isolated PR constraints validations. |
| 6 | +
|
| 7 | +This is part 1 of the automated compiler validation pipeline: |
| 8 | +- Ensures the PR isn't changing compiler performance and model graph generation unintentionally. |
| 9 | +- Checks current runtime dumps against base reference checkpoints (in tests/utils/reference_hlo_*.txt). |
| 10 | +""" |
| 11 | + |
| 12 | +import os |
| 13 | +import unittest |
| 14 | +import shutil |
| 15 | +import jax |
| 16 | +import pytest |
| 17 | +from maxtext.trainers.pre_train import train_compile |
| 18 | +from absl import flags |
| 19 | +import difflib |
| 20 | +import re |
| 21 | + |
| 22 | +pytestmark = [pytest.mark.integration_test] |
| 23 | + |
| 24 | +# Maximum number of filtered lines from HLO to compare or store in the reference file |
| 25 | +MAX_LINES = 2000 |
| 26 | + |
| 27 | +# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata" |
| 28 | +SECTION_CONTENT_REGEX = re.compile(r'^\d+ (?:"[^"]*"|\{[^\}]*\})$') |
| 29 | + |
| 30 | + |
| 31 | +# Filters and normalizes an HLO line to ignore environment-specific details. |
| 32 | +def filter_line(line): |
| 33 | + # Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata" |
| 34 | + if SECTION_CONTENT_REGEX.match(line): |
| 35 | + return None |
| 36 | + return line |
| 37 | + |
| 38 | + |
| 39 | +@pytest.mark.tpu_backend |
| 40 | +class TestHloDiff: |
| 41 | + """Tests for HLO Graph Diff Verification.""" |
| 42 | + |
| 43 | + def setup_method(self): |
| 44 | + # Disable cache to ensure compilation occurs every time |
| 45 | + jax.config.update("jax_enable_compilation_cache", False) |
| 46 | + |
| 47 | + def check_files_equal_ignoring_sections(self, file_path1, file_path2): |
| 48 | + """Asserts that two text files are identical, ignoring specific sections.""" |
| 49 | + |
| 50 | + def get_filtered_lines(file_path): |
| 51 | + with open(file_path, "r", encoding="utf-8") as f: |
| 52 | + lines = [] |
| 53 | + for line in f: |
| 54 | + line = filter_line(line) |
| 55 | + if line is None: |
| 56 | + continue |
| 57 | + lines.append(line) |
| 58 | + return lines |
| 59 | + |
| 60 | + lines1 = get_filtered_lines(file_path1)[:MAX_LINES] |
| 61 | + lines2 = get_filtered_lines(file_path2)[:MAX_LINES] |
| 62 | + |
| 63 | + if lines1 == lines2: |
| 64 | + return True |
| 65 | + |
| 66 | + print("\n" + "=" * 20 + " HLO Diff " + "=" * 20) |
| 67 | + for line in difflib.unified_diff(lines1, lines2, fromfile=file_path1, tofile=file_path2): |
| 68 | + print(line, end="") |
| 69 | + print("=" * 50 + "\n") |
| 70 | + return False |
| 71 | + |
| 72 | + @pytest.mark.parametrize( |
| 73 | + "test_id, config_file, overrides", |
| 74 | + [ |
| 75 | + ( |
| 76 | + "deepseek3", |
| 77 | + "src/maxtext/configs/models/deepseek3-test.yml", |
| 78 | + { |
| 79 | + "compile_topology": "v6e-4", |
| 80 | + "base_num_decoder_layers": 4, |
| 81 | + "per_device_batch_size": 1, |
| 82 | + "max_target_length": 128, |
| 83 | + }, |
| 84 | + ), |
| 85 | + ( |
| 86 | + "llama3_8b", |
| 87 | + "src/maxtext/configs/models/llama3-8b.yml", |
| 88 | + { |
| 89 | + "compile_topology": "v6e-4", |
| 90 | + "base_num_decoder_layers": 4, |
| 91 | + "per_device_batch_size": 1, |
| 92 | + "max_target_length": 128, |
| 93 | + }, |
| 94 | + ), |
| 95 | + ( |
| 96 | + "qwen3_1.7b", |
| 97 | + "src/maxtext/configs/models/qwen3-1.7b.yml", |
| 98 | + { |
| 99 | + "compile_topology": "v6e-4", |
| 100 | + "base_num_decoder_layers": 4, |
| 101 | + "per_device_batch_size": 1, |
| 102 | + "max_target_length": 128, |
| 103 | + }, |
| 104 | + ), |
| 105 | + ], |
| 106 | + ) |
| 107 | + def test_hlo_diff(self, test_id, config_file, overrides): |
| 108 | + """Test HLO diff for parameterized configurations.""" |
| 109 | + local_landing_dir = os.path.join(os.path.dirname(__file__), f"hlo_diff_dump_{test_id}") |
| 110 | + |
| 111 | + # Clean up before run |
| 112 | + if os.path.exists(local_landing_dir): |
| 113 | + shutil.rmtree(local_landing_dir) |
| 114 | + os.makedirs(local_landing_dir, exist_ok=True) |
| 115 | + |
| 116 | + try: |
| 117 | + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) |
| 118 | + config_path = os.path.join(base_dir, config_file) |
| 119 | + |
| 120 | + # Arguments for train_compile |
| 121 | + test_args = [ |
| 122 | + None, |
| 123 | + config_path, |
| 124 | + "steps=1", |
| 125 | + "dataset_type=synthetic", |
| 126 | + "override_model_config=true", |
| 127 | + "compile_topology_num_slices=1", |
| 128 | + ] |
| 129 | + |
| 130 | + for key, value in overrides.items(): |
| 131 | + test_args.append(f"{key}={value}") |
| 132 | + |
| 133 | + test_args.append( |
| 134 | + f'compile_xla_flags="--xla_dump_to={local_landing_dir} ' |
| 135 | + f'--xla_dump_hlo_as_text=True --xla_dump_hlo_module_re=jit_train_step"' |
| 136 | + ) |
| 137 | + |
| 138 | + print(f"Running train_compile for {test_id} with args:", test_args) |
| 139 | + |
| 140 | + # Initialize flags if not already done to avoid ParseCommandLine error |
| 141 | + if not flags.FLAGS.is_parsed(): |
| 142 | + flags.FLAGS(["dummy_program", "--grain_train_files=dummy"], known_only=True) |
| 143 | + |
| 144 | + train_compile.main(tuple(test_args)) |
| 145 | + |
| 146 | + print(f"Files in landing dir for {test_id}:", os.listdir(local_landing_dir)) |
| 147 | + |
| 148 | + # Locate the dumped HLO file |
| 149 | + files = os.listdir(local_landing_dir) |
| 150 | + matches = [f for f in files if f.endswith(".after_optimizations.txt")] |
| 151 | + dumped_hlo = matches[0] if matches else None |
| 152 | + |
| 153 | + assert dumped_hlo, f"Dumped HLO file not found for {test_id}!" |
| 154 | + |
| 155 | + dumped_hlo_path = os.path.join(local_landing_dir, dumped_hlo) |
| 156 | + |
| 157 | + reference_hlo_path = os.path.join(os.path.dirname(__file__), f"../utils/reference_hlo_{test_id}.txt") |
| 158 | + |
| 159 | + if not os.path.exists(reference_hlo_path): |
| 160 | + print(f"Reference file not found. Creating it at {reference_hlo_path}") |
| 161 | + with ( |
| 162 | + open(dumped_hlo_path, "r", encoding="utf-8") as f_in, |
| 163 | + open(reference_hlo_path, "w", encoding="utf-8") as f_out, |
| 164 | + ): |
| 165 | + count = 0 |
| 166 | + for line in f_in: |
| 167 | + line = filter_line(line) |
| 168 | + if line is None: |
| 169 | + continue |
| 170 | + f_out.write(line) |
| 171 | + count += 1 |
| 172 | + if count >= MAX_LINES: |
| 173 | + break |
| 174 | + print(f"Reference file created for {test_id}. Please commit it.") |
| 175 | + else: |
| 176 | + assert self.check_files_equal_ignoring_sections(dumped_hlo_path, reference_hlo_path), ( |
| 177 | + f"HLO deviation detected in {test_id}! If this is intended, please run the 'Update HLO Reference' workflow " |
| 178 | + f"from Github Actions against your branch to update the reference HLO." |
| 179 | + ) |
| 180 | + print(f"HLO comparison successful against {reference_hlo_path}") |
| 181 | + |
| 182 | + finally: |
| 183 | + if os.path.exists(local_landing_dir): |
| 184 | + shutil.rmtree(local_landing_dir) |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + unittest.main() |
0 commit comments