Skip to content

Commit 803e47d

Browse files
authored
Generic numeric debugging (#19317) (#19317) (#19317)
Summary: Pull Request resolved: #19317 Pulled By: metascroy metascroy Differential Revision: D103956056
1 parent 09a7cbe commit 803e47d

16 files changed

Lines changed: 2864 additions & 0 deletions
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "spec",
7+
srcs = ["_spec.py"],
8+
)
9+
10+
runtime.python_library(
11+
name = "custom_ops_lib",
12+
srcs = ["custom_ops_lib.py"],
13+
deps = [
14+
"//caffe2:torch",
15+
],
16+
)
17+
18+
runtime.python_library(
19+
name = "selectors",
20+
srcs = ["_selectors.py"],
21+
deps = [
22+
"//caffe2:torch",
23+
],
24+
)
25+
26+
runtime.python_library(
27+
name = "reducers",
28+
srcs = ["_reducers.py"],
29+
deps = [
30+
"//caffe2:torch",
31+
"//executorch/exir/dialects:lib",
32+
],
33+
)
34+
35+
runtime.python_library(
36+
name = "tap_pass",
37+
srcs = ["_tap_pass.py"],
38+
deps = [
39+
"//caffe2:torch",
40+
":custom_ops_lib",
41+
":reducers",
42+
":selectors",
43+
":spec",
44+
],
45+
)
46+
47+
runtime.python_library(
48+
name = "strip_pass",
49+
srcs = ["_strip_pass.py"],
50+
deps = [
51+
"//caffe2:torch",
52+
":reducers",
53+
":tap_pass",
54+
],
55+
)
56+
57+
runtime.python_library(
58+
name = "convenience",
59+
srcs = ["_convenience.py"],
60+
deps = [
61+
"fbsource//third-party/pypi/pandas:pandas",
62+
"//caffe2:torch",
63+
"//executorch/exir:lib",
64+
"//executorch/runtime:runtime",
65+
":reducers",
66+
":selectors",
67+
":spec",
68+
":strip_pass",
69+
":tap_pass",
70+
],
71+
)
72+
73+
runtime.python_library(
74+
name = "lib",
75+
srcs = ["__init__.py"],
76+
deps = [
77+
":convenience",
78+
":custom_ops_lib",
79+
":reducers",
80+
":selectors",
81+
":spec",
82+
":strip_pass",
83+
":tap_pass",
84+
],
85+
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
"""
10+
Public API for the ExecuTorch numerical debugger.
11+
12+
Backend-agnostic intermediate-value tap:
13+
14+
- Runtime side : USER_OUTPUT taps (this module — works through delegates without
15+
any backend-side changes)
16+
17+
Typical usage:
18+
19+
from executorch.devtools.intermediate_output_tap import (
20+
compare_aot_runtime_dataframe,
21+
tap_intermediate_outputs, strip_taps_, STATS,
22+
)
23+
24+
ep = export(model, example_inputs)
25+
ep_tapped, specs = tap_intermediate_outputs(ep, reducer=STATS)
26+
aot_flat, _ = pytree.tree_flatten(ep_tapped.module()(*example_inputs))
27+
edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()])
28+
strip_taps_(edge)
29+
et_program = edge.to_executorch()
30+
31+
rt_flat = runtime.forward(*example_inputs)
32+
df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat)
33+
"""
34+
35+
# Importing this module registers torch.ops.executorch_devtools.tap.Tensor.
36+
from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401
37+
from executorch.devtools.intermediate_output_tap._convenience import (
38+
compare_aot_runtime_dataframe,
39+
tap_compare,
40+
)
41+
from executorch.devtools.intermediate_output_tap._reducers import (
42+
FULL_TENSOR,
43+
get_reducer,
44+
StatReducer,
45+
STATS,
46+
)
47+
from executorch.devtools.intermediate_output_tap._selectors import (
48+
NodeSelector,
49+
select_all,
50+
select_all_call_function,
51+
select_any,
52+
select_by_module_class,
53+
select_by_module_path,
54+
select_by_op_type,
55+
select_not,
56+
)
57+
from executorch.devtools.intermediate_output_tap._spec import TapSpec
58+
from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_
59+
from executorch.devtools.intermediate_output_tap._tap_pass import (
60+
find_tap_nodes,
61+
is_tap_node,
62+
tap_intermediate_outputs_,
63+
TapRule,
64+
)
65+
66+
67+
__all__ = [
68+
# Core API
69+
"tap_intermediate_outputs_",
70+
"strip_taps_",
71+
"TapSpec",
72+
"TapRule",
73+
# Convenience
74+
"tap_compare",
75+
"compare_aot_runtime_dataframe",
76+
# Reducers
77+
"StatReducer",
78+
"FULL_TENSOR",
79+
"STATS",
80+
"get_reducer",
81+
# Selectors
82+
"NodeSelector",
83+
"select_all_call_function",
84+
"select_by_op_type",
85+
"select_by_module_path",
86+
"select_by_module_class",
87+
"select_any",
88+
"select_all",
89+
"select_not",
90+
# Helpers
91+
"find_tap_nodes",
92+
"is_tap_node",
93+
]
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
"""
10+
Convenience helpers built on top of `tap_intermediate_outputs_` / `strip_taps_`:
11+
12+
* `tap_compare`: one-shot helper that exports a model, taps it, lowers with
13+
the user's partitioner, runs through the ExecuTorch runtime, and returns
14+
the AOT-vs-runtime comparison DataFrame plus the tap specs. The simplest
15+
way to use the intermediate-output tap.
16+
* `specs_to_dataframe`: build a per-tap DataFrame from a tap_specs list and
17+
the runtime's flat output tuple.
18+
* `compare_aot_runtime_dataframe`: side-by-side AOT-vs-runtime DataFrame from
19+
the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip
20+
runtime program.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import os
26+
import tempfile
27+
from collections.abc import Sequence
28+
from typing import Any
29+
30+
import pandas as pd
31+
import torch
32+
import torch.utils._pytree as pytree
33+
from executorch.devtools.intermediate_output_tap._spec import TapSpec
34+
from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_
35+
from executorch.devtools.intermediate_output_tap._tap_pass import (
36+
tap_intermediate_outputs_,
37+
TapRule,
38+
)
39+
40+
41+
def tap_compare(
42+
model: torch.nn.Module,
43+
example_inputs: tuple[Any, ...],
44+
partitioner: list | None = None,
45+
*,
46+
rules: Sequence[TapRule] | TapRule | None = None,
47+
error_on_empty: bool = True,
48+
) -> tuple[pd.DataFrame, list[TapSpec]]:
49+
"""
50+
One-shot AOT-vs-runtime numerical-debugging helper.
51+
52+
Runs the full pipeline: export -> tap -> capture AOT reference values
53+
-> lower with `partitioner` -> strip -> to_executorch -> runtime
54+
-> AOT-vs-runtime DataFrame.
55+
56+
Args:
57+
model: Eager nn.Module to debug.
58+
example_inputs: Positional args to the model's forward.
59+
partitioner: Optional list of partitioners passed to
60+
`to_edge_transform_and_lower`. Defaults to `[]` (no delegation).
61+
rules: Same semantics as `tap_intermediate_outputs_` — a sequence of
62+
`(selector, reducer)` pairs (or a single tuple as sugar).
63+
Defaults to `[(select_all_call_function(), STATS)]`.
64+
error_on_empty: Same semantics as `tap_intermediate_outputs_`.
65+
66+
Returns:
67+
A `(df, specs)` tuple where:
68+
- `df`: side-by-side AOT-vs-runtime DataFrame from
69+
`compare_aot_runtime_dataframe`.
70+
- `specs`: list of `TapSpec`s in tap-creation order.
71+
"""
72+
from executorch.exir import to_edge_transform_and_lower
73+
74+
ep = torch.export.export(model, example_inputs, strict=True)
75+
ep_t, specs = tap_intermediate_outputs_(
76+
ep,
77+
rules=rules,
78+
error_on_empty=error_on_empty,
79+
)
80+
81+
# AOT-side reference values: tap.Tensor's eager impl applies the reducer,
82+
# so the flat outputs of the tapped EP already contain reduced values at
83+
# the same positions the runtime will use.
84+
aot_out = ep_t.module()(*example_inputs)
85+
aot_flat, _ = pytree.tree_flatten(aot_out)
86+
87+
edge = to_edge_transform_and_lower(ep_t, partitioner=partitioner or [])
88+
strip_taps_(edge)
89+
et_program = edge.to_executorch()
90+
91+
flat_inputs, _ = pytree.tree_flatten(example_inputs)
92+
rt_flat = list(_run_pte(et_program, flat_inputs))
93+
94+
df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat)
95+
return df, specs
96+
97+
98+
def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]:
99+
from executorch.runtime import Runtime, Verification
100+
101+
with tempfile.TemporaryDirectory() as temp_dir:
102+
pte_path = os.path.join(temp_dir, "model.pte")
103+
et_program.save(pte_path)
104+
rt = Runtime.get()
105+
program = rt.load_program(pte_path, verification=Verification.Minimal)
106+
method = program.load_method("forward")
107+
return method.execute(example_inputs)
108+
109+
110+
def _flat_floats(v: Any) -> list[float]:
111+
"""Flatten a tap value (tensor / list / scalar) to a flat list of floats."""
112+
if isinstance(v, torch.Tensor):
113+
return [
114+
float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist()
115+
]
116+
if isinstance(v, (list, tuple)):
117+
out: list[float] = []
118+
for x in v:
119+
out.extend(_flat_floats(x))
120+
return out
121+
try:
122+
return [float(v)]
123+
except (TypeError, ValueError):
124+
return []
125+
126+
127+
def _sqnr_db(aot_vals: list[float], rt_vals: list[float]) -> float:
128+
"""Signal-to-quantization-noise ratio in dB. Higher is better.
129+
130+
Thin wrapper around `torch.ao.ns.fx.utils.compute_sqnr` (the canonical
131+
implementation already used by `backends/test/harness/error_statistics.py`).
132+
"""
133+
from torch.ao.ns.fx.utils import compute_sqnr
134+
135+
n = min(len(aot_vals), len(rt_vals))
136+
if n == 0:
137+
return float("nan")
138+
aot_t = torch.tensor(aot_vals[:n], dtype=torch.float32)
139+
rt_t = torch.tensor(rt_vals[:n], dtype=torch.float32)
140+
return float(compute_sqnr(rt_t, aot_t))
141+
142+
143+
def compare_aot_runtime_dataframe(
144+
specs: Sequence[TapSpec],
145+
aot_flat: Sequence[Any],
146+
rt_flat: Sequence[Any],
147+
) -> pd.DataFrame:
148+
"""
149+
Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of
150+
the *tapped* ExportedProgram (eager) and the post-strip runtime program.
151+
152+
Both `aot_flat[spec.output_index]` and `rt_flat[spec.output_index]` already
153+
contain the *reduced* tap value, since `tap.Tensor`'s eager impl applies
154+
the named reducer (see `custom_ops_lib.py`).
155+
156+
Output columns per spec:
157+
- For non-FULL_TENSOR reducers: one `aot_<field>` and `rt_<field>` per
158+
reducer field (e.g. `aot_min`, `rt_min`, ...).
159+
- For FULL_TENSOR: `sqnr_db` (signal-to-noise of aot vs rt over the
160+
whole tensor, in dB)
161+
"""
162+
rows: list[dict[str, Any]] = []
163+
for spec in specs:
164+
aot_vals = _flat_floats(aot_flat[spec.output_index])
165+
rt_vals = _flat_floats(rt_flat[spec.output_index])
166+
167+
row: dict[str, Any] = {
168+
"node_name": spec.node_name,
169+
"module_path": spec.module_path,
170+
"module_class": spec.module_class,
171+
"op_target": spec.op_target,
172+
"reducer_name": spec.reducer_name,
173+
"output_index": spec.output_index,
174+
}
175+
176+
if spec.reducer_name == "FULL_TENSOR":
177+
row["sqnr_db"] = _sqnr_db(aot_vals, rt_vals)
178+
row["aot_numel"] = len(aot_vals)
179+
row["rt_numel"] = len(rt_vals)
180+
else:
181+
fields = (
182+
list(spec.fields)
183+
if spec.fields
184+
else [f"v{i}" for i in range(max(len(aot_vals), len(rt_vals)))]
185+
)
186+
for i, f in enumerate(fields):
187+
row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan")
188+
row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan")
189+
190+
rows.append(row)
191+
return pd.DataFrame(rows)

0 commit comments

Comments
 (0)