@@ -1242,9 +1242,17 @@ def test_replace_conv1d_with_linear(self) -> None:
12421242 self .assertTrue (result .modified )
12431243 graph_after_passes = result .graph_module
12441244
1245- # Validate numerical accuracy
1245+ # Conv and linear compute the same dot product but accumulate fp32
1246+ # terms in different order, so non-associativity of floating-point
1247+ # addition produces diffs up to ~1.2e-05. Use rtol=2e-05.
12461248 inputs = [x , weights , bias ]
1247- validate (gm_before , graph_after_passes , inputs , "ReplaceTrivialConvWithLinear" )
1249+ validate (
1250+ gm_before ,
1251+ graph_after_passes ,
1252+ inputs ,
1253+ "ReplaceTrivialConvWithLinear" ,
1254+ rtol = 2e-5 ,
1255+ )
12481256
12491257 # Assert that conv1d is trivially converted to linear
12501258 self .assertEqual (
@@ -1278,9 +1286,17 @@ def test_replace_conv2d_with_linear(self) -> None:
12781286 self .assertTrue (result .modified )
12791287 graph_after_passes = result .graph_module
12801288
1281- # Validate numerical accuracy
1289+ # Conv and linear compute the same dot product but accumulate fp32
1290+ # terms in different order, so non-associativity of floating-point
1291+ # addition produces diffs up to ~1.2e-05. Use rtol=2e-05.
12821292 inputs = [x , weights , bias ]
1283- validate (gm_before , graph_after_passes , inputs , "ReplaceTrivialConvWithLinear" )
1293+ validate (
1294+ gm_before ,
1295+ graph_after_passes ,
1296+ inputs ,
1297+ "ReplaceTrivialConvWithLinear" ,
1298+ rtol = 2e-5 ,
1299+ )
12841300
12851301 # Assert that conv2d is trivially converted to linear
12861302 self .assertEqual (
0 commit comments