Skip to content

Commit ea412d8

Browse files
Arm backend: Add FP8 infrastructure and cast support (pytorch#19702)
Change-Id: I2757fcaf61cb1910a8b65e3c49853c677239b2f0 cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent b4a9e72 commit ea412d8

8 files changed

Lines changed: 205 additions & 8 deletions

File tree

backends/arm/operator_support/to_dim_order_copy_support.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,26 @@ def _merge_supported_types(
120120
torch.float32,
121121
],
122122
}
123+
SUPPORTED_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = {
124+
torch.float16: [torch.float8_e4m3fn],
125+
torch.float32: [torch.float8_e4m3fn],
126+
torch.float8_e4m3fn: [torch.float16, torch.float32],
127+
}
128+
SUPPORTED_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = {
129+
torch.float16: [torch.float8_e5m2],
130+
torch.float32: [torch.float8_e5m2],
131+
torch.float8_e5m2: [torch.float16, torch.float32],
132+
}
133+
SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = {
134+
torch.bfloat16: [torch.float8_e4m3fn],
135+
torch.float8_e4m3fn: [torch.bfloat16],
136+
}
137+
SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = {
138+
torch.bfloat16: [torch.float8_e5m2],
139+
torch.float8_e5m2: [torch.bfloat16],
140+
}
123141

