Skip to content

Commit 19ffa55

Browse files
rascaniclaude
andauthored
Cortex-M backend: add quantized_activation op with LUT lowering for sigmoid/tanh/silu (pytorch#19792)
### Summary CMSIS-NN has no s8 activation primitive — the s16 path requantizes around an on-target polynomial, which costs an extra s8 → s16 → activation → s8 trip per call. Instead this lowers standalone aten.sigmoid / aten.tanh / aten.silu to a single cortex_m.quantized_activation(input, lut) op backed by a 256-entry int8 LUT precomputed at AoT from the input/output qparams and the activation function. The kernel is a single byte-indexed lookup loop -- shape-agnostic, activation-agnostic, and free of any runtime requantization. Encoding the activation in the LUT bytes rather than a kind enum keeps the kernel surface to one op. For SiLU specifically, the LUT can encode `x * sigmoid(x)` directly, so the naive sigmoid-plus-elementwise-mul decomposition is unnecessary. aten.silu is added to the to_edge preserve_ops list so it doesn't decompose to sigmoid+mul before the lowering pass sees it; this is set globally because no per-test opt-out exists today. LUT-build numerics deliberately mirror the existing cortex_m CMSIS-NN conventions. Sigmoid/silu use a sign-branched stable form that always exponentiates a non-positive value, so the LUT build can't trip OverflowError for unusually wide input qparams. The final fp → int8 quantize uses round-half-away-from-zero, matching the rounding requantize_cmsis applies after its right-shift in passes_utils. ### Test plan In Silero VAD the final `sigmoid(final_conv(x))` now lowers; the 3 remaining sigmoids and 2 tanhs are LSTMCell gates and stay in aten because PyTorch export captures nn.LSTMCell as a single high-level op -- the quantizer never sees the gates and can't annotate them, and to_edge only decomposes the cell after the quantizer has run. test_lstm_cell.py captures the expected end-state as an xfail that will flip green once a pre-annotation decompose pass lands; that work is tracked as a separate follow-up. Other activations (GELU for KWT, Mish, ELU, Softplus) plug in as a few additional entries in passes_utils._ACTIVATION_FNS plus matching quantizer patterns. The generic op + LUT design carries them with no kernel changes. --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent a6f0cf1 commit 19ffa55

12 files changed

Lines changed: 475 additions & 6 deletions

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ set(_cortex_m_kernels__srcs
8181
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
8282
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_pad.cpp
8383
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
84+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_activation.cpp
8485
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
8586
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_avg_pool2d.cpp
8687
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_batch_matmul.cpp
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "cortex_m_ops_common.h"
10+
11+
#include <cstring>
12+
13+
#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1)
14+
#include <arm_mve.h>
15+
#define HAS_HELIUM_SIMD 1
16+
#endif
17+
18+
#if defined(ARM_MATH_DSP) && !defined(HAS_HELIUM_SIMD)
19+
#include <arm_acle.h>
20+
#define HAS_DSP_PACKED_LUT 1
21+
#endif
22+
23+
namespace cortex_m {
24+
namespace native {
25+
26+
#if defined(HAS_DSP_PACKED_LUT)
27+
// Local 4-byte read/write helpers. We deliberately don't include
28+
// `arm_nnsupportfunctions.h` for the equivalent CMSIS-NN `arm_nn_read_s8x4_ia`
29+
// / `arm_nn_write_s8x4_ia` -- the header is public but pulls in the entire
30+
// CMSIS-NN support surface (~1500 lines) just for two memcpy wrappers.
31+
static inline uint32_t read_u8x4_ia(const int8_t** in) {
32+
uint32_t val;
33+
std::memcpy(&val, *in, 4);
34+
*in += 4;
35+
return val;
36+
}
37+
38+
static inline void write_u8x4_ia(int8_t** out, uint32_t val) {
39+
std::memcpy(*out, &val, 4);
40+
*out += 4;
41+
}
42+
#endif
43+
44+
// cppcheck-suppress unusedFunction
45+
Tensor& quantized_activation_out(
46+
KernelRuntimeContext& /*context*/,
47+
const Tensor& input,
48+
const Tensor& lut,
49+
Tensor& out) {
50+
ET_CHECK_MSG(
51+
input.scalar_type() == ScalarType::Char,
52+
"quantized_activation: input must be int8");
53+
ET_CHECK_MSG(
54+
out.scalar_type() == ScalarType::Char,
55+
"quantized_activation: output must be int8");
56+
ET_CHECK_MSG(
57+
lut.scalar_type() == ScalarType::Char,
58+
"quantized_activation: lut must be int8");
59+
ET_CHECK_MSG(
60+
lut.numel() == 256,
61+
"quantized_activation: lut must have 256 entries, got %" PRId64,
62+
static_cast<int64_t>(lut.numel()));
63+
ET_CHECK_MSG(
64+
input.numel() == out.numel(),
65+
"quantized_activation: input and output must have the same numel");
66+
67+
const int8_t* in_data = input.const_data_ptr<int8_t>();
68+
const int8_t* lut_data = lut.const_data_ptr<int8_t>();
69+
int8_t* out_data = out.mutable_data_ptr<int8_t>();
70+
71+
// The LUT is precomputed AoT from the input/output qparams and the
72+
// activation function (sigmoid / tanh / silu / ...), so the kernel does not
73+
// need to know which activation it is implementing. The signed int8 input
74+
// is biased by 128 to use it as an unsigned [0, 255] table index.
75+
const int64_t n = input.numel();
76+
int64_t i = 0;
77+
78+
#if defined(HAS_HELIUM_SIMD)
79+
// M55/M85: 16 lanes per iteration. Reinterpret the int8 input as uint8
80+
// (bit-identical load), add 128 mod 256 to produce a uint8 LUT index, then
81+
// gather-load the int8 result from the LUT.
82+
for (; i + 15 < n; i += 16) {
83+
uint8x16_t in_u8 = vldrbq_u8(reinterpret_cast<const uint8_t*>(in_data + i));
84+
uint8x16_t idx = vaddq_n_u8(in_u8, 128);
85+
int8x16_t result = vldrbq_gather_offset_s8(lut_data, idx);
86+
vstrbq_s8(out_data + i, result);
87+
}
88+
#elif defined(HAS_DSP_PACKED_LUT)
89+
// M4/M7 (DSP, no MVE): process 4 bytes per iteration. The DSP win comes from
90+
// (a) folding 4 byte-loads into one word-load, (b) batching the +128 bias
91+
// with `__uadd8`, and (c) folding 4 byte-stores into one word-store. The
92+
// LUT lookups themselves still hit memory four times per word -- no DSP
93+
// gather instruction exists on M-class.
94+
const int8_t* in_ptr = in_data;
95+
int8_t* out_ptr = out_data;
96+
const int64_t word_iters = n >> 2;
97+
for (int64_t w = 0; w < word_iters; ++w) {
98+
const uint32_t in_word = read_u8x4_ia(&in_ptr);
99+
const uint32_t idx_word = __uadd8(in_word, 0x80808080u);
100+
const uint32_t out_word = static_cast<uint32_t>(static_cast<uint8_t>(
101+
lut_data[idx_word & 0xFFu])) |
102+
(static_cast<uint32_t>(
103+
static_cast<uint8_t>(lut_data[(idx_word >> 8) & 0xFFu]))
104+
<< 8) |
105+
(static_cast<uint32_t>(
106+
static_cast<uint8_t>(lut_data[(idx_word >> 16) & 0xFFu]))
107+
<< 16) |
108+
(static_cast<uint32_t>(
109+
static_cast<uint8_t>(lut_data[(idx_word >> 24) & 0xFFu]))
110+
<< 24);
111+
write_u8x4_ia(&out_ptr, out_word);
112+
}
113+
i = word_iters << 2;
114+
#endif
115+
116+
// 4x-unrolled scalar tail. On M-class cores without MVE or DSP the unroll
117+
// lets the compiler issue independent LUT loads; on the MVE / DSP paths
118+
// above this only runs for the < 16- (or < 4-) element remainder.
119+
for (; i + 3 < n; i += 4) {
120+
out_data[i + 0] = lut_data[static_cast<uint8_t>(in_data[i + 0] + 128)];
121+
out_data[i + 1] = lut_data[static_cast<uint8_t>(in_data[i + 1] + 128)];
122+
out_data[i + 2] = lut_data[static_cast<uint8_t>(in_data[i + 2] + 128)];
123+
out_data[i + 3] = lut_data[static_cast<uint8_t>(in_data[i + 3] + 128)];
124+
}
125+
for (; i < n; ++i) {
126+
out_data[i] = lut_data[static_cast<uint8_t>(in_data[i] + 128)];
127+
}
128+
129+
return out;
130+
}
131+
132+
} // namespace native
133+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,35 @@ def quantized_mul_impl(
264264
return result
265265

266266

267+
# ===================================================================
268+
# QUANTIZED ACTIVATION (LUT) OPERATION DEFINITION
269+
# ===================================================================
270+
# Generic table-lookup activation. The 256-entry int8 LUT is precomputed AoT
271+
# from the input/output qparams and the activation function (sigmoid, tanh,
272+
# silu, ...), so the kernel is identical regardless of which activation it
273+
# evaluates: out[i] = lut[input[i] + 128].
274+
lib.define("quantized_activation(Tensor input, Tensor lut) -> Tensor")
275+
lib.define(
276+
"quantized_activation.out(Tensor input, Tensor lut, *, Tensor(a!) out) -> Tensor(a!)"
277+
)
278+
279+
280+
@register_fake("cortex_m::quantized_activation") # type: ignore[misc]
281+
def quantized_activation_meta(input: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
282+
assert input.dtype == torch.int8, "quantized_activation input must be int8"
283+
assert lut.dtype == torch.int8 and lut.numel() == 256, (
284+
"quantized_activation lut must be int8 with 256 entries; "
285+
f"got dtype={lut.dtype}, numel={lut.numel()}"
286+
)
287+
return torch.empty_like(input)
288+
289+
290+
@impl(lib, "quantized_activation", "CompositeExplicitAutograd") # type: ignore[misc]
291+
def quantized_activation_impl(input: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
292+
indices = input.to(torch.int32) + 128
293+
return lut[indices].to(torch.int8)
294+
295+
267296
# ===================================================================
268297
# QUANTIZED BATCH MATMUL OPERATION DEFINITION
269298
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
- arg_meta: null
3030
kernel_name: cortex_m::quantized_mul_out
3131

32+
- func: cortex_m::quantized_activation.out(Tensor input, Tensor lut, *, Tensor(a!) out) -> Tensor(a!)
33+
variants: function
34+
kernels:
35+
- arg_meta: null
36+
kernel_name: cortex_m::quantized_activation_out
37+
3238
- func: cortex_m::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
3339
variants: function
3440
kernels:

backends/cortex_m/ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ OPERATORS = [
7070
"quantized_avg_pool2d",
7171
"quantized_batch_matmul",
7272
"quantized_max_pool2d",
73+
"quantized_activation",
7374
]
7475

7576
def define_common_targets():

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1414

1515
from executorch.backends.cortex_m.passes.cortex_m_pass import CortexMPass
16-
from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot
16+
from executorch.backends.cortex_m.passes.passes_utils import (
17+
build_activation_lut,
18+
quantize_multiplier_aot,
19+
)
1720
from executorch.backends.cortex_m.passes.scratch_buffer_sizes import (
1821
required_cmsis_nn_buffer_sizes,
1922
)
@@ -483,6 +486,38 @@ def _get_bmm_replacement(self, node):
483486
)
484487
return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args
485488

489+
def _get_activation_replacement(self, node):
490+
"""Lower a standalone quantized sigmoid / tanh / silu to a single
491+
cortex_m.quantized_activation call backed by an AoT-built 256-entry
492+
int8 LUT. The kernel is shape-agnostic; the LUT encodes both the
493+
activation function and the input/output qparams.
494+
"""
495+
input_qparams = node.meta["input_qparams"][0]
496+
output_qparams = node.meta["output_qparams"][0]
497+
lut_tensor = build_activation_lut(
498+
node.target,
499+
float(input_qparams.scale),
500+
int(input_qparams.zp),
501+
float(output_qparams.scale),
502+
int(output_qparams.zp),
503+
)
504+
505+
# Constant placeholders must appear before user-input placeholders;
506+
# anchor on the first existing placeholder so the new LUT lands in the
507+
# constant-placeholder block at the top of the graph.
508+
first_placeholder = next(n for n in node.graph.nodes if n.op == "placeholder")
509+
with node.graph.inserting_before(first_placeholder):
510+
lut_node = create_constant_placeholder(
511+
self.exported_program,
512+
node.graph,
513+
node.name + "_lut",
514+
InputKind.PARAMETER,
515+
lut_tensor,
516+
)
517+
518+
new_args = (node.args[0], lut_node)
519+
return exir_ops.edge.cortex_m.quantized_activation.default, new_args
520+
486521
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
487522
modified = False
488523
for node in graph_module.graph.nodes:
@@ -506,6 +541,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
506541
op, args = self._get_convolution_replacement(node)
507542
case exir_ops.edge.aten.bmm.default:
508543
op, args = self._get_bmm_replacement(node)
544+
case (
545+
exir_ops.edge.aten.sigmoid.default
546+
| exir_ops.edge.aten.tanh.default
547+
| exir_ops.edge.aten.silu.default
548+
):
549+
op, args = self._get_activation_replacement(node)
509550
case _:
510551
continue
511552

backends/cortex_m/passes/passes_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,67 @@ def is_qualified_int8_node(args) -> bool:
190190
return False
191191

192192

193+
def _stable_sigmoid(x: float) -> float:
194+
# Always exponentiate the non-positive value so `math.exp` never overflows
195+
# for unusually large `|x|` (e.g. wide-range input qparams). Algebraically
196+
# identical to `1 / (1 + exp(-x))`.
197+
if x >= 0:
198+
return 1.0 / (1.0 + math.exp(-x))
199+
e = math.exp(x)
200+
return e / (1.0 + e)
201+
202+
203+
def _stable_silu(x: float) -> float:
204+
return x * _stable_sigmoid(x)
205+
206+
207+
_ACTIVATION_FNS = {
208+
exir_ops.edge.aten.sigmoid.default: _stable_sigmoid,
209+
exir_ops.edge.aten.tanh.default: math.tanh,
210+
exir_ops.edge.aten.silu.default: _stable_silu,
211+
}
212+
213+
214+
def _round_half_away_from_zero(x: float) -> int:
215+
# Matches the rounding convention `requantize_cmsis` (above) applies after
216+
# the right-shift step: ties on positive values round toward +∞, ties on
217+
# negative values round toward -∞. Python's built-in `round` would use
218+
# banker's rounding instead and disagree at exact half-integers.
219+
return int(math.copysign(math.floor(abs(x) + 0.5), x)) if x != 0 else 0
220+
221+
222+
def build_activation_lut(
223+
target,
224+
input_scale: float,
225+
input_zp: int,
226+
output_scale: float,
227+
output_zp: int,
228+
) -> torch.Tensor:
229+
"""AoT-compute a 256-entry int8 lookup table for a quantized activation.
230+
231+
`target` is the edge-dialect op being lowered (e.g.
232+
`exir_ops.edge.aten.sigmoid.default`).
233+
234+
The LUT is indexed by the input byte value biased by 128: for any int8
235+
input `q_in`, the kernel reads `lut[q_in + 128]` to get the int8 output.
236+
Because the LUT is computed in float and quantized once per entry, the
237+
runtime kernel is a single memory-lookup with no requantization math.
238+
"""
239+
if target not in _ACTIVATION_FNS:
240+
raise ValueError(
241+
f"build_activation_lut: unsupported activation target {target!r} "
242+
f"(supported: {sorted(t.__name__ for t in _ACTIVATION_FNS)})"
243+
)
244+
f = _ACTIVATION_FNS[target]
245+
lut = torch.empty(256, dtype=torch.int8)
246+
for q in range(-128, 128):
247+
x = (q - input_zp) * input_scale
248+
y = f(x)
249+
q_out = _round_half_away_from_zero(y / output_scale + output_zp)
250+
lut[q + 128] = max(-128, min(127, q_out))
251+
return lut
252+
253+
193254
def quantize_multiplier_aot(scale: float) -> tuple[int, int]:
194255
if scale == 0.0:
195256
return 0, 0

backends/cortex_m/quantizer/pattern_checkers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,25 @@ def check_quantization_config(
9999
return is_int8
100100

101101

102+
class CortexMActivationCheck(PatternCheck):
103+
"""Accept standalone elementwise activations (sigmoid / tanh / silu)
104+
that the LUT-based cortex_m.quantized_activation op handles uniformly.
105+
106+
The kernel is shape-agnostic and the LUT is computed AoT from per-tensor
107+
qparams, so the only thing to enforce is int8 per-tensor quantization.
108+
"""
109+
110+
@classmethod
111+
def check_quantization_config(
112+
cls, pattern: list[Node], quantization_config: QuantizationConfig
113+
) -> bool:
114+
is_int8 = cls.is_int8_activations(quantization_config)
115+
is_per_tensor = cls.is_per_tensor(
116+
quantization_config.get_input_act_qspec()
117+
) and cls.is_per_tensor(quantization_config.get_output_act_qspec())
118+
return is_int8 and is_per_tensor
119+
120+
102121
class CortexMSoftmaxCheck(PatternCheck):
103122

104123
@classmethod

backends/cortex_m/quantizer/quantizer_support.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from executorch.backends.cortex_m.quantizer.pattern_checkers import (
8+
CortexMActivationCheck,
89
CortexMAddMulCheck,
910
CortexMAvgPool2DCheck,
1011
CortexMBmmCheck,
@@ -119,6 +120,12 @@
119120
(torch.ops.aten.softmax.int,): CortexMSoftmaxCheck,
120121
}
121122

123+
ACTIVATION_OP_PATTERNS = {
124+
(torch.ops.aten.sigmoid.default,): CortexMActivationCheck,
125+
(torch.ops.aten.tanh.default,): CortexMActivationCheck,
126+
(torch.ops.aten.silu.default,): CortexMActivationCheck,
127+
}
128+
122129
POOL_OP_PATTERNS = {
123130
(torch.ops.aten.avg_pool2d.default,): CortexMAvgPool2DCheck,
124131
(torch.ops.aten.max_pool2d.default,): CortexMMaxPool2DCheck,
@@ -161,4 +168,5 @@
161168
| CONV_TRANSPOSE_OP_PATTERNS
162169
| POOL_OP_PATTERNS
163170
| BMM_OP_PATTERNS
171+
| ACTIVATION_OP_PATTERNS
164172
)

0 commit comments

Comments
 (0)