diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index b00e8a8c..ca3132bb 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -5,6 +5,7 @@ from dataclasses import dataclass +import os import torch from einops import rearrange from torch import nn @@ -351,19 +352,50 @@ def __init__(self, config: Chronos2CoreConfig): self.dropout = nn.Dropout(config.dropout_rate) def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False ) -> AttentionOutput: # flip time and batch axes because attention operates along dim=-2 hidden_states = rearrange(hidden_states, "batch time d -> time batch d") normed_hidden_states = self.layer_norm(hidden_states) - attention_output: AttentionOutput = self.self_attention( - normed_hidden_states, mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) + + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask + fast_normed_hidden_states = torch.zeros( + (group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)), + dtype=normed_hidden_states.dtype, + device=normed_hidden_states.device, + ) + for i in range(group_num): + start_index = group_start_index[i].item() + group_len = counts[i].item() + end_index = start_index + group_len + fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[ + :, start_index:end_index, : + ] + fast_attention_output: AttentionOutput = self.self_attention( + fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions + ) + attn_weights = fast_attention_output.attn_weights + attention_output = torch.empty_like(normed_hidden_states) + for i in range(group_num): + start_index = group_start_index[i].item() + group_len = counts[i].item() + end_index = start_index + group_len + attention_output[:, start_index:end_index, :] = fast_attention_output[0][ + i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, : + ] + else: + attention_output: AttentionOutput = self.self_attention( + normed_hidden_states, mask=attention_mask, output_attentions=output_attentions + ) + attn_weights = attention_output.attn_weights + attention_output = attention_output[0] + + hidden_states = hidden_states + self.dropout(attention_output) # flip time and batch axes back to their original position hidden_states = rearrange(hidden_states, "time batch d -> batch time d") - return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) + return AttentionOutput(hidden_states=hidden_states, attn_weights=attn_weights) class ResidualBlock(nn.Module): diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2a..411fb54a 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -4,6 +4,7 @@ # Authors: Abdul Fatir Ansari import copy +import os from dataclasses import dataclass from typing import cast @@ -51,7 +52,7 @@ def forward( *, position_ids: torch.Tensor, attention_mask: torch.Tensor, - group_time_mask: torch.Tensor, + group_time_mask: torch.Tensor | tuple, output_attentions: bool = False, ) -> Chronos2EncoderBlockOutput: # apply time attention @@ -129,6 +130,24 @@ def _construct_and_invert_group_time_mask( group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b") group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + # pad group_time_mask to max group size for less useless computation in self-attention + unique_groups, counts = torch.unique(group_ids, return_counts=True) + max_group_len = counts.max().item() + group_num = unique_groups.size(0) + fast_group_time_mask = torch.full( + (group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device + ) + group_start_index = counts.cumsum(dim=0) - counts + for i in range(group_num): + group_len = counts[i].item() + start_index = group_start_index[i].item() + end_index = start_index + group_len + fast_group_time_mask[ + i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len + ] = group_time_mask[:, :, start_index : end_index, start_index : end_index] + return fast_group_time_mask, group_start_index, counts, max_group_len, group_num + return group_time_mask def forward( @@ -152,7 +171,13 @@ def forward( extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype) # construct group time mask - group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype) + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask( + group_ids, attention_mask, inputs_embeds.dtype + ) + group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num) + else: + group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype) all_time_self_attentions: tuple[torch.Tensor, ...] = () all_group_self_attentions: tuple[torch.Tensor, ...] = ()