Skip to content

Commit 01d21fa

Browse files
authored
Support recording inplace op intermediate output
Differential Revision: D93646471 Pull Request resolved: #17796
1 parent 75f5a76 commit 01d21fa

4 files changed

Lines changed: 266 additions & 5 deletions

File tree

codegen/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ def __call__(
309309
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_{f.func.name}");
310310
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
311311
{ret_prefix}{kernel_call}(context, {args_str});
312-
{event_tracer_output_logging}
313312
{return_assignment}
313+
{event_tracer_output_logging}
314314
{exception_boundary_end}
315315
}}
316316
),

codegen/test/test_executorch_gen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,9 @@ def test_codegen_unboxed_specialized(self) -> None:
516516
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1");
517517
EXECUTORCH_SCOPE_PROF("native_call_op_1");
518518
bool result_ = at::native::default_kernel(context, );
519+
*stack[0] = EValue(result_);
519520
internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
520521
521-
*stack[0] = EValue(result_);
522522
523523
}
524524
),
@@ -615,9 +615,9 @@ def test_codegen_unboxed_default(self) -> None:
615615
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1");
616616
EXECUTORCH_SCOPE_PROF("native_call_op_1");
617617
bool result_ = at::native::default_kernel(context, );
618+
*stack[0] = EValue(result_);
618619
internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
619620
620-
*stack[0] = EValue(result_);
621621
622622
}
623623
),
@@ -642,9 +642,9 @@ def test_codegen_unboxed_default(self) -> None:
642642
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1");
643643
EXECUTORCH_SCOPE_PROF("native_call_op_1");
644644
bool result_ = at::native::default_kernel(context, );
645+
*stack[0] = EValue(result_);
645646
internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
646647
647-
*stack[0] = EValue(result_);
648648
} catch (const std::exception& ex) {
649649
ET_LOG(Error, "Kernel threw an exception: %s", ex.what());
650650
context.fail(torch::executor::Error::Internal);
@@ -686,9 +686,9 @@ def test_codegen_unboxed_default_kernel_key_selected(self) -> None:
686686
internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1");
687687
EXECUTORCH_SCOPE_PROF("native_call_op_1");
688688
bool result_ = at::native::default_kernel(context, );
689+
*stack[0] = EValue(result_);
689690
internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
690691
691-
*stack[0] = EValue(result_);
692692
693693
}
694694
),

devtools/etdump/etdump_flatcc.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,14 @@ Result<long> ETDumpGen::write_tensor_or_return_error(Tensor tensor) {
714714
return static_cast<size_t>(-1);
715715
}
716716

717+
// A tensor with nbytes > 0 but null data pointer indicates a corrupt PTE
718+
// or a bug in the system. This should not happen in normal operation.
719+
ET_CHECK_OR_RETURN_ERROR(
720+
tensor.const_data_ptr() != nullptr,
721+
InvalidState,
722+
"Tensor has nbytes=%zu but null data pointer. This indicates a corrupt program or internal error.",
723+
tensor.nbytes());
724+
717725
if (!data_sink_) {
718726
return Error::InvalidArgument;
719727
}

devtools/inspector/tests/inspector_test.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,30 @@
5858
to_edge,
5959
to_edge_transform_and_lower,
6060
)
61+
from executorch.exir.capture._config import ExecutorchBackendConfig
6162
from executorch.extension.pybindings.portable_lib import (
6263
_load_for_executorch_from_buffer,
6364
)
6465
from torch.export import export, ExportedProgram
6566

6667

