Skip to content

Commit cd0562b

Browse files
committed
Add compiler checks validations and automatic artifact PR auto extraction updates pipelines integrations execution rules
1 parent 586e692 commit cd0562b

8 files changed

Lines changed: 6314 additions & 7 deletions

.github/workflows/run_tests_against_package.yml

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ on:
7070
description: 'If false, maxtext_sha must be provided for checkout'
7171
type: boolean
7272
default: false
73+
is_update_hlo:
74+
required: false
75+
type: boolean
76+
default: false
7377

7478
permissions:
7579
contents: read
@@ -167,13 +171,19 @@ jobs:
167171
else
168172
SPLIT_ARGS=""
169173
fi
170-
$PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \
171-
-v \
172-
-m "${FINAL_PYTEST_MARKER}" \
173-
--durations=0 \
174-
$PYTEST_COV_ARGS \
175-
$SPLIT_ARGS \
176-
${INPUTS_PYTEST_EXTRA_ARGS}
174+
175+
# Setup substitution: If manually updating HLO, skip tests execution and run only the update script instead!
176+
if [ "${INPUTS_IS_UPDATE_HLO}" == "true" ]; then
177+
python3 tests/utils/update_hlo_references.py
178+
else
179+
$PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \
180+
-v \
181+
-m "${FINAL_PYTEST_MARKER}" \
182+
--durations=0 \
183+
$PYTEST_COV_ARGS \
184+
$SPLIT_ARGS \
185+
${INPUTS_PYTEST_EXTRA_ARGS}
186+
fi
177187
178188
env:
179189
PYTHONPATH: "${{ github.workspace }}/src"
@@ -185,6 +195,14 @@ jobs:
185195
INPUTS_WORKER_GROUP: ${{ inputs.worker_group }}
186196
INPUTS_PYTEST_EXTRA_ARGS: ${{ inputs.pytest_extra_args }}
187197
INPUTS_MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }}
198+
INPUTS_IS_UPDATE_HLO: ${{ inputs.is_update_hlo }}
199+
- name: Upload Reference HLO
200+
if: ${{ inputs.is_update_hlo }}
201+
uses: actions/upload-artifact@v4
202+
with:
203+
name: reference-hlo
204+
path: tests/utils/reference_hlo_*.txt
205+
if-no-files-found: ignore
188206
- name: Upload results to Codecov
189207
if: ${{ !inputs.maxtext_installed }} # Skip code coverage upload for maxtext image testing
190208
uses: codecov/codecov-action@v5

.github/workflows/run_tests_coordinator.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ on:
4747
description: 'If false, maxtext_sha must be provided for checkout'
4848
type: boolean
4949
default: false
50+
is_update_hlo:
51+
required: false
52+
type: boolean
53+
default: false
5054

5155
permissions:
5256
contents: read
@@ -150,3 +154,4 @@ jobs:
150154
worker_group: ${{ matrix.worker_group }}
151155
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
152156
maxtext_sha: ${{ inputs.maxtext_sha }}
157+
is_update_hlo: ${{ inputs.is_update_hlo }}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: "Update HLO References (for hlo_diff_test.py)"
2+
3+
on:
4+
workflow_dispatch:
5+
permissions:
6+
contents: read
7+
8+
jobs:
9+
build-wheel:
10+
uses: ./.github/workflows/build_package.yml
11+
with:
12+
device_type: tpu
13+
device_name: v6e-4
14+
cloud_runner: linux-x86-n2-16-buildkit
15+
16+
run-tests:
17+
needs: build-wheel
18+
uses: ./.github/workflows/run_tests_coordinator.yml
19+
with:
20+
flavor: tpu-integration
21+
base_image: maxtext-unit-test-tpu:py312
22+
is_scheduled_run: false
23+
maxtext_sha: ${{ github.sha }}
24+
is_update_hlo: true
25+
26+
commit-changes:
27+
needs: run-tests # Wait for tests to finish
28+
runs-on: ubuntu-latest
29+
permissions:
30+
contents: write
31+
steps:
32+
- name: Checkout code
33+
uses: actions/checkout@v4
34+
with:
35+
ref: ${{ github.ref }}
36+
37+
- name: Download Reference HLO
38+
uses: actions/download-artifact@v4
39+
with:
40+
name: reference-hlo
41+
path: tests/utils/
42+
43+
- name: Commit and Push changes
44+
run: |
45+
git config --global user.name "github-actions[bot]"
46+
git config --global user.email "github-actions[bot]@users.noreply.github.com"
47+
git add tests/utils/reference_hlo_*.txt
48+
git commit -m "Update reference HLO from CI artifact"
49+
git push

