Skip to content

Commit e3d5de2

Browse files
Arm backend: Preserve MXFP linear output dtype (#20487)
Infer the output dtype of MXFP linear replacements from the source nn.Linear module. Keep the internal MXFP custom op output in FP32, and insert a cast back to the inferred dtype when needed. This lets BF16 models keep BF16 outputs from MXFP linear layers, which keeps SDPA input and attention mask dtypes compatible during export. Add Qwen3 VL layer coverage for MXFP8 BF16 attention, MLP, and decoder layers. Change-Id: Id6143ff330aeeca0815756c5468efb9930ac185f cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent d0916de commit e3d5de2

4 files changed

Lines changed: 131 additions & 3 deletions

File tree

backends/arm/ao_ext/ops/mxfp_linear_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
)
3434

3535

36+
_SUPPORTED_OUTPUT_DTYPES: set[torch.dtype] = {
37+
torch.float32,
38+
torch.bfloat16,
39+
}
40+
41+
3642
def _get_mx_elem_dtype(
3743
weight_qdata: torch.Tensor,
3844
weight_payload_dtype: str = "",
@@ -139,10 +145,12 @@ def __init__(
139145
bias: torch.Tensor | None,
140146
weight_dtype: MXFPDType,
141147
block_size: int,
148+
output_dtype: torch.dtype = torch.float32,
142149
) -> None:
143150
super().__init__()
144151
self.weight_dtype = mxfp_dtype_to_str(weight_dtype)
145152
self.block_size = block_size
153+
self.output_dtype = output_dtype
146154

147155
self.register_buffer("weight_qdata", weight_qdata, persistent=True)
148156
self.register_buffer("weight_scale", weight_scale, persistent=True)
@@ -159,14 +167,17 @@ def __init__(
159167
)
160168

161169
def forward(self, x: torch.Tensor) -> torch.Tensor:
162-
return torch.ops.tosa_mxfp.linear.default(
170+
output = torch.ops.tosa_mxfp.linear.default(
163171
x,
164172
self.weight_qdata,
165173
self.weight_scale,
166174
self.bias,
167175
self.block_size,
168176
self.weight_dtype,
169177
)
178+
if self.output_dtype != torch.float32:
179+
output = output.to(self.output_dtype)
180+
return output
170181

171182

172183
def transform_linear_to_mxfp(
@@ -196,10 +207,14 @@ def transform_linear_to_mxfp(
196207
weight_scale = weight_scale.unsqueeze(0)
197208

198209
bias = module.bias.detach().to(torch.float32) if module.bias is not None else None
210+
output_dtype = weight.dtype
211+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
212+
raise ValueError(f"Unsupported output_dtype: {output_dtype}")
199213
return MXFPLinearOp(
200214
weight_qdata,
201215
weight_scale,
202216
bias,
203217
config.weight_dtype,
204218
config.block_size,
219+
output_dtype,
205220
)

backends/arm/test/misc/test_mxfp_linear_ao.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,53 @@ def _is_selected_linear(module: torch.nn.Module, fqn: str) -> bool:
9898
assert isinstance(model.skipped, torch.nn.Linear)
9999

100100

101+
def test_mxfp_linear_preserves_bfloat16_output_dtype() -> None:
102+
model = LinearModule().eval().to(torch.bfloat16)
103+
to_mxfp(
104+
model,
105+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
106+
)
107+
108+
output = model(torch.randn(4, 32, dtype=torch.bfloat16))
109+
110+
assert isinstance(model.linear, MXFPLinearOp)
111+
assert model.linear.output_dtype == torch.bfloat16
112+
assert output.dtype == torch.bfloat16
113+
114+
115+
def test_mxfp_linear_op_output_dtype_constructor_arg() -> None:
116+
model = LinearModule().eval()
117+
config = MXFPOpConfig(weight_dtype=torch.float8_e4m3fn)
118+
to_mxfp(
119+
model,
120+
config,
121+
)
122+
assert isinstance(model.linear, MXFPLinearOp)
123+
124+
fp32_linear = MXFPLinearOp(
125+
model.linear.weight_qdata,
126+
model.linear.weight_scale,
127+
model.linear.bias,
128+
config.weight_dtype,
129+
config.block_size,
130+
)
131+
bf16_linear = MXFPLinearOp(
132+
model.linear.weight_qdata,
133+
model.linear.weight_scale,
134+
model.linear.bias,
135+
config.weight_dtype,
136+
config.block_size,
137+
output_dtype=torch.bfloat16,
138+
)
139+
140+
test_input = torch.randn(4, 32)
141+
142+
assert fp32_linear.output_dtype == torch.float32
143+
assert fp32_linear(test_input).dtype == torch.float32
144+
assert bf16_linear.output_dtype == torch.bfloat16
145+
assert bf16_linear(test_input).dtype == torch.bfloat16
146+
147+
101148
def _test_mxfp_linear_export_preserves_custom_op(config: MXFPOpConfig) -> None:
102149
model = LinearModule().eval()
103150
to_mxfp(model, config)
@@ -135,3 +182,26 @@ def test_mxfp6_e3m2_linear_export_preserves_custom_op() -> None:
135182
_test_mxfp_linear_export_preserves_custom_op(
136183
MXFPOpConfig(weight_dtype=DTYPE_FP6_E3M2)
137184
)
185+
186+
187+
def test_mxfp_linear_export_preserves_inferred_bfloat16_output_dtype() -> None:
188+
model = LinearModule().eval().to(torch.bfloat16)
189+
to_mxfp(
190+
model,
191+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
192+
)
193+
194+
exported = export(model, (torch.randn(4, 32, dtype=torch.bfloat16),), strict=False)
195+
196+
cast_nodes = [
197+
node
198+
for node in exported.graph_module.graph.nodes
199+
if node.op == "call_function" and node.target == torch.ops.aten.to.dtype
200+
]
201+
202+
assert len(cast_nodes) == 1
203+
assert cast_nodes[0].args[1] == torch.bfloat16
204+
assert cast_nodes[0].meta["val"].dtype == torch.bfloat16
205+
cast_input = cast_nodes[0].args[0]
206+
assert isinstance(cast_input, torch.fx.Node)
207+
assert cast_input.target == torch.ops.tosa_mxfp.linear.default

backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,31 @@ def test_qwen3_vl_tosa_mxfp8_fp32(test_case: Qwen3VLTestCase):
560560
pipeline.run()
561561

562562

563+
@common.parametrize(
564+
"test_case",
565+
TOSA_MXFP8_TEST_CASES,
566+
)
567+
def test_qwen3_vl_tosa_mxfp8_bf16(test_case: Qwen3VLTestCase):
568+
model, inputs = test_case.model_cls.prepare_model_and_inputs()
569+
model, inputs = _to_bfloat16(model, inputs)
570+
mxfp_config = MXFPOpConfig(weight_dtype=torch.float8_e4m3fn)
571+
572+
with torch.no_grad():
573+
pipeline = MXFPTosaPipelineFP[input_t](
574+
model,
575+
inputs,
576+
aten_op=aten_op_mxfp_linear,
577+
exir_op=[],
578+
filter_fn=_is_linear,
579+
frobenius_threshold=0.05,
580+
cosine_threshold=0.995,
581+
mxfp_config=mxfp_config,
582+
tosa_version="1.1",
583+
tosa_extensions=["bf16", "mxfp"],
584+
)
585+
pipeline.run()
586+
587+
563588
@common.SkipIfNoModelConverter
564589
@common.parametrize(
565590
"test_case",

backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,15 @@ def _get_nodes_from_target(
5454

5555
def _rewrite_linear_module(
5656
config: MXFPOpConfig,
57+
model_dtype: torch.dtype = torch.float32,
5758
) -> tuple[torch.fx.GraphModule, list[torch.fx.Node], list[torch.fx.Node]]:
58-
model = _LinearModule(bias=True).eval()
59+
model = _LinearModule(bias=True).eval().to(model_dtype)
5960
to_mxfp(model, config, filter_fn=_is_linear)
60-
exported = export(model, (torch.randn(4, 5, 32),), strict=False)
61+
exported = export(
62+
model,
63+
(torch.randn(4, 5, 32, dtype=model_dtype),),
64+
strict=False,
65+
)
6166
tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp")
6267

6368
with TosaLoweringContext(tosa_spec):
@@ -98,6 +103,19 @@ def test_rewrite_mxfp_linear_replaces_custom_op() -> None:
98103
assert tuple(output_node.meta["val"][0].shape) == (4, 5, 8)
99104

100105

106+
def test_rewrite_mxfp_linear_preserves_inferred_bfloat16_output_cast() -> None:
107+
graph_module, _, matmul_nodes = _rewrite_linear_module(
108+
MXFPOpConfig(),
109+
model_dtype=torch.bfloat16,
110+
)
111+
112+
output_node = graph_module.graph.output_node()
113+
114+
assert len(matmul_nodes) == 1
115+
assert matmul_nodes[0].meta["val"].dtype == torch.float32
116+
assert output_node.meta["val"][0].dtype == torch.bfloat16
117+
118+
101119
def test_rewrite_mxfp6_linear_marks_payload_dtype() -> None:
102120
graph_module, cast_nodes, matmul_nodes = _rewrite_linear_module(
103121
MXFPOpConfig(weight_dtype=DTYPE_FP6_E2M3)

0 commit comments

Comments
 (0)