From 0fa1a41162533e5a04a5693f6410d21897021bda Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Thu, 29 Jan 2026 11:18:41 +0100 Subject: [PATCH 1/5] TransformerDecoder: optional positional encoding and final matmul --- .../transformer/transformer_decoder_v1.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index f4d0e1f2..38710c56 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field -from typing import List, Optional, Tuple, TypedDict, Union +from typing import List, Optional, Tuple, TypedDict, Union, NotRequired from i6_models.config import ModelConfiguration from i6_models.parts.conformer import ( @@ -141,6 +141,8 @@ class TransformerDecoderV1Config(ModelConfiguration): logits_bias: Whether to add a bias to the output logits. Usually False is a good choice. share_embedding: Whether to share the input and output embedding. + use_positional_encoding: use a sinus positional encoding on the initial input + do_output_embedding_matmul: apply the final model output x output embedding matmul """ block_cfg: TransformerDecoderBlockV1Config @@ -150,13 +152,15 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool + use_positional_encoding: bool = True + do_output_embedding_matmul: bool = True class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos: Tensor + pos: NotRequired[Tensor] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -190,13 +194,20 @@ def __init__(self, cfg: TransformerDecoderV1Config): else: self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) + self.use_positional_encoding = cfg.use_positional_encoding + self.do_output_embedding_matmul = cfg.do_output_embedding_matmul + def get_initial_state(self) -> TransformerDecoderV1State: """:return: initial decoder state""" - return { + state: TransformerDecoderV1State = { "block_state": [block.get_initial_state() for block in self.module_list], - "pos": torch.tensor(0, dtype=torch.int32), } + if self.use_positional_encoding: + state["pos"] = torch.tensor(0, dtype=torch.int32) + + return state + def transform_encoder_output( self, encoder_output: Tensor, @@ -229,10 +240,13 @@ def forward( - `s = get_initial_state()`. """ x = self.input_embedding(labels) * self.input_embedding_scale - sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim - ) - x = x + sinus_pe.unsqueeze(0) + + if self.use_positional_encoding: + sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( + torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim + ) + x = x + sinus_pe.unsqueeze(0) + x = self.input_dropout(x) output = x @@ -243,11 +257,16 @@ def forward( new_state: TransformerDecoderV1State = { **state, "block_state": new_block_states, - "pos": state["pos"] + labels_lens.max(), } + if self.use_positional_encoding: + new_state["pos"] = state["pos"] + labels_lens.max() + output = self.out_norm(output) - output_logits = ( - F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) - ) - return output_logits, new_state + + if self.do_output_embedding_matmul: + output = ( + F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) + ) + + return output, new_state From a2897e544d1c33f47192419d3b36aa9817f89c33 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:03:56 +0100 Subject: [PATCH 2/5] initial proposal to commentary --- .../transformer/transformer_decoder_v1.py | 106 +++++++++++++----- 1 file changed, 75 insertions(+), 31 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 38710c56..17ddf036 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -17,12 +17,11 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F from dataclasses import dataclass, field from typing import List, Optional, Tuple, TypedDict, Union, NotRequired -from i6_models.config import ModelConfiguration +from i6_models.config import ModelConfiguration, ModuleFactoryV1 from i6_models.parts.conformer import ( ConformerMHSARelPosV1, ConformerPositionwiseFeedForwardV2, @@ -128,6 +127,52 @@ def forward( return labels, {**state, "module_states": new_states} +@dataclass +class SinusoidalPositionalEncodingV1Config(ModelConfiguration): + """ + Attributes: + embedding_dim: embedding dimension + """ + embedding_dim = int + + +class PositionalEncodingV1State(TypedDict): + """ + State for some positional encoding. + """ + pos: Tensor + + +class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]): + """ + Computes and applies a sinusoidal positional encoding. + """ + def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): + super().__init__() + + self.embed_dim = cfg.embedding_dim + + def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + """ + Apply sinusoidal positional encoding on the inputs. + + :param inputs: tensor to apply the positional encoding on + :param lengths: input lengths + :param state: current state of positional encoding. + """ + sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( + torch.arange(inputs.shape[-1], device=inputs.device) + state["pos"], self.embed_dim + ) + output = inputs + sinus_pe.unsqueeze(0) + + new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()} + + return output, new_state + + def get_initial_state(self) -> PositionalEncodingV1State: + return {"pos": Tensor(0, dtype=torch.int32)} + + @dataclass class TransformerDecoderV1Config(ModelConfiguration): """ @@ -141,8 +186,8 @@ class TransformerDecoderV1Config(ModelConfiguration): logits_bias: Whether to add a bias to the output logits. Usually False is a good choice. share_embedding: Whether to share the input and output embedding. - use_positional_encoding: use a sinus positional encoding on the initial input - do_output_embedding_matmul: apply the final model output x output embedding matmul + positional_encoding: optionally apply some positional encoding to the input embeddings. + output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension. """ block_cfg: TransformerDecoderBlockV1Config @@ -152,15 +197,15 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool - use_positional_encoding: bool = True - do_output_embedding_matmul: bool = True + positional_encoding: Optional[ModuleFactoryV1] + output_linear_projection: bool = True class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos: NotRequired[Tensor] + pos_state: NotRequired[PositionalEncodingV1State] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -188,14 +233,21 @@ def __init__(self, cfg: TransformerDecoderV1Config): ) self.out_norm = nn.LayerNorm(self.model_dim) self.share_embedding = cfg.share_embedding - if cfg.share_embedding: - assert not cfg.logits_bias, "Cannot use logits bias with shared embedding" - nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init + + cfg.positional_encoding = None + if cfg.positional_encoding is not None: + self.positional_encoding = cfg.positional_encoding() + + self.output_linear_projection = cfg.output_linear_projection + + if not self.output_linear_projection: + self.out_logits = nn.Identity() else: self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) - self.use_positional_encoding = cfg.use_positional_encoding - self.do_output_embedding_matmul = cfg.do_output_embedding_matmul + if cfg.share_embedding and self.output_linear_projection: + self.out_logits.weight = self.input_embedding.weight + nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init def get_initial_state(self) -> TransformerDecoderV1State: """:return: initial decoder state""" @@ -203,8 +255,8 @@ def get_initial_state(self) -> TransformerDecoderV1State: "block_state": [block.get_initial_state() for block in self.module_list], } - if self.use_positional_encoding: - state["pos"] = torch.tensor(0, dtype=torch.int32) + if self.positional_encoding is not None: + state["pos_state"] = self.positional_encoding.get_initial_state() return state @@ -239,13 +291,15 @@ def forward( - `enc_out, enc_out_mask = forward_some_encoder(...)` and - `s = get_initial_state()`. """ + new_state: TransformerDecoderV1State = { + **state, + } + x = self.input_embedding(labels) * self.input_embedding_scale - if self.use_positional_encoding: - sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim - ) - x = x + sinus_pe.unsqueeze(0) + if self.positional_encoding is not None: + x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"]) + new_state["pos_state"] = new_pos_state x = self.input_dropout(x) @@ -254,19 +308,9 @@ def forward( for block, block_state in zip(self.module_list, state["block_state"]): output, new_block_state = block(output, labels_lens, block_state) new_block_states.append(new_block_state) - new_state: TransformerDecoderV1State = { - **state, - "block_state": new_block_states, - } - - if self.use_positional_encoding: - new_state["pos"] = state["pos"] + labels_lens.max() + new_state["block_state"] = new_block_states output = self.out_norm(output) - - if self.do_output_embedding_matmul: - output = ( - F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) - ) + output = self.out_logits(output) return output, new_state From 9ec338f597d41964c5bfbf50bcfc5c287ffb4c21 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:13:35 +0100 Subject: [PATCH 3/5] formating --- i6_models/assemblies/transformer/transformer_decoder_v1.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 17ddf036..4153763f 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -133,6 +133,7 @@ class SinusoidalPositionalEncodingV1Config(ModelConfiguration): Attributes: embedding_dim: embedding dimension """ + embedding_dim = int @@ -140,6 +141,7 @@ class PositionalEncodingV1State(TypedDict): """ State for some positional encoding. """ + pos: Tensor @@ -147,6 +149,7 @@ class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodi """ Computes and applies a sinusoidal positional encoding. """ + def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): super().__init__() From 86bfe3e4156d1a0dd2db19e20db7f4c1dd628c87 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:23:14 +0100 Subject: [PATCH 4/5] fix input to _sinusoidal_pe --- .../assemblies/transformer/transformer_decoder_v1.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 4153763f..97179ea3 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -155,16 +155,17 @@ def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): self.embed_dim = cfg.embedding_dim - def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State): """ Apply sinusoidal positional encoding on the inputs. - :param inputs: tensor to apply the positional encoding on + :param inputs: input embeddings + :param labels: input labels :param lengths: input lengths :param state: current state of positional encoding. """ sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(inputs.shape[-1], device=inputs.device) + state["pos"], self.embed_dim + torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim ) output = inputs + sinus_pe.unsqueeze(0) @@ -301,7 +302,7 @@ def forward( x = self.input_embedding(labels) * self.input_embedding_scale if self.positional_encoding is not None: - x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"]) + x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"]) new_state["pos_state"] = new_pos_state x = self.input_dropout(x) From caf8f3ec99a9d7febde70c5c8c093b5c8a285f03 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Thu, 26 Mar 2026 14:30:47 +0100 Subject: [PATCH 5/5] cleanup --- .../transformer/transformer_decoder_v1.py | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 97179ea3..a4c20298 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -10,6 +10,8 @@ "TransformerDecoderBlockV1Config", "TransformerDecoderBlockV1State", "TransformerDecoderBlockV1", + "PositionalEncodingV1State", + "SinusoidalPositionalEncodingV1", "TransformerDecoderV1Config", "TransformerDecoderV1State", "TransformerDecoderV1", @@ -127,16 +129,6 @@ def forward( return labels, {**state, "module_states": new_states} -@dataclass -class SinusoidalPositionalEncodingV1Config(ModelConfiguration): - """ - Attributes: - embedding_dim: embedding dimension - """ - - embedding_dim = int - - class PositionalEncodingV1State(TypedDict): """ State for some positional encoding. @@ -150,12 +142,10 @@ class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodi Computes and applies a sinusoidal positional encoding. """ - def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): + def __init__(self, cfg): super().__init__() - self.embed_dim = cfg.embedding_dim - - def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State): """ Apply sinusoidal positional encoding on the inputs. @@ -165,7 +155,7 @@ def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: Positi :param state: current state of positional encoding. """ sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim + torch.arange(inputs.shape[1], device=inputs.device) + state["pos"], inputs.shape[-1] ) output = inputs + sinus_pe.unsqueeze(0) @@ -192,6 +182,7 @@ class TransformerDecoderV1Config(ModelConfiguration): share_embedding: Whether to share the input and output embedding. positional_encoding: optionally apply some positional encoding to the input embeddings. output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension. + input_embedding_dim: Input embedding dimension. If None, use the model dimension specified by block_cfg. """ block_cfg: TransformerDecoderBlockV1Config @@ -201,15 +192,16 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool - positional_encoding: Optional[ModuleFactoryV1] + positional_encoding: Optional[ModuleFactoryV1] = ModuleFactoryV1(SinusoidalPositionalEncodingV1, None) output_linear_projection: bool = True + input_embedding_dim: Optional[int] = None class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos_state: NotRequired[PositionalEncodingV1State] + pos: NotRequired[Tensor] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -228,17 +220,28 @@ def __init__(self, cfg: TransformerDecoderV1Config): self.model_dim = cfg.block_cfg.ff_cfg.input_dim self.input_dropout = BroadcastDropout(cfg.input_dropout) - self.input_embedding = nn.Embedding(cfg.num_output, self.model_dim) + + embedding_dim = cfg.input_embedding_dim + if embedding_dim is None: + embedding_dim = self.model_dim + + self.input_embedding = nn.Embedding(cfg.num_output, embedding_dim) self.input_embedding_scale = ( - cfg.input_embedding_scale if cfg.input_embedding_scale is not None else self.model_dim**0.5 + cfg.input_embedding_scale if cfg.input_embedding_scale is not None else embedding_dim**0.5 ) + + if embedding_dim != self.model_dim: + self.embedding_projection = nn.Linear(embedding_dim, self.model_dim) + else: + self.embedding_projection = nn.Identity() + self.module_list = torch.nn.ModuleList( [TransformerDecoderBlockV1(cfg.block_cfg) for _ in range(cfg.num_blocks)] ) self.out_norm = nn.LayerNorm(self.model_dim) self.share_embedding = cfg.share_embedding - cfg.positional_encoding = None + self.positional_encoding = None if cfg.positional_encoding is not None: self.positional_encoding = cfg.positional_encoding() @@ -260,7 +263,7 @@ def get_initial_state(self) -> TransformerDecoderV1State: } if self.positional_encoding is not None: - state["pos_state"] = self.positional_encoding.get_initial_state() + state["pos"] = self.positional_encoding.get_initial_state()["pos"] return state @@ -295,15 +298,14 @@ def forward( - `enc_out, enc_out_mask = forward_some_encoder(...)` and - `s = get_initial_state()`. """ - new_state: TransformerDecoderV1State = { - **state, - } + new_state: TransformerDecoderV1State = {**state} x = self.input_embedding(labels) * self.input_embedding_scale + x = self.embedding_projection(x) if self.positional_encoding is not None: - x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"]) - new_state["pos_state"] = new_pos_state + x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"]) + new_state["pos"] = new_pos_state["pos"] x = self.input_dropout(x)