Skip to content

Commit ac3003e

Browse files
authored
[cuda backend] replace floor_div with float_div (pytorch#20000)
After pin bump to pytorch 2.12, we noticed that `floor_div` with tensor as divisor [can not be correctly compiled by AOT Inductor,](pytorch/pytorch#186164) leading to cuda-backend-delegated model output irrevalant with input (e.g. gemma4-31b). To mitigate the issue, this PR replaces `floor_div` with `float_div` to support the models we need.
1 parent ff90ade commit ac3003e

4 files changed

Lines changed: 373 additions & 2 deletions

File tree

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ jobs:
340340
name: "whisper-large-v3-turbo"
341341
quant: "non-quantized"
342342
with:
343-
timeout: 90
343+
timeout: 150
344344
secrets-env: EXECUTORCH_HF_TOKEN
345345
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
346346
gpu-arch-type: cuda

backends/cuda/cuda_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import (
2020
MoveCondPredicateToCpuPass,
2121
)
22+
from executorch.backends.cuda.passes.replace_int64_floordiv import (
23+
ReplaceInt64FloorDivWithFloatPass,
24+
)
2225
from executorch.backends.cuda.triton.replacement_pass import (
2326
ReplaceEdgeOpWithTritonOpPass,
2427
)
@@ -257,7 +260,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
257260
f"Expected 'ON' or 'OFF'."
258261
)
259262
triton_kernel_mode = mode
260-
passes = [MoveCondPredicateToCpuPass()]
263+
passes = [MoveCondPredicateToCpuPass(), ReplaceInt64FloorDivWithFloatPass()]
261264
if triton_kernel_mode == "ON":
262265
passes.append(ReplaceEdgeOpWithTritonOpPass())
263266
return passes
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Graph Transformation Pass for Integer Floor-Division Replacement.
9+
10+
Rewrites integer (int64/int32) floor-division into a float64-domain floor to
11+
work around a torch-2.12 AOTInductor/Inductor CUDA miscompile:
12+
13+
floor_divide(a, b) -> floor(a.to(float64) / b.to(float64)).to(orig_int_dtype)
14+
"""
15+
16+
import logging
17+
18+
import torch
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
from torch.fx import GraphModule, Node
21+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
22+
23+
logger = logging.getLogger(__name__)
24+
25+
# NOTE: Integer dtypes we rewrite. float64 (53-bit mantissa) is for
26+
# |value| < 2**53, which covers models' index ranges but not enough
27+
# for extreme large numbers.
28+
_INT_DTYPES = (torch.int64, torch.int32)
29+
30+
# Edge ops that perform a floor-rounded integer division.
31+
_FLOOR_DIVIDE_OP = exir_ops.edge.aten.floor_divide.default
32+
_DIV_MODE_OPS = (
33+
exir_ops.edge.aten.div.Tensor_mode,
34+
exir_ops.edge.aten.div.Scalar_mode,
35+
)
36+
37+
38+
class ReplaceInt64FloorDivWithFloatPass(PassBase):
39+
# Work around a torch-2.12 AOTInductor/Inductor CUDA miscompile of integer
40+
# (int64) floor-division: fused/broadcast int64 floor_divide is mis-lowered
41+
# (truncation instead of floor; cross-division term bleed under dynamic shapes).
42+
# TODO(gasoonjia): remove this pass once the upstream issue solved.
43+
# Upstream issue: https://github.com/pytorch/pytorch/issues/186164
44+
"""
45+
Pass to rewrite integer floor-division into a float64-domain floor.
46+
47+
Matches ``floor_divide.default`` and the floor-mode ``div.Tensor_mode`` /
48+
``div.Scalar_mode`` overloads on integer operands, and replaces each with
49+
``floor(a.to(float64) / b.to(float64)).to(orig_int_dtype)`` built from edge
50+
dialect ops. Float floor-division and non-integer nodes are left untouched.
51+
"""
52+
53+
def __init__(self):
54+
super().__init__()
55+
self._replacement_count = 0
56+
57+
def call(self, graph_module: GraphModule) -> PassResult:
58+
self._replacement_count = 0
59+
modified = False
60+
61+
for node in graph_module.graph.nodes:
62+
if not self._should_replace_node(node):
63+
continue
64+
try:
65+
self._replace_node(graph_module, node)
66+
modified = True
67+
self._replacement_count += 1
68+
except Exception as e:
69+
logger.warning(f"Failed to rewrite floor-div node {node.name}: {e}")
70+
# Continue with other nodes even if one fails.
71+
72+
if modified:
73+
graph_module.recompile()
74+
75+
logger.info(
76+
f"Rewrote {self._replacement_count} integer floor-division nodes "
77+
f"into float64-domain floor"
78+
)
79+
80+
return PassResult(graph_module, modified)
81+
82+
@staticmethod
83+
def _node_dtype(node: Node):
84+
val = node.meta.get("val", None)
85+
if isinstance(val, torch.Tensor):
86+
return val.dtype
87+
return None
88+
89+
@staticmethod
90+
def _rounding_mode(node: Node):
91+
if "rounding_mode" in node.kwargs:
92+
return node.kwargs["rounding_mode"]
93+
# Trailing positional arg: div(self, other, rounding_mode)
94+
if len(node.args) > 2:
95+
return node.args[2]
96+
return None
97+
98+
def _should_replace_node(self, node: Node) -> bool:
99+
if node.op != "call_function":
100+
return False
101+
102+
if node.target == _FLOOR_DIVIDE_OP:
103+
pass
104+
elif node.target in _DIV_MODE_OPS:
105+
if self._rounding_mode(node) != "floor":
106+
return False
107+
else:
108+
return False
109+
110+
# Only rewrite when the result is an integer tensor. Guard meta access:
111+
# a node may lack meta["val"]; skip conservatively if so.
112+
out_dtype = self._node_dtype(node)
113+
if out_dtype not in _INT_DTYPES:
114+
return False
115+
116+
return True
117+
118+
def _replace_node(self, graph_module: GraphModule, node: Node) -> None:
119+
orig_dtype = self._node_dtype(node)
120+
a = node.args[0]
121+
b = node.args[1]
122+
123+
graph = graph_module.graph
124+
with graph.inserting_before(node):
125+
a_f = graph.call_function(
126+
exir_ops.edge.aten._to_copy.default,
127+
args=(a,),
128+
kwargs={"dtype": torch.float64},
129+
)
130+
if isinstance(b, Node):
131+
b_f = graph.call_function(
132+
exir_ops.edge.aten._to_copy.default,
133+
args=(b,),
134+
kwargs={"dtype": torch.float64},
135+
)
136+
q = graph.call_function(exir_ops.edge.aten.div.Tensor, args=(a_f, b_f))
137+
else:
138+
# Python-scalar divisor: stays bit-exact, no cast needed for b.
139+
q = graph.call_function(
140+
exir_ops.edge.aten.div.Scalar, args=(a_f, float(b))
141+
)
142+
fl = graph.call_function(exir_ops.edge.aten.floor.default, args=(q,))
143+
new_node = graph.call_function(
144+
exir_ops.edge.aten._to_copy.default,
145+
args=(fl,),
146+
kwargs={"dtype": orig_dtype},
147+
)
148+
149+
new_node.meta = node.meta.copy()
150+
151+
node.replace_all_uses_with(new_node)
152+
graph.erase_node(node)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from backends.cuda.passes.replace_int64_floordiv import (
11+
ReplaceInt64FloorDivWithFloatPass,
12+
)
13+
from executorch.exir import to_edge
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from torch.export import export
16+
17+
18+
_INT_DIV_OPS = (
19+
exir_ops.edge.aten.floor_divide.default,
20+
exir_ops.edge.aten.div.Tensor_mode,
21+
exir_ops.edge.aten.div.Scalar_mode,
22+
)
23+
24+
25+
def _count_int_floordiv(graph_module) -> int:
26+
"""Count integer floor-division nodes remaining in the graph."""
27+
n = 0
28+
for node in graph_module.graph.nodes:
29+
if node.op != "call_function" or node.target not in _INT_DIV_OPS:
30+
continue
31+
if node.target in (
32+
exir_ops.edge.aten.div.Tensor_mode,
33+
exir_ops.edge.aten.div.Scalar_mode,
34+
):
35+
rmode = node.kwargs.get("rounding_mode", None)
36+
if rmode != "floor":
37+
continue
38+
val = node.meta.get("val", None)
39+
if isinstance(val, torch.Tensor) and val.dtype in (
40+
torch.int64,
41+
torch.int32,
42+
):
43+
n += 1
44+
return n
45+
46+
47+
class TestReplaceInt64FloorDivWithFloatPass(unittest.TestCase):
48+
"""Test the ReplaceInt64FloorDivWithFloatPass transformation pass."""
49+
50+
def _edge_gm(self, module, inputs):
51+
ep = to_edge(export(module, inputs, strict=True))
52+
return ep, ep.exported_program().graph_module
53+
54+
def test_tensor_tensor_floordiv_rewritten(self):
55+
"""int64 a // b (tensor/tensor), including negative numerators."""
56+
57+
class M(torch.nn.Module):
58+
def forward(self, a, b):
59+
return a // b
60+
61+
a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long)
62+
b = torch.tensor([2, 3, 4, 5, 3, 7], dtype=torch.long)
63+
ep, gm = self._edge_gm(M().eval(), (a, b))
64+
65+
self.assertGreater(_count_int_floordiv(gm), 0)
66+
ReplaceInt64FloorDivWithFloatPass()(gm)
67+
self.assertEqual(_count_int_floordiv(gm), 0)
68+
69+
out = ep.exported_program().module()(a, b)
70+
self.assertEqual(out.dtype, torch.int64)
71+
self.assertTrue(torch.equal(out, a // b))
72+
73+
def test_scalar_divisor_floordiv_rewritten(self):
74+
"""int64 a // 3 (scalar divisor lifted to a 0-d tensor constant)."""
75+
76+
class M(torch.nn.Module):
77+
def forward(self, a):
78+
return a // 3
79+
80+
a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long)
81+
ep, gm = self._edge_gm(M().eval(), (a,))
82+
83+
self.assertGreater(_count_int_floordiv(gm), 0)
84+
ReplaceInt64FloorDivWithFloatPass()(gm)
85+
self.assertEqual(_count_int_floordiv(gm), 0)
86+
87+
out = ep.exported_program().module()(a)
88+
self.assertTrue(torch.equal(out, a // 3))
89+
90+
def test_div_rounding_mode_floor_rewritten(self):
91+
"""torch.div(..., rounding_mode='floor') on int64 is rewritten."""
92+
93+
class M(torch.nn.Module):
94+
def forward(self, a, b):
95+
return torch.div(a, b, rounding_mode="floor")
96+
97+
a = torch.tensor([-5, 7, -8, 9], dtype=torch.long)
98+
b = torch.tensor([2, 3, 4, 5], dtype=torch.long)
99+
ep, gm = self._edge_gm(M().eval(), (a, b))
100+
101+
self.assertGreater(_count_int_floordiv(gm), 0)
102+
ReplaceInt64FloorDivWithFloatPass()(gm)
103+
self.assertEqual(_count_int_floordiv(gm), 0)
104+
105+
out = ep.exported_program().module()(a, b)
106+
self.assertTrue(torch.equal(out, torch.div(a, b, rounding_mode="floor")))
107+
108+
def test_int32_floordiv_rewritten(self):
109+
"""int32 floor-division is also rewritten and stays int32."""
110+
111+
class M(torch.nn.Module):
112+
def forward(self, a, b):
113+
return a // b
114+
115+
a = torch.tensor([-5, 7, -8, 9], dtype=torch.int32)
116+
b = torch.tensor([2, 3, 4, 5], dtype=torch.int32)
117+
ep, gm = self._edge_gm(M().eval(), (a, b))
118+
119+
self.assertGreater(_count_int_floordiv(gm), 0)
120+
ReplaceInt64FloorDivWithFloatPass()(gm)
121+
self.assertEqual(_count_int_floordiv(gm), 0)
122+
123+
out = ep.exported_program().module()(a, b)
124+
self.assertEqual(out.dtype, torch.int32)
125+
self.assertTrue(torch.equal(out, a // b))
126+
127+
def test_float_division_untouched(self):
128+
"""Real float division must not be rewritten."""
129+
130+
class M(torch.nn.Module):
131+
def forward(self, a, b):
132+
return a / b
133+
134+
a = torch.tensor([1.0, 2.0, 3.0])
135+
b = torch.tensor([2.0, 3.0, 4.0])
136+
ep, gm = self._edge_gm(M().eval(), (a, b))
137+
138+
before = [n.target for n in gm.graph.nodes if n.op == "call_function"]
139+
result = ReplaceInt64FloorDivWithFloatPass()(gm)
140+
self.assertFalse(result.modified)
141+
after = [n.target for n in gm.graph.nodes if n.op == "call_function"]
142+
self.assertEqual(before, after)
143+
144+
def test_trunc_rounding_mode_untouched(self):
145+
"""div with rounding_mode='trunc' must not be rewritten."""
146+
147+
class M(torch.nn.Module):
148+
def forward(self, a, b):
149+
return torch.div(a, b, rounding_mode="trunc")
150+
151+
a = torch.tensor([-5, 7, -8, 9], dtype=torch.long)
152+
b = torch.tensor([2, 3, 4, 5], dtype=torch.long)
153+
ep, gm = self._edge_gm(M().eval(), (a, b))
154+
155+
result = ReplaceInt64FloorDivWithFloatPass()(gm)
156+
self.assertFalse(result.modified)
157+
158+
def test_floor_divide_default_branch(self):
159+
"""Exercise the floor_divide.default match/rewrite branch.
160+
161+
This pin lowers ``//`` to ``div.Tensor_mode``; floor_divide.default does
162+
not appear naturally, so we synthesize it by retargeting a node.
163+
"""
164+
165+
class M(torch.nn.Module):
166+
def forward(self, a, b):
167+
return a // b
168+
169+
a = torch.tensor([-5, 7, -8, 9], dtype=torch.long)
170+
b = torch.tensor([2, 3, 4, 5], dtype=torch.long)
171+
ep, gm = self._edge_gm(M().eval(), (a, b))
172+
173+
# Retarget the div.Tensor_mode node to floor_divide.default.
174+
for node in list(gm.graph.nodes):
175+
if node.target == exir_ops.edge.aten.div.Tensor_mode:
176+
with gm.graph.inserting_before(node):
177+
new = gm.graph.call_function(
178+
exir_ops.edge.aten.floor_divide.default, args=node.args
179+
)
180+
new.meta = node.meta.copy()
181+
node.replace_all_uses_with(new)
182+
gm.graph.erase_node(node)
183+
gm.recompile()
184+
185+
self.assertGreater(_count_int_floordiv(gm), 0)
186+
ReplaceInt64FloorDivWithFloatPass()(gm)
187+
self.assertEqual(_count_int_floordiv(gm), 0)
188+
189+
out = ep.exported_program().module()(a, b)
190+
self.assertTrue(torch.equal(out, a // b))
191+
192+
def test_ring_buffer_mask_analog(self):
193+
"""gemma4_31b sliding-window analog: negative numerators + scalar divisor."""
194+
195+
class M(torch.nn.Module):
196+
def forward(self, input_pos):
197+
buf_size = 8
198+
seq_len = input_pos.shape[0]
199+
total_written = input_pos[0] + seq_len
200+
j = torch.arange(buf_size, dtype=torch.long)
201+
wraps = (total_written - 1 - j) // buf_size
202+
return j + wraps * buf_size
203+
204+
input_pos = torch.arange(3, dtype=torch.long)
205+
ep, gm = self._edge_gm(M().eval(), (input_pos,))
206+
207+
ReplaceInt64FloorDivWithFloatPass()(gm)
208+
self.assertEqual(_count_int_floordiv(gm), 0)
209+
210+
out = ep.exported_program().module()(input_pos)
211+
ref = M()(input_pos)
212+
self.assertTrue(torch.equal(out, ref))
213+
214+
215+
if __name__ == "__main__":
216+
unittest.main()

0 commit comments

Comments
 (0)