Skip to content

Commit 552a1d3

Browse files
committed
Add compiler checks validations and automatic artifact PR auto extraction updates pipelines integrations execution rules
1 parent 586e692 commit 552a1d3

7 files changed

Lines changed: 2286 additions & 7 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ jobs:
199199
base_image: maxtext-unit-test-tpu:py312
200200
cloud_runner: linux-x86-ct6e-180-4tpu
201201
pytest_marker: 'not cpu_only and not gpu_only and not integration_test and not post_training'
202+
pytest_addopts: '--ignore=tests/integration/hlo_diff_test.py'
202203
xla_python_client_mem_fraction: 0.75
203204
tf_force_gpu_allow_growth: false
204205
container_resource_option: "--privileged"

.github/workflows/run_tests_against_package.yml

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ on:
7070
description: 'If false, maxtext_sha must be provided for checkout'
7171
type: boolean
7272
default: false
73+
is_update_hlo:
74+
required: false
75+
type: boolean
76+
default: false
7377

7478
permissions:
7579
contents: read
@@ -167,13 +171,19 @@ jobs:
167171
else
168172
SPLIT_ARGS=""
169173
fi
170-
$PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \
171-
-v \
172-
-m "${FINAL_PYTEST_MARKER}" \
173-
--durations=0 \
174-
$PYTEST_COV_ARGS \
175-
$SPLIT_ARGS \
176-
${INPUTS_PYTEST_EXTRA_ARGS}
174+
175+
# Setup substitution: If manually updating HLO, skip tests execution and run only the update script instead!
176+
if [ "${INPUTS_IS_UPDATE_HLO}" == "true" ]; then
177+
python3 tests/utils/update_hlo_references.py
178+
else
179+
$PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \
180+
-v \
181+
-m "${FINAL_PYTEST_MARKER}" \
182+
--durations=0 \
183+
$PYTEST_COV_ARGS \
184+
$SPLIT_ARGS \
185+
${INPUTS_PYTEST_EXTRA_ARGS}
186+
fi
177187
178188
env:
179189
PYTHONPATH: "${{ github.workspace }}/src"
@@ -185,6 +195,14 @@ jobs:
185195
INPUTS_WORKER_GROUP: ${{ inputs.worker_group }}
186196
INPUTS_PYTEST_EXTRA_ARGS: ${{ inputs.pytest_extra_args }}
187197
INPUTS_MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }}
198+
INPUTS_IS_UPDATE_HLO: ${{ inputs.is_update_hlo }}
199+
- name: Upload Reference HLO
200+
if: ${{ inputs.is_update_hlo }}
201+
uses: actions/upload-artifact@v4
202+
with:
203+
name: reference-hlo
204+
path: tests/utils/reference_hlo_*.txt
205+
if-no-files-found: ignore
188206
- name: Upload results to Codecov
189207
if: ${{ !inputs.maxtext_installed }} # Skip code coverage upload for maxtext image testing
190208
uses: codecov/codecov-action@v5

.github/workflows/run_tests_coordinator.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ on:
4747
description: 'If false, maxtext_sha must be provided for checkout'
4848
type: boolean
4949
default: false
50+
is_update_hlo:
51+
required: false
52+
type: boolean
53+
default: false
5054

5155
permissions:
5256
contents: read
@@ -150,3 +154,4 @@ jobs:
150154
worker_group: ${{ matrix.worker_group }}
151155
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
152156
maxtext_sha: ${{ inputs.maxtext_sha }}
157+
is_update_hlo: ${{ inputs.is_update_hlo }}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: "Update HLO Reference (for hlo_diff_test.py)"
2+
3+
on:
4+
workflow_dispatch:
5+
permissions:
6+
contents: read
7+
8+
jobs:
9+
build-wheel:
10+
uses: ./.github/workflows/build_package.yml
11+
with:
12+
device_type: tpu
13+
device_name: v6e-4
14+
cloud_runner: linux-x86-n2-16-buildkit
15+
16+
run-tests:
17+
needs: build-wheel
18+
uses: ./.github/workflows/run_tests_coordinator.yml
19+
with:
20+
flavor: tpu-unit
21+
base_image: maxtext-unit-test-tpu:py312
22+
is_scheduled_run: false
23+
maxtext_sha: ${{ github.sha }}
24+
is_update_hlo: true
25+
26+
commit-changes:
27+
needs: run-tests # Wait for tests to finish
28+
runs-on: ubuntu-latest
29+
permissions:
30+
contents: write
31+
steps:
32+
- name: Checkout code
33+
uses: actions/checkout@v4
34+
with:
35+
ref: ${{ github.ref }}
36+
37+
- name: Download Reference HLO
38+
uses: actions/download-artifact@v4
39+
with:
40+
name: reference-hlo
41+
path: tests/utils/
42+
43+
- name: Commit and Push changes
44+
run: |
45+
git config --global user.name "github-actions[bot]"
46+
git config --global user.email "github-actions[bot]@users.noreply.github.com"
47+
git add tests/utils/reference_hlo_*.txt
48+
git commit -m "Update reference HLO from CI artifact"
49+
git push

