-
Notifications
You must be signed in to change notification settings - Fork 966
Fix Conv1d w8a32 operator (#16607) #16607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv( | |
| torch.ops.aten.permute.default, | ||
| (other_inputs[0], [0, 2, 1]), # NCL -> NLC | ||
| ) | ||
| assert "val" in other_inputs[0].meta, "Missing val metadata on input node" | ||
| original_val = other_inputs[0].meta["val"] | ||
| assert original_val.fake_mode is not None, "fake_mode is None on input node" | ||
| with original_val.fake_mode: | ||
| transposed_inputs.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [0, 2, 1] | ||
| ) | ||
| # Propagate val metadata for transposed_inputs | ||
| if "val" in other_inputs[0].meta: | ||
| original_val = other_inputs[0].meta["val"] | ||
| fake_mode = original_val.fake_mode | ||
| if fake_mode is not None: | ||
| with fake_mode: | ||
| transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1]) | ||
| transposed_inputs.meta["val"] = transposed_val | ||
| else: | ||
| transposed_inputs.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [0, 2, 1] | ||
| ) | ||
|
Comment on lines
+441
to
+452
|
||
| copy_node_metadata(transposed_inputs, other_inputs[0]) | ||
|
|
||
| transposed_weights = graph_module.graph.call_function( | ||
| torch.ops.aten.permute.default, | ||
| (weights_inputs[0], [2, 0, 1]), # NCL -> LNC | ||
| ) | ||
| assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node" | ||
| original_val = weights_inputs[0].meta["val"] | ||
| assert original_val.fake_mode is not None, "fake_mode is None on weight node" | ||
| with original_val.fake_mode: | ||
| transposed_weights.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [2, 0, 1] | ||
| ) | ||
| # Propagate val metadata for transposed_weights | ||
| if "val" in weights_inputs[0].meta: | ||
| original_val = weights_inputs[0].meta["val"] | ||
| fake_mode = original_val.fake_mode | ||
| if fake_mode is not None: | ||
| with fake_mode: | ||
| transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1]) | ||
| transposed_weights.meta["val"] = transposed_val | ||
| else: | ||
| transposed_weights.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [2, 0, 1] | ||
| ) | ||
|
Comment on lines
+441
to
+470
|
||
| copy_node_metadata(transposed_weights, weights_inputs[0]) | ||
|
|
||
| args = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -718,7 +718,7 @@ def get_anchors( | |||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| cnn_weights = conv_layer.args[1] | ||||||||||||||||||||||||
| if hasattr(cnn_weights.meta, "tensor_meta"): | ||||||||||||||||||||||||
| if "tensor_meta" in cnn_weights.meta: | ||||||||||||||||||||||||
| cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape | ||||||||||||||||||||||||
| # Bail if the channels are not multiple of 4 (SIMD) | ||||||||||||||||||||||||
| if cnn_weights_shape[0] % 4 != 0: | ||||||||||||||||||||||||
|
|
@@ -744,6 +744,18 @@ def get_anchors( | |||||||||||||||||||||||
| conv_layer, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| inputs = conv_layer.args[0] | ||||||||||||||||||||||||
| if "tensor_meta" in inputs.meta: | ||||||||||||||||||||||||
| inputs_shape = inputs.meta["tensor_meta"].shape | ||||||||||||||||||||||||
| # Bail if length != kernel size - Not yet supported | ||||||||||||||||||||||||
|
Comment on lines
+747
to
+750
|
||||||||||||||||||||||||
| if inputs_shape[-1] != cnn_weights_shape[2]: | ||||||||||||||||||||||||
|
Comment on lines
+750
to
+751
|
||||||||||||||||||||||||
| # Bail if length != kernel size - Not yet supported | |
| if inputs_shape[-1] != cnn_weights_shape[2]: | |
| # Bail only when the input length is smaller than the kernel size. | |
| # Conv1d supports input lengths greater than the kernel size. | |
| if inputs_shape[-1] < cnn_weights_shape[2]: |
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check restricts the w8a32_conv pattern to only match when the input length equals the kernel size (3). While the comment indicates this is intentionally not yet supported, this is quite restrictive. Standard convolution operations typically support input lengths greater than or equal to the kernel size. The reference implementation in ref_implementations.py (lines 926-970) and the test in test_ref_implementations.py (lines 1156-1166 show length=5 with kernel=3) both support arbitrary input lengths. Consider whether this restriction is necessary, or if it should be relaxed to allow input_length >= kernel_size to enable the optimization in more cases.
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new anchor-bail condition inputs_shape[-1] != cnn_weights_shape[2] incorrectly restricts quantized_w8a32 Conv1d fusion to cases where input length equals kernel size (3). The operator’s fake/meta kernel and reference implementation support general input lengths (output length in_length - kernel + 1), and existing tests exercise in_length=5 with kernel=3 (see backends/cadence/aot/tests/test_ref_implementations.py:1170+). This check will prevent valid fusions and likely regress model coverage; it should be removed or replaced with the actual supported constraint (if any).
| inputs = conv_layer.args[0] | |
| if "tensor_meta" in inputs.meta: | |
| inputs_shape = inputs.meta["tensor_meta"].shape | |
| # Bail if length != kernel size - Not yet supported | |
| if inputs_shape[-1] != cnn_weights_shape[2]: | |
| return ( | |
| PartitionAnchors( | |
| empty=True, | |
| ), | |
| conv_layer, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_args_and_kwargs_mixed_w8a32_convnow conditionally propagatesmeta["val"]for the inserted permute nodes (and adds a fake_mode fallback). There doesn't appear to be any unit/integration test coverage exercising QuantFusion on a Conv1d->quantized_w8a32_conv rewrite, so regressions here (e.g., missing/incorrect meta causing later passes to fail) may go unnoticed.Add a small test that runs QuantFusion on a minimal Conv1d graph and asserts the resulting graph contains the expected permutes +
cadence::quantized_w8a32_conv, and that the graph can run FakeTensor/meta propagation without errors.