|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | + |
| 9 | +import torch |
| 10 | +from backends.cuda.passes.replace_int64_floordiv import ( |
| 11 | + ReplaceInt64FloorDivWithFloatPass, |
| 12 | +) |
| 13 | +from executorch.exir import to_edge |
| 14 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 15 | +from torch.export import export |
| 16 | + |
| 17 | + |
| 18 | +_INT_DIV_OPS = ( |
| 19 | + exir_ops.edge.aten.floor_divide.default, |
| 20 | + exir_ops.edge.aten.div.Tensor_mode, |
| 21 | + exir_ops.edge.aten.div.Scalar_mode, |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +def _count_int_floordiv(graph_module) -> int: |
| 26 | + """Count integer floor-division nodes remaining in the graph.""" |
| 27 | + n = 0 |
| 28 | + for node in graph_module.graph.nodes: |
| 29 | + if node.op != "call_function" or node.target not in _INT_DIV_OPS: |
| 30 | + continue |
| 31 | + if node.target in ( |
| 32 | + exir_ops.edge.aten.div.Tensor_mode, |
| 33 | + exir_ops.edge.aten.div.Scalar_mode, |
| 34 | + ): |
| 35 | + rmode = node.kwargs.get("rounding_mode", None) |
| 36 | + if rmode != "floor": |
| 37 | + continue |
| 38 | + val = node.meta.get("val", None) |
| 39 | + if isinstance(val, torch.Tensor) and val.dtype in ( |
| 40 | + torch.int64, |
| 41 | + torch.int32, |
| 42 | + ): |
| 43 | + n += 1 |
| 44 | + return n |
| 45 | + |
| 46 | + |
| 47 | +class TestReplaceInt64FloorDivWithFloatPass(unittest.TestCase): |
| 48 | + """Test the ReplaceInt64FloorDivWithFloatPass transformation pass.""" |
| 49 | + |
| 50 | + def _edge_gm(self, module, inputs): |
| 51 | + ep = to_edge(export(module, inputs, strict=True)) |
| 52 | + return ep, ep.exported_program().graph_module |
| 53 | + |
| 54 | + def test_tensor_tensor_floordiv_rewritten(self): |
| 55 | + """int64 a // b (tensor/tensor), including negative numerators.""" |
| 56 | + |
| 57 | + class M(torch.nn.Module): |
| 58 | + def forward(self, a, b): |
| 59 | + return a // b |
| 60 | + |
| 61 | + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) |
| 62 | + b = torch.tensor([2, 3, 4, 5, 3, 7], dtype=torch.long) |
| 63 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 64 | + |
| 65 | + self.assertGreater(_count_int_floordiv(gm), 0) |
| 66 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 67 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 68 | + |
| 69 | + out = ep.exported_program().module()(a, b) |
| 70 | + self.assertEqual(out.dtype, torch.int64) |
| 71 | + self.assertTrue(torch.equal(out, a // b)) |
| 72 | + |
| 73 | + def test_scalar_divisor_floordiv_rewritten(self): |
| 74 | + """int64 a // 3 (scalar divisor lifted to a 0-d tensor constant).""" |
| 75 | + |
| 76 | + class M(torch.nn.Module): |
| 77 | + def forward(self, a): |
| 78 | + return a // 3 |
| 79 | + |
| 80 | + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) |
| 81 | + ep, gm = self._edge_gm(M().eval(), (a,)) |
| 82 | + |
| 83 | + self.assertGreater(_count_int_floordiv(gm), 0) |
| 84 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 85 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 86 | + |
| 87 | + out = ep.exported_program().module()(a) |
| 88 | + self.assertTrue(torch.equal(out, a // 3)) |
| 89 | + |
| 90 | + def test_div_rounding_mode_floor_rewritten(self): |
| 91 | + """torch.div(..., rounding_mode='floor') on int64 is rewritten.""" |
| 92 | + |
| 93 | + class M(torch.nn.Module): |
| 94 | + def forward(self, a, b): |
| 95 | + return torch.div(a, b, rounding_mode="floor") |
| 96 | + |
| 97 | + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) |
| 98 | + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) |
| 99 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 100 | + |
| 101 | + self.assertGreater(_count_int_floordiv(gm), 0) |
| 102 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 103 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 104 | + |
| 105 | + out = ep.exported_program().module()(a, b) |
| 106 | + self.assertTrue(torch.equal(out, torch.div(a, b, rounding_mode="floor"))) |
| 107 | + |
| 108 | + def test_int32_floordiv_rewritten(self): |
| 109 | + """int32 floor-division is also rewritten and stays int32.""" |
| 110 | + |
| 111 | + class M(torch.nn.Module): |
| 112 | + def forward(self, a, b): |
| 113 | + return a // b |
| 114 | + |
| 115 | + a = torch.tensor([-5, 7, -8, 9], dtype=torch.int32) |
| 116 | + b = torch.tensor([2, 3, 4, 5], dtype=torch.int32) |
| 117 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 118 | + |
| 119 | + self.assertGreater(_count_int_floordiv(gm), 0) |
| 120 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 121 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 122 | + |
| 123 | + out = ep.exported_program().module()(a, b) |
| 124 | + self.assertEqual(out.dtype, torch.int32) |
| 125 | + self.assertTrue(torch.equal(out, a // b)) |
| 126 | + |
| 127 | + def test_float_division_untouched(self): |
| 128 | + """Real float division must not be rewritten.""" |
| 129 | + |
| 130 | + class M(torch.nn.Module): |
| 131 | + def forward(self, a, b): |
| 132 | + return a / b |
| 133 | + |
| 134 | + a = torch.tensor([1.0, 2.0, 3.0]) |
| 135 | + b = torch.tensor([2.0, 3.0, 4.0]) |
| 136 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 137 | + |
| 138 | + before = [n.target for n in gm.graph.nodes if n.op == "call_function"] |
| 139 | + result = ReplaceInt64FloorDivWithFloatPass()(gm) |
| 140 | + self.assertFalse(result.modified) |
| 141 | + after = [n.target for n in gm.graph.nodes if n.op == "call_function"] |
| 142 | + self.assertEqual(before, after) |
| 143 | + |
| 144 | + def test_trunc_rounding_mode_untouched(self): |
| 145 | + """div with rounding_mode='trunc' must not be rewritten.""" |
| 146 | + |
| 147 | + class M(torch.nn.Module): |
| 148 | + def forward(self, a, b): |
| 149 | + return torch.div(a, b, rounding_mode="trunc") |
| 150 | + |
| 151 | + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) |
| 152 | + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) |
| 153 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 154 | + |
| 155 | + result = ReplaceInt64FloorDivWithFloatPass()(gm) |
| 156 | + self.assertFalse(result.modified) |
| 157 | + |
| 158 | + def test_floor_divide_default_branch(self): |
| 159 | + """Exercise the floor_divide.default match/rewrite branch. |
| 160 | +
|
| 161 | + This pin lowers ``//`` to ``div.Tensor_mode``; floor_divide.default does |
| 162 | + not appear naturally, so we synthesize it by retargeting a node. |
| 163 | + """ |
| 164 | + |
| 165 | + class M(torch.nn.Module): |
| 166 | + def forward(self, a, b): |
| 167 | + return a // b |
| 168 | + |
| 169 | + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) |
| 170 | + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) |
| 171 | + ep, gm = self._edge_gm(M().eval(), (a, b)) |
| 172 | + |
| 173 | + # Retarget the div.Tensor_mode node to floor_divide.default. |
| 174 | + for node in list(gm.graph.nodes): |
| 175 | + if node.target == exir_ops.edge.aten.div.Tensor_mode: |
| 176 | + with gm.graph.inserting_before(node): |
| 177 | + new = gm.graph.call_function( |
| 178 | + exir_ops.edge.aten.floor_divide.default, args=node.args |
| 179 | + ) |
| 180 | + new.meta = node.meta.copy() |
| 181 | + node.replace_all_uses_with(new) |
| 182 | + gm.graph.erase_node(node) |
| 183 | + gm.recompile() |
| 184 | + |
| 185 | + self.assertGreater(_count_int_floordiv(gm), 0) |
| 186 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 187 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 188 | + |
| 189 | + out = ep.exported_program().module()(a, b) |
| 190 | + self.assertTrue(torch.equal(out, a // b)) |
| 191 | + |
| 192 | + def test_ring_buffer_mask_analog(self): |
| 193 | + """gemma4_31b sliding-window analog: negative numerators + scalar divisor.""" |
| 194 | + |
| 195 | + class M(torch.nn.Module): |
| 196 | + def forward(self, input_pos): |
| 197 | + buf_size = 8 |
| 198 | + seq_len = input_pos.shape[0] |
| 199 | + total_written = input_pos[0] + seq_len |
| 200 | + j = torch.arange(buf_size, dtype=torch.long) |
| 201 | + wraps = (total_written - 1 - j) // buf_size |
| 202 | + return j + wraps * buf_size |
| 203 | + |
| 204 | + input_pos = torch.arange(3, dtype=torch.long) |
| 205 | + ep, gm = self._edge_gm(M().eval(), (input_pos,)) |
| 206 | + |
| 207 | + ReplaceInt64FloorDivWithFloatPass()(gm) |
| 208 | + self.assertEqual(_count_int_floordiv(gm), 0) |
| 209 | + |
| 210 | + out = ep.exported_program().module()(input_pos) |
| 211 | + ref = M()(input_pos) |
| 212 | + self.assertTrue(torch.equal(out, ref)) |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + unittest.main() |
0 commit comments