68+
# Models for testing inplace ops intermediate output logging
69+
class IndexPutModel(torch.nn.Module):
70+
"""
71+
A model that uses index_put to update a tensor at specific indices.
72+
When the reinplace_pass is enabled, this will be converted to index_put_
73+
(the inplace variant), which was causing issues with event tracer logging.
74+
"""
75+
76+
def __init__(self):
77+
super().__init__()
78+
self.register_buffer("data", torch.zeros(5, 3))
79+
80+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
81+
result = self.data.index_put((indices,), values)
82+
return result.sum()
83+
84+
6785
OP_TYPE = "aten::add"
6886
EVENT_BLOCK_NAME = "block_0"
6987
EVENTS_SIZE = 10
@@ -1993,3 +2011,238 @@ def _gen_random_events(self) -> List[Event]:
19932011
)
19942012
)
19952013
return events
2014+
2015+
2016+
class TestInplaceOpsIntermediateOutput(unittest.TestCase):
2017+
"""
2018+
Test suite for verifying that inplace operators correctly log intermediate
2019+
outputs when the event tracer is enabled.
2020+
2021+
This validates the fix for an issue where inplace ops converted by the
2022+
reinplace_pass could cause logging errors because the output tensor's data
2023+
pointer was null at the time of logging.
2024+
2025+
Note: The reinplace_pass currently only supports converting index_put to
2026+
index_put_ (see executorch/exir/passes/reinplace.py).
2027+
"""
2028+
2029+
def _run_model_and_get_inspector(
2030+
self,
2031+
model: torch.nn.Module,
2032+
example_inputs: tuple,
2033+
run_reinplace_pass: bool = True,
2034+
) -> Inspector:
2035+
"""
2036+
Helper method to export a model, run it with event tracing, and return
2037+
an Inspector instance for verifying intermediate outputs.
2038+
"""
2039+
model.eval()
2040+
2041+
with tempfile.TemporaryDirectory() as tmp_dir:
2042+
model_path = os.path.join(tmp_dir, "model.pte")
2043+
etrecord_path = os.path.join(tmp_dir, "etrecord.bin")
2044+
etdump_path = os.path.join(tmp_dir, "etdump.etdp")
2045+
debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin")
2046+
2047+
# Step 1: Export the model
2048+
exported_program = export(model, example_inputs)
2049+
self.assertIsNotNone(exported_program)
2050+
2051+
# Step 2: Convert to edge dialect
2052+
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
2053+
edge_program = to_edge(exported_program, compile_config=edge_compile_config)
2054+
self.assertIsNotNone(edge_program)
2055+
2056+
# Keep a copy for etrecord
2057+
edge_program_copy = to_edge(
2058+
export(model, example_inputs), compile_config=edge_compile_config
2059+
)
2060+
2061+
# Step 3: Convert to executorch with reinplace_pass enabled
2062+
executorch_config = ExecutorchBackendConfig(
2063+
run_reinplace_pass=run_reinplace_pass
2064+
)
2065+
executorch_program = edge_program.to_executorch(config=executorch_config)
2066+
self.assertIsNotNone(executorch_program)
2067+
2068+
# Step 4: Generate ETRecord
2069+
generate_etrecord(
2070+
etrecord_path,
2071+
edge_program_copy,
2072+
executorch_program,
2073+
)
2074+
2075+
# Step 5: Save the PTE file
2076+
with open(model_path, "wb") as f:
2077+
executorch_program.write_to_file(f)
2078+
2079+
# Step 6: Load and run with event tracing enabled
2080+
with open(model_path, "rb") as f:
2081+
pte_buffer = f.read()
2082+
2083+
executorch_module = _load_for_executorch_from_buffer(
2084+
pte_buffer,
2085+
enable_etdump=True,
2086+
debug_buffer_size=1024 * 1024, # 1MB for testing
2087+
)
2088+
self.assertIsNotNone(executorch_module)
2089+
2090+
# Run the model
2091+
import torch.utils._pytree as pytree
2092+
2093+
flattened_inputs = pytree.tree_flatten(example_inputs)[0]
2094+
executorch_module.run_method("forward", tuple(flattened_inputs))
2095+
2096+
# Write ETDump results
2097+
executorch_module.write_etdump_result_to_file(
2098+
etdump_path, debug_buffer_path
2099+
)
2100+
2101+
# Check if event tracer captured data
2102+
if not os.path.exists(etdump_path):
2103+
self.skipTest(
2104+
"Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true"
2105+
)
2106+
2107+
# Step 7: Create Inspector and return
2108+
inspector = Inspector(
2109+
etdump_path=etdump_path,
2110+
etrecord=etrecord_path,
2111+
debug_buffer_path=debug_buffer_path,
2112+
)
2113+
return inspector
2114+
2115+
def test_index_put_without_reinplace_pass(self):
2116+
"""
2117+
Test that the model works correctly without the reinplace pass as a
2118+
baseline comparison, and verify intermediate output correctness.
2119+
"""
2120+
model = IndexPutModel()
2121+
indices = torch.tensor([0, 2, 4])
2122+
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
2123+
example_inputs = (indices, values)
2124+
2125+
# Compute expected intermediate output of index_put
2126+
# index_put on zeros(5,3) with indices [0,2,4] and values [[1,2,3],[4,5,6],[7,8,9]]
2127+
# Result should be:
2128+
# [[1, 2, 3],
2129+
# [0, 0, 0],
2130+
# [4, 5, 6],
2131+
# [0, 0, 0],
2132+
# [7, 8, 9]]
2133+
expected_index_put_output = torch.zeros(5, 3)
2134+
expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0])
2135+
expected_index_put_output[2] = torch.tensor([4.0, 5.0, 6.0])
2136+
expected_index_put_output[4] = torch.tensor([7.0, 8.0, 9.0])
2137+
2138+
inspector = self._run_model_and_get_inspector(
2139+
model, example_inputs, run_reinplace_pass=False
2140+
)
2141+
2142+
self.assertIsNotNone(inspector)
2143+
self.assertGreater(len(inspector.event_blocks), 0)
2144+
2145+
# Verify intermediate output correctness (same validation as with reinplace)
2146+
found_index_put_output = False
2147+
for event_block in inspector.event_blocks:
2148+
for event in event_block.events:
2149+
if hasattr(event, "debug_data") and event.debug_data is not None:
2150+
for debug_entry in event.debug_data:
2151+
if isinstance(debug_entry, torch.Tensor):
2152+
# Verify tensor has valid data pointer
2153+
self.assertIsNotNone(
2154+
debug_entry.data_ptr(),
2155+
"Intermediate output tensor should have valid data pointer",
2156+
)
2157+
self.assertNotEqual(
2158+
debug_entry.data_ptr(),
2159+
0,
2160+
"Intermediate output tensor data pointer should not be null",
2161+
)
2162+
2163+
# Check if this matches our expected index_put output shape
2164+
if debug_entry.shape == expected_index_put_output.shape:
2165+
if torch.allclose(
2166+
debug_entry, expected_index_put_output, atol=1e-5
2167+
):
2168+
found_index_put_output = True
2169+
2170+
self.assertTrue(
2171+
found_index_put_output,
2172+
"Expected to find index_put intermediate output with correct tensor data (without reinplace pass).",
2173+
)
2174+
2175+
def test_index_put_intermediate_output_data_correctness(self):
2176+
"""
2177+
Test that the intermediate output values captured by the event tracer
2178+
are valid tensors with correct data.
2179+
2180+
This specifically validates that:
2181+
1. The output tensor has a valid (non-null) data pointer
2182+
2. The output tensor contains the correct values after index_put_
2183+
"""
2184+
model = IndexPutModel()
2185+
# Use simple values to verify correctness
2186+
indices = torch.tensor([0, 1])
2187+
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
2188+
example_inputs = (indices, values)
2189+
2190+
# Compute expected intermediate output of index_put
2191+
# index_put on zeros(5,3) with indices [0,1] and values [[1,2,3],[4,5,6]]
2192+
# Result should be:
2193+
# [[1, 2, 3],
2194+
# [4, 5, 6],
2195+
# [0, 0, 0],
2196+
# [0, 0, 0],
2197+
# [0, 0, 0]]
2198+
expected_index_put_output = torch.zeros(5, 3)
2199+
expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0])
2200+
expected_index_put_output[1] = torch.tensor([4.0, 5.0, 6.0])
2201+
2202+
inspector = self._run_model_and_get_inspector(
2203+
model, example_inputs, run_reinplace_pass=True
2204+
)
2205+
2206+
self.assertIsNotNone(inspector)
2207+
self.assertGreater(len(inspector.event_blocks), 0)
2208+
2209+
total_events = sum(len(eb.events) for eb in inspector.event_blocks)
2210+
self.assertGreater(
2211+
total_events, 0, "Expected at least one event to be captured"
2212+
)
2213+
2214+
# Find and verify the index_put_ output
2215+
found_index_put_output = False
2216+
for event_block in inspector.event_blocks:
2217+
for event in event_block.events:
2218+
# Check if this event has debug_data (intermediate outputs)
2219+
if hasattr(event, "debug_data") and event.debug_data is not None:
2220+
for debug_entry in event.debug_data:
2221+
if isinstance(debug_entry, torch.Tensor):
2222+
# Verify tensor has valid data pointer
2223+
self.assertIsNotNone(
2224+
debug_entry.data_ptr(),
2225+
"Intermediate output tensor should have valid data pointer",
2226+
)
2227+
self.assertNotEqual(
2228+
debug_entry.data_ptr(),
2229+
0,
2230+
"Intermediate output tensor data pointer should not be null",
2231+
)
2232+
2233+
# Check if this matches our expected index_put output shape
2234+
if debug_entry.shape == expected_index_put_output.shape:
2235+
# Verify the data is correct
2236+
if torch.allclose(
2237+
debug_entry, expected_index_put_output, atol=1e-5
2238+
):
2239+
found_index_put_output = True
2240+
2241+
# Assert that we found the expected index_put output with correct data
2242+
# This validates that the intermediate output was properly logged
2243+
# and contains the correct tensor values
2244+
self.assertTrue(
2245+
found_index_put_output,
2246+
"Expected to find index_put intermediate output with correct tensor data. "
2247+
"The output tensor should match the expected result of index_put operation.",
2248+
)

0 commit comments

Comments
 (0)