Skip to content

Commit 8020fe0

Browse files
authored
Add a16w8 MHA softmax FVP coverage for Ethos-U85 (pytorch#19493)
Differential Revision: D103734699 Pull Request resolved: pytorch#19493
1 parent 84f39aa commit 8020fe0

2 files changed

Lines changed: 118 additions & 0 deletions

File tree

backends/arm/test/ops/test_softmax.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from typing import Tuple
99

10+
import pytest
11+
1012
import torch
1113
from executorch.backends.arm.test import common
1214
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -120,3 +122,118 @@ def test_softmax_vgf_quant(test_data):
120122
# TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests
121123
# pipeline.change_args("run_method_and_compare_outputs", qtol=1)
122124
pipeline.run()
125+
126+
127+
# ---------------------------------------------------------------------------
128+
# a16w8 (int16 IO + int8 weights) softmax FVP coverage.
129+
#
130+
# Sweeps a multi-head-attention-shaped softmax over a wide range of
131+
# pre-softmax input magnitudes to surface int16 numerics issues in the
132+
# lowered graph (e.g. the Ethos-U85 ReduceSum int16 silent-zero issue in the
133+
# softmax decomposition, fixed by the follow-up Vela patch in this stack).
134+
# ---------------------------------------------------------------------------
135+
136+
137+
class MultiHeadAttentionSoftmax(torch.nn.Module):
138+
"""Generic multi-head-attention softmax: reshape -> softmax(dim=-1) -> flatten.
139+
140+
H heads, M query tokens, W K/V window. Output shape: (N, T, H*M*W).
141+
"""
142+
143+
H = 4
144+
M = 1
145+
W = 16
146+
IN_FEATURES = H * M * W # 64
147+
148+
def forward(self, x: torch.Tensor) -> torch.Tensor:
149+
n, t, _ = x.shape
150+
x = x.reshape(n, t, self.H, self.M, self.W)
151+
x = torch.softmax(x, dim=-1)
152+
x = x.reshape(n, t, self.IN_FEATURES)
153+
return x
154+
155+
156+
# (input_low, input_high) per case. Keys are the parametrize ids.
157+
# Range coverage spans realistic post-1/sqrt(d) attention logits (typically
158+
# in [-10, +10]) plus a couple of wider buffer cases. atol below is sized
159+
# at ~1.5x the observed FVP max-abs softmax error across the sweep at
160+
# qtol=0, measured against the quantized reference.
161+
mha_softmax_sweep = {
162+
"range_neg0p01_to_0p01": (-0.01, 0.01),
163+
"range_neg0p1_to_0p1": (-0.1, 0.1),
164+
"range_neg1_to_1": (-1.0, 1.0),
165+
"range_neg3_to_3": (-3.0, 3.0),
166+
"range_neg5_to_5": (-5.0, 5.0),
167+
"range_neg10_to_10": (-10.0, 10.0),
168+
"range_neg30_to_30": (-30.0, 30.0),
169+
}
170+
171+
_MHA_ATOL = 0.003
172+
173+
174+
def _make_mha_softmax_inputs(
175+
input_low: float, input_high: float, num_test: int = 8, seed: int = 42
176+
) -> Tuple[torch.Tensor]:
177+
# Local Generator so this helper does not mutate the global RNG state
178+
# and the test suite stays order-independent.
179+
gen = torch.Generator().manual_seed(seed)
180+
span = input_high - input_low
181+
return (
182+
torch.rand(
183+
num_test,
184+
1,
185+
MultiHeadAttentionSoftmax.IN_FEATURES,
186+
generator=gen,
187+
)
188+
* span
189+
+ input_low,
190+
)
191+
192+
193+
@common.parametrize("sweep_case", mha_softmax_sweep)
194+
@common.XfailIfNoCorstone300
195+
def test_mha_softmax_a16w8_u55_INT(sweep_case: Tuple[float, float]) -> None:
196+
input_low, input_high = sweep_case
197+
pipeline = EthosU55PipelineINT[input_t1](
198+
MultiHeadAttentionSoftmax(),
199+
_make_mha_softmax_inputs(input_low, input_high),
200+
[],
201+
exir_ops=[],
202+
a16w8_quantization=True,
203+
symmetric_io_quantization=True,
204+
epsilon=2**-16,
205+
atol=_MHA_ATOL,
206+
)
207+
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
208+
pipeline.run()
209+
210+
211+
# All cases hit the Ethos-U85 int16 ReduceSum silent-zero issue inside the
212+
# softmax decomposition. strict=False so the test target stays green both
213+
# on stock Vela 5.0 (cases XFAIL) and once the upstream Vela fix lands
214+
# (cases XPASS).
215+
# Upstream report:
216+
# https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/issues/23
217+
@common.parametrize("sweep_case", mha_softmax_sweep)
218+
@common.XfailIfNoCorstone320
219+
@pytest.mark.xfail(
220+
reason=(
221+
"Ethos-U85 int16 ReduceSum silent-zero in softmax decomposition: "
222+
"https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/issues/23"
223+
),
224+
strict=False,
225+
)
226+
def test_mha_softmax_a16w8_u85_INT(sweep_case: Tuple[float, float]) -> None:
227+
input_low, input_high = sweep_case
228+
pipeline = EthosU85PipelineINT[input_t1](
229+
MultiHeadAttentionSoftmax(),
230+
_make_mha_softmax_inputs(input_low, input_high),
231+
[],
232+
exir_ops=[],
233+
a16w8_quantization=True,
234+
symmetric_io_quantization=True,
235+
epsilon=2**-16,
236+
atol=_MHA_ATOL,
237+
)
238+
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
239+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_arm_tests():
2929
"ops/test_rsqrt.py",
3030
"ops/test_slice.py",
3131
"ops/test_sigmoid.py",
32+
"ops/test_softmax.py",
3233
"ops/test_sub.py",
3334
"ops/test_sum.py",
3435
"ops/test_tanh.py",

0 commit comments

Comments
 (0)