tests/integration/hlo_diff_test.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Tests for HLO Graph Diff Verification.
2+
3+
Integrates validation checks system that detects unintended compiler graph transformations
4+
or model graph deviations without breaking isolated PR constraints validations.
5+
6+
This is part 1 of the automated compiler validation pipeline:
7+
- Ensures the PR isn't changing compiler performance and model graph generation unintentionally.
8+
- Checks current runtime dumps against base reference checkpoints (in tests/utils/reference_hlo_*.txt).
9+
"""
10+
11+
import os
12+
import unittest
13+
import shutil
14+
import jax
15+
import pytest
16+
from maxtext.trainers.pre_train import train_compile
17+
from absl import flags
18+
import difflib
19+
import re
20+
21+
# Maximum number of filtered lines from HLO to compare or store in the reference file
22+
MAX_LINES = 2000
23+
24+
# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata"
25+
SECTION_CONTENT_REGEX = re.compile(r'^\d+ (?:"[^"]*"|\{[^\}]*\})$')
26+
27+
28+
# Filters and normalizes an HLO line to ignore environment-specific details.
29+
def filter_line(line):
30+
# Matches operation sections content like: 1234 {"sharding parameters", ...} or 1234 "operation metadata"
31+
if SECTION_CONTENT_REGEX.match(line):
32+
return None
33+
return line
34+
35+
36+
@pytest.mark.tpu_backend
37+
class TestHloDiff:
38+
"""Tests for HLO Graph Diff Verification."""
39+
40+
def setup_method(self):
41+
# Disable cache to ensure compilation occurs every time
42+
jax.config.update("jax_enable_compilation_cache", False)
43+
44+
def check_files_equal_ignoring_sections(self, file_path1, file_path2):
45+
"""Asserts that two text files are identical, ignoring specific sections."""
46+
47+
def get_filtered_lines(file_path):
48+
with open(file_path, "r", encoding="utf-8") as f:
49+
lines = []
50+
for line in f:
51+
line = filter_line(line)
52+
if line is None:
53+
continue
54+
lines.append(line)
55+
return lines
56+
57+
lines1 = get_filtered_lines(file_path1)[:MAX_LINES]
58+
lines2 = get_filtered_lines(file_path2)[:MAX_LINES]
59+
60+
if lines1 == lines2:
61+
return True
62+
63+
print("\n" + "=" * 20 + " HLO Diff " + "=" * 20)
64+
for line in difflib.unified_diff(lines1, lines2, fromfile=file_path1, tofile=file_path2):
65+
print(line, end="")
66+
print("=" * 50 + "\n")
67+
return False
68+
69+
@pytest.mark.parametrize("test_id, config_file, overrides, reference_suffix", [
70+
(
71+
"deepseek3",
72+
"src/maxtext/configs/models/deepseek3-test.yml",
73+
{
74+
"compile_topology": "v6e-4",
75+
"base_num_decoder_layers": 4,
76+
"per_device_batch_size": 1,
77+
"max_target_length": 128,
78+
},
79+
"deepseek3",
80+
),
81+
])
82+
def test_hlo_diff(self, test_id, config_file, overrides, reference_suffix):
83+
"""Test HLO diff for parameterized configurations."""
84+
local_landing_dir = os.path.join(os.path.dirname(__file__), f"hlo_diff_dump_{test_id}")
85+
86+
# Clean up before run
87+
if os.path.exists(local_landing_dir):
88+
shutil.rmtree(local_landing_dir)
89+
os.makedirs(local_landing_dir, exist_ok=True)
90+
91+
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
92+
config_path = os.path.join(base_dir, config_file)
93+
94+
# Arguments for train_compile
95+
test_args = [
96+
None,
97+
config_path,
98+
"steps=1",
99+
"dataset_type=synthetic",
100+
"override_model_config=true",
101+
"compile_topology_num_slices=1",
102+
]
103+
104+
for key, value in overrides.items():
105+
test_args.append(f"{key}={value}")
106+
107+
test_args.append(
108+
f'compile_xla_flags="--xla_dump_to={local_landing_dir} '
109+
f'--xla_dump_hlo_as_text=True --xla_dump_hlo_module_re=jit_train_step"'
110+
)
111+
112+
print(f"Running train_compile for {test_id} with args:", test_args)
113+
114+
# Initialize flags if not already done to avoid ParseCommandLine error
115+
if not flags.FLAGS.is_parsed():
116+
flags.FLAGS(["dummy_program", "--grain_train_files=dummy"], known_only=True)
117+
118+
train_compile.main(tuple(test_args))
119+
120+
print(f"Files in landing dir for {test_id}:", os.listdir(local_landing_dir))
121+
122+
# Locate the dumped HLO file
123+
files = os.listdir(local_landing_dir)
124+
matches = [f for f in files if f.endswith(".after_optimizations.txt")]
125+
dumped_hlo = matches[0] if matches else None
126+
127+
assert dumped_hlo, f"Dumped HLO file not found for {test_id}!"
128+
129+
dumped_hlo_path = os.path.join(local_landing_dir, dumped_hlo)
130+
131+
reference_hlo_path = os.path.join(
132+
os.path.dirname(__file__), f"../utils/reference_hlo_{reference_suffix}.txt"
133+
)
134+
135+
if not os.path.exists(reference_hlo_path):
136+
print(f"Reference file not found. Creating it at {reference_hlo_path}")
137+
with open(dumped_hlo_path, "r", encoding="utf-8") as f_in, open(reference_hlo_path, "w", encoding="utf-8") as f_out:
138+
count = 0
139+
for line in f_in:
140+
line = filter_line(line)
141+
if line is None:
142+
continue
143+
f_out.write(line)
144+
count += 1
145+
if count >= MAX_LINES:
146+
break
147+
print(f"Reference file created for {test_id}. Please commit it.")
148+
else:
149+
assert self.check_files_equal_ignoring_sections(dumped_hlo_path, reference_hlo_path), (
150+
f"HLO deviation detected in {test_id}! If this is intended, please run the 'Update HLO Reference' workflow "
151+
f"from Github Actions against your branch to update the reference HLO."
152+
)
153+
print(f"HLO comparison successful against {reference_hlo_path}")
154+
155+
# Clean up
156+
shutil.rmtree(local_landing_dir)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main()

0 commit comments

Comments
 (0)