Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 93 additions & 24 deletions i6_models/assemblies/transformer/transformer_decoder_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
"TransformerDecoderBlockV1Config",
"TransformerDecoderBlockV1State",
"TransformerDecoderBlockV1",
"PositionalEncodingV1State",
"SinusoidalPositionalEncodingV1",
"TransformerDecoderV1Config",
"TransformerDecoderV1State",
"TransformerDecoderV1",
]

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from dataclasses import dataclass, field
from typing import List, Optional, Tuple, TypedDict, Union
from typing import List, Optional, Tuple, TypedDict, Union, NotRequired

from i6_models.config import ModelConfiguration
from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerMHSARelPosV1,
ConformerPositionwiseFeedForwardV2,
Expand Down Expand Up @@ -128,6 +129,44 @@ def forward(
return labels, {**state, "module_states": new_states}


class PositionalEncodingV1State(TypedDict):
"""
State for some positional encoding.
"""

pos: Tensor


class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]):
"""
Computes and applies a sinusoidal positional encoding.
"""

def __init__(self, cfg):
super().__init__()

def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State):
"""
Apply sinusoidal positional encoding on the inputs.

:param inputs: input embeddings
:param labels: input labels
:param lengths: input lengths
:param state: current state of positional encoding.
"""
sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should be moved to primitives?

torch.arange(inputs.shape[1], device=inputs.device) + state["pos"], inputs.shape[-1]
)
output = inputs + sinus_pe.unsqueeze(0)

new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()}

return output, new_state

def get_initial_state(self) -> PositionalEncodingV1State:
return {"pos": Tensor(0, dtype=torch.int32)}


@dataclass
class TransformerDecoderV1Config(ModelConfiguration):
"""
Expand All @@ -141,6 +180,9 @@ class TransformerDecoderV1Config(ModelConfiguration):
logits_bias: Whether to add a bias to the output logits.
Usually False is a good choice.
share_embedding: Whether to share the input and output embedding.
positional_encoding: optionally apply some positional encoding to the input embeddings.
output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension.
input_embedding_dim: Input embedding dimension. If None, use the model dimension specified by block_cfg.
"""

block_cfg: TransformerDecoderBlockV1Config
Expand All @@ -150,13 +192,16 @@ class TransformerDecoderV1Config(ModelConfiguration):
num_output: int
logits_bias: bool
share_embedding: bool
positional_encoding: Optional[ModuleFactoryV1] = ModuleFactoryV1(SinusoidalPositionalEncodingV1, None)
output_linear_projection: bool = True
input_embedding_dim: Optional[int] = None


class TransformerDecoderV1State(TypedDict):
"""Recurrent state of the transformer decoder."""

block_state: List[TransformerDecoderBlockV1State]
pos: Tensor
pos: NotRequired[Tensor]


class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]):
Expand All @@ -175,28 +220,53 @@ def __init__(self, cfg: TransformerDecoderV1Config):
self.model_dim = cfg.block_cfg.ff_cfg.input_dim

self.input_dropout = BroadcastDropout(cfg.input_dropout)
self.input_embedding = nn.Embedding(cfg.num_output, self.model_dim)

embedding_dim = cfg.input_embedding_dim
if embedding_dim is None:
embedding_dim = self.model_dim

self.input_embedding = nn.Embedding(cfg.num_output, embedding_dim)
self.input_embedding_scale = (
cfg.input_embedding_scale if cfg.input_embedding_scale is not None else self.model_dim**0.5
cfg.input_embedding_scale if cfg.input_embedding_scale is not None else embedding_dim**0.5
)

if embedding_dim != self.model_dim:
self.embedding_projection = nn.Linear(embedding_dim, self.model_dim)
else:
self.embedding_projection = nn.Identity()

self.module_list = torch.nn.ModuleList(
[TransformerDecoderBlockV1(cfg.block_cfg) for _ in range(cfg.num_blocks)]
)
self.out_norm = nn.LayerNorm(self.model_dim)
self.share_embedding = cfg.share_embedding
if cfg.share_embedding:
assert not cfg.logits_bias, "Cannot use logits bias with shared embedding"
nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init

self.positional_encoding = None
if cfg.positional_encoding is not None:
self.positional_encoding = cfg.positional_encoding()

self.output_linear_projection = cfg.output_linear_projection

if not self.output_linear_projection:
self.out_logits = nn.Identity()
else:
self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realize, this sharing is weird. I would always set self.out_logits. If sharing, you can just do self.out_logits.weights = self.input_embedding.weight. That would simplify the other code.

Also, self.out_logits should always be set (be None if not used). But with my suggestion, you don't need to care about this.

And then you would also allow to have logits_bias=True with share_embedding=True.


if cfg.share_embedding and self.output_linear_projection:
self.out_logits.weight = self.input_embedding.weight
nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init

def get_initial_state(self) -> TransformerDecoderV1State:
""":return: initial decoder state"""
return {
state: TransformerDecoderV1State = {
"block_state": [block.get_initial_state() for block in self.module_list],
"pos": torch.tensor(0, dtype=torch.int32),
}

if self.positional_encoding is not None:
state["pos"] = self.positional_encoding.get_initial_state()["pos"]

return state

def transform_encoder_output(
self,
encoder_output: Tensor,
Expand Down Expand Up @@ -228,26 +298,25 @@ def forward(
- `enc_out, enc_out_mask = forward_some_encoder(...)` and
- `s = get_initial_state()`.
"""
new_state: TransformerDecoderV1State = {**state}

x = self.input_embedding(labels) * self.input_embedding_scale
sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe(
torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim
)
x = x + sinus_pe.unsqueeze(0)
x = self.embedding_projection(x)

if self.positional_encoding is not None:
x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"])
new_state["pos"] = new_pos_state["pos"]

x = self.input_dropout(x)

output = x
new_block_states = []
for block, block_state in zip(self.module_list, state["block_state"]):
output, new_block_state = block(output, labels_lens, block_state)
new_block_states.append(new_block_state)
new_state: TransformerDecoderV1State = {
**state,
"block_state": new_block_states,
"pos": state["pos"] + labels_lens.max(),
}
new_state["block_state"] = new_block_states

output = self.out_norm(output)
output_logits = (
F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output)
)
return output_logits, new_state
output = self.out_logits(output)

return output, new_state
Loading