Skip to content

Commit 5a920c3

Browse files
Arm backend: Preserve MXFP Conv2d output dtype (#20513)
Infer the MXFP Conv2d wrapper output dtype from the source Conv2d weight dtype, matching the MXFP linear path. Cast the custom op output back to bf16 when the original module is bf16, while keeping the MXFP TOSA op output in fp32. Add AO, export, and rewrite pass tests covering the default fp32 constructor path and inferred bf16 output preservation. Change-Id: I48fb70157439650b329d7db35fd794200fe1545d cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 11f363a commit 5a920c3

3 files changed

Lines changed: 118 additions & 3 deletions

File tree

backends/arm/ao_ext/ops/mxfp_conv2d_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
)
3333

3434

35+
_SUPPORTED_OUTPUT_DTYPES: set[torch.dtype] = {
36+
torch.float32,
37+
torch.bfloat16,
38+
}
39+
40+
3541
def _get_mx_elem_dtype(
3642
weight_qdata: torch.Tensor,
3743
weight_payload_dtype: str = "",
@@ -208,10 +214,12 @@ def __init__(
208214
groups: int,
209215
weight_dtype: MXFPDType,
210216
block_size: int,
217+
output_dtype: torch.dtype = torch.float32,
211218
) -> None:
212219
super().__init__()
213220
self.weight_dtype = mxfp_dtype_to_str(weight_dtype)
214221
self.block_size = block_size
222+
self.output_dtype = output_dtype
215223

216224
self.register_buffer("weight_qdata", weight_qdata, persistent=True)
217225
self.register_buffer("weight_scale", weight_scale, persistent=True)
@@ -233,7 +241,7 @@ def __init__(
233241
self.groups = groups
234242

235243
def forward(self, x: torch.Tensor) -> torch.Tensor:
236-
return torch.ops.tosa_mxfp.conv2d.default(
244+
output = torch.ops.tosa_mxfp.conv2d.default(
237245
x,
238246
self.weight_qdata,
239247
self.weight_scale,
@@ -245,6 +253,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
245253
self.block_size,
246254
self.weight_dtype,
247255
)
256+
if self.output_dtype != torch.float32:
257+
output = output.to(self.output_dtype)
258+
return output
248259

249260

250261
def transform_conv2d_to_mxfp(
@@ -276,6 +287,9 @@ def transform_conv2d_to_mxfp(
276287
)
277288

278289
bias = module.bias.detach().to(torch.float32) if module.bias is not None else None
290+
output_dtype = weight_ohwi.dtype
291+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
292+
raise ValueError(f"Unsupported output_dtype: {output_dtype}")
279293
return MXFPConv2dOp(
280294
weight_qdata,
281295
weight_scale,
@@ -286,4 +300,5 @@ def transform_conv2d_to_mxfp(
286300
module.groups,
287301
config.weight_dtype,
288302
config.block_size,
303+
output_dtype,
289304
)

backends/arm/test/misc/test_mxfp_conv2d_ao.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,61 @@ def test_mxfp_conv2d_quantize_supports_fp4_weights() -> None:
159159
)
160160

161161

162+
def test_mxfp_conv2d_preserves_bfloat16_output_dtype() -> None:
163+
model = Conv2dModule().eval().to(torch.bfloat16)
164+
to_mxfp(
165+
model,
166+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
167+
)
168+
169+
output = model(torch.randn(1, IN_CHANNELS, 8, 8, dtype=torch.bfloat16))
170+
171+
assert isinstance(model.conv, MXFPConv2dOp)
172+
assert model.conv.output_dtype == torch.bfloat16
173+
assert output.dtype == torch.bfloat16
174+
175+
176+
def test_mxfp_conv2d_op_output_dtype_constructor_arg() -> None:
177+
model = Conv2dModule().eval()
178+
config = MXFPOpConfig(weight_dtype=torch.float8_e4m3fn)
179+
to_mxfp(
180+
model,
181+
config,
182+
)
183+
assert isinstance(model.conv, MXFPConv2dOp)
184+
185+
fp32_conv = MXFPConv2dOp(
186+
model.conv.weight_qdata,
187+
model.conv.weight_scale,
188+
model.conv.bias,
189+
model.conv.stride,
190+
model.conv.padding,
191+
model.conv.dilation,
192+
model.conv.groups,
193+
config.weight_dtype,
194+
config.block_size,
195+
)
196+
bf16_conv = MXFPConv2dOp(
197+
model.conv.weight_qdata,
198+
model.conv.weight_scale,
199+
model.conv.bias,
200+
model.conv.stride,
201+
model.conv.padding,
202+
model.conv.dilation,
203+
model.conv.groups,
204+
config.weight_dtype,
205+
config.block_size,
206+
output_dtype=torch.bfloat16,
207+
)
208+
209+
test_input = torch.randn(1, IN_CHANNELS, 8, 8)
210+
211+
assert fp32_conv.output_dtype == torch.float32
212+
assert fp32_conv(test_input).dtype == torch.float32
213+
assert bf16_conv.output_dtype == torch.bfloat16
214+
assert bf16_conv(test_input).dtype == torch.bfloat16
215+
216+
162217
def _test_mxfp_conv2d_export_preserves_custom_op(config: MXFPOpConfig) -> None:
163218
model = Conv2dModule().eval()
164219
to_mxfp(model, config)
@@ -198,6 +253,33 @@ def test_mxfp6_e3m2_conv2d_export_preserves_custom_op() -> None:
198253
)
199254

200255

256+
def test_mxfp_conv2d_export_preserves_inferred_bfloat16_output_dtype() -> None:
257+
model = Conv2dModule().eval().to(torch.bfloat16)
258+
to_mxfp(
259+
model,
260+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
261+
)
262+
263+
exported = export(
264+
model,
265+
(torch.randn(1, IN_CHANNELS, 8, 8, dtype=torch.bfloat16),),
266+
strict=False,
267+
)
268+
269+
cast_nodes = [
270+
node
271+
for node in exported.graph_module.graph.nodes
272+
if node.op == "call_function" and node.target == torch.ops.aten.to.dtype
273+
]
274+
275+
assert len(cast_nodes) == 1
276+
assert cast_nodes[0].args[1] == torch.bfloat16
277+
assert cast_nodes[0].meta["val"].dtype == torch.bfloat16
278+
cast_input = cast_nodes[0].args[0]
279+
assert isinstance(cast_input, torch.fx.Node)
280+
assert cast_input.target == torch.ops.tosa_mxfp.conv2d.default
281+
282+
201283
def test_mxfp_conv2d_cpu_impl_matches_ref() -> None:
202284
ref_model = Conv2dModule().eval()
203285
test_model = Conv2dModule().eval()

backends/arm/test/passes/test_rewrite_mxfp_conv2d_pass.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ def _nodes_from_target(
5959
def _rewrite_conv2d_module(
6060
config: MXFPOpConfig,
6161
bias: bool = True,
62+
model_dtype: torch.dtype = torch.float32,
6263
) -> tuple[torch.fx.GraphModule, list[torch.fx.Node], list[torch.fx.Node]]:
63-
model = _Conv2dModule(bias=bias).eval()
64+
model = _Conv2dModule(bias=bias).eval().to(model_dtype)
6465
to_mxfp(model, config, filter_fn=_is_conv2d)
65-
exported = export(model, (torch.randn(1, 32, 10, 12),), strict=False)
66+
exported = export(
67+
model,
68+
(torch.randn(1, 32, 10, 12, dtype=model_dtype),),
69+
strict=False,
70+
)
6671
tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp")
6772

6873
with TosaLoweringContext(tosa_spec):
@@ -113,6 +118,19 @@ def test_rewrite_mxfp_conv2d_restores_output_shape() -> None:
113118
assert tuple(output_node.meta["val"].shape) == (1, 8, 5, 6)
114119

115120

121+
def test_rewrite_mxfp_conv2d_preserves_inferred_bfloat16_output_cast() -> None:
122+
graph_module, _, conv_nodes = _rewrite_conv2d_module(
123+
MXFPOpConfig(),
124+
model_dtype=torch.bfloat16,
125+
)
126+
127+
output_node = graph_module.graph.output_node()
128+
129+
assert len(conv_nodes) == 1
130+
assert conv_nodes[0].meta["val"].dtype == torch.float32
131+
assert output_node.meta["val"][0].dtype == torch.bfloat16
132+
133+
116134
def test_rewrite_mxfp4_conv2d_marks_payloads() -> None:
117135
model = _Conv2dModule(bias=True).eval()
118136
to_mxfp(

0 commit comments

Comments
 (0)