|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-unsafe |
| 8 | + |
| 9 | +""" |
| 10 | +Convenience helpers built on top of `tap_intermediate_outputs_` / `strip_taps_`: |
| 11 | +
|
| 12 | +* `tap_compare`: one-shot helper that exports a model, taps it, lowers with |
| 13 | + the user's partitioner, runs through the ExecuTorch runtime, and returns |
| 14 | + the AOT-vs-runtime comparison DataFrame plus the tap specs. The simplest |
| 15 | + way to use the intermediate-output tap. |
| 16 | +* `specs_to_dataframe`: build a per-tap DataFrame from a tap_specs list and |
| 17 | + the runtime's flat output tuple. |
| 18 | +* `compare_aot_runtime_dataframe`: side-by-side AOT-vs-runtime DataFrame from |
| 19 | + the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip |
| 20 | + runtime program. |
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +import os |
| 26 | +import tempfile |
| 27 | +from collections.abc import Sequence |
| 28 | +from typing import Any |
| 29 | + |
| 30 | +import pandas as pd |
| 31 | +import torch |
| 32 | +import torch.utils._pytree as pytree |
| 33 | +from executorch.devtools.intermediate_output_tap._spec import TapSpec |
| 34 | +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ |
| 35 | +from executorch.devtools.intermediate_output_tap._tap_pass import ( |
| 36 | + tap_intermediate_outputs_, |
| 37 | + TapRule, |
| 38 | +) |
| 39 | + |
| 40 | + |
| 41 | +def tap_compare( |
| 42 | + model: torch.nn.Module, |
| 43 | + example_inputs: tuple[Any, ...], |
| 44 | + partitioner: list | None = None, |
| 45 | + *, |
| 46 | + rules: Sequence[TapRule] | TapRule | None = None, |
| 47 | + error_on_empty: bool = True, |
| 48 | +) -> tuple[pd.DataFrame, list[TapSpec]]: |
| 49 | + """ |
| 50 | + One-shot AOT-vs-runtime numerical-debugging helper. |
| 51 | +
|
| 52 | + Runs the full pipeline: export -> tap -> capture AOT reference values |
| 53 | + -> lower with `partitioner` -> strip -> to_executorch -> runtime |
| 54 | + -> AOT-vs-runtime DataFrame. |
| 55 | +
|
| 56 | + Args: |
| 57 | + model: Eager nn.Module to debug. |
| 58 | + example_inputs: Positional args to the model's forward. |
| 59 | + partitioner: Optional list of partitioners passed to |
| 60 | + `to_edge_transform_and_lower`. Defaults to `[]` (no delegation). |
| 61 | + rules: Same semantics as `tap_intermediate_outputs_` — a sequence of |
| 62 | + `(selector, reducer)` pairs (or a single tuple as sugar). |
| 63 | + Defaults to `[(select_all_call_function(), STATS)]`. |
| 64 | + error_on_empty: Same semantics as `tap_intermediate_outputs_`. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + A `(df, specs)` tuple where: |
| 68 | + - `df`: side-by-side AOT-vs-runtime DataFrame from |
| 69 | + `compare_aot_runtime_dataframe`. |
| 70 | + - `specs`: list of `TapSpec`s in tap-creation order. |
| 71 | + """ |
| 72 | + from executorch.exir import to_edge_transform_and_lower |
| 73 | + |
| 74 | + ep = torch.export.export(model, example_inputs, strict=True) |
| 75 | + ep_t, specs = tap_intermediate_outputs_( |
| 76 | + ep, |
| 77 | + rules=rules, |
| 78 | + error_on_empty=error_on_empty, |
| 79 | + ) |
| 80 | + |
| 81 | + # AOT-side reference values: tap.Tensor's eager impl applies the reducer, |
| 82 | + # so the flat outputs of the tapped EP already contain reduced values at |
| 83 | + # the same positions the runtime will use. |
| 84 | + aot_out = ep_t.module()(*example_inputs) |
| 85 | + aot_flat, _ = pytree.tree_flatten(aot_out) |
| 86 | + |
| 87 | + edge = to_edge_transform_and_lower(ep_t, partitioner=partitioner or []) |
| 88 | + strip_taps_(edge) |
| 89 | + et_program = edge.to_executorch() |
| 90 | + |
| 91 | + flat_inputs, _ = pytree.tree_flatten(example_inputs) |
| 92 | + rt_flat = list(_run_pte(et_program, flat_inputs)) |
| 93 | + |
| 94 | + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) |
| 95 | + return df, specs |
| 96 | + |
| 97 | + |
| 98 | +def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]: |
| 99 | + from executorch.runtime import Runtime, Verification |
| 100 | + |
| 101 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 102 | + pte_path = os.path.join(temp_dir, "model.pte") |
| 103 | + et_program.save(pte_path) |
| 104 | + rt = Runtime.get() |
| 105 | + program = rt.load_program(pte_path, verification=Verification.Minimal) |
| 106 | + method = program.load_method("forward") |
| 107 | + return method.execute(example_inputs) |
| 108 | + |
| 109 | + |
| 110 | +def _flat_floats(v: Any) -> list[float]: |
| 111 | + """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" |
| 112 | + if isinstance(v, torch.Tensor): |
| 113 | + return [ |
| 114 | + float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist() |
| 115 | + ] |
| 116 | + if isinstance(v, (list, tuple)): |
| 117 | + out: list[float] = [] |
| 118 | + for x in v: |
| 119 | + out.extend(_flat_floats(x)) |
| 120 | + return out |
| 121 | + try: |
| 122 | + return [float(v)] |
| 123 | + except (TypeError, ValueError): |
| 124 | + return [] |
| 125 | + |
| 126 | + |
| 127 | +def _sqnr_db(aot_vals: list[float], rt_vals: list[float]) -> float: |
| 128 | + """Signal-to-quantization-noise ratio in dB. Higher is better. |
| 129 | +
|
| 130 | + Thin wrapper around `torch.ao.ns.fx.utils.compute_sqnr` (the canonical |
| 131 | + implementation already used by `backends/test/harness/error_statistics.py`). |
| 132 | + """ |
| 133 | + from torch.ao.ns.fx.utils import compute_sqnr |
| 134 | + |
| 135 | + n = min(len(aot_vals), len(rt_vals)) |
| 136 | + if n == 0: |
| 137 | + return float("nan") |
| 138 | + aot_t = torch.tensor(aot_vals[:n], dtype=torch.float32) |
| 139 | + rt_t = torch.tensor(rt_vals[:n], dtype=torch.float32) |
| 140 | + return float(compute_sqnr(rt_t, aot_t)) |
| 141 | + |
| 142 | + |
| 143 | +def compare_aot_runtime_dataframe( |
| 144 | + specs: Sequence[TapSpec], |
| 145 | + aot_flat: Sequence[Any], |
| 146 | + rt_flat: Sequence[Any], |
| 147 | +) -> pd.DataFrame: |
| 148 | + """ |
| 149 | + Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of |
| 150 | + the *tapped* ExportedProgram (eager) and the post-strip runtime program. |
| 151 | +
|
| 152 | + Both `aot_flat[spec.output_index]` and `rt_flat[spec.output_index]` already |
| 153 | + contain the *reduced* tap value, since `tap.Tensor`'s eager impl applies |
| 154 | + the named reducer (see `custom_ops_lib.py`). |
| 155 | +
|
| 156 | + Output columns per spec: |
| 157 | + - For non-FULL_TENSOR reducers: one `aot_<field>` and `rt_<field>` per |
| 158 | + reducer field (e.g. `aot_min`, `rt_min`, ...). |
| 159 | + - For FULL_TENSOR: `sqnr_db` (signal-to-noise of aot vs rt over the |
| 160 | + whole tensor, in dB) |
| 161 | + """ |
| 162 | + rows: list[dict[str, Any]] = [] |
| 163 | + for spec in specs: |
| 164 | + aot_vals = _flat_floats(aot_flat[spec.output_index]) |
| 165 | + rt_vals = _flat_floats(rt_flat[spec.output_index]) |
| 166 | + |
| 167 | + row: dict[str, Any] = { |
| 168 | + "node_name": spec.node_name, |
| 169 | + "module_path": spec.module_path, |
| 170 | + "module_class": spec.module_class, |
| 171 | + "op_target": spec.op_target, |
| 172 | + "reducer_name": spec.reducer_name, |
| 173 | + "output_index": spec.output_index, |
| 174 | + } |
| 175 | + |
| 176 | + if spec.reducer_name == "FULL_TENSOR": |
| 177 | + row["sqnr_db"] = _sqnr_db(aot_vals, rt_vals) |
| 178 | + row["aot_numel"] = len(aot_vals) |
| 179 | + row["rt_numel"] = len(rt_vals) |
| 180 | + else: |
| 181 | + fields = ( |
| 182 | + list(spec.fields) |
| 183 | + if spec.fields |
| 184 | + else [f"v{i}" for i in range(max(len(aot_vals), len(rt_vals)))] |
| 185 | + ) |
| 186 | + for i, f in enumerate(fields): |
| 187 | + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") |
| 188 | + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") |
| 189 | + |
| 190 | + rows.append(row) |
| 191 | + return pd.DataFrame(rows) |
0 commit comments