Skip to content

Commit 431032b

Browse files
Arm backend: Add quantizer tutorial (pytorch#18490)
Adds a new jupyter notebook tutorial for the new composable quantizer in the Arm backend Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent ddd62c5 commit 431032b

File tree

1 file changed

+316
-0
lines changed

1 file changed

+316
-0
lines changed
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# Copyright 2026 Arm Limited and/or its affiliates.\n",
10+
"#\n",
11+
"# This source code is licensed under the BSD-style license found in the\n",
12+
"# LICENSE file in the root directory of this source tree."
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"# WIP: TOSA/EthosU/VgfQuantizer composable quantizer tutorial\n",
20+
"\n",
21+
"This is an in-depth tutorial of the new `TOSA/EthosU/VgfQuantizer` API. While the `TOSAQuantizer` is used in the example, both the\n",
22+
"`EthosUQuantizer` and `VgfQuantizer` directly inherit from this base class. \n",
23+
"\n",
24+
"Note that the main API and functionality remains largely the same to allow for a drop-in replacement, but the underlying framework is different - as will be explained. **Both the quantizer and this tutorial are currently experimental and may change without prior notice.** Refer to https://github.com/pytorch/executorch/issues/17701 for questions and feedback.\n",
25+
"\n",
26+
"Before you begin:\n",
27+
"1. (In a clean virtual environment with a compatible Python version) Install executorch using `./install_executorch.sh`\n",
28+
"2. Install Arm TOSA dependencies using `examples/arm/setup.sh --disable-ethos-u-deps`\n",
29+
"\n",
30+
"With all commands executed from the base `executorch` folder."
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"metadata": {},
36+
"source": [
37+
"## Setup model and logging"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"import torch\n",
47+
"\n",
48+
"class ToyModel(torch.nn.Module):\n",
49+
" def __init__(self):\n",
50+
" super().__init__()\n",
51+
" self.conv1 = torch.nn.Conv2d(1, 1, 1)\n",
52+
" self.conv2 = torch.nn.Conv2d(1, 1, 1)\n",
53+
"\n",
54+
" def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
55+
" x = self.conv1(x)\n",
56+
" x = torch.relu(x) \n",
57+
" y = self.conv2(y)\n",
58+
" z = x/y\n",
59+
" return z.view((1,))\n",
60+
"\n",
61+
"example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1))\n",
62+
"\n",
63+
"model = ToyModel()"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": null,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"# Set logger to DEBUG for full quantization report\n",
73+
"import logging\n",
74+
"logging.basicConfig()\n",
75+
"logging.getLogger().setLevel(logging.DEBUG)"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# If you have model-explorer installed, you can visualize the exported program with the following code:\n",
85+
"from executorch.devtools.visualization import visualize\n",
86+
"\n",
87+
"exported_program = torch.export.export(model, example_inputs)\n",
88+
"visualize(exported_program)"
89+
]
90+
},
91+
{
92+
"cell_type": "markdown",
93+
"metadata": {},
94+
"source": [
95+
"# Basic quantizer useage\n",
96+
"The experimental API is enabled by setting `use_composable_quantizer=True` when initializing\n",
97+
"the quantizer. The name `composable_quantizer` refers to the new implementation using multiple\n",
98+
"separate quantizers; the user configures quantization by specifying a sequence of quantizers,\n",
99+
"with each annotating a selection of nodes with a particular quantization config.\n",
100+
"\n",
101+
"The node selection and quantization config is set by the user. They can be selected using basic API-calls as demonstrated here,\n",
102+
"or be completely customized as shown in the advanced section. However, the backend has limits on what is supported and may\n",
103+
"reject quantization of nodes with unsupported quantization configs. A few operators additionally require special quantization\n",
104+
"strategies for numerical correctness, which is encoded in the backend specific `TOSAQuantizationSpec`. These special cases will\n",
105+
"be reported in a quantization report.\n",
106+
"\n",
107+
"The quantizer additionally applies it's own filtering of the selected nodes to only quantize what is known to be supported\n",
108+
"in the backend.\n",
109+
"\n",
110+
"Below, the model is quantized by three different quantizers:\n",
111+
"1. The nodes named 'conv2d' and 'relu' are quantized using the a8a8 config.\n",
112+
"2. The remaning conv is targeted with None, leaving it non-quantized.\n",
113+
"3. The remaining nodes are targeted by the global config, which is a16w8.\n",
114+
"\n",
115+
"Note that order of configuration is important, later specified quantizers have precedence (with the exception of global,\n",
116+
"which is always applied last). Switching 1 and 2 would leave both convolutions in floating point. "
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n",
126+
"from executorch.backends.arm.quantizer import (\n",
127+
" TOSAQuantizer,\n",
128+
" get_symmetric_quantization_config,\n",
129+
" get_symmetric_a16w8_quantization_config,\n",
130+
")\n",
131+
"from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n",
132+
"\n",
133+
"# Export the model\n",
134+
"exported_program = torch.export.export(model, example_inputs)\n",
135+
"graph_module = exported_program.module(check_guards=False)\n",
136+
"\n",
137+
"# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
138+
"target = \"TOSA-1.0+INT\"\n",
139+
"compile_spec = TosaCompileSpec(target)\n",
140+
"quantizer = TOSAQuantizer(compile_spec, use_composable_quantizer=True)\n",
141+
"\n",
142+
"a16w8_config = get_symmetric_a16w8_quantization_config()\n",
143+
"fp_config = None\n",
144+
"a8w8_config = get_symmetric_quantization_config()\n",
145+
"\n",
146+
"quantizer.set_global(a16w8_config) # Gloabl config, applied last\n",
147+
"quantizer.set_module_type(torch.nn.Conv2d, fp_config) # Applied second, remaning conv2d set to floating point\n",
148+
"quantizer.set_node_name([\"conv2d\", \"relu\"], a8w8_config) # Applied first, conv+relu quantized using the a8a8 config\n",
149+
"\n",
150+
"# Post training quantization\n",
151+
"quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n",
152+
"quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input\n",
153+
"quantized_graph_module = convert_pt2e(quantized_graph_module)"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"exported_program = torch.export.export(quantized_graph_module, example_inputs)\n",
163+
"visualize(exported_program)"
164+
]
165+
},
166+
{
167+
"cell_type": "markdown",
168+
"metadata": {},
169+
"source": [
170+
"### The quantization report\n",
171+
"\n",
172+
"In the logged quantization report each quantizer has added one header describing targeted nodes, the used quantization config, and the supported operators / operator patterns. \n",
173+
"```\n",
174+
"PatternQuantizer using NodeNameNodeFinder targeting names: conv2d, relu\n",
175+
"Annotating with executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config(is_per_channel=True)\n",
176+
"Supported operators and patterns defined by executorch.backends.arm.quantizer.quantizer_support.TOSA_QUANTIZER_SUPPORT_DICT\n",
177+
"```\n",
178+
"\n",
179+
"it then gives a short overview of how many nodes it has targeted\n",
180+
"```\n",
181+
" Accepted nodes: 2\n",
182+
" Rejected due to previous annotation: 0\n",
183+
" Rejected nodes: 0\n",
184+
"```\n",
185+
"\n",
186+
"and finally a node-by-node report\n",
187+
"```\n",
188+
" NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP\n",
189+
" -- ----------- ---------------------------------------- ---------------------\n",
190+
" ╒ conv2d x: INT8_PER_TENSOR_QSPEC NO_QSPEC\n",
191+
" | _param_constant0: INT8_PER_CHANNEL_QSPEC\n",
192+
" | _param_constant1: DERIVED_QSPEC\n",
193+
" ╘ relu \n",
194+
"```\n",
195+
"\n",
196+
"The brackets here indicates that the conv2d and relu has been recognized as a single pattern\n",
197+
"to be quantized to allow a fusing later in the backend. One quantization config translates to\n",
198+
"many different quantization annotations for different types of tensors; per tensor for\n",
199+
"activations, per channel for weights, and a special quantization spec for the int32 bias. \n",
200+
"\n",
201+
"### Pre-transform for annotation vs. final quantization report\n",
202+
"One important detail is that there are two reports printed, one named PRE-TRANSFORM_FOR_ANNOTATION QUANTIZATION REPORT,\n",
203+
"and one named FINAL QUANTIZATION REPORT. This is related to the fact that some operators has to be decomposed before quantization to ensure\n",
204+
"that all \"sub operators\" gets quantized properly. As an example, the division operator in the first report\n",
205+
"has decomposed into a reciprocal and multiplication operator in the second. Had it not been marked for quantization\n",
206+
"in the first step, it would have remained a single division operator.\n",
207+
"\n",
208+
"**This is important to be aware of when doing mixed quantization since this means that for an operator to be fully quantized,\n",
209+
"both the original operator and the decomposition needs to be targeted.**\n",
210+
"\n",
211+
"### SharedQspecQuantizer\n",
212+
"Last in the report there is always an additional quantizer applied which is not specified by the user, the SharedQspecQuantizer.\n",
213+
"It handles data shuffling operators without numerical behaviour such as copies and reshapes to ensure that they are quantized with the same qspec as\n",
214+
"surrounding nodes, rather than counting on the user to configure them correctly. It shouldn't need much attention as it is not configured, \n",
215+
"but it is good to be aware of when analyzing the quantization behaviour. The targeted operators are defined by `SHARED_QSPEC_OPS_DEFAULT`\n",
216+
"in the quantizer class."
217+
]
218+
},
219+
{
220+
"cell_type": "markdown",
221+
"metadata": {},
222+
"source": [
223+
"# Advanced quantizer useage\n",
224+
"\n",
225+
"The composability of the quantizer has an additional benefit for advanced user in that each component can easily be modified\n",
226+
"or completely be swapped out in cases where special behaviour is needed. Let's see this in action by recreating what happens under\n",
227+
"the hood when the `set_node_name` API is used."
228+
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": null,
233+
"metadata": {},
234+
"outputs": [],
235+
"source": [
236+
"import executorch.backends.cortex_m.quantizer.node_finders as node_finders\n",
237+
"from executorch.backends.arm.quantizer.quantization_config import TOSAQuantizationConfig\n",
238+
"from executorch.backends.arm.quantizer.arm_quantizer_utils import PatternQuantizer\n",
239+
"from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher\n",
240+
"from torchao.quantization.pt2e.quantizer import (\n",
241+
" QuantizationSpec,\n",
242+
")\n",
243+
"from torchao.quantization.pt2e import (\n",
244+
" MinMaxObserver,\n",
245+
")\n",
246+
"\n",
247+
"# Export the model\n",
248+
"exported_program = torch.export.export(model, example_inputs)\n",
249+
"graph_module = exported_program.module(check_guards=False)\n",
250+
"\n",
251+
"# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
252+
"target = \"TOSA-1.0+INT\"\n",
253+
"compile_spec = TosaCompileSpec(target)\n",
254+
"quantizer = TOSAQuantizer(compile_spec, use_composable_quantizer=True)\n",
255+
"\n",
256+
"\n",
257+
"# The first component is the selection of nodes, done through NodeFinders\n",
258+
"# A node finder is a class implementing the NodeFinder interface\n",
259+
"# This is instantiated inside the set_node_name function\n",
260+
"node_finder = node_finders.NodeNameNodeFinder(\"conv2d\")\n",
261+
"\n",
262+
"# The second component is the quantization config, which may be custom\n",
263+
"# This is what is returned by the get_symmetric_quantization_config function\n",
264+
"qspec = QuantizationSpec(torch.int8, MinMaxObserver, quant_min=-128, quant_max=127, qscheme = torch.per_tensor_symmetric)\n",
265+
"quantization_config = TOSAQuantizationConfig(input_activation=qspec,output_activation=qspec, weight=None, bias=None)\n",
266+
"\n",
267+
"# The third component is the pattern matcher which defines support for the backend\n",
268+
"# This would typically be the TOSA_QUANTIZER_SUPPORT_DICT but here a minimal support dict is created for demonstration purposes\n",
269+
"# A pattern_checker is a class implementing the PatternChecker interface, or it can be None if no extra checks are needed\n",
270+
"# This is instantiated by the backend and used by all sub-quantizers\n",
271+
"pattern_checker = None\n",
272+
"SUPPORT_DICT = {(torch.ops.aten.conv2d.default,) : pattern_checker}\n",
273+
"pattern_matcher = PatternMatcher(SUPPORT_DICT, \"MY_SUPPORT_DICT\")\n",
274+
"\n",
275+
"# All components are brought together in the PatterQuantizer and added to the quantizer\n",
276+
"# This is done last in the set_node_name function\n",
277+
"pattern_quantizer = PatternQuantizer(quantization_config, node_finder, pattern_matcher)\n",
278+
"quantizer.add_quantizer(pattern_quantizer)\n",
279+
"\n",
280+
"quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n",
281+
"quantized_graph_module(*example_inputs)\n",
282+
"quantized_graph_module = convert_pt2e(quantized_graph_module)"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"As confirmed by the report, the quantizer has now only targeted one single convolution with a custom quantization config, and this config was then propagated to the relu node by the SharedQpsecQuantizer as expected. The view stays in float since the preceeding division operator is in float.\n",
290+
"\n",
291+
"This useage of the quantizer has less guarantees of producing numerically correct or even functional graphs, but it can be a useful tool for debugging or when an otherwise unsupported behaviour is required."
292+
]
293+
}
294+
],
295+
"metadata": {
296+
"kernelspec": {
297+
"display_name": ".venv (3.10.15)",
298+
"language": "python",
299+
"name": "python3"
300+
},
301+
"language_info": {
302+
"codemirror_mode": {
303+
"name": "ipython",
304+
"version": 3
305+
},
306+
"file_extension": ".py",
307+
"mimetype": "text/x-python",
308+
"name": "python",
309+
"nbconvert_exporter": "python",
310+
"pygments_lexer": "ipython3",
311+
"version": "3.10.15"
312+
}
313+
},
314+
"nbformat": 4,
315+
"nbformat_minor": 4
316+
}

0 commit comments

Comments
 (0)