Skip to content

Commit 7850292

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Fix Conv1d w8a32 operator (pytorch#16607)
Summary: #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Reviewed By: mcremon-meta Differential Revision: D89863750
1 parent 267a59d commit 7850292

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,40 @@ def get_args_and_kwargs_mixed_w8a32_conv(
432432
torch.ops.aten.permute.default,
433433
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
434434
)
435+
# Propagate val metadata for transposed_inputs
436+
if "val" in other_inputs[0].meta:
437+
original_val = other_inputs[0].meta["val"]
438+
fake_mode = original_val.fake_mode
439+
if fake_mode is not None:
440+
with fake_mode:
441+
transposed_val = torch.ops.aten.permute.default(
442+
original_val, [0, 2, 1]
443+
)
444+
transposed_inputs.meta["val"] = transposed_val
445+
else:
446+
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
447+
original_val, [0, 2, 1]
448+
)
435449
copy_node_metadata(transposed_inputs, other_inputs[0])
436450

437451
transposed_weights = graph_module.graph.call_function(
438452
torch.ops.aten.permute.default,
439453
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
440454
)
455+
# Propagate val metadata for transposed_weights
456+
if "val" in weights_inputs[0].meta:
457+
original_val = weights_inputs[0].meta["val"]
458+
fake_mode = original_val.fake_mode
459+
if fake_mode is not None:
460+
with fake_mode:
461+
transposed_val = torch.ops.aten.permute.default(
462+
original_val, [2, 0, 1]
463+
)
464+
transposed_weights.meta["val"] = transposed_val
465+
else:
466+
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
467+
original_val, [2, 0, 1]
468+
)
441469
copy_node_metadata(transposed_weights, weights_inputs[0])
442470

443471
args = (

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,18 @@ def get_anchors(
651651
conv_layer,
652652
)
653653

654+
inputs = conv_layer.args[0]
655+
if "tensor_meta" in inputs.meta:
656+
inputs_shape = inputs.meta["tensor_meta"].shape
657+
# Bail if length != kernel size - Not yet supported
658+
if inputs_shape[-1] != cnn_weights_shape[2]:
659+
return (
660+
PartitionAnchors(
661+
empty=True,
662+
),
663+
conv_layer,
664+
)
665+
654666
return (
655667
PartitionAnchors(
656668
inputs=[],

0 commit comments

Comments
 (0)