tests/integration/hlo_diff_test.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2026 Google LLC
2+
"""Tests for HLO Graph Diff Verification.
3+
4+
Integrates validation checks system that detects unintended compiler graph transformations
5+
or model graph deviations without breaking isolated PR constraints validations.
6+
7+
This is part 1 of the automated compiler validation pipeline:
8+
- Ensures the PR isn't changing compiler performance and model graph generation unintentionally.
9+
- Checks current runtime dumps against base reference checkpoints (in tests/utils/reference_hlo_*.txt).
10+
"""
11+
12+
import os
13+
import unittest
14+
import shutil
15+
import jax
16+
import pytest
17+
from maxtext.trainers.pre_train import train_compile
18+
from absl import flags
19+
import difflib
20+
import re
21+
22+
pytestmark = [pytest.mark.integration_test]
23+
24+
# Maximum number of filtered lines from HLO to compare or store in the reference file
25+
MAX_LINES = 2000
26+
27+
# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata"
28+
SECTION_CONTENT_REGEX = re.compile(r'^\d+ (?:"[^"]*"|\{[^\}]*\})$')
29+
30+
31+
# Filters and normalizes an HLO line to ignore environment-specific details.
32+
def filter_line(line):
33+
# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata"
34+
if SECTION_CONTENT_REGEX.match(line):
35+
return None
36+
return line
37+
38+
39+
@pytest.mark.tpu_backend
40+
class TestHloDiff:
41+
"""Tests for HLO Graph Diff Verification."""
42+
43+
def setup_method(self):
44+
# Disable cache to ensure compilation occurs every time
45+
jax.config.update("jax_enable_compilation_cache", False)
46+
47+
def check_files_equal_ignoring_sections(self, file_path1, file_path2):
48+
"""Asserts that two text files are identical, ignoring specific sections."""
49+
50+
def get_filtered_lines(file_path):
51+
with open(file_path, "r", encoding="utf-8") as f:
52+
lines = []
53+
for line in f:
54+
line = filter_line(line)
55+
if line is None:
56+
continue
57+
lines.append(line)
58+
return lines
59+
60+
lines1 = get_filtered_lines(file_path1)[:MAX_LINES]
61+
lines2 = get_filtered_lines(file_path2)[:MAX_LINES]
62+
63+
if lines1 == lines2:
64+
return True
65+
66+
print("\n" + "=" * 20 + " HLO Diff " + "=" * 20)
67+
for line in difflib.unified_diff(lines1, lines2, fromfile=file_path1, tofile=file_path2):
68+
print(line, end="")
69+
print("=" * 50 + "\n")
70+
return False
71+
72+
@pytest.mark.parametrize(
73+
"test_id, config_file, overrides",
74+
[
75+
(
76+
"deepseek3",
77+
"src/maxtext/configs/models/deepseek3-test.yml",
78+
{
79+
"compile_topology": "v6e-4",
80+
"base_num_decoder_layers": 4,
81+
"per_device_batch_size": 1,
82+
"max_target_length": 128,
83+
},
84+
),
85+
(
86+
"llama3_8b",
87+
"src/maxtext/configs/models/llama3-8b.yml",
88+
{
89+
"compile_topology": "v6e-4",
90+
"base_num_decoder_layers": 4,
91+
"per_device_batch_size": 1,
92+
"max_target_length": 128,
93+
},
94+
),
95+
(
96+
"qwen3_1.7b",
97+
"src/maxtext/configs/models/qwen3-1.7b.yml",
98+
{
99+
"compile_topology": "v6e-4",
100+
"base_num_decoder_layers": 4,
101+
"per_device_batch_size": 1,
102+
"max_target_length": 128,
103+
},
104+
),
105+
],
106+
)
107+
def test_hlo_diff(self, test_id, config_file, overrides):
108+
"""Test HLO diff for parameterized configurations."""
109+
local_landing_dir = os.path.join(os.path.dirname(__file__), f"hlo_diff_dump_{test_id}")
110+
111+
# Clean up before run
112+
if os.path.exists(local_landing_dir):
113+
shutil.rmtree(local_landing_dir)
114+
os.makedirs(local_landing_dir, exist_ok=True)
115+
116+
try:
117+
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
118+
config_path = os.path.join(base_dir, config_file)
119+
120+
# Arguments for train_compile
121+
test_args = [
122+
None,
123+
config_path,
124+
"steps=1",
125+
"dataset_type=synthetic",
126+
"override_model_config=true",
127+
"compile_topology_num_slices=1",
128+
]
129+
130+
for key, value in overrides.items():
131+
test_args.append(f"{key}={value}")
132+
133+
test_args.append(
134+
f'compile_xla_flags="--xla_dump_to={local_landing_dir} '
135+
f'--xla_dump_hlo_as_text=True --xla_dump_hlo_module_re=jit_train_step"'
136+
)
137+
138+
print(f"Running train_compile for {test_id} with args:", test_args)
139+
140+
# Initialize flags if not already done to avoid ParseCommandLine error
141+
if not flags.FLAGS.is_parsed():
142+
flags.FLAGS(["dummy_program", "--grain_train_files=dummy"], known_only=True)
143+
144+
train_compile.main(tuple(test_args))
145+
146+
print(f"Files in landing dir for {test_id}:", os.listdir(local_landing_dir))
147+
148+
# Locate the dumped HLO file
149+
files = os.listdir(local_landing_dir)
150+
matches = [f for f in files if f.endswith(".after_optimizations.txt")]
151+
dumped_hlo = matches[0] if matches else None
152+
153+
assert dumped_hlo, f"Dumped HLO file not found for {test_id}!"
154+
155+
dumped_hlo_path = os.path.join(local_landing_dir, dumped_hlo)
156+
157+
reference_hlo_path = os.path.join(os.path.dirname(__file__), f"../utils/reference_hlo_{test_id}.txt")
158+
159+
if not os.path.exists(reference_hlo_path):
160+
print(f"Reference file not found. Creating it at {reference_hlo_path}")
161+
with (
162+
open(dumped_hlo_path, "r", encoding="utf-8") as f_in,
163+
open(reference_hlo_path, "w", encoding="utf-8") as f_out,
164+
):
165+
count = 0
166+
for line in f_in:
167+
line = filter_line(line)
168+
if line is None:
169+
continue
170+
f_out.write(line)
171+
count += 1
172+
if count >= MAX_LINES:
173+
break
174+
print(f"Reference file created for {test_id}. Please commit it.")
175+
else:
176+
assert self.check_files_equal_ignoring_sections(dumped_hlo_path, reference_hlo_path), (
177+
f"HLO deviation detected in {test_id}! If this is intended, please run the 'Update HLO Reference' workflow "
178+
f"from Github Actions against your branch to update the reference HLO."
179+
)
180+
print(f"HLO comparison successful against {reference_hlo_path}")
181+
182+
finally:
183+
if os.path.exists(local_landing_dir):
184+
shutil.rmtree(local_landing_dir)
185+
186+
187+
if __name__ == "__main__":
188+
unittest.main()

0 commit comments

Comments
 (0)