Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions src/chronos/chronos2/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dataclasses import dataclass

import os
import torch
from einops import rearrange
from torch import nn
Expand Down Expand Up @@ -351,19 +352,50 @@ def __init__(self, config: Chronos2CoreConfig):
self.dropout = nn.Dropout(config.dropout_rate)

def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False
) -> AttentionOutput:
# flip time and batch axes because attention operates along dim=-2
hidden_states = rearrange(hidden_states, "batch time d -> time batch d")
normed_hidden_states = self.layer_norm(hidden_states)
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
)
hidden_states = hidden_states + self.dropout(attention_output[0])

if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask
fast_normed_hidden_states = torch.zeros(
(group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)),
dtype=normed_hidden_states.dtype,
device=normed_hidden_states.device,
)
for i in range(group_num):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[
:, start_index:end_index, :
]
fast_attention_output: AttentionOutput = self.self_attention(
fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions
)
attn_weights = fast_attention_output.attn_weights
attention_output = torch.empty_like(normed_hidden_states)
for i in range(group_num):
start_index = group_start_index[i].item()
group_len = counts[i].item()
end_index = start_index + group_len
attention_output[:, start_index:end_index, :] = fast_attention_output[0][
i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :
]
else:
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states, mask=attention_mask, output_attentions=output_attentions
)
attn_weights = attention_output.attn_weights
attention_output = attention_output[0]

hidden_states = hidden_states + self.dropout(attention_output)
# flip time and batch axes back to their original position
hidden_states = rearrange(hidden_states, "time batch d -> batch time d")

return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights)
return AttentionOutput(hidden_states=hidden_states, attn_weights=attn_weights)


class ResidualBlock(nn.Module):
Expand Down
29 changes: 27 additions & 2 deletions src/chronos/chronos2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>

import copy
import os
from dataclasses import dataclass
from typing import cast

Expand Down Expand Up @@ -51,7 +52,7 @@ def forward(
*,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
group_time_mask: torch.Tensor,
group_time_mask: torch.Tensor | tuple,
output_attentions: bool = False,
) -> Chronos2EncoderBlockOutput:
# apply time attention
Expand Down Expand Up @@ -129,6 +130,24 @@ def _construct_and_invert_group_time_mask(
group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b")
group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min

if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
# pad group_time_mask to max group size for less useless computation in self-attention
unique_groups, counts = torch.unique(group_ids, return_counts=True)
max_group_len = counts.max().item()
group_num = unique_groups.size(0)
fast_group_time_mask = torch.full(
(group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device
)
group_start_index = counts.cumsum(dim=0) - counts
for i in range(group_num):
group_len = counts[i].item()
start_index = group_start_index[i].item()
end_index = start_index + group_len
fast_group_time_mask[
i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len
] = group_time_mask[:, :, start_index : end_index, start_index : end_index]
return fast_group_time_mask, group_start_index, counts, max_group_len, group_num

return group_time_mask

def forward(
Expand All @@ -152,7 +171,13 @@ def forward(
extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype)

# construct group time mask
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)
if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1":
flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask(
group_ids, attention_mask, inputs_embeds.dtype
)
group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num)
else:
group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype)

all_time_self_attentions: tuple[torch.Tensor, ...] = ()
all_group_self_attentions: tuple[torch.Tensor, ...] = ()
Expand Down