Skip to content

Commit 8f1b5ee

Browse files
authored
Fix flaky ReplaceTrivialConvWithLinear pass validation tolerance (#18482)
Differential Revision: D98001101 Pull Request resolved: #18482
1 parent 2a68e74 commit 8f1b5ee

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,9 +1242,17 @@ def test_replace_conv1d_with_linear(self) -> None:
12421242
self.assertTrue(result.modified)
12431243
graph_after_passes = result.graph_module
12441244

1245-
# Validate numerical accuracy
1245+
# Conv and linear compute the same dot product but accumulate fp32
1246+
# terms in different order, so non-associativity of floating-point
1247+
# addition produces diffs up to ~1.2e-05. Use rtol=2e-05.
12461248
inputs = [x, weights, bias]
1247-
validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear")
1249+
validate(
1250+
gm_before,
1251+
graph_after_passes,
1252+
inputs,
1253+
"ReplaceTrivialConvWithLinear",
1254+
rtol=2e-5,
1255+
)
12481256

12491257
# Assert that conv1d is trivially converted to linear
12501258
self.assertEqual(
@@ -1278,9 +1286,17 @@ def test_replace_conv2d_with_linear(self) -> None:
12781286
self.assertTrue(result.modified)
12791287
graph_after_passes = result.graph_module
12801288

1281-
# Validate numerical accuracy
1289+
# Conv and linear compute the same dot product but accumulate fp32
1290+
# terms in different order, so non-associativity of floating-point
1291+
# addition produces diffs up to ~1.2e-05. Use rtol=2e-05.
12821292
inputs = [x, weights, bias]
1283-
validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear")
1293+
validate(
1294+
gm_before,
1295+
graph_after_passes,
1296+
inputs,
1297+
"ReplaceTrivialConvWithLinear",
1298+
rtol=2e-5,
1299+
)
12841300

12851301
# Assert that conv2d is trivially converted to linear
12861302
self.assertEqual(

0 commit comments

Comments
 (0)