|
58 | 58 | to_edge, |
59 | 59 | to_edge_transform_and_lower, |
60 | 60 | ) |
| 61 | +from executorch.exir.capture._config import ExecutorchBackendConfig |
61 | 62 | from executorch.extension.pybindings.portable_lib import ( |
62 | 63 | _load_for_executorch_from_buffer, |
63 | 64 | ) |
64 | 65 | from torch.export import export, ExportedProgram |
65 | 66 |
|
66 | 67 |
|
| 68 | +# Models for testing inplace ops intermediate output logging |
| 69 | +class IndexPutModel(torch.nn.Module): |
| 70 | + """ |
| 71 | + A model that uses index_put to update a tensor at specific indices. |
| 72 | + When the reinplace_pass is enabled, this will be converted to index_put_ |
| 73 | + (the inplace variant), which was causing issues with event tracer logging. |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__(self): |
| 77 | + super().__init__() |
| 78 | + self.register_buffer("data", torch.zeros(5, 3)) |
| 79 | + |
| 80 | + def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor: |
| 81 | + result = self.data.index_put((indices,), values) |
| 82 | + return result.sum() |
| 83 | + |
| 84 | + |
67 | 85 | OP_TYPE = "aten::add" |
68 | 86 | EVENT_BLOCK_NAME = "block_0" |
69 | 87 | EVENTS_SIZE = 10 |
@@ -1993,3 +2011,238 @@ def _gen_random_events(self) -> List[Event]: |
1993 | 2011 | ) |
1994 | 2012 | ) |
1995 | 2013 | return events |
| 2014 | + |
| 2015 | + |
| 2016 | +class TestInplaceOpsIntermediateOutput(unittest.TestCase): |
| 2017 | + """ |
| 2018 | + Test suite for verifying that inplace operators correctly log intermediate |
| 2019 | + outputs when the event tracer is enabled. |
| 2020 | +
|
| 2021 | + This validates the fix for an issue where inplace ops converted by the |
| 2022 | + reinplace_pass could cause logging errors because the output tensor's data |
| 2023 | + pointer was null at the time of logging. |
| 2024 | +
|
| 2025 | + Note: The reinplace_pass currently only supports converting index_put to |
| 2026 | + index_put_ (see executorch/exir/passes/reinplace.py). |
| 2027 | + """ |
| 2028 | + |
| 2029 | + def _run_model_and_get_inspector( |
| 2030 | + self, |
| 2031 | + model: torch.nn.Module, |
| 2032 | + example_inputs: tuple, |
| 2033 | + run_reinplace_pass: bool = True, |
| 2034 | + ) -> Inspector: |
| 2035 | + """ |
| 2036 | + Helper method to export a model, run it with event tracing, and return |
| 2037 | + an Inspector instance for verifying intermediate outputs. |
| 2038 | + """ |
| 2039 | + model.eval() |
| 2040 | + |
| 2041 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 2042 | + model_path = os.path.join(tmp_dir, "model.pte") |
| 2043 | + etrecord_path = os.path.join(tmp_dir, "etrecord.bin") |
| 2044 | + etdump_path = os.path.join(tmp_dir, "etdump.etdp") |
| 2045 | + debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin") |
| 2046 | + |
| 2047 | + # Step 1: Export the model |
| 2048 | + exported_program = export(model, example_inputs) |
| 2049 | + self.assertIsNotNone(exported_program) |
| 2050 | + |
| 2051 | + # Step 2: Convert to edge dialect |
| 2052 | + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) |
| 2053 | + edge_program = to_edge(exported_program, compile_config=edge_compile_config) |
| 2054 | + self.assertIsNotNone(edge_program) |
| 2055 | + |
| 2056 | + # Keep a copy for etrecord |
| 2057 | + edge_program_copy = to_edge( |
| 2058 | + export(model, example_inputs), compile_config=edge_compile_config |
| 2059 | + ) |
| 2060 | + |
| 2061 | + # Step 3: Convert to executorch with reinplace_pass enabled |
| 2062 | + executorch_config = ExecutorchBackendConfig( |
| 2063 | + run_reinplace_pass=run_reinplace_pass |
| 2064 | + ) |
| 2065 | + executorch_program = edge_program.to_executorch(config=executorch_config) |
| 2066 | + self.assertIsNotNone(executorch_program) |
| 2067 | + |
| 2068 | + # Step 4: Generate ETRecord |
| 2069 | + generate_etrecord( |
| 2070 | + etrecord_path, |
| 2071 | + edge_program_copy, |
| 2072 | + executorch_program, |
| 2073 | + ) |
| 2074 | + |
| 2075 | + # Step 5: Save the PTE file |
| 2076 | + with open(model_path, "wb") as f: |
| 2077 | + executorch_program.write_to_file(f) |
| 2078 | + |
| 2079 | + # Step 6: Load and run with event tracing enabled |
| 2080 | + with open(model_path, "rb") as f: |
| 2081 | + pte_buffer = f.read() |
| 2082 | + |
| 2083 | + executorch_module = _load_for_executorch_from_buffer( |
| 2084 | + pte_buffer, |
| 2085 | + enable_etdump=True, |
| 2086 | + debug_buffer_size=1024 * 1024, # 1MB for testing |
| 2087 | + ) |
| 2088 | + self.assertIsNotNone(executorch_module) |
| 2089 | + |
| 2090 | + # Run the model |
| 2091 | + import torch.utils._pytree as pytree |
| 2092 | + |
| 2093 | + flattened_inputs = pytree.tree_flatten(example_inputs)[0] |
| 2094 | + executorch_module.run_method("forward", tuple(flattened_inputs)) |
| 2095 | + |
| 2096 | + # Write ETDump results |
| 2097 | + executorch_module.write_etdump_result_to_file( |
| 2098 | + etdump_path, debug_buffer_path |
| 2099 | + ) |
| 2100 | + |
| 2101 | + # Check if event tracer captured data |
| 2102 | + if not os.path.exists(etdump_path): |
| 2103 | + self.skipTest( |
| 2104 | + "Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true" |
| 2105 | + ) |
| 2106 | + |
| 2107 | + # Step 7: Create Inspector and return |
| 2108 | + inspector = Inspector( |
| 2109 | + etdump_path=etdump_path, |
| 2110 | + etrecord=etrecord_path, |
| 2111 | + debug_buffer_path=debug_buffer_path, |
| 2112 | + ) |
| 2113 | + return inspector |
| 2114 | + |
| 2115 | + def test_index_put_without_reinplace_pass(self): |
| 2116 | + """ |
| 2117 | + Test that the model works correctly without the reinplace pass as a |
| 2118 | + baseline comparison, and verify intermediate output correctness. |
| 2119 | + """ |
| 2120 | + model = IndexPutModel() |
| 2121 | + indices = torch.tensor([0, 2, 4]) |
| 2122 | + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) |
| 2123 | + example_inputs = (indices, values) |
| 2124 | + |
| 2125 | + # Compute expected intermediate output of index_put |
| 2126 | + # index_put on zeros(5,3) with indices [0,2,4] and values [[1,2,3],[4,5,6],[7,8,9]] |
| 2127 | + # Result should be: |
| 2128 | + # [[1, 2, 3], |
| 2129 | + # [0, 0, 0], |
| 2130 | + # [4, 5, 6], |
| 2131 | + # [0, 0, 0], |
| 2132 | + # [7, 8, 9]] |
| 2133 | + expected_index_put_output = torch.zeros(5, 3) |
| 2134 | + expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0]) |
| 2135 | + expected_index_put_output[2] = torch.tensor([4.0, 5.0, 6.0]) |
| 2136 | + expected_index_put_output[4] = torch.tensor([7.0, 8.0, 9.0]) |
| 2137 | + |
| 2138 | + inspector = self._run_model_and_get_inspector( |
| 2139 | + model, example_inputs, run_reinplace_pass=False |
| 2140 | + ) |
| 2141 | + |
| 2142 | + self.assertIsNotNone(inspector) |
| 2143 | + self.assertGreater(len(inspector.event_blocks), 0) |
| 2144 | + |
| 2145 | + # Verify intermediate output correctness (same validation as with reinplace) |
| 2146 | + found_index_put_output = False |
| 2147 | + for event_block in inspector.event_blocks: |
| 2148 | + for event in event_block.events: |
| 2149 | + if hasattr(event, "debug_data") and event.debug_data is not None: |
| 2150 | + for debug_entry in event.debug_data: |
| 2151 | + if isinstance(debug_entry, torch.Tensor): |
| 2152 | + # Verify tensor has valid data pointer |
| 2153 | + self.assertIsNotNone( |
| 2154 | + debug_entry.data_ptr(), |
| 2155 | + "Intermediate output tensor should have valid data pointer", |
| 2156 | + ) |
| 2157 | + self.assertNotEqual( |
| 2158 | + debug_entry.data_ptr(), |
| 2159 | + 0, |
| 2160 | + "Intermediate output tensor data pointer should not be null", |
| 2161 | + ) |
| 2162 | + |
| 2163 | + # Check if this matches our expected index_put output shape |
| 2164 | + if debug_entry.shape == expected_index_put_output.shape: |
| 2165 | + if torch.allclose( |
| 2166 | + debug_entry, expected_index_put_output, atol=1e-5 |
| 2167 | + ): |
| 2168 | + found_index_put_output = True |
| 2169 | + |
| 2170 | + self.assertTrue( |
| 2171 | + found_index_put_output, |
| 2172 | + "Expected to find index_put intermediate output with correct tensor data (without reinplace pass).", |
| 2173 | + ) |
| 2174 | + |
| 2175 | + def test_index_put_intermediate_output_data_correctness(self): |
| 2176 | + """ |
| 2177 | + Test that the intermediate output values captured by the event tracer |
| 2178 | + are valid tensors with correct data. |
| 2179 | +
|
| 2180 | + This specifically validates that: |
| 2181 | + 1. The output tensor has a valid (non-null) data pointer |
| 2182 | + 2. The output tensor contains the correct values after index_put_ |
| 2183 | + """ |
| 2184 | + model = IndexPutModel() |
| 2185 | + # Use simple values to verify correctness |
| 2186 | + indices = torch.tensor([0, 1]) |
| 2187 | + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |
| 2188 | + example_inputs = (indices, values) |
| 2189 | + |
| 2190 | + # Compute expected intermediate output of index_put |
| 2191 | + # index_put on zeros(5,3) with indices [0,1] and values [[1,2,3],[4,5,6]] |
| 2192 | + # Result should be: |
| 2193 | + # [[1, 2, 3], |
| 2194 | + # [4, 5, 6], |
| 2195 | + # [0, 0, 0], |
| 2196 | + # [0, 0, 0], |
| 2197 | + # [0, 0, 0]] |
| 2198 | + expected_index_put_output = torch.zeros(5, 3) |
| 2199 | + expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0]) |
| 2200 | + expected_index_put_output[1] = torch.tensor([4.0, 5.0, 6.0]) |
| 2201 | + |
| 2202 | + inspector = self._run_model_and_get_inspector( |
| 2203 | + model, example_inputs, run_reinplace_pass=True |
| 2204 | + ) |
| 2205 | + |
| 2206 | + self.assertIsNotNone(inspector) |
| 2207 | + self.assertGreater(len(inspector.event_blocks), 0) |
| 2208 | + |
| 2209 | + total_events = sum(len(eb.events) for eb in inspector.event_blocks) |
| 2210 | + self.assertGreater( |
| 2211 | + total_events, 0, "Expected at least one event to be captured" |
| 2212 | + ) |
| 2213 | + |
| 2214 | + # Find and verify the index_put_ output |
| 2215 | + found_index_put_output = False |
| 2216 | + for event_block in inspector.event_blocks: |
| 2217 | + for event in event_block.events: |
| 2218 | + # Check if this event has debug_data (intermediate outputs) |
| 2219 | + if hasattr(event, "debug_data") and event.debug_data is not None: |
| 2220 | + for debug_entry in event.debug_data: |
| 2221 | + if isinstance(debug_entry, torch.Tensor): |
| 2222 | + # Verify tensor has valid data pointer |
| 2223 | + self.assertIsNotNone( |
| 2224 | + debug_entry.data_ptr(), |
| 2225 | + "Intermediate output tensor should have valid data pointer", |
| 2226 | + ) |
| 2227 | + self.assertNotEqual( |
| 2228 | + debug_entry.data_ptr(), |
| 2229 | + 0, |
| 2230 | + "Intermediate output tensor data pointer should not be null", |
| 2231 | + ) |
| 2232 | + |
| 2233 | + # Check if this matches our expected index_put output shape |
| 2234 | + if debug_entry.shape == expected_index_put_output.shape: |
| 2235 | + # Verify the data is correct |
| 2236 | + if torch.allclose( |
| 2237 | + debug_entry, expected_index_put_output, atol=1e-5 |
| 2238 | + ): |
| 2239 | + found_index_put_output = True |
| 2240 | + |
| 2241 | + # Assert that we found the expected index_put output with correct data |
| 2242 | + # This validates that the intermediate output was properly logged |
| 2243 | + # and contains the correct tensor values |
| 2244 | + self.assertTrue( |
| 2245 | + found_index_put_output, |
| 2246 | + "Expected to find index_put intermediate output with correct tensor data. " |
| 2247 | + "The output tensor should match the expected result of index_put operation.", |
| 2248 | + ) |
0 commit comments