Skip to content

Commit 493e84a

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add quantizable GRU and RNN modules for ARM backend (#17141)
Summary: Adds quantizable versions of GRU and RNN modules that can be used with PyTorch quantization-aware training (QAT) for the ARM backend. The standard nn.GRU and nn.RNN are opaque composite ops that the quantizer cannot annotate. These modules decompose the RNN operations into nn.Linear + FloatFunctional so that QAT observers can be inserted at each arithmetic boundary. ## New modules: - `GRUCell`, `_GRUSingleLayer`, `_GRULayer`, `GRU` - `RNNCell`, `_RNNSingleLayer`, `_RNNLayer`, `RNN` ## Features: - `from_float()` class method to convert from nn.GRU/nn.RNN - Multi-layer support - Bidirectional support - Both tanh and relu nonlinearities (for RNN) ## Usage: ```python from executorch.backends.arm.quantizable import GRU, RNN # Create quantizable GRU model = GRU(input_size=10, hidden_size=20, num_layers=2) # Or convert from existing nn.GRU eager_model = torch.nn.GRU(10, 20, 2) eager_model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm") quantizable_model = GRU.from_float(eager_model) ``` Differential Revision: D92059608
1 parent 126507c commit 493e84a

7 files changed

Lines changed: 1179 additions & 0 deletions

File tree

backends/arm/quantizable/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
runtime.python_library(
4+
name = "quantizable",
5+
srcs = glob(["*.py"]),
6+
deps = [
7+
"//caffe2:torch",
8+
],
9+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .gru import GRU # noqa
7+
from .rnn import RNN # noqa

backends/arm/quantizable/gru.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Quantizable GRU modules following the torch.ao.nn.quantizable.LSTM pattern.
7+
8+
The standard nn.GRU is an opaque composite op that the quantizer cannot
9+
annotate. This module decomposes GRU into nn.Linear + FloatFunctional
10+
so that QAT observers can be inserted at each arithmetic boundary.
11+
12+
GRU cell equations:
13+
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
14+
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
15+
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
16+
h_t = (1 - z_t) * n_t + z_t * h_{t-1}
17+
18+
"""
19+
20+
from typing import List, Optional, Tuple
21+
22+
import torch
23+
from torch import nn, Tensor
24+
25+
26+
class GRUCell(nn.Module):
27+
"""A quantizable GRU cell with FloatFunctional ops for each arithmetic boundary."""
28+
29+
_FLOAT_MODULE = nn.GRUCell
30+
31+
def __init__(
32+
self,
33+
input_size: int,
34+
hidden_size: int,
35+
bias: bool = True,
36+
device=None,
37+
dtype=None,
38+
) -> None:
39+
factory_kwargs = {"device": device, "dtype": dtype}
40+
super().__init__()
41+
self.input_size = input_size
42+
self.hidden_size = hidden_size
43+
self.bias = bias
44+
45+
# Input projections: x_t -> [r, z, n] (3*hidden_size)
46+
self.input_linear = nn.Linear(
47+
input_size, 3 * hidden_size, bias=bias, **factory_kwargs
48+
)
49+
# Hidden projections: h_{t-1} -> [r, z, n] (3*hidden_size)
50+
self.hidden_linear = nn.Linear(
51+
hidden_size, 3 * hidden_size, bias=bias, **factory_kwargs
52+
)
53+
54+
# Gate activations
55+
self.reset_gate = nn.Sigmoid()
56+
self.update_gate = nn.Sigmoid()
57+
self.new_gate = nn.Tanh()
58+
59+
# FloatFunctional for each observable arithmetic op
60+
self.add_r = torch.ao.nn.quantized.FloatFunctional() # input_r + hidden_r
61+
self.add_z = torch.ao.nn.quantized.FloatFunctional() # input_z + hidden_z
62+
self.mul_r_nh = torch.ao.nn.quantized.FloatFunctional() # r_t * hidden_n
63+
self.add_n = torch.ao.nn.quantized.FloatFunctional() # input_n + r*hidden_n
64+
self.mul_1mz_n = torch.ao.nn.quantized.FloatFunctional() # (1-z) * n
65+
self.mul_z_h = torch.ao.nn.quantized.FloatFunctional() # z * h_{t-1}
66+
self.add_h = torch.ao.nn.quantized.FloatFunctional() # (1-z)*n + z*h
67+
68+
def forward(self, x: Tensor, hidden: Optional[Tensor] = None) -> Tensor:
69+
if hidden is None:
70+
hidden = torch.zeros(x.shape[0], self.hidden_size, device=x.device)
71+
72+
igates = self.input_linear(x)
73+
hgates = self.hidden_linear(hidden)
74+
75+
# Split into r, z, n components
76+
H = self.hidden_size
77+
input_r, input_z, input_n = (
78+
igates[:, :H],
79+
igates[:, H : 2 * H],
80+
igates[:, 2 * H :],
81+
)
82+
hidden_r, hidden_z, hidden_n = (
83+
hgates[:, :H],
84+
hgates[:, H : 2 * H],
85+
hgates[:, 2 * H :],
86+
)
87+
88+
r_t = self.reset_gate(self.add_r.add(input_r, hidden_r))
89+
z_t = self.update_gate(self.add_z.add(input_z, hidden_z))
90+
n_t = self.new_gate(self.add_n.add(input_n, self.mul_r_nh.mul(r_t, hidden_n)))
91+
92+
h_t = self.add_h.add(
93+
self.mul_1mz_n.mul(1.0 - z_t, n_t),
94+
self.mul_z_h.mul(z_t, hidden),
95+
)
96+
return h_t
97+
98+
@classmethod
99+
def from_params(
100+
cls,
101+
wi: Tensor,
102+
wh: Tensor,
103+
bi: Optional[Tensor] = None,
104+
bh: Optional[Tensor] = None,
105+
) -> "GRUCell":
106+
input_size = wi.shape[1]
107+
hidden_size = wh.shape[1]
108+
cell = cls(input_size, hidden_size, bias=(bi is not None))
109+
cell.input_linear.weight = nn.Parameter(wi)
110+
if bi is not None:
111+
cell.input_linear.bias = nn.Parameter(bi)
112+
cell.hidden_linear.weight = nn.Parameter(wh)
113+
if bh is not None:
114+
cell.hidden_linear.bias = nn.Parameter(bh)
115+
return cell
116+
117+
@classmethod
118+
def from_float(cls, other, use_precomputed_fake_quant=False):
119+
assert type(other) is cls._FLOAT_MODULE
120+
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
121+
observed = cls.from_params(
122+
other.weight_ih,
123+
other.weight_hh,
124+
other.bias_ih,
125+
other.bias_hh,
126+
)
127+
observed.qconfig = other.qconfig
128+
observed.input_linear.qconfig = other.qconfig
129+
observed.hidden_linear.qconfig = other.qconfig
130+
return observed
131+
132+
133+
class _GRUSingleLayer(nn.Module):
134+
"""A single one-directional GRU layer that processes a sequence."""
135+
136+
def __init__(
137+
self,
138+
input_size: int,
139+
hidden_size: int,
140+
bias: bool = True,
141+
device=None,
142+
dtype=None,
143+
) -> None:
144+
factory_kwargs = {"device": device, "dtype": dtype}
145+
super().__init__()
146+
self.cell = GRUCell(input_size, hidden_size, bias=bias, **factory_kwargs)
147+
148+
def forward(
149+
self,
150+
x: Tensor,
151+
hidden: Optional[Tensor] = None,
152+
reverse: bool = False,
153+
) -> Tuple[Tensor, Tensor]:
154+
result = []
155+
seq_len = x.shape[0]
156+
indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
157+
for i in indices:
158+
hidden = self.cell(x[i], hidden)
159+
result.append(hidden)
160+
if reverse:
161+
result.reverse()
162+
return torch.stack(result, 0), hidden
163+
164+
@classmethod
165+
def from_params(cls, *args, **kwargs):
166+
cell = GRUCell.from_params(*args, **kwargs)
167+
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
168+
layer.cell = cell
169+
return layer
170+
171+
172+
class _GRULayer(nn.Module):
173+
"""A single bi-directional GRU layer."""
174+
175+
def __init__(
176+
self,
177+
input_size: int,
178+
hidden_size: int,
179+
bias: bool = True,
180+
batch_first: bool = False,
181+
bidirectional: bool = False,
182+
device=None,
183+
dtype=None,
184+
) -> None:
185+
factory_kwargs = {"device": device, "dtype": dtype}
186+
super().__init__()
187+
self.batch_first = batch_first
188+
self.bidirectional = bidirectional
189+
self.layer_fw = _GRUSingleLayer(
190+
input_size, hidden_size, bias=bias, **factory_kwargs
191+
)
192+
if self.bidirectional:
193+
self.layer_bw = _GRUSingleLayer(
194+
input_size, hidden_size, bias=bias, **factory_kwargs
195+
)
196+
197+
def forward(
198+
self, x: Tensor, hidden: Optional[Tensor] = None
199+
) -> Tuple[Tensor, Tensor]:
200+
if self.batch_first:
201+
x = x.transpose(0, 1)
202+
203+
hx_fw = None
204+
hx_bw = None
205+
if hidden is not None:
206+
if self.bidirectional:
207+
hx_fw = hidden[0]
208+
hx_bw = hidden[1]
209+
else:
210+
hx_fw = hidden
211+
212+
result_fw, h_fw = self.layer_fw(x, hx_fw)
213+
214+
if self.bidirectional:
215+
result_bw, h_bw = self.layer_bw(x, hx_bw, reverse=True)
216+
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
217+
h = torch.stack([h_fw, h_bw], 0)
218+
else:
219+
result = result_fw
220+
h = h_fw
221+
222+
if self.batch_first:
223+
result = result.transpose(0, 1)
224+
225+
return result, h
226+
227+
@classmethod
228+
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
229+
assert hasattr(other, "qconfig") or (qconfig is not None)
230+
231+
input_size = kwargs.get("input_size", other.input_size)
232+
hidden_size = kwargs.get("hidden_size", other.hidden_size)
233+
bias = kwargs.get("bias", other.bias)
234+
batch_first = kwargs.get("batch_first", other.batch_first)
235+
bidirectional = kwargs.get("bidirectional", other.bidirectional)
236+
237+
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
238+
layer.qconfig = getattr(other, "qconfig", qconfig)
239+
240+
wi = getattr(other, f"weight_ih_l{layer_idx}")
241+
wh = getattr(other, f"weight_hh_l{layer_idx}")
242+
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
243+
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
244+
layer.layer_fw = _GRUSingleLayer.from_params(wi, wh, bi, bh)
245+
246+
if other.bidirectional:
247+
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
248+
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
249+
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
250+
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
251+
layer.layer_bw = _GRUSingleLayer.from_params(wi, wh, bi, bh)
252+
return layer
253+
254+
255+
class GRU(nn.Module):
256+
"""A quantizable GRU following the torch.ao.nn.quantizable.LSTM pattern.
257+
258+
Converts a standard nn.GRU into observable form with nn.Linear +
259+
FloatFunctional ops for each arithmetic boundary.
260+
261+
"""
262+
263+
_FLOAT_MODULE = nn.GRU
264+
265+
def __init__(
266+
self,
267+
input_size: int,
268+
hidden_size: int,
269+
num_layers: int = 1,
270+
bias: bool = True,
271+
batch_first: bool = False,
272+
dropout: float = 0.0,
273+
bidirectional: bool = False,
274+
device=None,
275+
dtype=None,
276+
) -> None:
277+
factory_kwargs = {"device": device, "dtype": dtype}
278+
super().__init__()
279+
self.input_size = input_size
280+
self.hidden_size = hidden_size
281+
self.num_layers = num_layers
282+
self.bias = bias
283+
self.batch_first = batch_first
284+
self.dropout = float(dropout)
285+
self.bidirectional = bidirectional
286+
self.training = False
287+
288+
num_directions = 2 if bidirectional else 1
289+
layers: List[_GRULayer] = [
290+
_GRULayer(
291+
input_size,
292+
hidden_size,
293+
bias,
294+
batch_first=False,
295+
bidirectional=bidirectional,
296+
**factory_kwargs,
297+
)
298+
]
299+
for _ in range(1, num_layers):
300+
layers.append(
301+
_GRULayer(
302+
hidden_size * num_directions,
303+
hidden_size,
304+
bias,
305+
batch_first=False,
306+
bidirectional=bidirectional,
307+
**factory_kwargs,
308+
)
309+
)
310+
self.layers = nn.ModuleList(layers)
311+
312+
def forward(
313+
self, x: Tensor, hidden: Optional[Tensor] = None
314+
) -> Tuple[Tensor, Tensor]:
315+
if self.batch_first:
316+
x = x.transpose(0, 1)
317+
318+
num_directions = 2 if self.bidirectional else 1
319+
if hidden is None:
320+
hx_list = [None] * self.num_layers
321+
else:
322+
hx = hidden.reshape(
323+
self.num_layers, num_directions, hidden.shape[-2], hidden.shape[-1]
324+
)
325+
hx_list = [hx[idx].squeeze(0) for idx in range(self.num_layers)]
326+
327+
h_list = []
328+
for idx, layer in enumerate(self.layers):
329+
x, h = layer(x, hx_list[idx])
330+
h_list.append(h)
331+
332+
h_tensor = torch.stack(h_list)
333+
h_tensor = h_tensor.reshape(-1, h_tensor.shape[-2], h_tensor.shape[-1])
334+
335+
if self.batch_first:
336+
x = x.transpose(0, 1)
337+
338+
return x, h_tensor
339+
340+
@classmethod
341+
def from_float(cls, other, qconfig=None):
342+
assert isinstance(other, cls._FLOAT_MODULE)
343+
assert hasattr(other, "qconfig") or qconfig
344+
observed = cls(
345+
other.input_size,
346+
other.hidden_size,
347+
other.num_layers,
348+
other.bias,
349+
other.batch_first,
350+
other.dropout,
351+
other.bidirectional,
352+
)
353+
observed.qconfig = getattr(other, "qconfig", qconfig)
354+
for idx in range(other.num_layers):
355+
observed.layers[idx] = _GRULayer.from_float(
356+
other, idx, qconfig, batch_first=False
357+
)
358+
if other.training:
359+
observed.train()
360+
observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
361+
else:
362+
observed.eval()
363+
observed = torch.ao.quantization.prepare(observed, inplace=True)
364+
return observed

0 commit comments

Comments
 (0)