Skip to content

Commit 6f6ff6a

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Fix GRU w8a32 operator
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
1 parent 5bc996a commit 6f6ff6a

4 files changed

Lines changed: 33 additions & 8 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2854,7 +2854,15 @@ def quantized_w8a32_gru_meta(
28542854
bias_hidden: torch.Tensor,
28552855
b_h_scale: float,
28562856
) -> torch.Tensor:
2857-
return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32)
2857+
seq_len = inputs.shape[1]
2858+
assert seq_len == 1
2859+
# inputs comes in shape [batch, seq_len, input_size]
2860+
# hidden comes in shape [batch, seq_len, hidden_size]
2861+
# weights_inputs comes in shape [3 * hidden_size, input_size]
2862+
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
2863+
# output comes in empty with shape [2, batch, seq_len, hidden_size]
2864+
# The first dimension stacks the output and the new hidden state
2865+
return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32)
28582866

28592867

28602868
# Validate that all meta kernels have reference implementations

backends/cadence/aot/quantizer/patterns.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,14 +696,16 @@ def get_anchors(
696696
)
697697

698698
# Bail if input or states are not multiple of 4 (SIMD)
699-
if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0:
699+
tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None)
700+
if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0:
700701
return (
701702
PartitionAnchors(
702703
empty=True,
703704
),
704705
gru_layer,
705706
)
706-
if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0:
707+
tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None)
708+
if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0:
707709
return (
708710
PartitionAnchors(
709711
empty=True,
@@ -718,13 +720,22 @@ def __init__(self, args, meta):
718720

719721
wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)
720722

723+
# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
724+
# Both biases get the same quantization scale to match the cpp operator
725+
bias_ih_node = wrapper.args[2]
726+
bias_ih_edge = (bias_ih_node, gru_layer)
727+
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)
728+
721729
return (
722730
PartitionAnchors(
723731
inputs=[],
724732
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
725733
weights=[(wrapper, 0), (wrapper, 1)],
726734
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
727-
biases=[(wrapper, 2), (wrapper, 3)],
735+
biases=[
736+
(wrapper, 2), # bias_ih gets normal qspec
737+
(wrapper, 3, shared_bias_qspec), # bias_hh shares observer with bias_ih
738+
],
728739
output=[],
729740
others=[(gru_layer, 0), (gru_layer, 1)],
730741
),

backends/cadence/aot/ref_implementations.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,8 +1060,13 @@ def quantized_w8a32_gru(
10601060

10611061
assert new_hidden.shape == original_hidden_shape
10621062

1063-
new_hidden = new_hidden.view(-1)
1064-
return torch.stack([new_hidden, new_hidden], dim=0)
1063+
batch_size = inputs.shape[0]
1064+
input_dim = inputs.shape[1]
1065+
hidden_dim = hidden.shape[-1]
1066+
1067+
new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)
1068+
1069+
return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)
10651070

10661071

10671072
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,10 +2975,11 @@ def test_quantized_w8a32_gru(
29752975
torch.float32,
29762976
f"Output dtype should be float32 in {name}",
29772977
)
2978+
expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1])
29782979
self.assertEqual(
29792980
output.shape,
2980-
(2, hidden.shape[-1]),
2981-
f"Output shape should match {(2, hidden.shape[-1])} in {name}",
2981+
expected_shape,
2982+
f"Output shape should match {expected_shape} in {name}",
29822983
)
29832984
assert isinstance(output, torch.Tensor)
29842985

0 commit comments

Comments
 (0)