|
| 1 | +# Copyright 2026 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +"""Quantizable GRU modules following the torch.ao.nn.quantizable.LSTM pattern. |
| 7 | +
|
| 8 | +The standard nn.GRU is an opaque composite op that the quantizer cannot |
| 9 | +annotate. This module decomposes GRU into nn.Linear + FloatFunctional |
| 10 | +so that QAT observers can be inserted at each arithmetic boundary. |
| 11 | +
|
| 12 | +GRU cell equations: |
| 13 | + r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr) |
| 14 | + z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz) |
| 15 | + n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn)) |
| 16 | + h_t = (1 - z_t) * n_t + z_t * h_{t-1} |
| 17 | +
|
| 18 | +""" |
| 19 | + |
| 20 | +from typing import List, Optional, Tuple |
| 21 | + |
| 22 | +import torch |
| 23 | +from torch import nn, Tensor |
| 24 | + |
| 25 | + |
| 26 | +class GRUCell(nn.Module): |
| 27 | + """A quantizable GRU cell with FloatFunctional ops for each arithmetic boundary.""" |
| 28 | + |
| 29 | + _FLOAT_MODULE = nn.GRUCell |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + input_size: int, |
| 34 | + hidden_size: int, |
| 35 | + bias: bool = True, |
| 36 | + device=None, |
| 37 | + dtype=None, |
| 38 | + ) -> None: |
| 39 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 40 | + super().__init__() |
| 41 | + self.input_size = input_size |
| 42 | + self.hidden_size = hidden_size |
| 43 | + self.bias = bias |
| 44 | + |
| 45 | + # Input projections: x_t -> [r, z, n] (3*hidden_size) |
| 46 | + self.input_linear = nn.Linear( |
| 47 | + input_size, 3 * hidden_size, bias=bias, **factory_kwargs |
| 48 | + ) |
| 49 | + # Hidden projections: h_{t-1} -> [r, z, n] (3*hidden_size) |
| 50 | + self.hidden_linear = nn.Linear( |
| 51 | + hidden_size, 3 * hidden_size, bias=bias, **factory_kwargs |
| 52 | + ) |
| 53 | + |
| 54 | + # Gate activations |
| 55 | + self.reset_gate = nn.Sigmoid() |
| 56 | + self.update_gate = nn.Sigmoid() |
| 57 | + self.new_gate = nn.Tanh() |
| 58 | + |
| 59 | + # FloatFunctional for each observable arithmetic op |
| 60 | + self.add_r = torch.ao.nn.quantized.FloatFunctional() # input_r + hidden_r |
| 61 | + self.add_z = torch.ao.nn.quantized.FloatFunctional() # input_z + hidden_z |
| 62 | + self.mul_r_nh = torch.ao.nn.quantized.FloatFunctional() # r_t * hidden_n |
| 63 | + self.add_n = torch.ao.nn.quantized.FloatFunctional() # input_n + r*hidden_n |
| 64 | + self.mul_1mz_n = torch.ao.nn.quantized.FloatFunctional() # (1-z) * n |
| 65 | + self.mul_z_h = torch.ao.nn.quantized.FloatFunctional() # z * h_{t-1} |
| 66 | + self.add_h = torch.ao.nn.quantized.FloatFunctional() # (1-z)*n + z*h |
| 67 | + |
| 68 | + def forward(self, x: Tensor, hidden: Optional[Tensor] = None) -> Tensor: |
| 69 | + if hidden is None: |
| 70 | + hidden = torch.zeros(x.shape[0], self.hidden_size, device=x.device) |
| 71 | + |
| 72 | + igates = self.input_linear(x) |
| 73 | + hgates = self.hidden_linear(hidden) |
| 74 | + |
| 75 | + # Split into r, z, n components |
| 76 | + H = self.hidden_size |
| 77 | + input_r, input_z, input_n = ( |
| 78 | + igates[:, :H], |
| 79 | + igates[:, H : 2 * H], |
| 80 | + igates[:, 2 * H :], |
| 81 | + ) |
| 82 | + hidden_r, hidden_z, hidden_n = ( |
| 83 | + hgates[:, :H], |
| 84 | + hgates[:, H : 2 * H], |
| 85 | + hgates[:, 2 * H :], |
| 86 | + ) |
| 87 | + |
| 88 | + r_t = self.reset_gate(self.add_r.add(input_r, hidden_r)) |
| 89 | + z_t = self.update_gate(self.add_z.add(input_z, hidden_z)) |
| 90 | + n_t = self.new_gate(self.add_n.add(input_n, self.mul_r_nh.mul(r_t, hidden_n))) |
| 91 | + |
| 92 | + h_t = self.add_h.add( |
| 93 | + self.mul_1mz_n.mul(1.0 - z_t, n_t), |
| 94 | + self.mul_z_h.mul(z_t, hidden), |
| 95 | + ) |
| 96 | + return h_t |
| 97 | + |
| 98 | + @classmethod |
| 99 | + def from_params( |
| 100 | + cls, |
| 101 | + wi: Tensor, |
| 102 | + wh: Tensor, |
| 103 | + bi: Optional[Tensor] = None, |
| 104 | + bh: Optional[Tensor] = None, |
| 105 | + ) -> "GRUCell": |
| 106 | + input_size = wi.shape[1] |
| 107 | + hidden_size = wh.shape[1] |
| 108 | + cell = cls(input_size, hidden_size, bias=(bi is not None)) |
| 109 | + cell.input_linear.weight = nn.Parameter(wi) |
| 110 | + if bi is not None: |
| 111 | + cell.input_linear.bias = nn.Parameter(bi) |
| 112 | + cell.hidden_linear.weight = nn.Parameter(wh) |
| 113 | + if bh is not None: |
| 114 | + cell.hidden_linear.bias = nn.Parameter(bh) |
| 115 | + return cell |
| 116 | + |
| 117 | + @classmethod |
| 118 | + def from_float(cls, other, use_precomputed_fake_quant=False): |
| 119 | + assert type(other) is cls._FLOAT_MODULE |
| 120 | + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" |
| 121 | + observed = cls.from_params( |
| 122 | + other.weight_ih, |
| 123 | + other.weight_hh, |
| 124 | + other.bias_ih, |
| 125 | + other.bias_hh, |
| 126 | + ) |
| 127 | + observed.qconfig = other.qconfig |
| 128 | + observed.input_linear.qconfig = other.qconfig |
| 129 | + observed.hidden_linear.qconfig = other.qconfig |
| 130 | + return observed |
| 131 | + |
| 132 | + |
| 133 | +class _GRUSingleLayer(nn.Module): |
| 134 | + """A single one-directional GRU layer that processes a sequence.""" |
| 135 | + |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + input_size: int, |
| 139 | + hidden_size: int, |
| 140 | + bias: bool = True, |
| 141 | + device=None, |
| 142 | + dtype=None, |
| 143 | + ) -> None: |
| 144 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 145 | + super().__init__() |
| 146 | + self.cell = GRUCell(input_size, hidden_size, bias=bias, **factory_kwargs) |
| 147 | + |
| 148 | + def forward( |
| 149 | + self, |
| 150 | + x: Tensor, |
| 151 | + hidden: Optional[Tensor] = None, |
| 152 | + reverse: bool = False, |
| 153 | + ) -> Tuple[Tensor, Tensor]: |
| 154 | + result = [] |
| 155 | + seq_len = x.shape[0] |
| 156 | + indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len) |
| 157 | + for i in indices: |
| 158 | + hidden = self.cell(x[i], hidden) |
| 159 | + result.append(hidden) |
| 160 | + if reverse: |
| 161 | + result.reverse() |
| 162 | + return torch.stack(result, 0), hidden |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def from_params(cls, *args, **kwargs): |
| 166 | + cell = GRUCell.from_params(*args, **kwargs) |
| 167 | + layer = cls(cell.input_size, cell.hidden_size, cell.bias) |
| 168 | + layer.cell = cell |
| 169 | + return layer |
| 170 | + |
| 171 | + |
| 172 | +class _GRULayer(nn.Module): |
| 173 | + """A single bi-directional GRU layer.""" |
| 174 | + |
| 175 | + def __init__( |
| 176 | + self, |
| 177 | + input_size: int, |
| 178 | + hidden_size: int, |
| 179 | + bias: bool = True, |
| 180 | + batch_first: bool = False, |
| 181 | + bidirectional: bool = False, |
| 182 | + device=None, |
| 183 | + dtype=None, |
| 184 | + ) -> None: |
| 185 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 186 | + super().__init__() |
| 187 | + self.batch_first = batch_first |
| 188 | + self.bidirectional = bidirectional |
| 189 | + self.layer_fw = _GRUSingleLayer( |
| 190 | + input_size, hidden_size, bias=bias, **factory_kwargs |
| 191 | + ) |
| 192 | + if self.bidirectional: |
| 193 | + self.layer_bw = _GRUSingleLayer( |
| 194 | + input_size, hidden_size, bias=bias, **factory_kwargs |
| 195 | + ) |
| 196 | + |
| 197 | + def forward( |
| 198 | + self, x: Tensor, hidden: Optional[Tensor] = None |
| 199 | + ) -> Tuple[Tensor, Tensor]: |
| 200 | + if self.batch_first: |
| 201 | + x = x.transpose(0, 1) |
| 202 | + |
| 203 | + hx_fw = None |
| 204 | + hx_bw = None |
| 205 | + if hidden is not None: |
| 206 | + if self.bidirectional: |
| 207 | + hx_fw = hidden[0] |
| 208 | + hx_bw = hidden[1] |
| 209 | + else: |
| 210 | + hx_fw = hidden |
| 211 | + |
| 212 | + result_fw, h_fw = self.layer_fw(x, hx_fw) |
| 213 | + |
| 214 | + if self.bidirectional: |
| 215 | + result_bw, h_bw = self.layer_bw(x, hx_bw, reverse=True) |
| 216 | + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) |
| 217 | + h = torch.stack([h_fw, h_bw], 0) |
| 218 | + else: |
| 219 | + result = result_fw |
| 220 | + h = h_fw |
| 221 | + |
| 222 | + if self.batch_first: |
| 223 | + result = result.transpose(0, 1) |
| 224 | + |
| 225 | + return result, h |
| 226 | + |
| 227 | + @classmethod |
| 228 | + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): |
| 229 | + assert hasattr(other, "qconfig") or (qconfig is not None) |
| 230 | + |
| 231 | + input_size = kwargs.get("input_size", other.input_size) |
| 232 | + hidden_size = kwargs.get("hidden_size", other.hidden_size) |
| 233 | + bias = kwargs.get("bias", other.bias) |
| 234 | + batch_first = kwargs.get("batch_first", other.batch_first) |
| 235 | + bidirectional = kwargs.get("bidirectional", other.bidirectional) |
| 236 | + |
| 237 | + layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) |
| 238 | + layer.qconfig = getattr(other, "qconfig", qconfig) |
| 239 | + |
| 240 | + wi = getattr(other, f"weight_ih_l{layer_idx}") |
| 241 | + wh = getattr(other, f"weight_hh_l{layer_idx}") |
| 242 | + bi = getattr(other, f"bias_ih_l{layer_idx}", None) |
| 243 | + bh = getattr(other, f"bias_hh_l{layer_idx}", None) |
| 244 | + layer.layer_fw = _GRUSingleLayer.from_params(wi, wh, bi, bh) |
| 245 | + |
| 246 | + if other.bidirectional: |
| 247 | + wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") |
| 248 | + wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") |
| 249 | + bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) |
| 250 | + bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) |
| 251 | + layer.layer_bw = _GRUSingleLayer.from_params(wi, wh, bi, bh) |
| 252 | + return layer |
| 253 | + |
| 254 | + |
| 255 | +class GRU(nn.Module): |
| 256 | + """A quantizable GRU following the torch.ao.nn.quantizable.LSTM pattern. |
| 257 | +
|
| 258 | +Converts a standard nn.GRU into observable form with nn.Linear + |
| 259 | + FloatFunctional ops for each arithmetic boundary. |
| 260 | +
|
| 261 | + """ |
| 262 | + |
| 263 | + _FLOAT_MODULE = nn.GRU |
| 264 | + |
| 265 | + def __init__( |
| 266 | + self, |
| 267 | + input_size: int, |
| 268 | + hidden_size: int, |
| 269 | + num_layers: int = 1, |
| 270 | + bias: bool = True, |
| 271 | + batch_first: bool = False, |
| 272 | + dropout: float = 0.0, |
| 273 | + bidirectional: bool = False, |
| 274 | + device=None, |
| 275 | + dtype=None, |
| 276 | + ) -> None: |
| 277 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 278 | + super().__init__() |
| 279 | + self.input_size = input_size |
| 280 | + self.hidden_size = hidden_size |
| 281 | + self.num_layers = num_layers |
| 282 | + self.bias = bias |
| 283 | + self.batch_first = batch_first |
| 284 | + self.dropout = float(dropout) |
| 285 | + self.bidirectional = bidirectional |
| 286 | + self.training = False |
| 287 | + |
| 288 | + num_directions = 2 if bidirectional else 1 |
| 289 | + layers: List[_GRULayer] = [ |
| 290 | + _GRULayer( |
| 291 | + input_size, |
| 292 | + hidden_size, |
| 293 | + bias, |
| 294 | + batch_first=False, |
| 295 | + bidirectional=bidirectional, |
| 296 | + **factory_kwargs, |
| 297 | + ) |
| 298 | + ] |
| 299 | + for _ in range(1, num_layers): |
| 300 | + layers.append( |
| 301 | + _GRULayer( |
| 302 | + hidden_size * num_directions, |
| 303 | + hidden_size, |
| 304 | + bias, |
| 305 | + batch_first=False, |
| 306 | + bidirectional=bidirectional, |
| 307 | + **factory_kwargs, |
| 308 | + ) |
| 309 | + ) |
| 310 | + self.layers = nn.ModuleList(layers) |
| 311 | + |
| 312 | + def forward( |
| 313 | + self, x: Tensor, hidden: Optional[Tensor] = None |
| 314 | + ) -> Tuple[Tensor, Tensor]: |
| 315 | + if self.batch_first: |
| 316 | + x = x.transpose(0, 1) |
| 317 | + |
| 318 | + num_directions = 2 if self.bidirectional else 1 |
| 319 | + if hidden is None: |
| 320 | + hx_list = [None] * self.num_layers |
| 321 | + else: |
| 322 | + hx = hidden.reshape( |
| 323 | + self.num_layers, num_directions, hidden.shape[-2], hidden.shape[-1] |
| 324 | + ) |
| 325 | + hx_list = [hx[idx].squeeze(0) for idx in range(self.num_layers)] |
| 326 | + |
| 327 | + h_list = [] |
| 328 | + for idx, layer in enumerate(self.layers): |
| 329 | + x, h = layer(x, hx_list[idx]) |
| 330 | + h_list.append(h) |
| 331 | + |
| 332 | + h_tensor = torch.stack(h_list) |
| 333 | + h_tensor = h_tensor.reshape(-1, h_tensor.shape[-2], h_tensor.shape[-1]) |
| 334 | + |
| 335 | + if self.batch_first: |
| 336 | + x = x.transpose(0, 1) |
| 337 | + |
| 338 | + return x, h_tensor |
| 339 | + |
| 340 | + @classmethod |
| 341 | + def from_float(cls, other, qconfig=None): |
| 342 | + assert isinstance(other, cls._FLOAT_MODULE) |
| 343 | + assert hasattr(other, "qconfig") or qconfig |
| 344 | + observed = cls( |
| 345 | + other.input_size, |
| 346 | + other.hidden_size, |
| 347 | + other.num_layers, |
| 348 | + other.bias, |
| 349 | + other.batch_first, |
| 350 | + other.dropout, |
| 351 | + other.bidirectional, |
| 352 | + ) |
| 353 | + observed.qconfig = getattr(other, "qconfig", qconfig) |
| 354 | + for idx in range(other.num_layers): |
| 355 | + observed.layers[idx] = _GRULayer.from_float( |
| 356 | + other, idx, qconfig, batch_first=False |
| 357 | + ) |
| 358 | + if other.training: |
| 359 | + observed.train() |
| 360 | + observed = torch.ao.quantization.prepare_qat(observed, inplace=True) |
| 361 | + else: |
| 362 | + observed.eval() |
| 363 | + observed = torch.ao.quantization.prepare(observed, inplace=True) |
| 364 | + return observed |
0 commit comments