124-
def is_node_tosa_supported(
142+
def is_node_tosa_supported( # noqa: C901
125143
self, node: fx.Node, tosa_spec: TosaSpecification
126144
) -> bool:
127145
"""Return True if the node is supported by TOSA.
@@ -148,6 +166,26 @@ def is_node_tosa_supported(
148166
supported_dtypes = self._merge_supported_types(
149167
self.SUPPORTED_BF16_EXTENSION_DTYPES, supported_dtypes
150168
)
169+
if tosa_spec.support_extension("fp8e4m3"):
170+
supported_dtypes = self._merge_supported_types(
171+
self.SUPPORTED_FP8E4M3_EXTENSION_DTYPES, supported_dtypes
172+
)
173+
if tosa_spec.support_extension("fp8e5m2"):
174+
supported_dtypes = self._merge_supported_types(
175+
self.SUPPORTED_FP8E5M2_EXTENSION_DTYPES, supported_dtypes
176+
)
177+
if tosa_spec.support_extension("bf16") and tosa_spec.support_extension(
178+
"fp8e4m3"
179+
):
180+
supported_dtypes = self._merge_supported_types(
181+
self.SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES, supported_dtypes
182+
)
183+
if tosa_spec.support_extension("bf16") and tosa_spec.support_extension(
184+
"fp8e5m2"
185+
):
186+
supported_dtypes = self._merge_supported_types(
187+
self.SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES, supported_dtypes
188+
)
151189

152190
if len(node.all_input_nodes) != 1:
153191
self.reporter.report_reject(

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def tosa_support_factory(
295295
disallowed_dtypes = [torch.float64]
296296
if not tosa_spec.support_extension("bf16"):
297297
disallowed_dtypes.append(torch.bfloat16)
298+
if not tosa_spec.support_extension("fp8e4m3"):
299+
disallowed_dtypes.append(torch.float8_e4m3fn)
300+
if not tosa_spec.support_extension("fp8e5m2"):
301+
disallowed_dtypes.append(torch.float8_e5m2)
298302
if tosa_spec.is_U55_subset:
299303
disallowed_dtypes.append(torch.bool)
300304
negative_checks.append(

backends/arm/operators/op_tosa_identity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def define_node(
4646
ts.DType.FP16,
4747
ts.DType.FP32,
4848
ts.DType.BF16,
49+
ts.DType.FP8E4M3,
50+
ts.DType.FP8E5M2,
4951
],
5052
self.tosa_spec,
5153
)

backends/arm/process_node.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,22 @@
3030

3131
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
3232
tensor = tensor.detach().cpu().contiguous()
33-
if tensor.dtype == torch.bfloat16:
33+
if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
3434
try:
3535
import ml_dtypes # type: ignore[import-not-found]
3636
except ImportError as e:
3737
raise RuntimeError(
38-
"ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?"
38+
f"ml_dtypes is required to serialize {tensor.dtype} tensors for TOSA. "
39+
"Have you run setup.sh?"
3940
) from e
40-
return tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16)
41+
42+
ml_dtype_map = {
43+
torch.bfloat16: (torch.uint16, ml_dtypes.bfloat16),
44+
torch.float8_e4m3fn: (torch.uint8, ml_dtypes.float8_e4m3fn),
45+
torch.float8_e5m2: (torch.uint8, ml_dtypes.float8_e5m2),
46+
}
47+
storage_dtype, ml_dtype = ml_dtype_map[tensor.dtype]
48+
return tensor.view(storage_dtype).numpy().view(ml_dtype)
4149
else:
4250
return tensor.numpy()
4351

backends/arm/test/ops/test_to_copy.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,130 @@ def test_to_tosa_FP_bf16_with_extension():
116116
pipeline.run()
117117

118118

119+
_TO_COPY_TEST_DATA_FP_FP8 = {
120+
"fp32_to_fp8e4m3": lambda: (
121+
torch.rand((1, 2, 3, 4), dtype=torch.float32),
122+
torch.float8_e4m3fn,
123+
"fp8e4m3",
124+
),
125+
"fp16_to_fp8e5m2": lambda: (
126+
torch.rand((1, 2, 3, 4), dtype=torch.float16),
127+
torch.float8_e5m2,
128+
"fp8e5m2",
129+
),
130+
"fp8e4m3_to_fp32": lambda: (
131+
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn),
132+
torch.float32,
133+
"fp8e4m3",
134+
),
135+
"fp8e5m2_to_fp16": lambda: (
136+
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2),
137+
torch.float16,
138+
"fp8e5m2",
139+
),
140+
}
141+
142+
143+
def test_to_tosa_FP_fp8e4m3_requires_extension():
144+
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float32)
145+
pipeline = OpNotSupportedPipeline[input_t1](
146+
Cast(torch.float8_e4m3fn),
147+
(test_tensor,),
148+
{
149+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
150+
},
151+
)
152+
pipeline.run()
153+
154+
155+
def test_to_tosa_FP_fp8e5m2_requires_extension():
156+
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float16)
157+
pipeline = OpNotSupportedPipeline[input_t1](
158+
Cast(torch.float8_e5m2),
159+
(test_tensor,),
160+
{
161+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
162+
},
163+
)
164+
pipeline.run()
165+
166+
167+
def test_to_tosa_FP_bf16_to_fp8e4m3_requires_both_extensions():
168+
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16)
169+
pipeline = OpNotSupportedPipeline[input_t1](
170+
Cast(torch.float8_e4m3fn),
171+
(test_tensor,),
172+
{
173+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
174+
},
175+
tosa_extensions=["bf16"],
176+
)
177+
pipeline.run()
178+
179+
180+
def test_to_tosa_FP_bf16_to_fp8e5m2_requires_both_extensions():
181+
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16)
182+
pipeline = OpNotSupportedPipeline[input_t1](
183+
Cast(torch.float8_e5m2),
184+
(test_tensor,),
185+
{
186+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
187+
},
188+
tosa_extensions=["bf16"],
189+
)
190+
pipeline.run()
191+
192+
193+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_FP_FP8)
194+
def test_to_tosa_FP_fp8_with_extension(test_data: Tuple):
195+
test_tensor, new_dtype, tosa_extension = test_data()
196+
pipeline = TosaPipelineFP[input_t1](
197+
Cast(new_dtype),
198+
(test_tensor,),
199+
aten_op=[],
200+
exir_op=[],
201+
tosa_extensions=[tosa_extension],
202+
)
203+
pipeline.run()
204+
205+
206+
_TO_COPY_TEST_DATA_BF16_FP8 = {
207+
"bf16_to_fp8e4m3": lambda: (
208+
torch.rand((1, 2, 3, 4), dtype=torch.bfloat16),
209+
torch.float8_e4m3fn,
210+
["bf16", "fp8e4m3"],
211+
),
212+
"fp8e4m3_to_bf16": lambda: (
213+
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn),
214+
torch.bfloat16,
215+
["bf16", "fp8e4m3"],
216+
),
217+
"bf16_to_fp8e5m2": lambda: (
218+
torch.rand((1, 2, 3, 4), dtype=torch.bfloat16),
219+
torch.float8_e5m2,
220+
["bf16", "fp8e5m2"],
221+
),
222+
"fp8e5m2_to_bf16": lambda: (
223+
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2),
224+
torch.bfloat16,
225+
["bf16", "fp8e5m2"],
226+
),
227+
}
228+
229+
230+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_BF16_FP8)
231+
def test_to_tosa_FP_bf16_fp8_with_extensions(test_data: Tuple):
232+
test_tensor, new_dtype, tosa_extensions = test_data()
233+
pipeline = TosaPipelineFP[input_t1](
234+
Cast(new_dtype),
235+
(test_tensor,),
236+
aten_op=[],
237+
exir_op=[],
238+
tosa_extensions=tosa_extensions,
239+
)
240+
pipeline.run()
241+
242+
119243
@common.parametrize("test_data", _TO_COPY_TEST_DATA_FP)
120244
@common.SkipIfNoModelConverter
121245
def test_to_vgf_no_quant(test_data: Tuple):

backends/arm/test/runner_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
torch.int16: np.int16,
6363
torch.int32: np.int32,
6464
torch.int64: np.int64,
65+
torch.float8_e4m3fn: np.uint8,
66+
torch.float8_e5m2: np.uint8,
6567
torch.float16: np.float16,
6668
torch.float32: np.float32,
6769
torch.float64: np.float64,
@@ -190,6 +192,10 @@ def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
190192
if tensor.dtype == torch.bfloat16:
191193
# Numpy doesn't support bfloat16, use, uint16 instead. Dtype is inferred from model anyways.
192194
tensor = tensor.view(torch.uint16)
195+
elif tensor.dtype == torch.float8_e4m3fn:
196+
tensor = tensor.view(torch.uint8)
197+
elif tensor.dtype == torch.float8_e5m2:
198+
tensor = tensor.view(torch.uint8)
193199
return tensor.numpy()
194200

195201

backends/arm/tosa/mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def map_dtype(data_type: torch.dtype) -> Any:
106106
torch.int32: ts.DType.INT32,
107107
torch.int: ts.DType.INT32,
108108
torch.bool: ts.DType.BOOL,
109+
torch.float8_e4m3fn: ts.DType.FP8E4M3,
110+
torch.float8_e5m2: ts.DType.FP8E5M2,
109111
}
110112
if data_type not in dtype_map:
111113
raise ValueError(f"Unknown type: {data_type}")
@@ -231,6 +233,12 @@ def __validate(self, tosa_spec: TosaSpecification) -> bool:
231233
case ts.DType.BF16:
232234
if not tosa_spec.support_extension("bf16"):
233235
return False
236+
case ts.DType.FP8E4M3:
237+
if not tosa_spec.support_extension("fp8e4m3"):
238+
return False
239+
case ts.DType.FP8E5M2:
240+
if not tosa_spec.support_extension("fp8e5m2"):
241+
return False
234242

235243
return True
236244

backends/test/harness/tester.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -439,17 +439,24 @@ def _assert_outputs_equal(
439439
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
440440
)
441441
else:
442+
# torch.allclose() does not have a CPU implementation for FP8 tensors
443+
# in some PyTorch builds, so compare FP8 outputs in float32 instead.
444+
compare_model = model
445+
compare_ref = ref
446+
if model.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
447+
compare_model = model.to(torch.float32)
448+
compare_ref = ref.to(torch.float32)
442449
assert torch.allclose(
443-
model,
444-
ref,
450+
compare_model,
451+
compare_ref,
445452
atol=atol,
446453
rtol=rtol,
447454
equal_nan=True,
448455
), (
449456
f"Output {i} does not match reference output.\n"
450457
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
451458
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
452-
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
459+
f"\tDifference: max: {torch.max(compare_model-compare_ref)}, abs: {torch.max(torch.abs(compare_model-compare_ref))}, mean abs error: {torch.mean(torch.abs(compare_model-compare_ref).to(torch.double))}.\n"
453460
f"\t-- Model vs. Reference --\n"
454461
f"\t Numel: {model.numel()}, {ref.numel()}\n"
455462
f"\tMedian: {model.median()}, {ref.median()}\n"

0 commit comments

Comments
 (0)