diff --git a/scripts/run_hpo.py b/scripts/run_hpo.py new file mode 100644 index 00000000..07256fa7 --- /dev/null +++ b/scripts/run_hpo.py @@ -0,0 +1,87 @@ +import argparse +import json +import logging +import os +import sys +from typing import List, Optional + +from ts_benchmark.hpo import run_optuna_search + +# 确保可以导入 ts_benchmark 包 +sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run hyperparameter optimisation using Optuna.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config-path", + type=str, + required=True, + help="Relative path to the config JSON under the config directory.", + ) + parser.add_argument( + "--data-name-list", + type=str, + nargs="+", + required=True, + help="One or more series names on which to perform HPO.", + ) + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Fully qualified model name to optimise (e.g. 'olinear.OLinear').", + ) + parser.add_argument( + "--save-path", + type=str, + default="", + help=( + "Relative path under the result directory to store HPO outputs. " + "If empty, results will be saved directly under 'result/'." + ), + ) + parser.add_argument( + "--adapter", + type=str, + default=None, + help="Optional adapter name to wrap the model during evaluation.", + ) + parser.add_argument( + "--n-trials", + type=int, + default=10, + help="Number of Optuna trials to run.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility.", + ) + + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s(%(lineno)d): %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logging.getLogger("optuna").setLevel(logging.WARNING) + + result = run_optuna_search( + config_path=args.config_path, + data_name_list=args.data_name_list, + model_name=args.model_name, + save_path=args.save_path, + n_trials=args.n_trials, + seed=args.seed, + ) + + # Pretty-print the best trial results to stdout. + print(json.dumps(result, indent=2)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ts_benchmark/baselines/__init__.py b/ts_benchmark/baselines/__init__.py index 48cafcad..d18f4379 100755 --- a/ts_benchmark/baselines/__init__.py +++ b/ts_benchmark/baselines/__init__.py @@ -5,3 +5,5 @@ "darts_regression_model_adapter": "ts_benchmark.baselines.darts.darts_regression_model_adapter", "transformer_adapter": "ts_benchmark.baselines.time_series_library.adapters_for_transformers.transformer_adapter", } + +from .olinear import OLinear \ No newline at end of file diff --git a/ts_benchmark/baselines/olinear/__init__.py b/ts_benchmark/baselines/olinear/__init__.py new file mode 100644 index 00000000..07240afa --- /dev/null +++ b/ts_benchmark/baselines/olinear/__init__.py @@ -0,0 +1,20 @@ +from types import ModuleType +import sys + +from .olinear import OLinear + + +class _CallableModule(ModuleType): + def __call__(self, **kwargs): + return OLinear(**kwargs) + + def required_hyper_params(self): + if hasattr(OLinear, "required_hyper_params"): + return OLinear.required_hyper_params() + return {} + + +_current_module = sys.modules[__name__] +_current_module.__class__ = _CallableModule + +__all__ = ["OLinear"] diff --git a/ts_benchmark/baselines/olinear/layers/AutoCorrelation.py b/ts_benchmark/baselines/olinear/layers/AutoCorrelation.py new file mode 100644 index 00000000..c5ec78c3 --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/AutoCorrelation.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +import math +from math import sqrt +import os + + +class AutoCorrelation(nn.Module): + """ + AutoCorrelation Mechanism with the following two phases: + (1) period-based dependencies discovery + (2) time delay aggregation + This block can replace the self-attention family mechanism seamlessly. + """ + + def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): + super(AutoCorrelation, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def time_delay_agg_training(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the training phase. + """ + # values: b,h,d,l + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] + weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + pattern = torch.roll(tmp_values, -int(index[i]), -1) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def time_delay_agg_inference(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the inference phase. + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + weights, delay = torch.topk(mean_value, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def time_delay_agg_full(self, values, corr): + """ + Standard version of Autocorrelation + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() + # find top k + top_k = int(self.factor * math.log(length)) + weights, delay = torch.topk(corr, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[..., i].unsqueeze(-1) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) + return delays_agg + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + + # period-based dependencies + q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) + k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) + res = q_fft * torch.conj(k_fft) + corr = torch.fft.irfft(res, dim=-1) + + # time delay agg + if self.training: + V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + else: + V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + + if self.output_attention: + return V.contiguous(), corr.permute(0, 3, 1, 2) + else: + return V.contiguous(), None + + +class AutoCorrelationLayer(nn.Module): + def __init__(self, correlation, d_model, n_heads, d_keys=None, + d_values=None): + super(AutoCorrelationLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_correlation = correlation + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_correlation( + queries, + keys, + values, + attn_mask + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn diff --git a/ts_benchmark/baselines/olinear/layers/Autoformer_EncDec.py b/ts_benchmark/baselines/olinear/layers/Autoformer_EncDec.py new file mode 100644 index 00000000..6fce4bcd --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Autoformer_EncDec.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class my_Layernorm(nn.Module): + """ + Special designed layernorm for the seasonal part + """ + + def __init__(self, channels): + super(my_Layernorm, self).__init__() + self.layernorm = nn.LayerNorm(channels) + + def forward(self, x): + x_hat = self.layernorm(x) + bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) + return x_hat - bias + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class series_decomp_multi(nn.Module): + """ + Multiple Series decomposition block from FEDformer + """ + + def __init__(self, kernel_size): + super(series_decomp_multi, self).__init__() + self.kernel_size = kernel_size + self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] + + def forward(self, x): + moving_mean = [] + res = [] + for func in self.series_decomp: + sea, moving_avg = func(x) + moving_mean.append(moving_avg) + res.append(sea) + + sea = sum(res) / len(res) + moving_mean = sum(moving_mean) / len(moving_mean) + return sea, moving_mean + + +class EncoderLayer(nn.Module): + """ + Autoformer encoder layer with the progressive decomposition architecture + """ + + def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + x, _ = self.decomp1(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + res, _ = self.decomp2(x + y) + return res, attn + + +class Encoder(nn.Module): + """ + Autoformer encoder + """ + + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + """ + Autoformer decoder layer with the progressive decomposition architecture + """ + + def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, + moving_avg=25, dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.decomp3 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, + padding_mode='circular', bias=False) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask + )[0]) + x, trend1 = self.decomp1(x) + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask + )[0]) + x, trend2 = self.decomp2(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + x, trend3 = self.decomp3(x + y) + + residual_trend = trend1 + trend2 + trend3 + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) + return x, residual_trend + + +class Decoder(nn.Module): + """ + Autoformer encoder + """ + + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend diff --git a/ts_benchmark/baselines/olinear/layers/AxialAttention.py b/ts_benchmark/baselines/olinear/layers/AxialAttention.py new file mode 100644 index 00000000..fbab09ed --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/AxialAttention.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from torch.nn import Softmax + + +class RowAttention(nn.Module): + def __init__(self, in_dim, q_k_dim): + super(RowAttention, self).__init__() + self.in_dim = in_dim + self.q_k_dim = q_k_dim + + self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) + self.softmax = Softmax(dim=2) + self.gamma = nn.Parameter(torch.zeros(1)) + + + def forward(self, x): + b, _, h, w = x.size() + Q = self.query_conv(x) + K = self.key_conv(x) + V = self.value_conv(x) + + Q = Q.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w).permute(0, 2, 1) + K = K.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) + V = V.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) + + row_attn = torch.bmm(Q, K) + row_attn = self.softmax(row_attn) + out = torch.bmm(V, row_attn.permute(0, 2, 1)) + out = out.view(b, h, -1, w).permute(0, 2, 1, 3) + out = self.gamma * out + x + return out + + +class ColAttention(nn.Module): + def __init__(self, in_dim, q_k_dim): + super(ColAttention, self).__init__() + self.in_dim = in_dim + self.q_k_dim = q_k_dim + + self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) + self.softmax = Softmax(dim=2) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + b, _, h, w = x.size() + Q = self.query_conv(x) + K = self.key_conv(x) + V = self.value_conv(x) + + Q = Q.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h).permute(0, 2, 1) # size = (b*w,h,c2) + K = K.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c2,h) + V = V.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c1,h) + + col_attn = torch.bmm(Q, K) + col_attn = self.softmax(col_attn) + out = torch.bmm(V, col_attn.permute(0, 2, 1)) + out = out.view(b, w, -1, h).permute(0, 2, 3, 1) + out = self.gamma * out + x + return out diff --git a/ts_benchmark/baselines/olinear/layers/Conv_Blocks.py b/ts_benchmark/baselines/olinear/layers/Conv_Blocks.py new file mode 100644 index 00000000..7d3d3586 --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Conv_Blocks.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Inception_Block_V1(nn.Module): + def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): + super(Inception_Block_V1, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + kernels = [] + for i in range(self.num_kernels): + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + res_list = [] + for i in range(self.num_kernels): + res_list.append(self.kernels[i](x)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res + + +class Inception_Block_V2(nn.Module): + def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): + super(Inception_Block_V2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + kernels = [] + for i in range(self.num_kernels // 2): + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=(1, 2 * i + 3), padding=(0, i + 1))) + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=(2 * i + 3, 1), padding=(i + 1, 0))) + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + res_list = [] + for i in range(self.num_kernels + 1): + res_list.append(self.kernels[i](x)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res + + +class ResNetBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, out_channels, hidden_channels=None, stride=1, downsample=None, act_layer=None): + super().__init__() + # 第一个卷积层 + hidden_channels = hidden_channels or out_channels + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(hidden_channels) + # 第二个卷积层 + self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + # 下采样,用于调整维度 + self.downsample = downsample + + self.act_layer = act_layer or nn.GELU() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act_layer(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + # out = self.act_layer(out) + + return out + diff --git a/ts_benchmark/baselines/olinear/layers/Conv_Blocks_plus.py b/ts_benchmark/baselines/olinear/layers/Conv_Blocks_plus.py new file mode 100644 index 00000000..8b54ce5c --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Conv_Blocks_plus.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + + +class Inception_Block_V1(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, num_kernels=6, init_weight=True): + super(Inception_Block_V1, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + self.stride = stride + kernels = [] + for i in range(self.num_kernels): + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i, stride=stride)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + res_list = [] + for i in range(self.num_kernels): + res_list.append(self.kernels[i](x)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res + + +class Inception_Trans_Block_V1(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, num_kernels=6, init_weight=True): + super(Inception_Trans_Block_V1, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + self.stride = stride + + kernels = [] + for i in range(self.num_kernels): + kernels.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i, stride=stride)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, output_size): + res_list = [] + for i in range(self.num_kernels): + res_list.append(self.kernels[i](x, output_size=output_size)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res \ No newline at end of file diff --git a/ts_benchmark/baselines/olinear/layers/Embed.py b/ts_benchmark/baselines/olinear/layers/Embed.py new file mode 100644 index 00000000..16ee5b8a --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Embed.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + minute_x = self.minute_embed(x[:, :, 4]) if hasattr( + self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h': 4, 't': 5, 's': 6, + 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = self.value_embedding( + x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + return self.dropout(x) + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_inverted, self).__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark=None): + # x_mark: [b,l,1] as covariates + # x: [B,L,N]-->[b,n,l] + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + # the potential to take covariates (e.g. timestamps) as tokens + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] [b,n,l'] which is different from the input size + return self.dropout(x) + + +class DataEmbedding_conv(nn.Module): + def __init__(self, c_in=96, d_model=256, token_num=7, dropout=0.1, PatchTST_flag=False): + super(DataEmbedding_conv, self).__init__() + self.PatchTST_flag = PatchTST_flag + assert c_in > 0 and d_model > 0 + self.act_layer = nn.GELU() + + if not self.PatchTST_flag: + self.norm_layer = nn.LayerNorm(d_model) + + self.linear_layer = nn.Linear(c_in, d_model) + + self.positional_encoding = nn.Parameter(torch.zeros(1, 1, d_model)) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark=None): + # x_mark: [b,l,n] as covariates + # x: [B,L,N]-->[b,n,l] + x = x.permute(0, 2, 1) + + if self.PatchTST_flag: + return x + + x = self.linear_layer(x) + self.positional_encoding + + if x_mark is not None: + # the potential to take covariates (e.g. timestamps) as tokens + x = torch.cat([x, self.linear_layer(x_mark.permute(0, 2, 1))], 1) + + # x: [Batch Variate d_model] [b,n,l'] which is different from the input size + # return self.dropout(x) + return self.norm_layer(x) + # return self.act_layer(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + # if x_mark is None: + # x = self.value_embedding(x) + # else: + # x = self.value_embedding(x) + self.temporal_embedding(x_mark) + + x = self.value_embedding(x) + return self.dropout(x) + + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout, one_output=False): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.one_output = one_output + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + + if self.one_output: + return self.dropout(x) + else: + return self.dropout(x), n_vars diff --git a/ts_benchmark/baselines/olinear/layers/FANLayer.py b/ts_benchmark/baselines/olinear/layers/FANLayer.py new file mode 100644 index 00000000..6febfffe --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/FANLayer.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FANLayer(nn.Module): + """ + FANLayer: The layer used in FAN (https://arxiv.org/abs/2410.02675). + + Args: + input_dim (int): The number of input features. + output_dim (int): The number of output features. + p_ratio (float): The ratio of output dimensions used for cosine and sine parts (default: 0.25). + activation (str or callable): The activation function to apply to the g component. If a string is passed, + the corresponding activation from torch.nn.functional is used (default: 'gelu'). + use_p_bias (bool): If True, include bias in the linear transformations of p component (default: True). + There is almost no difference between bias and non-bias in our experiments. + """ + + def __init__(self, input_dim, output_dim, p_ratio=0.25, activation='gelu', use_p_bias=True): + super(FANLayer, self).__init__() + + # Ensure the p_ratio is within a valid range + assert 0 < p_ratio < 0.5, "p_ratio must be between 0 and 0.5" + + self.p_ratio = p_ratio + p_output_dim = int(output_dim * self.p_ratio) + g_output_dim = output_dim - p_output_dim * 2 # Account for cosine and sine terms + + # Linear transformation for the p component (for cosine and sine parts) + self.input_linear_p = nn.Linear(input_dim, p_output_dim, bias=use_p_bias) + + # Linear transformation for the g component + self.input_linear_g = nn.Linear(input_dim, g_output_dim) + + # Set the activation function + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation if activation else lambda x: x + + def forward(self, src): + """ + Args: + src (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + Tensor: Output tensor of shape (batch_size, output_dim), after applying the FAN layer. + """ + + # Apply the linear transformation followed by the activation for the g component + g = self.activation(self.input_linear_g(src)) + + # Apply the linear transformation for the p component + p = self.input_linear_p(src) + + # Concatenate cos(p), sin(p), and activated g along the last dimension + output = torch.cat((torch.cos(p), torch.sin(p), g), dim=-1) + + return output diff --git a/ts_benchmark/baselines/olinear/layers/Leddam.py b/ts_benchmark/baselines/olinear/layers/Leddam.py new file mode 100644 index 00000000..b49fd7aa --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Leddam.py @@ -0,0 +1,381 @@ +import torch +import torch.nn as nn +import math +from torch import Tensor +import torch.nn.functional as F +from typing import Optional + + +class Leddam(nn.Module): + def __init__(self, configs, + enc_in, + seq_len, + d_model, + dropout, + pe_type, + kernel_size, + n_layers=3): + + super(Leddam, self).__init__() + self.n_layers = n_layers + self.LD = LD(kernel_size=kernel_size) + self.channel_attn_blocks = nn.ModuleList([ + channel_attn_block(enc_in, d_model, dropout) + for _ in range(self.n_layers) + ]) + self.auto_attn_blocks = nn.ModuleList([ + auto_attn_block(enc_in, d_model, dropout) + for _ in range(self.n_layers) + ]) + self.position_embedder = DataEmbedding(pe_type=pe_type, seq_len=seq_len, + d_model=d_model, c_in=enc_in) + + def forward(self, inp): + inp = self.position_embedder(inp.permute(0, 2, 1)).permute(0, 2, 1) + main = self.LD(inp) + residual = inp - main + + res_1 = residual + res_2 = residual + for i in range(self.n_layers): + res_1 = self.auto_attn_blocks[i](res_1) + for i in range(self.n_layers): + res_2 = self.channel_attn_blocks[i](res_2) + res = res_1 + res_2 + + return res, main + + +class channel_attn_block(nn.Module): + def __init__(self, enc_in, d_model, dropout): + super(channel_attn_block, self).__init__() + self.channel_att_norm = nn.BatchNorm1d(enc_in) + self.fft_norm = nn.LayerNorm(d_model) + self.channel_attn = MultiheadAttention(d_model=d_model, n_heads=1, proj_dropout=dropout) + self.fft_layer = nn.Sequential( + nn.Linear(d_model, int(d_model * 2)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(d_model * 2), d_model), + ) + + def forward(self, residual): + res_2 = self.channel_att_norm(self.channel_attn(residual.permute(0, 2, 1)) + residual.permute(0, 2, 1)) + res_2 = self.fft_norm(self.fft_layer(res_2) + res_2) + return res_2.permute(0, 2, 1) + + +class auto_attn_block(nn.Module): + def __init__(self, enc_in, d_model, dropout): + super(auto_attn_block, self).__init__() + self.auto_attn_norm = nn.BatchNorm1d(enc_in) + self.fft_norm = nn.LayerNorm(d_model) + self.auto_attn = Auto_Attention(P=64, d_model=d_model, proj_dropout=dropout) + self.fft_layer = nn.Sequential( + nn.Linear(d_model, int(d_model * 2)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(d_model * 2), d_model), + ) + + def forward(self, residual): + res_1 = self.auto_attn_norm((self.auto_attn(residual) + residual).permute(0, 2, 1)) + res_1 = self.fft_norm(self.fft_layer(res_1) + res_1) + return res_1.permute(0, 2, 1) + + +class LD(nn.Module): + def __init__(self, kernel_size=25): + super(LD, self).__init__() + # Define a shared convolution layers for all channels + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, stride=1, padding=int(kernel_size // 2), + padding_mode='replicate', bias=True) + # Define the parameters for Gaussian initialization + kernel_size_half = kernel_size // 2 + sigma = 1.0 # 1 for variance + weights = torch.zeros(1, 1, kernel_size) + for i in range(kernel_size): + weights[0, 0, i] = math.exp(-((i - kernel_size_half) / (2 * sigma)) ** 2) + + # Set the weights of the convolution layer + self.conv.weight.data = F.softmax(weights, dim=-1) + self.conv.bias.data.fill_(0.0) + + def forward(self, inp): + # Permute the input tensor to match the expected shape for 1D convolution (B, N, T) + inp = inp.permute(0, 2, 1) + # Split the input tensor into separate channels + input_channels = torch.split(inp, 1, dim=1) + + # Apply convolution to each channel + conv_outputs = [self.conv(input_channel) for input_channel in input_channels] + + # Concatenate the channel outputs + out = torch.cat(conv_outputs, dim=1) + out = out.permute(0, 2, 1) + return out + + +class Auto_Attention(nn.Module): + def __init__(self, P, d_model, proj_dropout=0.2): + """ + Initialize the Auto-Attention module. + + Args: + d_model (int): The input and output dimension for queries, keys, and values. + """ + super(Auto_Attention, self).__init__() + self.W_Q = nn.Linear(d_model, d_model) + self.W_K = nn.Linear(d_model, d_model) + self.W_V = nn.Linear(d_model, d_model) + self.out_projector = nn.Sequential(nn.Linear(d_model, d_model), nn.Dropout(proj_dropout)) + self.P = P + self.scale = nn.Parameter(torch.tensor(d_model ** -0.5), requires_grad=False) + + def auto_attention(self, inp): + """ + Perform auto-attention mechanism on the input. + + Args: + inp (torch.Tensor): Input data of shape [B, N, T], where B is the batch size, + N is the number of features, and T is the sequence length. + Returns: + output (torch.Tensor): Output after auto-attention. + """ + # Separate query and key + query = self.W_Q(inp[:, :, 0, :].unsqueeze(-2)) # Query + keys = self.W_K(inp) # Keys + values = self.W_V(inp) # Values + + # Calculate dot product + attn_scores = torch.matmul(query, keys.transpose(-2, -1)) * self.scale + + # Normalize attention scores + attn_scores = F.softmax(attn_scores, dim=-1) + + # Weighted sum + output = torch.matmul(attn_scores, values) + + return output + + def forward(self, inp): + """ + Forward pass of the Auto-Attention module. + + Args: + P (int): The period for autoregressive behavior. + inp (torch.Tensor): Input data of shape [B, T, N], where B is the batch size, + T is the sequence length, and N is the number of features. + + Returns: + output (torch.Tensor): Output after autoregressive self-attention. + """ + # Permute the input for further processing + inp = inp.permute(0, 2, 1) # [B, T, N] -> [B, N, T] + + T = inp.size(-1) + + cat_sequences = [inp] + index = int(T / self.P) - 1 if T % self.P == 0 else int(T / self.P) + + for i in range(index): + end = (i + 1) * self.P + + # Concatenate sequences to support autoregressive behavior + cat_sequence = torch.cat([inp[:, :, end:], inp[:, :, 0:end]], dim=-1) + cat_sequences.append(cat_sequence) + + # Stack the concatenated sequences + output = torch.stack(cat_sequences, dim=-1) + + # Permute the output for attention calculation + output = output.permute(0, 1, 3, 2) + + # Apply autoregressive self-attention + output = self.auto_attention(output).squeeze(-2) + output = self.out_projector(output).permute(0, 2, 1) + + return output + + +class MultiheadAttention(nn.Module): + def __init__(self, d_model, n_heads=1, attn_dropout=0., proj_dropout=0.2): + """Multi Head Attention Layer + Input shape: + Q: [batch_size (bs) x max_q_len x d_model] + K, V: [batch_size (bs) x q_len x d_model] + mask: [q_len x q_len] + """ + super().__init__() + d_k = d_model // n_heads + d_v = d_model // n_heads + + self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v + + self.W_Q = nn.Linear(d_model, d_k * n_heads) + self.W_K = nn.Linear(d_model, d_k * n_heads) + self.W_V = nn.Linear(d_model, d_v * n_heads) + + # Scaled Dot-Product Attention (multiple heads) + self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout) + # Poject output + self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) + + def forward(self, Q: Tensor, K: Optional[Tensor] = None, V: Optional[Tensor] = None, prev: Optional[Tensor] = None, + ): + + bs = Q.size(0) + if K is None: + K = Q + if V is None: + V = Q + # Linear (+ split in multiple heads) + # q_s : [bs x n_heads x max_q_len x d_k] + q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2) + # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) + k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1) + v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2) # v_s : [bs x n_heads x q_len x d_v] + + # Apply Scaled Dot-Product Attention (multiple heads) + if prev is not None: + output, prev = self.sdp_attn(q_s, k_s, v_s) + else: + output = self.sdp_attn(q_s, k_s, v_s) + # output: [bs x n_heads x q_len x d_v] + + # back to the original inputs dimensions + output = output.transpose(1, 2).contiguous().view(bs, -1, + self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] + output = self.to_out(output) + if prev is not None: + return output, prev + else: + return output + + +class ScaledDotProductAttention(nn.Module): + def __init__(self, d_model, n_heads, attn_dropout=0.): + super().__init__() + self.attn_dropout = nn.Dropout(attn_dropout) + head_dim = d_model // n_heads + self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=False) + + def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None): + """ + Input shape: + q : [bs x n_heads x max_q_len x d_k] + k : [bs x n_heads x d_k x seq_len] + v : [bs x n_heads x seq_len x d_v] + prev : [bs x n_heads x q_len x seq_len] + Output shape: + output: [bs x n_heads x q_len x d_v] + attn : [bs x n_heads x q_len x seq_len] + scores : [bs x n_heads x q_len x seq_len] + """ + # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence + attn_scores = torch.matmul(q, k) * (self.scale) # Scale + + # Add pre-softmax attention scores from the previous layer (optional) + if prev is not None: + attn_scores = attn_scores + prev + # Normalize the attention weights + attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights: [bs x n_heads x max_q_len x q_len] + attn_weights = self.attn_dropout(attn_weights) + + # Compute the new values given the attention weights + output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] + if prev is not None: + return output, attn_scores + else: + return output + + +class DataEmbedding(nn.Module): + def __init__(self, pe_type, seq_len, d_model, c_in, dropout=0.): + super(DataEmbedding, self).__init__() + + self.value_embedding = nn.Linear(seq_len, d_model) + self.position_embedding = positional_encoding(pe=pe_type, learn_pe=True, q_len=c_in, d_model=d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + x = self.value_embedding(x) + self.position_embedding + return self.dropout(x) + + +# pos_encoding + +def SinCosPosEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + + return pe + + +def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3): + x = .5 if exponential else 1 + i = 0 + for i in range(100): + cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * ( + torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 + if abs(cpe.mean()) <= eps: + break + elif cpe.mean() > eps: + x += .001 + else: + x -= .001 + i += 1 + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + + return cpe + + +def Coord1dPosEncoding(q_len, exponential=False, normalize=True): + cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** (.5 if exponential else 1)) - 1) + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + + return cpe + + +def positional_encoding(pe, learn_pe, q_len, d_model): + # Positional encoding + if pe == None or pe == 'no': + W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((q_len, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((q_len, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((q_len, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((q_len, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'lin1d': + W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) + elif pe == 'exp1d': + W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) + elif pe == 'lin2d': + W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) + elif pe == 'exp2d': + W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) + elif pe == 'sincos': + W_pos = SinCosPosEncoding(q_len, d_model, normalize=True) + else: + raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) diff --git a/ts_benchmark/baselines/olinear/layers/Leddam_enhanced.py b/ts_benchmark/baselines/olinear/layers/Leddam_enhanced.py new file mode 100644 index 00000000..440e7238 --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Leddam_enhanced.py @@ -0,0 +1,518 @@ +import torch +import torch.nn as nn +import math, os +from torch import Tensor +import torch.nn.functional as F +from typing import Optional +import random + +from ..utils.tools import plot_mat + + +class Leddam(nn.Module): + def __init__(self, configs, + enc_in, + seq_len, + d_model, + dropout, + pe_type, + kernel_size, + n_layers=3): + + super(Leddam, self).__init__() + self.n_layers = n_layers + self.LD = LD(kernel_size=kernel_size) + self.channel_attn_blocks = nn.ModuleList([ + channel_attn_block(configs, enc_in, d_model, dropout) + for _ in range(self.n_layers) + ]) + self.auto_attn_blocks = nn.ModuleList([ + auto_attn_block(enc_in, d_model, dropout) + for _ in range(self.n_layers) + ]) + self.position_embedder = DataEmbedding(pe_type=pe_type, seq_len=seq_len, + d_model=d_model, c_in=enc_in) + + def forward(self, inp): + inp = self.position_embedder(inp.permute(0, 2, 1)).permute(0, 2, 1) + main = self.LD(inp) + residual = inp - main + + res_1 = residual + res_2 = residual + for i in range(self.n_layers): + res_1 = self.auto_attn_blocks[i](res_1) + for i in range(self.n_layers): + res_2 = self.channel_attn_blocks[i](res_2) + res = res_1 + res_2 + + return res, main + + +class channel_attn_block(nn.Module): + def __init__(self, configs, enc_in, d_model, dropout): + super(channel_attn_block, self).__init__() + self.channel_att_norm = nn.BatchNorm1d(enc_in) + self.fft_norm = nn.LayerNorm(d_model) + self.channel_attn = MultiheadAttention(configs, d_model=d_model, n_heads=1, proj_dropout=dropout) + self.fft_layer = nn.Sequential( + nn.Linear(d_model, int(d_model * 2)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(d_model * 2), d_model), + ) + + def forward(self, residual): + res_2 = self.channel_att_norm(self.channel_attn(residual.permute(0, 2, 1)) + residual.permute(0, 2, 1)) + res_2 = self.fft_norm(self.fft_layer(res_2) + res_2) + return res_2.permute(0, 2, 1) + + +class auto_attn_block(nn.Module): + def __init__(self, enc_in, d_model, dropout): + super(auto_attn_block, self).__init__() + self.auto_attn_norm = nn.BatchNorm1d(enc_in) + self.fft_norm = nn.LayerNorm(d_model) + self.auto_attn = Auto_Attention(P=64, d_model=d_model, proj_dropout=dropout) + self.fft_layer = nn.Sequential( + nn.Linear(d_model, int(d_model * 2)), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(int(d_model * 2), d_model), + ) + + def forward(self, residual): + res_1 = self.auto_attn_norm((self.auto_attn(residual) + residual).permute(0, 2, 1)) + res_1 = self.fft_norm(self.fft_layer(res_1) + res_1) + return res_1.permute(0, 2, 1) + + +class LD(nn.Module): + def __init__(self, kernel_size=25): + super(LD, self).__init__() + # Define a shared convolution layers for all channels + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, stride=1, padding=int(kernel_size // 2), + padding_mode='replicate', bias=True) + # Define the parameters for Gaussian initialization + kernel_size_half = kernel_size // 2 + sigma = 1.0 # 1 for variance + weights = torch.zeros(1, 1, kernel_size) + for i in range(kernel_size): + weights[0, 0, i] = math.exp(-((i - kernel_size_half) / (2 * sigma)) ** 2) + + # Set the weights of the convolution layer + self.conv.weight.data = F.softmax(weights, dim=-1) + self.conv.bias.data.fill_(0.0) + + def forward(self, inp): + # Permute the input tensor to match the expected shape for 1D convolution (B, N, T) + inp = inp.permute(0, 2, 1) + # Split the input tensor into separate channels + input_channels = torch.split(inp, 1, dim=1) + + # Apply convolution to each channel + conv_outputs = [self.conv(input_channel) for input_channel in input_channels] + + # Concatenate the channel outputs + out = torch.cat(conv_outputs, dim=1) + out = out.permute(0, 2, 1) + return out + + +class Auto_Attention(nn.Module): + def __init__(self, P, d_model, proj_dropout=0.2): + """ + Initialize the Auto-Attention module. + + Args: + d_model (int): The input and output dimension for queries, keys, and values. + """ + super(Auto_Attention, self).__init__() + self.W_Q = nn.Linear(d_model, d_model) + self.W_K = nn.Linear(d_model, d_model) + self.W_V = nn.Linear(d_model, d_model) + self.out_projector = nn.Sequential(nn.Linear(d_model, d_model), nn.Dropout(proj_dropout)) + self.P = P + self.scale = nn.Parameter(torch.tensor(d_model ** -0.5), requires_grad=False) + + def auto_attention(self, inp): + """ + Perform auto-attention mechanism on the input. + + Args: + inp (torch.Tensor): Input data of shape [B, N, T], where B is the batch size, + N is the number of features, and T is the sequence length. + Returns: + output (torch.Tensor): Output after auto-attention. + """ + # Separate query and key + query = self.W_Q(inp[:, :, 0, :].unsqueeze(-2)) # Query + keys = self.W_K(inp) # Keys + values = self.W_V(inp) # Values + + # Calculate dot product + attn_scores = torch.matmul(query, keys.transpose(-2, -1)) * self.scale + + # Normalize attention scores + attn_scores = F.softmax(attn_scores, dim=-1) + + # Weighted sum + output = torch.matmul(attn_scores, values) + + return output + + def forward(self, inp): + """ + Forward pass of the Auto-Attention module. + + Args: + P (int): The period for autoregressive behavior. + inp (torch.Tensor): Input data of shape [B, T, N], where B is the batch size, + T is the sequence length, and N is the number of features. + + Returns: + output (torch.Tensor): Output after autoregressive self-attention. + """ + # Permute the input for further processing + inp = inp.permute(0, 2, 1) # [B, T, N] -> [B, N, T] + + T = inp.size(-1) + + cat_sequences = [inp] + index = int(T / self.P) - 1 if T % self.P == 0 else int(T / self.P) + + for i in range(index): + end = (i + 1) * self.P + + # Concatenate sequences to support autoregressive behavior + cat_sequence = torch.cat([inp[:, :, end:], inp[:, :, 0:end]], dim=-1) + cat_sequences.append(cat_sequence) + + # Stack the concatenated sequences [B,N,T,tokens] + output = torch.stack(cat_sequences, dim=-1) + + # Permute the output for attention calculation [B,N,tokens,T] + output = output.permute(0, 1, 3, 2) + + # Apply autoregressive self-attention + output = self.auto_attention(output).squeeze(-2) + output = self.out_projector(output).permute(0, 2, 1) + + return output + + +class MultiheadAttention(nn.Module): + def __init__(self, configs, d_model, n_heads=1, attn_dropout=0., proj_dropout=0.2): + """Multi Head Attention Layer + Input shape: + Q: [batch_size (bs) x max_q_len x d_model] + K, V: [batch_size (bs) x q_len x d_model] + mask: [q_len x q_len] + """ + super().__init__() + + self.attnLinear = configs.Leddam_attnLinear + + if self.attnLinear: + print('AttnLinear is used...') + + self.token_num = configs.enc_in + + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + self.dropout = nn.Dropout(attn_dropout) + self.dropout_proj = nn.Dropout(proj_dropout) + else: + print('MultiheadAttention is used...') + + d_k = d_model // n_heads + d_v = d_model // n_heads + + self.configs, self.n_heads, self.d_k, self.d_v = configs, n_heads, d_k, d_v + + self.W_Q = nn.Linear(d_model, d_k * n_heads) + self.W_K = nn.Linear(d_model, d_k * n_heads) + self.W_V = nn.Linear(d_model, d_v * n_heads) + + # Scaled Dot-Product Attention (multiple heads) + self.sdp_attn = ScaledDotProductAttention(configs, d_model, n_heads, attn_dropout=attn_dropout) + # Project output + self.to_out = nn.Sequential( + nn.Linear(n_heads * d_v, d_model), + nn.Dropout(proj_dropout) + ) + + def forward(self, Q: Tensor, K: Optional[Tensor] = None, V: Optional[Tensor] = None, prev: Optional[Tensor] = None, + ): + + bs = Q.size(0) + + if self.attnLinear: + # Q: b, n, d + values = self.v_proj(Q) + A = F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + + new_x = A @ values + output = self.dropout_proj(self.out_proj(new_x)) + + return output + + else: + + if K is None: + K = Q + if V is None: + V = Q + # Linear (+ split in multiple heads) + # q_s : [bs x n_heads x max_q_len x d_k] + q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2) + # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) + k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1) + v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, + 2) # v_s : [bs x n_heads x q_len x d_v] + + # Apply Scaled Dot-Product Attention (multiple heads) + if prev is not None: + output, prev = self.sdp_attn(q_s, k_s, v_s) + else: + output = self.sdp_attn(q_s, k_s, v_s) + # output: [bs x n_heads x q_len x d_v] + + # back to the original inputs dimensions + output = output.transpose(1, 2).contiguous().view(bs, -1, + self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] + output = self.to_out(output) + + if prev is not None: + return output, prev + else: + return output + + +class ScaledDotProductAttention(nn.Module): + def __init__(self, configs, d_model, n_heads, attn_dropout=0.): + super().__init__() + self.attn_dropout = nn.Dropout(attn_dropout) + self.num_heads = n_heads + head_dim = d_model // n_heads + self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=False) + + self.SF_mode = configs.attn_enhance + self.softmax_flag = configs.attn_softmax_flag + self.weight_plus = configs.attn_weight_plus + self.outside_softmax = configs.attn_outside_softmax + + self.plot_mat_flag = configs.plot_mat_flag + self.save_folder = os.path.join('./attn_results', + f'{configs.plot_mat_label}_{configs.seq_len}_{configs.pred_len}') + + self.token_num = configs.enc_in + + if self.SF_mode: + print('Enhanced attention is used...') + print(f'self.weight_plus in FullAttention_ablation: {self.weight_plus}') + print(f'self.softmax_flag in FullAttention_ablation: {self.softmax_flag}') + print(f'self.outside_softmax in FullAttention_ablation: {self.outside_softmax}') + else: + print('Vanilla attention is used...') + + if self.SF_mode and self.token_num is not None: + # [1,1,N,1] + if self.softmax_flag: + self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + + init_weight_mat = (torch.eye(self.token_num) * 1.0 + + torch.randn(self.token_num, self.token_num) * 1.0) + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :].repeat(1, self.num_heads or 1, 1, 1)) + + def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None): + """ + Input shape: + q : [bs x n_heads x max_q_len x d_k] + k : [bs x n_heads x d_k x seq_len] + v : [bs x n_heads x seq_len x d_v] + prev : [bs x n_heads x q_len x seq_len] + Output shape: + output: [bs x n_heads x q_len x d_v] + attn : [bs x n_heads x q_len x seq_len] + scores : [bs x n_heads x q_len x seq_len] + """ + # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence + A = torch.matmul(q, k) * self.scale # Scale + + # Add pre-softmax attention scores from the previous layer (optional) + if prev is not None: + A = A + prev + + weight_mat = None + ori_attn_mat = None + if not self.training and self.plot_mat_flag and self.SF_mode: + ori_attn_mat = A.softmax(dim=-1) + + # attention matrix adjustment; 240507 + if self.SF_mode and self.token_num is not None: + # token_contribution = self.token_contribution.to(queries.device) + + # 2d + if self.softmax_flag: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau), dim=-1) + else: + weight_mat = F.softplus(self.weight_mat) + + if self.SF_mode and weight_mat is not None: + if self.outside_softmax: + if self.weight_plus: + A = A.softmax(dim=-1) + weight_mat + else: + A = A.softmax(dim=-1) * weight_mat + A = F.normalize(A, p=1, dim=-1) + else: + if self.weight_plus: + A = A + weight_mat + else: + # ablation: this is a better choice + A = A * weight_mat + + # attention matrix [b,h,l,s] + A = torch.softmax(A, dim=-1) + else: + A = torch.softmax(A, dim=-1) + + # plot + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', + str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + if ori_attn_mat is not None: + ori_att_mat_2d = ori_attn_mat[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(ori_att_mat_2d, str_cat='ori_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_" + f"{time_or_channel}", + save_folder=self.save_folder) + + if weight_mat is not None: + batch_idx2 = min(batch_idx, weight_mat.shape[0] - 1) + head_idx2 = min(head_idx, weight_mat.shape[1] - 1) + weight_mat_2d = weight_mat[batch_idx2, head_idx2, :, :] + + plot_mat(weight_mat_2d, str_cat='adding_mat_2d', + str0=f"batch_{batch_idx2}_head_{head_idx2}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in FullAttention has been saved to {self.save_folder}...') + + # Normalize the attention weights + # attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights: [bs x n_heads x max_q_len x q_len] + attn_weights = self.attn_dropout(A) + + # Compute the new values given the attention weights + output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] + + if prev is not None: + return output, A + else: + return output + + +class DataEmbedding(nn.Module): + def __init__(self, pe_type, seq_len, d_model, c_in, dropout=0.): + super(DataEmbedding, self).__init__() + + self.value_embedding = nn.Linear(seq_len, d_model) + self.position_embedding = positional_encoding(pe=pe_type, learn_pe=True, q_len=c_in, d_model=d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + x = self.value_embedding(x) + self.position_embedding + return self.dropout(x) + + +# pos_encoding + +def SinCosPosEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + + return pe + + +def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3): + x = .5 if exponential else 1 + i = 0 + for i in range(100): + cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * ( + torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 + if abs(cpe.mean()) <= eps: + break + elif cpe.mean() > eps: + x += .001 + else: + x -= .001 + i += 1 + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + + return cpe + + +def Coord1dPosEncoding(q_len, exponential=False, normalize=True): + cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** (.5 if exponential else 1)) - 1) + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + + return cpe + + +def positional_encoding(pe, learn_pe, q_len, d_model): + # Positional encoding + if pe == None or pe == 'no': + W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((q_len, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((q_len, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((q_len, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((q_len, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'lin1d': + W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) + elif pe == 'exp1d': + W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) + elif pe == 'lin2d': + W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) + elif pe == 'exp2d': + W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) + elif pe == 'sincos': + W_pos = SinCosPosEncoding(q_len, d_model, normalize=True) + else: + raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) diff --git a/ts_benchmark/baselines/olinear/layers/RevIN.py b/ts_benchmark/baselines/olinear/layers/RevIN.py new file mode 100644 index 00000000..38055e8c --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/RevIN.py @@ -0,0 +1,96 @@ +# code from https://github.com/ts-kim/RevIN, with minor modifications + +import torch +import torch.nn as nn + +from ..utils.tools import forward_fill + + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + self.mask = None + if self.affine: + self._init_params() + + def forward(self, x, mode: str, mask=None): + # x [b,l,n] + if mode == 'norm': + self._get_statistics(x, mask) + x = self._normalize(x, mask) + elif mode == 'denorm': + x = self._denormalize(x) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x, mask=None): + self.mask = mask + dim2reduce = tuple(range(1, x.ndim - 1)) + if self.subtract_last: + self.last = x[:, -1, :].unsqueeze(1) + else: + if mask is None: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + else: + assert isinstance(mask, torch.Tensor) + # print(type(mask)) + x = x.masked_fill(mask, 0) # in case other values are filled + self.mean = (torch.sum(x, dim=1) / torch.sum(~mask, dim=1)).unsqueeze(1).detach() + # self.mean could be nan or inf + self.mean = torch.nan_to_num(self.mean, nan=0.0, posinf=0.0, neginf=0.0) + + if mask is None: + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + else: + self.stdev = (torch.sqrt(torch.sum((x - self.mean) ** 2, dim=1) / torch.sum(~mask, dim=1) + self.eps) + .unsqueeze(1).detach()) + self.stdev = torch.nan_to_num(self.stdev, nan=0.0, posinf=None, neginf=None) + + def _normalize(self, x, mask=None): + self.mask = mask + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + + x = x / self.stdev + + # x should be zero, if the values are masked + if mask is not None: + # forward fill + # x, mask2 = forward_fill(x, mask) + # x = x.masked_fill(mask2, 0) + + # mean imputation + x = x.masked_fill(mask, 0) + + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x diff --git a/ts_benchmark/baselines/olinear/layers/RevIN_leddam.py b/ts_benchmark/baselines/olinear/layers/RevIN_leddam.py new file mode 100644 index 00000000..65a2292a --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/RevIN_leddam.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + + +class RevIN(nn.Module): + def __init__(self, channel, output_dim): + super(RevIN, self).__init__() + self.stdev = None + self.means = None + self.output_dim = output_dim + + def forward(self, x): + # Calculate mean and std along dim=1 + self.means = x.mean(1, keepdim=True).detach() + self.stdev = torch.sqrt(x.var(1, keepdim=True, unbiased=False) + 1e-5) + + # Normalize using learned parameters + x_normalized = (x - self.means) / self.stdev + return x_normalized + + def inverse_normalize(self, x_normalized): + x_normalized = x_normalized * \ + (self.stdev[:, 0, :].unsqueeze(1).repeat( + 1, self.output_dim, 1)) + x_normalized = x_normalized + \ + (self.means[:, 0, :].unsqueeze(1).repeat( + 1, self.output_dim, 1)) + return x_normalized diff --git a/ts_benchmark/baselines/olinear/layers/SelfAttention_Family.py b/ts_benchmark/baselines/olinear/layers/SelfAttention_Family.py new file mode 100644 index 00000000..31e5cd21 --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/SelfAttention_Family.py @@ -0,0 +1,1796 @@ +import os +import random +import time +from math import sqrt + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from einops import rearrange, reduce +from reformer_pytorch import LSHSelfAttention + +from ..utils.masking import TriangularCausalMask, ProbMask +from ..utils.tools import plot_mat, moore_penrose_iter_pinv + + +# from entmax import sparsemax, entmax15 + + +# Code implementation from https://github.com/thuml/Flowformer +class FlowAttention(nn.Module): + def __init__(self, attention_dropout=0.1): + super(FlowAttention, self).__init__() + self.dropout = nn.Dropout(attention_dropout) + print('FlowAttention is used...') + + def kernel_method(self, x): + return torch.sigmoid(x) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, *args, **kwargs): + queries = queries.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + # kernel + queries = self.kernel_method(queries) + keys = self.kernel_method(keys) + # incoming and outgoing + normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6)) + normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6)) + # reweighting + normalizer_row_refine = ( + torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6)) + normalizer_col_refine = ( + torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6)) + # competition and allocation + normalizer_row_refine = torch.sigmoid( + normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2]))) + normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis + # multiply + kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None]) + x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1, + 2).contiguous() + return x, None + + +# Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch +class FlashAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FlashAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + print('FlashAttention is used...') + + def flash_attention_forward(self, Q, K, V, mask=None): + # BLOCK_SIZE = 32 + BLOCK_SIZE = 128 + NEG_INF = -1e10 # -infinity + EPSILON = 1e-10 + # mask = torch.randint(0, 2, (128, 8)).to(device='cuda') + O = torch.zeros_like(Q, requires_grad=True) + l = torch.zeros(Q.shape[:-1])[..., None] + m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF + + O = O.to(device='cuda') + l = l.to(device='cuda') + m = m.to(device='cuda') + + Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) + KV_BLOCK_SIZE = BLOCK_SIZE + + Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) + K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) + V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) + if mask is not None: + mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1)) + + Tr = len(Q_BLOCKS) + Tc = len(K_BLOCKS) + + O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) + l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) + m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) + + for j in range(Tc): + Kj = K_BLOCKS[j] + Vj = V_BLOCKS[j] + if mask is not None: + maskj = mask_BLOCKS[j] + + for i in range(Tr): + Qi = Q_BLOCKS[i] + Oi = O_BLOCKS[i] + li = l_BLOCKS[i] + mi = m_BLOCKS[i] + + scale = 1 / np.sqrt(Q.shape[-1]) + Qi_scaled = Qi * scale + + S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) + if mask is not None: + # Masking + maskj_temp = rearrange(maskj, 'b j -> b 1 1 j') + S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF) + + m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True) + P_ij = torch.exp(S_ij - m_block_ij) + if mask is not None: + # Masking + P_ij = torch.where(maskj_temp > 0, P_ij, 0.) + + l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON + + P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) + + mi_new = torch.maximum(m_block_ij, mi) + li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij + + O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + ( + torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj + l_BLOCKS[i] = li_new + m_BLOCKS[i] = mi_new + + O = torch.cat(O_BLOCKS, dim=2) + l = torch.cat(l_BLOCKS, dim=2) + m = torch.cat(m_BLOCKS, dim=2) + return O, l, m + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, token_weight=None, *args, **kwargs): + res = \ + self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), + values.permute(0, 2, 1, 3), + attn_mask)[0] + return res.permute(0, 2, 1, 3).contiguous(), None + + +class FullAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, imp_mode=False, ij_mat_flag=False, ij_attn_adjust_init=10.0, ij_mat_para=0, + num_heads=None, weight_plus=False, plot_mat_flag=False, save_folder='./'): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.imp_mode = imp_mode + self.token_num = token_num + self.ij_mat_flag = ij_mat_flag # False # + self.ij_mat_para = ij_mat_para + self.weight_plus = weight_plus + self.plot_mat = plot_mat_flag + self.save_folder = os.path.join(save_folder, 'attn_mat') + if self.plot_mat: + os.makedirs(self.save_folder, exist_ok=True) + self.num_heads = 1 + + print(f'self.weight_plus in FullAttention: {self.weight_plus}') + + print(f'ij_mat_flag in FullAttention:{self.ij_mat_flag}') + print(f'ij_mat_para in FullAttention:{self.ij_mat_para}') + + if self.imp_mode and self.token_num is not None: + self.token_contribution = nn.Parameter(torch.zeros(1, self.num_heads or 1, 1, self.token_num)) + # [1,1,N,N] + self.tau = nn.Parameter(torch.ones(self.token_num, 1)) + if self.ij_mat_para: + print('self.ij_mat_para in FullAttention is enabled...') + + # self.weight_mat = nn.Parameter(torch.randn(1, self.token_num, self.token_num) * 1.0) + self.weight_mat = nn.Parameter((torch.eye(self.token_num) * 1.0 + + torch.randn(self.token_num, self.token_num) * 1.0) + [None, :, :].repeat(self.num_heads or 1, 1, 1)) + + # ablation study: ones(token_num, token_num) sucks + # self.weight_mat = nn.Parameter((torch.ones(token_num, token_num) * 1.0)[None, :, :]. + # repeat(self.num_heads or 1, 1, 1)) + + elif self.ij_mat_flag: + print('distance-based short-sight-attention in FullAttention is enabled...') + + self.attn_tau = nn.Parameter(torch.ones(token_num, 1) * ij_attn_adjust_init) + # ij_mat = torch.ones(token_num, token_num) + self.exp_para = nn.Parameter(torch.tensor(-10.0)) + range_tensor = torch.arange(token_num) + ij_mat = (range_tensor.view(token_num, 1) - + range_tensor.view(1, token_num)).abs() + + self.register_buffer('ij_mat', ij_mat) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, token_weight=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + # this_device = queries.device + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + weight_mat = None + token_contribution = None + + # attention matrix adjustment; 240507 + if self.imp_mode and self.token_num is not None: + # token_contribution = self.token_contribution.to(queries.device) + + if token_weight is None: + # 2d + if self.ij_mat_para: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau), dim=-1)[None, :, :, :] + # softplus 241014 + # weight_mat = F.softplus(self.weight_mat / F.softplus(self.tau))[None, :, :, :] + elif self.ij_mat_flag: + ij_mat = self.ij_mat.pow(F.softplus(self.exp_para)) if self.ij_mat_flag else self.ij_mat + weight_mat = (-ij_mat.unsqueeze(0).unsqueeze(0) + / F.softplus(self.attn_tau).unsqueeze(0).unsqueeze(0)) + # weight_mat = weight_mat.exp() + weight_mat = F.softmax(weight_mat / F.softplus(self.tau), dim=-1) + else: + # 4d + if token_weight.shape[-1] != self.token_num: + print(f'token_weight ({token_weight.shape[-1]}) does not match token_num ({self.token_num})!') + raise ValueError + # token_weight: [b,l] --> [b,1,1,l], in case that there is 0 in token_weight + token_weight = torch.maximum(token_weight.unsqueeze(1).unsqueeze(1), torch.tensor(1e-5)) + if self.ij_mat_para: + weight_mat = self.weight_mat.unsqueeze(0) # [1,h,l,l] + # token_weight is considered + weight_mat = F.softmax(weight_mat * token_weight / + F.softplus(self.tau.unsqueeze(0).unsqueeze(0)), dim=-1) + + else: + if self.ij_mat_flag: + ij_mat = self.ij_mat.pow(F.softplus(self.exp_para)) + weight_mat = (-ij_mat.unsqueeze(0).unsqueeze(0) + / F.softplus(self.attn_tau).unsqueeze(0).unsqueeze(0) / token_weight) + weight_mat = F.softmax(weight_mat, dim=-1) + + else: + weight_mat = F.softmax(token_weight / F.softplus(self.tau), dim=-1) + + if token_contribution is not None: + A = A + token_contribution + + if self.weight_plus and weight_mat is not None: + A = A + weight_mat + + # attention matrix [b,h,l,s] + A = torch.softmax(A, dim=-1) + + if not self.weight_plus and weight_mat is not None: + A = A * weight_mat + A = F.normalize(A, p=1, dim=-1) + + # plot + if not self.training and self.plot_mat and random.random() < 0.08 and A.shape[-1] > 10: + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'temporal' if self.ij_mat_flag else 'channel' + plot_mat(att_mat_2d, str_cat='attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + if weight_mat is not None: + batch_idx2 = min(batch_idx, weight_mat.shape[0] - 1) + head_idx2 = min(head_idx, weight_mat.shape[1] - 1) + weight_mat_2d = weight_mat[batch_idx2, head_idx2, :, :] + + # + range_tensor = torch.arange(self.token_num) + ij_mat = (range_tensor.view(self.token_num, 1) - + range_tensor.view(1, self.token_num)).abs() + manual_weight = (ij_mat + 1).pow(1) + weight_mat_2d = weight_mat_2d * manual_weight.to(weight_mat_2d.device) + + weight_mat_2d = F.normalize(weight_mat_2d, p=1, dim=-1) + + plot_mat(weight_mat_2d, str_cat='weight_mat_2d', + str0=f"batch_{batch_idx2}_head_{head_idx2}_{time_or_channel}", + save_folder=self.save_folder) + + # only the ij + if self.ij_mat_flag: + ij_mat = self.ij_mat.pow(F.softplus(self.exp_para)) + weight_mat = (-ij_mat.unsqueeze(0).unsqueeze(0) + / F.softplus(self.attn_tau).unsqueeze(0).unsqueeze(0)) + weight_mat = F.softmax(weight_mat, dim=-1) + weight_mat_2d_ij = weight_mat[0, 0, :, :] + + # + range_tensor = torch.arange(self.token_num) + ij_mat = (range_tensor.view(self.token_num, 1) - + range_tensor.view(1, self.token_num)).abs() + manual_weight = (ij_mat + 1).pow(0.5) + weight_mat_2d_ij = weight_mat_2d_ij * manual_weight.to(weight_mat_2d_ij.device) + + weight_mat_2d_ij = F.normalize(weight_mat_2d_ij, p=1, dim=-1) + plot_mat(weight_mat_2d_ij, str_cat='weight_mat_2d_ij', str0=f"batch_{0}_head_{0}_{time_or_channel}", + save_folder=self.save_folder) + + # i will handle this + if self.ij_mat_flag: + # channel does not need this + # mask + if hasattr(self, 'ij_mat'): + ij_mat = self.ij_mat + else: + range_tensor = torch.arange(self.token_num) + ij_mat = (range_tensor.view(self.token_num, 1) - + range_tensor.view(1, self.token_num)).abs() + manual_weight = (ij_mat + 1).pow(-0.5) + att_mat_2d = att_mat_2d * manual_weight.to(att_mat_2d.device) + + att_mat_2d = F.normalize(att_mat_2d, p=1, dim=-1) + plot_mat(att_mat_2d, str_cat='attn_mat', str0=f"manual_batch_{batch_idx}_head_{head_idx}_" + f"{time_or_channel}", + save_folder=self.save_folder) + + print('Attention matrix in FullAttention has been saved...') + + # dropout, reserved + A = self.dropout(A) + + # print(f'A.shape: {A.shape}') + # print(f'values.shape: {values.shape}') + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class FullAttention_SF(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, SF_mode=False, contri_flag=False, ij_mat_flag=False, ij_attn_adjust_init=10.0, + ij_mat_para=0, weight_plus=False, plot_mat_flag=False, save_folder='./'): + super(FullAttention_SF, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.SF_mode = SF_mode + self.contri_flag = contri_flag + self.token_num = token_num + self.ij_mat_flag = ij_mat_flag # False # + self.ij_mat_para = ij_mat_para + self.weight_plus = weight_plus + self.plot_mat_flag = plot_mat_flag + self.save_folder = os.path.join(save_folder) + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + self.num_heads = 1 + + print(f'self.weight_plus in FullAttention: {self.weight_plus}') + + print(f'ij_mat_flag in FullAttention:{self.ij_mat_flag}') + print(f'ij_mat_para in FullAttention:{self.ij_mat_para}') + + if self.contri_flag: + # [1,1,1,N] + self.token_contri = nn.Parameter(torch.ones(1, self.num_heads or 1, 1, self.token_num)) + else: + self.token_contri = None + + # self.tau_general = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + + if self.SF_mode and self.weight_plus: + self.sum_weight = nn.Parameter(torch.tensor(1.0)) + + if self.SF_mode and self.token_num is not None: + # [1,1,N,1] + self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + # scalar + # self.tau = nn.Parameter(torch.ones(1)) + # [1,1,1,N] + # self.tau = nn.Parameter(torch.ones(1, 1, 1, self.token_num)) + # [1,1,N,N] + # self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, self.token_num)) + if self.ij_mat_para: + print('self.ij_mat_para in FullAttention is enabled...') + + # self.weight_mat = nn.Parameter(torch.randn(1, self.token_num, self.token_num) * 1.0) + self.weight_mat = nn.Parameter((torch.eye(self.token_num) * 1.0 + + torch.randn(self.token_num, self.token_num) * 1.0) + [None, None, :, :].repeat(1, self.num_heads or 1, 1, 1)) + + # ablation study: ones(token_num, token_num) sucks + # self.weight_mat = nn.Parameter((torch.ones(token_num, token_num) * 1.0)[None, :, :]. + # repeat(self.num_heads or 1, 1, 1)) + + elif self.ij_mat_flag: + print('distance-based short-sight-attention in FullAttention is enabled...') + + self.attn_tau = nn.Parameter(torch.ones(token_num, 1) * ij_attn_adjust_init) + # ij_mat = torch.ones(token_num, token_num) + self.exp_para = nn.Parameter(torch.tensor(-10.0)) + range_tensor = torch.arange(token_num) + ij_mat = (range_tensor.view(token_num, 1) - + range_tensor.view(1, token_num)).abs() + + self.register_buffer('ij_mat', ij_mat) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, token_weight=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + # this_device = queries.device + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + weight_mat = None + + # attention matrix adjustment; 240507 + if self.SF_mode and self.token_num is not None: + # token_contribution = self.token_contribution.to(queries.device) + + # 2d + if self.ij_mat_para: + if self.contri_flag and self.token_contri is not None: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau) / F.softplus(self.token_contri), + dim=-1) + else: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau), dim=-1) + # softplus 241014 + # weight_mat = F.softplus(self.weight_mat / F.softplus(self.tau))[None, :, :, :] + # only use weight_mat: not good + # weight_mat = F.sigmoid(self.weight_mat) + # weight_mat = F.softplus(self.weight_mat) + elif self.ij_mat_flag: + # [N, N] + ij_mat = self.ij_mat.pow(F.softplus(self.exp_para)) + # [1,1,N,N] + weight_mat = (-ij_mat.unsqueeze(0).unsqueeze(0) + / F.softplus(self.attn_tau).unsqueeze(0).unsqueeze(0)) + # weight_mat = weight_mat.exp() + weight_mat = F.softmax(weight_mat / F.softplus(self.tau), dim=-1) + + if self.SF_mode and weight_mat is not None: + if self.weight_plus: + A = A + self.sum_weight * weight_mat + else: + # ablation: this is a better choice + A = A * weight_mat + + elif self.contri_flag: + # not SF mode but use token contribution + token_contri = F.softmax(self.token_contri, dim=-1) + A = A * token_contri + + # attention matrix [b,h,l,s] + A = torch.softmax(A, dim=-1) + # not helpful + # A = torch.softmax(A / F.softplus(self.tau_general), dim=-1) + # double softmax + # A = A.softmax(dim=-1).softmax(dim=-1) + + # if not self.weight_plus and weight_mat is not None: + # A = A * weight_mat + # A = F.normalize(A, p=1, dim=-1) + + # plot + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='self_attn_SF', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + if weight_mat is not None: + batch_idx2 = min(batch_idx, weight_mat.shape[0] - 1) + head_idx2 = min(head_idx, weight_mat.shape[1] - 1) + weight_mat_2d = weight_mat[batch_idx2, head_idx2, :, :] + + plot_mat(weight_mat_2d, str_cat='weight_mat_2d', + str0=f"batch_{batch_idx2}_head_{head_idx2}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in FullAttention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + # print(f'A.shape: {A.shape}') + # print(f'values.shape: {values.shape}') + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class FullAttention_ablation(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, SF_mode=1, softmax_flag=1, weight_plus=0, outside_softmax=0, + plot_mat_flag=False, save_folder='./', plot_grad_flag=False, **kwargs): + super(FullAttention_ablation, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.enhance_mode = SF_mode + self.softmax_flag = softmax_flag + self.token_num = token_num + self.outside_softmax = outside_softmax # False # + self.weight_plus = weight_plus + self.plot_mat_flag = plot_mat_flag + self.plot_grad_flag = plot_grad_flag + self.save_folder = os.path.join(save_folder) + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + self.num_heads = 1 + + print(f'self.weight_plus in FullAttention_ablation: {self.weight_plus}') + print(f'self.softmax_flag in FullAttention_ablation: {self.softmax_flag}') + print(f'self.outside_softmax in FullAttention_ablation: {self.outside_softmax}') + + if not self.enhance_mode: + print('Vanilla attention is used...') + else: + print('Enhanced attention is used...') + + if self.enhance_mode and self.token_num is not None: + # [1,1,N,1] + if self.softmax_flag: + self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + + init_weight_mat = (torch.eye(self.token_num) * 1.0 + + torch.randn(self.token_num, self.token_num) * 1.0) + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :].repeat(1, self.num_heads or 1, 1, 1)) + + self.tau2 = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + + def forward(self, queries, keys, values, attn_mask, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + weight_mat = None + ori_attn_mat = None + if self.enhance_mode: + if not self.training and self.plot_mat_flag: + ori_attn_mat = torch.softmax(A, dim=-1) + + # attention matrix adjustment; 240507 + if self.enhance_mode and self.token_num is not None: + # 2d + if self.softmax_flag: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau), dim=-1) + else: + # use scale or not + weight_mat = F.softplus(self.weight_mat) # / sqrt(self.token_num) + + if self.enhance_mode and weight_mat is not None: + if self.outside_softmax: + if self.weight_plus: + A = A.softmax(dim=-1) + weight_mat + + # ablations: two softmax + # A = A.softmax(dim=-1).softmax(dim=-1) + weight_mat + + # ablations: * tau + # A = (A * F.softplus(self.tau2)).softmax(dim=-1) + weight_mat + else: + A = A.softmax(dim=-1) * weight_mat + A = F.normalize(A, p=1, dim=-1) + else: + if self.weight_plus: + A = A + weight_mat + else: + # ablation: this is a better choice + A = A * weight_mat + + # attention matrix [b,h,l,s] + A = torch.softmax(A, dim=-1) + + else: + A = torch.softmax(A, dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 1: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + if ori_attn_mat is not None and self.enhance_mode: + ori_att_mat_2d = ori_attn_mat[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(ori_att_mat_2d, str_cat='ori_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_" + f"{time_or_channel}", + save_folder=self.save_folder) + + if weight_mat is not None and self.enhance_mode: + batch_idx2 = min(batch_idx, weight_mat.shape[0] - 1) + head_idx2 = min(head_idx, weight_mat.shape[1] - 1) + weight_mat_2d = weight_mat[batch_idx2, head_idx2, :, :] + + plot_mat(weight_mat_2d, str_cat='adding_mat_2d', + str0=f"batch_{batch_idx2}_head_{head_idx2}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in FullAttention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + # print(f'A.shape: {A.shape}') + # print(f'values.shape: {values.shape}') + V = torch.einsum("bhls,bshd->blhd", A, values) + + # plot gradient + if self.plot_grad_flag and random.random() < 0.01 and self.weight_mat.grad is not None: + batch_idx = random.randint(0, self.weight_mat.shape[0] - 1) + head_idx = random.randint(0, self.weight_mat.shape[1] - 1) + + # final + weight_mat_grad = self.weight_mat.grad[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(weight_mat_grad, str_cat='weight_grad_mat', + str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention gradient matrix in FullAttention has been saved to {self.save_folder}...') + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class FullAttention_L1(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, dim=None, plot_mat_flag=False, save_folder='./'): + super(FullAttention_L1, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + # self.tau1 = nn.Parameter(torch.ones(1, self.token_num, 1, 1)) + # self.tau2 = nn.Parameter(torch.ones(1, self.token_num, 1, 1)) + # self.power = nn.Parameter(torch.tensor(2.0)) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :]) + self.weight_mat2 = nn.Parameter(torch.randn(1, 1, self.token_num, self.token_num)) + # self.dim = dim + # if self.dim is not None: + # self.norm1 = nn.LayerNorm(self.dim) + # self.norm2 = nn.LayerNorm(self.dim) + + self.lin_v = nn.Linear(self.token_num, self.token_num) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, token_weight=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + scale = self.scale or 1. / sqrt(E) + + # queries, keys = ((queries / F.softplus(self.tau1)).softmax(dim=-1), + # (keys / F.softplus(self.tau2)).softmax(dim=-1)) + + # queries, keys = self.norm1(queries), self.norm2(keys) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + # A = scores + + # softmax --> L1 + # A = F.normalize(F.softplus(A) ** 3, p=1, dim=-1) + + # /tau (seems helpless) + # A = (A / F.softplus(self.tau)).softmax(dim=-1) + + # softmax & softmax + # A = (A.softmax(dim=-1)).softmax(dim=-1) # .softmax(dim=-1) + # A = A.exp().softmax(dim=-1) # could produce nan, probably due to overlarge values + + # softmax --> x^3: failed + + # enhanced attention + A = A.softmax(dim=-1) + F.softplus(self.weight_mat) + # A = entmax15(A, dim=-1) + F.softplus(self.weight_mat) + # A = sparsemax(A, dim=-1) + F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + + # x --> (x+3)^3 --> softmax: failed + # A = ((A + self.bias) ** 3).softmax(dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in FullAttention_L1 has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A + F.softmax(self.weight_mat2, dim=-1), values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class EnhancedAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, CovMat=None, plot_mat_flag=False, save_folder='./', + **kwargs): + super(EnhancedAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.CovMat = CovMat.unsqueeze(0).unsqueeze(0) if CovMat is not None else None + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :]) + + print('Enhanced Attention is used...') + + def forward(self, queries, keys, values, attn_mask=None, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + # original enhanced attention + if self.CovMat is not None: + A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) + else: + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + A = A.softmax(dim=-1) + F.softplus(self.weight_mat) + + A = F.normalize(A, p=1, dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in Enhanced Attention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A, values) + # V = self.v_proj(V) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class VanillaAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + plot_mat_flag=False, save_folder='./', **kwargs): + super(VanillaAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + print('Vanilla attention is used...') + + def forward(self, queries, keys, values, attn_mask=None, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + A = A.softmax(dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in vanilla Attention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A, values) + # V = self.v_proj(V) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class EnhancedAttention2(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, head_dim=None, + plot_mat_flag=False, save_folder='./', **kwargs): + super(EnhancedAttention2, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.head_dim = head_dim + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :]) + + self.q_k_layer = nn.Sequential( + nn.Linear(self.head_dim, self.head_dim * 4), + nn.GELU(), + nn.Linear(self.head_dim * 4, self.head_dim), + ) + + # self.Q_learn = nn.Parameter(torch.randn(1, self.token_num, 1, self.head_dim)) + # self.K_learn = nn.Parameter(torch.randn(1, self.token_num, 1, self.head_dim)) + + # self.tau = nn.Parameter(torch.tensor(3.0)) + + def forward(self, queries, keys, values, attn_mask=None, token_weight=None, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys * F.sigmoid(self.q_k_layer(queries))) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + # enhanced attention 2 + # alpha = F.sigmoid(self.tau) + # A = alpha * A.softmax(dim=-1) + (1 - alpha) * F.normalize(self.weight_mat, p=1, dim=-1) + + # QK^T + # delta_mat = torch.einsum("blhe,bshe->bhls", self.Q_learn, self.K_learn) + # A = A.softmax(dim=-1) + F.softplus(delta_mat) + # A = F.normalize(A, p=1, dim=-1) + + # original enhanced attention + A = A.softmax(dim=-1) + F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in Enhanced Attention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class BlockWiseEnhancedAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, plot_mat_flag=False, save_folder='./', block_size=150, shuffle=True, **kwargs): + super(BlockWiseEnhancedAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + self.block_size = int(block_size) + self.shuffle = shuffle + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + self.block_num = int(np.ceil(self.token_num / self.block_size)) + + if not self.shuffle: + init_weight_mat = torch.randn(1, 1, self.block_num, self.block_size, self.block_size) + self.weight_mat = nn.Parameter(init_weight_mat) + + def forward(self, queries, keys, values, attn_mask, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + # shuffle + if self.shuffle: + if self.training: + g = torch.Generator() + seed = int(time.time_ns()) % (2 ** 32) + g.manual_seed(seed) + perm = torch.randperm(L, generator=g) + else: + perm = torch.randperm(L) + queries = queries[:, perm, :, :] + keys = keys[:, perm, :, :] + values = values[:, perm, :, :] + + if self.token_num != self.block_num * self.block_size: + delta = self.block_num * self.block_size - self.token_num + queries = F.pad(queries, pad=(0, 0, 0, 0, 0, delta), mode='constant', value=0) + keys = F.pad(keys, pad=(0, 0, 0, 0, 0, delta), mode='constant', value=0) + values = F.pad(values, pad=(0, 0, 0, 0, 0, delta), mode='constant', value=0) + + queries = rearrange(queries, 'b (n s) h e -> b n s h e', n=self.block_num) + keys = rearrange(keys, 'b (n s) h e -> b n s h e', n=self.block_num) + + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("b n s h e, b n l h e -> b h n s l", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + if self.shuffle: + # vanilla attention + A = A.softmax(dim=-1) + else: + # enhanced attention + A = A.softmax(dim=-1) + F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 0.01: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in Enhanced Attention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + values = rearrange(values, 'b (n s) h e -> b n s h e', n=self.block_num) + # print(f'values.shape: {values.shape}') + # print(f'A.shape: {A.shape}') + + V = torch.einsum("b h n s l, b n l h e -> b n s h e", A, values) + + V = rearrange(V, 'b n s h e -> b (n s) h e')[:, :self.token_num, :, :] + + if self.shuffle: + # back to original order + inv_perm = torch.argsort(perm) + V = V[:, inv_perm, :, :] + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class RowWiseEnhancedAttention(nn.Module): + # aim for no performance loss + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, enhanced=True, plot_mat_flag=False, save_folder='./', block_size=150, **kwargs): + super(RowWiseEnhancedAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.enhanced = enhanced + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + self.block_size = int(block_size) + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + self.block_num = int(np.ceil(self.token_num / self.block_size)) + + if self.enhanced: + print('Enhanced attention in RowWiseEnhancedAttention is used...') + init_weight_mat = torch.randn(1, 1, self.token_num, self.token_num) + self.weight_mat = nn.Parameter(init_weight_mat) + else: + print('Vanilla attention in RowWiseEnhancedAttention is used...') + + def forward(self, queries, keys, values, attn_mask, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + V_list = [] + scale = self.scale or 1. / sqrt(E) + + for i in range(self.block_num): + start_idx = i * self.block_size + end_idx = min(L, (i + 1) * self.block_size) + query = queries[:, start_idx:end_idx, :, :] + + A = torch.einsum("blhe,bshe->bhls", query, keys) * scale + + if self.enhanced: + weight_mat = self.weight_mat[:, :, start_idx:end_idx, :] + A = A.softmax(dim=-1) + F.softplus(weight_mat) + A = F.normalize(A, p=1, dim=-1) + else: + A = A.softmax(dim=-1) + + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A, values) + + V_list.append(V) + + V = torch.concat(V_list, dim=1) + + return V.contiguous(), None + + +class NystromAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, plot_mat_flag=False, save_folder='./', num_landmarks=64, pinv_iterations=6, **kwargs): + super(NystromAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.token_num = token_num + self.dropout = nn.Dropout(attention_dropout) + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + self.num_landmarks = int(num_landmarks) + self.pinv_iterations = int(pinv_iterations) + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + + self.merge_size = int(np.ceil(self.token_num / self.num_landmarks)) + + def forward(self, queries, keys, values, attn_mask, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + if self.token_num != self.num_landmarks * self.merge_size: + delta = self.num_landmarks * self.merge_size - self.token_num + queries2 = F.pad(queries, pad=(0, 0, 0, 0, 0, delta), mode='constant', value=0) + keys2 = F.pad(keys, pad=(0, 0, 0, 0, 0, delta), mode='constant', value=0) + else: + queries2, keys2 = queries, keys + + # reduce the number of tokens + landmark_einops_eq = '... (n l) h e -> ... n h e' + q_landmarks = reduce(queries2, landmark_einops_eq, 'sum', l=self.merge_size) / self.merge_size + k_landmarks = reduce(keys2, landmark_einops_eq, 'sum', l=self.merge_size) / self.merge_size + + einops_eq = 'b l h e, b s h e -> b h l s' + sim1 = einsum(einops_eq, queries, k_landmarks) + sim2 = einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = einsum(einops_eq, q_landmarks, keys) + + # eq (15) in the paper and aggregate values + + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_iter_pinv(attn2, self.pinv_iterations) + + # b, h, l, e + V = (attn1 @ attn2_inv) @ (attn3 @ values.transpose(1, 2)) + + # b, l, h, e + V = V.transpose(1, 2) + + return V.contiguous(), None + + +class FullAttention_Gauss(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None): + super(FullAttention_Gauss, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.token_num = token_num + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + if token_num is not None: + self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + else: + self.tau = nn.Parameter(torch.ones(1)) + + def forward(self, queries, keys, values, attn_mask, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + # gauss RBF Kernel Weighting: could incur out of memory + # b,h,l,1,e - b,h,1,l,e --> b,h,l,l,e + diff = queries.transpose(1, 2).unsqueeze(-2) - keys.transpose(1, 2).unsqueeze(2) + # b,h,l,l + dist_sq = torch.sum(diff ** 2, dim=-1) + A = torch.exp(-dist_sq / F.softplus(self.tau)) + + # dropout, reserved + A = self.dropout(A) + + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class FullAttention_ablation_2501(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + token_num=None, SF_mode=1, softmax_flag=1, weight_plus=0, outside_softmax=0, enhance_alpha=1, + plot_mat_flag=False, save_folder='./', plot_grad_flag=False): # './utils/corr_mat/traffic.npy' + super(FullAttention_ablation_2501, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.SF_mode = SF_mode + self.softmax_flag = softmax_flag + self.token_num = token_num + self.outside_softmax = outside_softmax # False # + self.enhance_alpha = enhance_alpha # False # + self.weight_plus = weight_plus + self.plot_mat_flag = plot_mat_flag + self.plot_grad_flag = plot_grad_flag + self.save_folder = os.path.join(save_folder) + if self.plot_mat_flag and not os.path.exists(self.save_folder): + os.makedirs(self.save_folder, exist_ok=True) + self.num_heads = 1 + + print(f'self.weight_plus in FullAttention_ablation: {self.weight_plus}') + print(f'self.softmax_flag in FullAttention_ablation: {self.softmax_flag}') + print(f'self.outside_softmax in FullAttention_ablation: {self.outside_softmax}') + + if self.weight_plus and self.enhance_alpha: + self.alpha = nn.Parameter(torch.tensor(0.0)) + + if not self.SF_mode: + print('Vanilla attention is used...') + else: + print('Enhanced attention is used...') + + if self.SF_mode and self.token_num is not None: + # [1,1,N,1] + if self.softmax_flag: + self.tau = nn.Parameter(torch.ones(1, 1, self.token_num, 1)) + + init_weight_mat = (torch.eye(self.token_num) * 1.0 + + torch.randn(self.token_num, self.token_num) * 1.0) + # ablation + # init_weight_mat = (torch.eye(self.token_num) * 0.0 + + # torch.randn(self.token_num, self.token_num) * 1.0) + self.weight_mat = nn.Parameter(init_weight_mat[None, None, :, :].repeat(1, self.num_heads or 1, 1, 1)) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, token_weight=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + # this_device = queries.device + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = scale * scores + + weight_mat = None + ori_attn_mat = None + if self.SF_mode: + if not self.training and self.plot_mat_flag: + ori_attn_mat = torch.softmax(A, dim=-1) + + # attention matrix adjustment; 240507 + if self.SF_mode and self.token_num is not None: + # 2d + if self.softmax_flag: + weight_mat = F.softmax(self.weight_mat / F.softplus(self.tau), dim=-1) + else: + # use scale or not + weight_mat = F.softplus(self.weight_mat) # / sqrt(self.token_num) + + if self.SF_mode and weight_mat is not None: + if self.outside_softmax: + if self.weight_plus: + A = A.softmax(dim=-1) + weight_mat + else: + A = A.softmax(dim=-1) * weight_mat + A = F.normalize(A, p=1, dim=-1) + else: + if self.weight_plus: + if self.enhance_alpha: + alpha = F.softplus(self.alpha) + A = alpha * A + (1 - alpha) * weight_mat + else: + A = A + weight_mat + else: + # ablation: this is a better choice + A = A * weight_mat + + # attention matrix [b,h,l,s] + A = torch.softmax(A, dim=-1) + + else: + A = torch.softmax(A, dim=-1) + + # plot mat + if not self.training and self.plot_mat_flag and random.random() < 1: # and A.shape[-1] > 10 + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + + # final + att_mat_2d = A[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(att_mat_2d, str_cat='final_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + if ori_attn_mat is not None and self.SF_mode: + ori_att_mat_2d = ori_attn_mat[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(ori_att_mat_2d, str_cat='ori_attn_mat', str0=f"batch_{batch_idx}_head_{head_idx}_" + f"{time_or_channel}", + save_folder=self.save_folder) + + if weight_mat is not None and self.SF_mode: + batch_idx2 = min(batch_idx, weight_mat.shape[0] - 1) + head_idx2 = min(head_idx, weight_mat.shape[1] - 1) + weight_mat_2d = weight_mat[batch_idx2, head_idx2, :, :] + + plot_mat(weight_mat_2d, str_cat='adding_mat_2d', + str0=f"batch_{batch_idx2}_head_{head_idx2}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention matrix in FullAttention has been saved to {self.save_folder}...') + + # dropout, reserved + A = self.dropout(A) + + # print(f'A.shape: {A.shape}') + # print(f'values.shape: {values.shape}') + V = torch.einsum("bhls,bshd->blhd", A, values) + + # plot gradient + if self.plot_grad_flag and random.random() < 0.01 and self.weight_mat.grad is not None: + batch_idx = random.randint(0, self.weight_mat.shape[0] - 1) + head_idx = random.randint(0, self.weight_mat.shape[1] - 1) + + # final + weight_mat_grad = self.weight_mat.grad[batch_idx, head_idx, :, :] + time_or_channel = 'channel' + plot_mat(weight_mat_grad, str_cat='weight_grad_mat', + str0=f"batch_{batch_idx}_head_{head_idx}_{time_or_channel}", + save_folder=self.save_folder) + + print(f'Attention gradient matrix in FullAttention has been saved to {self.save_folder}...') + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class FullAttention_ori(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention_ori, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class LaserAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(LaserAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, *args, **kwargs): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + # laser attention + value_max = values.max(dim=1, keepdims=True)[0] + values = (values - value_max).exp() + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values).contiguous() + + # laser attention + V = V.log() + value_max + + if self.output_attention: + return V, A + else: + return V, None + + +class LinearAttention(nn.Module): + def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, output_attention=False, + mapping_fun='softmax_learn', token_num=None, imp_mode=False, d_model=None, + plot_mat_flag=False, save_folder='./'): + super(LinearAttention, self).__init__() + # mask_flag, factor, scale, attention_dropout are not used + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.mapping_fun = mapping_fun + self.plot_mat_flag = plot_mat_flag + self.save_folder = save_folder + + self.softmax = nn.Softmax(dim=-1) + + self.imp_mode = imp_mode + self.token_num = token_num + self.d_model = d_model + + print(f'Linear Attention is used. mapping_fun: {self.mapping_fun}') + + if self.imp_mode and self.token_num: + # [1,1,1,N] + self.attn_adjust = nn.Parameter(torch.zeros(self.token_num)).unsqueeze(0).unsqueeze(0).unsqueeze(0) + # [1,1,N,N] + ij_mat = torch.ones(1, 1, token_num, token_num) + # ij_mat = (torch.arange(token_num).view(token_num, 1) - + # torch.arange(token_num).view(1, token_num)).pow(2).unsqueeze(0).unsqueeze(0) + self.register_buffer('ij_mat', ij_mat) + + if self.mapping_fun == 'softmax_learn': + self.delta1 = nn.Parameter(torch.tensor(0.0)) + # self.delta1 = torch.tensor(0.0) + # self.delta2 = nn.Parameter(torch.tensor(-1.0)) + elif self.mapping_fun == 'softmax_learn_v2': + # assert self.token_num is not None, 'token_num should not be NAN' + self.tau_q = nn.Parameter(torch.zeros(1)) + # self.tau_k = nn.Parameter(torch.zeros(1)) + if self.mapping_fun == 'agent': + self.n_agent = 4 + assert self.d_model is not None, 'self.d_model is None' + self.agent_linear = nn.Linear(self.d_model, self.d_model) + self.agent_conv1d = nn.Conv1d(self.token_num, self.n_agent, kernel_size=1) + + assert mapping_fun in ['softmax_learn', 'softmax_q_k', 'x_3', 'relu', 'elu_plus_1', 'agent', 'softmax_learn_v2'] + + def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None, token_weight=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + + if self.mapping_fun == 'agent': + assert L == self.token_num, 'input does not match with LinearAttention' + # [b l h d] --> [b h l d] + q, k, v = queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) + if self.mapping_fun == 'softmax_learn': + q[q < 0] = -20 + k[k < 0] = -20 + q = self.softmax(q / nn.Softplus()(self.delta1)) + k = self.softmax(k / nn.Softplus()(self.delta1)) + elif self.mapping_fun == 'softmax_learn_v2': + q = torch.where(q < 0, -20, q) + k = torch.where(k < 0, -20, k) + q = self.softmax(q / F.softplus(self.tau_q)) + k = self.softmax(k / F.softplus(self.tau_q)) + + elif self.mapping_fun == 'softmax_q_k': + q = self.softmax(q) + # softmax2 + softmax2 = nn.Softmax(dim=-2) + k = softmax2(k) + elif self.mapping_fun == 'x_3': + # x**3 and relu + q = nn.ReLU()(q) + k = nn.ReLU()(k) + # x**3 + q_norm = q.norm(dim=-1, keepdim=True) + k_norm = k.norm(dim=-1, keepdim=True) + q = q ** 3 + k = k ** 3 + q = q / (q.norm(dim=-1, keepdim=True) + 1e-6) * q_norm.clone() # clone() to make it independent + k = k / (k.norm(dim=-1, keepdim=True) + 1e-6) * k_norm.clone() + elif self.mapping_fun == 'relu': + q = nn.ReLU()(q) + k = nn.ReLU()(k) + elif self.mapping_fun == 'elu_plus_1': + # elu+1 + q = F.elu(q) + 1 + k = F.elu(k) + 1 + elif self.mapping_fun == 'agent': + # agent [b,h,n,d] --> [b,h,N,d] + agent1 = (self.agent_conv1d(self.agent_linear(v).flatten(start_dim=0, end_dim=1)) + .view(B, H, -1, q.shape[-1])) + + # q * agent1.T --> softmax + q = self.softmax(torch.matmul(q, agent1.transpose(-1, -2))) # [b,h,N,n] + # agent1.T * k.T --> softmax + k = self.softmax(torch.matmul(agent1, k.transpose(-1, -2))).transpose(-1, -2) # [b,h,N,n] + + special_mode = self.imp_mode and self.token_num is not None and token_weight is not None + + output = None # to make pycharm happy + if not special_mode: + kv = torch.einsum("bhdl, bhle -> bhde", k.transpose(-1, -2), v) + z = 1 / (torch.einsum("bhld, bhd -> bhl", q, k.transpose(-1, -2).sum(dim=-1)) + 1e-6) + # output should be blhd. Bug here! 0508 + output = torch.einsum("bhle, bhed, bhl -> blhd", q, kv, z) + + if not self.training and self.plot_mat_flag and random.random() < 0.01: + A = torch.einsum("bhle,bhse->bhls", q, k) + A = F.normalize(A, p=1, dim=-1) + batch_idx = random.randint(0, A.shape[0] - 1) + head_idx = random.randint(0, A.shape[1] - 1) + att_mat_2d = A[batch_idx, head_idx, :, :] + plot_mat(att_mat_2d, str_cat='linear_attn_mat', str0=f"{self.mapping_fun}_" + f"batch_{batch_idx}_head_{head_idx}", + save_folder=self.save_folder) + print(f'Attention matrix in LinearAttention has been saved to {self.save_folder}...') + + attn_weights = None + if self.output_attention or special_mode: + # even though we do not explicitly visit attention matrix usually in linear attention, but here we are: + attn_scores = torch.einsum("bhld, bhLd -> bhlL", q, k) + attn_weights = F.normalize(attn_scores, p=1, dim=-1) + + if special_mode: + # token_weight: [b,l] --> [b,1,1,l] + # attn_adjust has to be positive + weight_mat = -self.ij_mat / F.softplus(self.attn_adjust) * token_weight.unsqueeze(1).unsqueeze(1) + weight_mat = weight_mat.exp() + attn_weights = attn_weights * weight_mat + attn_weights = F.normalize(attn_weights, p=1, dim=-1) + + output = torch.einsum("bhls,bhsd->blhd", attn_weights, v) + + return output.contiguous(), attn_weights + + +# Code implementation from https://github.com/zhouhaoyi/Informer2020 +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + print('ProbAttention is used...') + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + # real U = U_part(factor*ln(L_k))*L_q + index_sample = torch.randint(L_K, (L_Q, sample_k)) + K_sample = K_expand[:, :, torch.arange( + L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul( + Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, + L_Q, V_sum.shape[-1]).clone() + else: # use mask + # requires that L_Q == L_V, i.e. for self-attention only + assert (L_Q == L_V) + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / + L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[ + None, :, None], index, :] = attn + return context_in, attns + else: + return context_in, None + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, *args, **kwargs): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2, 1) + keys = keys.transpose(2, 1) + values = values.transpose(2, 1) + + U_part = self.factor * \ + np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) + u = self.factor * \ + np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK( + queries, keys, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1. / sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context( + context, values, scores_top, index, L_Q, attn_mask) + + return context.contiguous(), attn + + +class dynamic_projection(nn.Module): + def __init__(self, dim1, dim2): + super().__init__() + self.dim1 = dim1 + self.mlp = nn.Linear(dim1, dim2) + + def forward(self, src): + # src: b, n, d + assert src.shape[-1] == self.dim1 + src_dp = self.mlp(src) + src_dp = F.softmax(src_dp, dim=-1) + src_dp = torch.einsum('bef,bec -> bcf', src, src_dp) + return src_dp + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, dp_rank=None, imp_mode=False): + super(AttentionLayer, self).__init__() + + self.imp_mode = imp_mode + self.n_heads = n_heads + self.dp_rank = dp_rank + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + + if dp_rank: + self.dp_key = dynamic_projection(d_keys * n_heads, dp_rank) + self.dp_value = dynamic_projection(d_values * n_heads, dp_rank) + + self.out_projection = nn.Linear(d_values * n_heads, d_model) + + def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None, token_weight=None, **kwargs): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys) + values = self.value_projection(values) + + if self.dp_rank: + S = self.dp_rank + keys = self.dp_key(keys) + values = self.dp_value(values) + + keys = keys.view(B, S, H, -1) + values = values.view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask=attn_mask, + tau=tau, + delta=delta, + token_weight=token_weight.to(queries.device) if token_weight is not None else None + ) + # [b,l,h,s] + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class ReformerLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, causal=False, bucket_size=4, n_hashes=4): + super().__init__() + self.bucket_size = bucket_size + self.attn = LSHSelfAttention( + dim=d_model, + heads=n_heads, + bucket_size=bucket_size, + n_hashes=n_hashes, + causal=causal + ) + print('ReformerLayer is used...') + + def fit_length(self, queries): + # inside reformer: assert N % (bucket_size * 2) == 0 + B, N, C = queries.shape + if N % (self.bucket_size * 2) == 0: + return queries + else: + # fill the time series + fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) + return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) + + def forward(self, queries, keys, values, attn_mask, tau, delta, *args, **kwargs): + # in Reformer: defalut queries=keys + if queries.ndim > 3: + queries = queries.flatten(2) + B, N, C = queries.shape + + queries = self.attn(self.fit_length(queries))[:, :N, :] + return queries, None diff --git a/ts_benchmark/baselines/olinear/layers/StandardNorm.py b/ts_benchmark/baselines/olinear/layers/StandardNorm.py new file mode 100644 index 00000000..990d0fdc --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/StandardNorm.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(Normalize, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + self.non_norm = non_norm + if self.affine: + self._init_params() + + def forward(self, x, mode: str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim - 1)) + if self.subtract_last: + self.last = x[:, -1, :].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.non_norm: + return x + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.non_norm: + return x + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x diff --git a/ts_benchmark/baselines/olinear/layers/SwinTransformer.py b/ts_benchmark/baselines/olinear/layers/SwinTransformer.py new file mode 100644 index 00000000..6140f24c --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/SwinTransformer.py @@ -0,0 +1,496 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils.tools import create_sin_pos_embed, create_swin_relative_index, get_relative_coords_table +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = var2tuple(window_size) + window_size = fix_window(window_size, H, W) + assert H % window_size[0] == 0 and W % window_size[1] == 0 + # print('x.shape: ', x.shape) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + window_size = var2tuple(window_size) + window_size = fix_window(window_size, H, W) + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def fix_window(x, h, w, max_prod=math.inf): + x = (x[0] if 0 < x[0] <= h else h, x[1] if 0 < x[1] <= w else w) + if x[0] * x[1] > max_prod: + if x[0] == h: + x = (h, max(max_prod // h, 1)) + elif x[1] == w: + x = (max(max_prod // w, 1), w) + return x + + +def fix_shift(shift_size, window_size): + shift_size = [min(a, b) for a, b in zip(shift_size, [(window_size[0] + 1) // 2, (window_size[1] + 1) // 2])] + return shift_size + + +def win_padding(x, window_size): + if x is None: + return x, 0, 0 + + B, H, W, C = x.shape + window_size = var2tuple(window_size) + window_size = fix_window(window_size, H, W) + padding = [0, 0, 0, 0, 0, 0] + + rem1 = H % window_size[0] + if rem1 != 0: + padding[5] = window_size[0] - rem1 + + rem2 = W % window_size[1] + if rem2 != 0: + padding[3] = window_size[1] - rem2 + + x = F.pad(x, padding) + return x, x.shape[1], x.shape[2] + + +class WindowAttention(nn.Module): + def __init__(self, dim, window_size, num_heads, shift_size=0, mask_flag=False, rel_pos_flag=True, DPB=True, + seq_len=96, mask_weight_flag=True): + super().__init__() + self.row_att_mat_index = None + self.num_win = None + self.dim = dim + self.window_size = var2tuple(window_size) + self.num_heads = num_heads + self.scale = (dim // num_heads) ** -0.5 + self.shift_size = var2tuple(shift_size) + self.mask_flag = mask_flag + self.rel_pos_flag_ori = rel_pos_flag + self.rel_pos_flag = rel_pos_flag and any(i > 0 for i in self.window_size) + self.seq_len = seq_len + self.mask_weight_flag = mask_weight_flag + self.DPB = DPB + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + self.proj = nn.Linear(dim, dim) + + if self.DPB: + self.pow_mode = True + + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(), + nn.Linear(512, self.num_heads, bias=False)) # self.n_heads + self.h_times = nn.Parameter(torch.tensor(2.0)) + # no log + self.pow_para = nn.Parameter(torch.tensor(0.0)) + self.ws_scalar = nn.Parameter(torch.tensor(1.0)) if self.pow_mode else nn.Parameter(torch.tensor(5.0)) + self.ws_scalar2 = nn.Parameter(torch.tensor(2.0)) # 5-->3 + self.period_scale = nn.Parameter(torch.tensor(-3.0)) + self.relative_coords_table = None + self.relative_position_bias_table = None + elif self.rel_pos_flag: + w_s = [a if a > 0 else self.seq_len for a in self.window_size] + # define a parameter table of relative position bias # (2*Wh-1) * (2*Ww-1), nH + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * w_s[0] - 1) * (2 * w_s[1] - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) + + if all(i > 0 for i in self.window_size): + relative_position_index = create_swin_relative_index(self.window_size) + self.relative_position_index = relative_position_index + else: + self.relative_position_index = None + + elif self.rel_pos_flag_ori: + self.pos_table = nn.Parameter(torch.zeros(1, self.seq_len, self.dim)) + + # for different rows + self.tau = nn.Parameter(torch.tensor(-1.0)) + + def create_mask(self, window_size, shift_size, res): + if any(shift_size) and all(i > 0 for i in window_size): + # there was a bug here!!!! + img_mask = torch.zeros((res[0], res[1])) + if 0 < shift_size[0] < window_size[0]: + h_slices = (slice(0, -window_size[0]), + slice(-window_size[0], -shift_size[0]), + slice(-shift_size[0], None)) + else: + h_slices = (slice(0, None),) + if 0 < shift_size[1] < window_size[1]: + w_slices = (slice(0, -window_size[1]), + slice(-window_size[1], -shift_size[1]), + slice(-shift_size[1], None)) + else: + w_slices = (slice(0, None),) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[h, w] = cnt + cnt += 1 + else: + img_mask = torch.zeros((res[0], res[1])) + + # [num_windows, ws0, ws1] + # window_partition: x: (1, H, W, 1), window_size --> (num_windows, window_size, window_size, 1) + img_masks = window_partition(img_mask.unsqueeze(0).unsqueeze(-1), window_size).squeeze(-1) + self.num_win = num_win = img_masks.shape[0] + # mask the attention score: [num_windows, ws0*ws1,ws0*ws1] + mask = img_masks.view(num_win, 1, -1) - img_masks.view(num_win, -1, 1) + mask = mask.masked_fill(mask != 0, float('-inf')).masked_fill(mask == 0, float(0.0)) + # mask cannot be self.mask + return mask + + def forward(self, x, window_size=None, shift_size=None, res=None, imp_mask=None): + # imp_mask: [B*nW, w0*w1] + + window_change = 0 if window_size == self.window_size else 1 + if window_size is not None and window_change: + self.window_size = window_size + shift_change = 0 if shift_size == self.shift_size else 1 + if shift_size is not None and shift_change: + self.shift_size = shift_size + + if self.mask_flag and any(self.shift_size) and all(i > 0 for i in self.window_size): + swin_mask = self.create_mask(self.window_size, self.shift_size, res=res) + else: + swin_mask = None + + B, N, C = x.shape + if self.rel_pos_flag_ori and not self.rel_pos_flag: + x = x + self.pos_table[:, :N, :] + + # 3, B, num_heads, N, d + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # b,h,n,d + q, k, v = qkv[0], qkv[1], qkv[2] + + # b,h,n,n + attn = (q @ k.transpose(-2, -1)) * self.scale + + if self.DPB or self.rel_pos_flag: + if window_change or self.relative_position_index is None: + self.relative_position_index = create_swin_relative_index(self.window_size) + + relative_position_index = self.relative_position_index + + if self.DPB: + # 1, 2*Wh-1, 2*Ww-1, 2 # + F.sigmoid(self.period_scale) * period / seq_len + self.relative_coords_table = get_relative_coords_table(self.window_size, + h_times=F.softplus(self.h_times), + ws_scalar=F.softplus(self.ws_scalar), + ws_scalar2=F.softplus(self.ws_scalar2), + pow_para=F.sigmoid(self.pow_para), + pow_mode=self.pow_mode).to(x.device) + + # 1, 2*Wh-1, 2*Ww-1, n_heads --> ()*(), n_heads + self.relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.n_heads) + + if relative_position_index is not None: + # relative position # Wh*Ww,Wh*Ww,nH + relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if swin_mask is not None and self.num_win is not None: + size1 = attn.shape + # self.mask: [num_windows, ws0*ws1,ws0*ws1] + attn = (attn.view(-1, self.num_win, self.num_heads, N, N) + swin_mask.unsqueeze(0).unsqueeze(2).to( + attn.device)) + attn = attn.view(size1) # Apply the mask + + if imp_mask is not None: + # [b,n] --> [b,1,1,n] + attn = attn * imp_mask.unsqueeze(1).unsqueeze(1) + + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +def var2tuple(x, num=2): + num = int(num) + if isinstance(x, tuple): + if len(x) == num: + return x + elif len(x) > num: + return x[:num] + else: + return x + (x[-1],) * (num - len(x)) + return (x,) * num + + +class SwinTransformerBlock(nn.Module): + def __init__(self, dim, input_resolution=None, num_heads=8, window_size=(5, 5), shift_size=(0, 0), mask_flag=False, + seq_len=96, mask_weight_flag=True, series_shift=False, pad_first=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = var2tuple(window_size) + + self.shift_size = var2tuple(shift_size) + self.num_heads = num_heads + self.mask_weight_flag = mask_weight_flag + self.series_shift = series_shift + self.pad_first = pad_first + + self.norm1 = nn.LayerNorm(dim) + self.attn = WindowAttention(dim, self.window_size, num_heads, shift_size=self.shift_size, mask_flag=mask_flag, + seq_len=seq_len, mask_weight_flag=mask_weight_flag) + self.norm2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, 4 * dim), + nn.GELU(), + nn.Linear(4 * dim, dim), + ) + # self.pe = nn.Parameter(torch.zeros(1, 100, dim)) + + self.tau = nn.Parameter(torch.tensor(0.0)) + + def forward(self, x_tuple): + # args: x:[b, h, w, c], mask: [b, h, w] + + # parse input + if isinstance(x_tuple, tuple): + x, mask = x_tuple + mask = mask if self.mask_weight_flag else None + else: + x, mask = x_tuple, None + + mask_ori = mask + + B, H0, W0, C = x.shape + shortcut = x + + if mask is not None: + assert (B, H0, W0) == (mask.shape[0], mask.shape[1], mask.shape[2]), \ + "Check x and mask in SwinTransformerBlock..." + + if self.input_resolution is not None: + H, W = self.input_resolution + assert H == H0 and W == W0, "Input feature has wrong size in SwinTransformerBlock..." + + # adjust if window_size is too big + window_size2 = fix_window(self.window_size, H0, W0) + self.shift_size = fix_shift(self.shift_size, window_size2) + if self.window_size[0] != window_size2[0] or self.window_size[1] != window_size2[1]: + self.window_size = window_size2 + + if self.pad_first: + # pad --> shift + # do some padding if needed + x, H, W = win_padding(x, self.window_size) + if mask is not None: + mask = win_padding(mask.unsqueeze(-1), self.window_size)[0].squeeze(-1) + # Shift window partition + if any(self.shift_size): + + if self.series_shift: + # new way + # [B, H, W, C] --> [B, H*W, C] --> [B, shift(H*W), C] --> [B, H, W, C] + shifted_x = (torch.roll(x.flatten(start_dim=1, end_dim=2), shifts=-self.shift_size[1], dims=1) + .view(B, H, W, -1)) + if mask is not None: + mask = (torch.roll(mask.flatten(start_dim=1, end_dim=2), shifts=-self.shift_size[1], dims=1) + .view(B, H, W, -1)) + else: + # old method: use the traditional swin + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + if mask is not None: + mask = torch.roll(mask, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + + else: + shifted_x = x + else: + # shift --> pad + # Shift window partition + if any(self.shift_size): + + if self.series_shift: + # new way + # [B, H, W, C] --> [B, H*W, C] --> [B, shift(H*W), C] --> [B, H, W, C] + shifted_x = (torch.roll(x.flatten(start_dim=1, end_dim=2), shifts=-self.shift_size[1], dims=1) + .view(B, H0, W0, -1)) + if mask is not None: + mask = (torch.roll(mask.flatten(start_dim=1, end_dim=2), shifts=-self.shift_size[1], dims=1) + .view(B, H0, W0, -1)) + else: + # old method: use the traditional swin + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + if mask is not None: + mask = torch.roll(mask, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + + else: + shifted_x = x + + # do some padding if needed + shifted_x, H, W = win_padding(shifted_x, self.window_size) + if mask is not None: + mask = win_padding(mask.unsqueeze(-1), self.window_size)[0].squeeze(-1) + + shifted_x = self.norm1(shifted_x) + + # Window partition + # (B, H, W, C) --> (num_windows * B, window_size, window_size, C) --> + # (num_windows * B, window_size * window_size, C) + windows = window_partition(shifted_x, self.window_size).flatten(start_dim=1, end_dim=2) + if mask is not None: + # [B*nW, w0*w1] + mask = window_partition(mask.unsqueeze(-1), self.window_size).flatten(start_dim=1, end_dim=2).squeeze(-1) + # normalize + mask = F.normalize(mask.pow(F.softplus(self.tau)), p=1, dim=-1) + + # Window attention: input b,n,c + attn_windows = self.attn(windows, window_size=self.window_size, shift_size=self.shift_size, + res=(H, W), imp_mask=mask) + + # Merge windows + # (num_windows * B, window_size * window_size, C) --> (B, H, W, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + if self.pad_first: + # Reverse shift + if any(self.shift_size): + + if self.series_shift: + # new way + x = (torch.roll(shifted_x.flatten(start_dim=1, end_dim=2), shifts=self.shift_size[1], dims=1) + .view(B, H0, W0, -1)) + else: + # old way: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + + else: + x = shifted_x + + if H != H0 or W != W0: + x = x[:, :H0, :W0, :] + else: + if H != H0 or W != W0: + shifted_x = shifted_x[:, :H0, :W0, :] + + # Reverse + if any(self.shift_size): + + if self.series_shift: + # new way + x = (torch.roll(shifted_x.flatten(start_dim=1, end_dim=2), shifts=self.shift_size[1], dims=1) + .view(B, H0, W0, -1)) + else: + # old way: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + + else: + x = shifted_x + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x, mask_ori + + +class SwinTransformerBlockTwice(nn.Module): + def __init__(self, dim, input_resolution=None, c_in=7, c_out=7, num_heads=8, + window_size1=(5, 5), window_size2=(5, 5), shift_size=(3, 3), block_num=3, + attn_mask_flag=False, conv_patch_flag=True, seq_len=96, mask_weight_flag=True, + all_attn_flag=False): + super().__init__() + self.c_in = c_in + self.num_heads = num_heads + self.block_num = block_num + self.attn_mask_flag = attn_mask_flag + self.conv_patch_flag = conv_patch_flag + self.all_attn_flag = all_attn_flag + + self.window_size1 = var2tuple(window_size1) + self.window_size2 = var2tuple(window_size2) + self.shift_size = var2tuple(shift_size) + + print(f'self.window_size1: {self.window_size1}') + print(f'self.window_size2: {self.window_size2}') + print(f'self.shift_size: {self.shift_size}') + + # conv as patch; input should be [n,c,h,w] + if self.conv_patch_flag: + self.conv_patch = nn.Conv2d(c_in, c_in, kernel_size=3, stride=1, padding='same', bias=True) + + self.proj0 = nn.Linear(c_in, dim) if c_in != dim else nn.Identity() + + self.proj1 = nn.Linear(dim, c_out) if c_out != dim else nn.Identity() + + # column and row attention + self.block1 = nn.ModuleList([SwinTransformerBlock(dim, input_resolution, num_heads, + window_size=self.window_size1, + shift_size=(0, 0), seq_len=seq_len, + mask_weight_flag=mask_weight_flag, mask_flag=False) + for _ in range(self.block_num)]) + + if self.window_size2 is not None and not all(i == -1 for i in self.window_size1): + + self.block2 = nn.ModuleList([SwinTransformerBlock(dim, input_resolution, num_heads, + window_size=self.window_size2, + shift_size=self.shift_size, + mask_flag=self.attn_mask_flag, seq_len=seq_len, + mask_weight_flag=mask_weight_flag) + for _ in range(self.block_num)]) + if self.all_attn_flag: + self.block3 = nn.ModuleList([SwinTransformerBlock(dim, input_resolution, num_heads, + window_size=(-1, -1), + shift_size=0, + mask_flag=self.attn_mask_flag, seq_len=seq_len, + mask_weight_flag=mask_weight_flag) + for _ in range(self.block_num)]) + + self.all_swin_layers = nn.Sequential(*[layer for pair in zip(self.block1, self.block2, self.block3) + for layer in pair]) # , self.block3 + else: + self.all_swin_layers = nn.Sequential(*[layer for pair in zip(self.block1, self.block2) + for layer in pair]) + else: + print('Because of the setting of window_size2 or window_size1, only block1 is employed ' + 'in SwinTransformerBlockTwice...') + self.all_swin_layers = nn.Sequential(*self.block1) + + def forward(self, x, mask=None): + # args: x:[b,h,w,c], mask:[b,h,w] + # [b,h,w,c] --> [b,c,h,w] --> [b,h,w,c] + if self.conv_patch_flag: + x = self.conv_patch(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + # [b,h,w,c] + x = self.proj0(x) + x, _ = self.all_swin_layers((x, mask)) + + x = self.proj1(x) + return x diff --git a/ts_benchmark/baselines/olinear/layers/Transformer_EncDec.py b/ts_benchmark/baselines/olinear/layers/Transformer_EncDec.py new file mode 100644 index 00000000..6ddf4e6d --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/Transformer_EncDec.py @@ -0,0 +1,1101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils.tools import hier_half_token_weight +from ..layers.SelfAttention_Family import FullAttention, AttentionLayer +from ..layers.Embed import PatchEmbedding +from ..layers.FANLayer import FANLayer +from ..layers.newLinear import newLinear + +from ..utils.tools import create_sin_pos_embed +import math +import random +from typing import List +from ..layers.RevIN import RevIN + +from ..utils.CKA import CudaCKA + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=2, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", MLP_flag=True, + FAN_flag=False, **kwargs): + super(EncoderLayer, self).__init__() + self.MLP_flag = MLP_flag + self.FAN_flag = FAN_flag + # dff is defaulted at 4*d_model + d_ff = d_ff or 4 * d_model + self.attention = attention + self.norm1 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + if self.MLP_flag: + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + elif self.FAN_flag: + self.FANLayers = nn.Sequential( + FANLayer(d_model, d_ff, activation=activation), + FANLayer(d_ff, d_model, activation=None) + ) + + if self.MLP_flag or self.FAN_flag: + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, attn_mask=None, tau=None, delta=None, token_weight=None, **kwargs): + list_flag = isinstance(x, tuple) or isinstance(x, List) + if list_flag: + k_ori = x[1] + if len(x) not in {2, 3}: + raise ValueError('Input error in EncoderLayer') + q, k, v = (x[0], x[1], x[1]) if len(x) == 2 else (x[0], x[1], x[2]) + x = q + else: + q, k, v = x, x, x + + new_x, attn = self.attention( + q, k, v, + attn_mask=attn_mask, + tau=tau, delta=delta, + token_weight=token_weight + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + output = y + + if self.MLP_flag: + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + elif self.FAN_flag: + y = self.FANLayers(y) + + if self.MLP_flag or self.FAN_flag: + output = self.norm2(x + y) + + if list_flag: + return [output, k_ori], attn + else: + return output, attn + + +class LinearEncoder(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs): + super(LinearEncoder, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None + self.token_num = token_num + + self.norm1 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # attention --> linear + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + # self.bias = nn.Parameter(torch.zeros(1, 1, self.d_model)) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + values = self.v_proj(x) + + if self.CovMat is not None: + A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) + else: + A = F.softplus(self.weight_mat) + + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + + new_x = A @ values # + self.bias + + x = x + self.dropout(self.out_proj(new_x)) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_Multihead(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, n_heads=2, + **kwargs): + super(LinearEncoder_Multihead, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = None # CovMat.unsqueeze(0) if CovMat is not None else + self.token_num = token_num + self.n_heads = n_heads + + self.norm1 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # attention --> linear + head_dim = d_model // n_heads + self.v_proj = nn.Linear(d_model, head_dim * head_dim) + self.out_proj = nn.Linear(head_dim * head_dim, d_model) + + self.weight_mat = nn.Parameter(torch.randn(self.n_heads, self.token_num, self.token_num)) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, n, d_model + B, N, D = x.shape + # b,n,h,d + values = self.v_proj(x).reshape(B, N, self.n_heads, -1) + + A = F.softplus(self.weight_mat) + + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + + # cuda out of memory + new_x = (A @ values.transpose(1, 2)).transpose(1, 2).flatten(-2) + + x = x + self.dropout(self.out_proj(new_x)) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_newffn(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, **kwargs): + super(LinearEncoder_newffn, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None + self.token_num = token_num + + self.norm1 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # attention --> linear + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + # self.ffn1 = newLinear(input_dim=d_model, output_dim=d_ff, bias=True) + # self.ffn2 = newLinear(input_dim=d_ff, output_dim=d_model, bias=True) + + self.ffn1 = nn.Linear(d_model, d_ff) + self.ffn2 = nn.Linear(d_ff, d_model) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + values = self.v_proj(x) + + if self.CovMat is not None: + A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) + else: + A = F.softplus(self.weight_mat) + + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + + new_x = A @ values + + x = x + self.dropout(self.out_proj(new_x)) + x = self.norm1(x) + + # ffn + y = self.dropout(self.activation(self.ffn1(x))) + y = self.dropout(self.ffn2(y)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_ablation(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, + var_linear_mode='normal', temp_linear=True, temp_attn_linear=False, var_linear_enable=True, **kwargs): + super(LinearEncoder_ablation, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None + self.token_num = token_num + + self.var_linear_enable = var_linear_enable + self.linear_r_mode = var_linear_mode != 'normal' + self.temp_linear = temp_linear + self.temp_attn_linear = temp_attn_linear + + self.dropout = nn.Dropout(dropout) + + if not self.var_linear_enable: + assert self.temp_linear or self.temp_attn_linear + + if self.var_linear_enable: + self.norm1 = nn.LayerNorm(d_model) + + # attention --> linear + if self.linear_r_mode: + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + else: + self.var_lin = nn.Linear(self.token_num, self.token_num) + + if self.temp_linear: + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + elif self.temp_attn_linear: + self.ffn1 = newLinear(d_model, d_ff) + self.ffn2 = newLinear(d_ff, d_model) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + output = x # to make PyCharm happy + + if self.var_linear_enable: + + if self.linear_r_mode: + values = self.v_proj(x) + if self.CovMat is not None: + A = F.softmax(self.CovMat, dim=-1) + F.softplus(self.weight_mat) + else: + A = F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + new_x = A @ values + new_x = self.out_proj(new_x) + else: + new_x = self.var_lin(x.transpose(-1, -2)).transpose(-1, -2) + + x = x + self.dropout(new_x) + output = x = self.norm1(x) + + if self.temp_linear: + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + elif self.temp_attn_linear: + y = self.dropout(self.activation(self.ffn1(x))) + y = self.dropout(self.ffn2(y)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_pre_post_lin(nn.Module): + def __init__(self, d_model, d_ff=None, dropout=0.1, activation="relu", token_num=None, + pre_lin=True, post_lin=True, **kwargs): + super(LinearEncoder_pre_post_lin, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.token_num = token_num + + self.pre_lin = pre_lin + self.post_lin = post_lin + + self.dropout = nn.Dropout(dropout) + + self.norm1 = nn.LayerNorm(d_model) + + # attention --> linear + + self.v_proj = nn.Linear(d_model, d_model) if self.pre_lin else nn.Identity() + self.out_proj = nn.Linear(d_model, d_model) if self.post_lin else nn.Identity() + + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + + values = self.v_proj(x) + A = F.softplus(self.weight_mat) + A = F.normalize(A, p=1, dim=-1) + A = self.dropout(A) + new_x = A @ values + new_x = self.out_proj(new_x) + + x = x + self.dropout(new_x) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_abla_design(nn.Module): + def __init__(self, d_model, d_ff=None, CovMat=None, dropout=0.1, activation="relu", token_num=None, + CovMatTrans='softmax', WeightTrans='softplus', NormSet='L1', onlyconv=False, **kwargs): + super(LinearEncoder_abla_design, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.CovMat = CovMat.unsqueeze(0) if CovMat is not None else None + self.token_num = token_num + + self.CovMatTrans = CovMatTrans + self.WeightTrans = WeightTrans + self.NormSet = NormSet + self.onlyconv = onlyconv + + self.dropout = nn.Dropout(dropout) + + self.norm1 = nn.LayerNorm(d_model) + + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + assert not (self.CovMatTrans == 'none' and self.WeightTrans == 'none'), \ + 'both self.CovMatTrans and self.WeightTrans are None' + + self.cov_mat_trans = 0.0 + if self.CovMatTrans != 'none' or self.onlyconv: + if self.CovMatTrans == 'softmax': + self.cov_mat_trans = F.softmax(self.CovMat, dim=-1) + elif self.CovMatTrans == 'softplus': + self.cov_mat_trans = F.softplus(self.CovMat) + elif self.CovMatTrans == 'sigmoid': + self.cov_mat_trans = F.sigmoid(self.CovMat) + elif self.CovMatTrans == 'relu': + self.cov_mat_trans = F.relu(self.CovMat) + elif self.CovMatTrans == 'identity': + self.cov_mat_trans = self.CovMat + else: + raise NotImplementedError + + # Linear design + if self.CovMatTrans != 'none' or self.onlyconv: + assert self.CovMat is not None, 'CovMat cannot be None in LinearEncoder.' + + self.weight_mat = None + if self.WeightTrans != 'none' and not self.onlyconv: + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + + values = self.v_proj(x) + + weight_mat_trans = 0.0 + if not self.onlyconv and self.WeightTrans != 'none' and self.weight_mat is not None: + if self.WeightTrans == 'softmax': + weight_mat_trans = F.softmax(self.weight_mat, dim=-1) + elif self.WeightTrans == 'softplus': + weight_mat_trans = F.softplus(self.weight_mat) + elif self.WeightTrans == 'sigmoid': + weight_mat_trans = F.sigmoid(self.weight_mat) + elif self.WeightTrans == 'relu': + weight_mat_trans = F.relu(self.weight_mat) + elif self.WeightTrans == 'identity': + weight_mat_trans = self.weight_mat + else: + raise NotImplementedError + + A = self.cov_mat_trans + weight_mat_trans + + if self.NormSet != 'none': + if self.NormSet == 'L1': + A = F.normalize(A, p=1, dim=-1) + elif self.NormSet == 'L2': + A = F.normalize(A, p=2, dim=-1) + + A = self.dropout(A) + new_x = A @ values + new_x = self.out_proj(new_x) + + x = x + self.dropout(new_x) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None + + +class LinearEncoder_abla_design_nocorr(nn.Module): + def __init__(self, d_model, d_ff=None, dropout=0.1, activation="relu", token_num=None, + WeightTrans='softplus', NormSet='L1', **kwargs): + super(LinearEncoder_abla_design_nocorr, self).__init__() + + d_ff = d_ff or 4 * d_model + self.d_model = d_model + self.d_ff = d_ff + self.token_num = token_num + + self.WeightTrans = WeightTrans + self.NormSet = NormSet + + self.dropout = nn.Dropout(dropout) + + self.norm1 = nn.LayerNorm(d_model) + + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + # Linear design + init_weight_mat = torch.eye(self.token_num) * 1.0 + torch.randn(self.token_num, self.token_num) * 1.0 + self.weight_mat = nn.Parameter(init_weight_mat[None, :, :]) + + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.activation = F.relu if activation == "relu" else F.gelu + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, **kwargs): + # x.shape: b, l, d_model + + values = self.v_proj(x) + + if self.WeightTrans == 'softmax': + weight_mat_trans = F.softmax(self.weight_mat, dim=-1) + elif self.WeightTrans == 'softplus': + weight_mat_trans = F.softplus(self.weight_mat) + elif self.WeightTrans == 'sigmoid': + weight_mat_trans = F.sigmoid(self.weight_mat) + elif self.WeightTrans == 'relu': + weight_mat_trans = F.relu(self.weight_mat) + elif self.WeightTrans == 'identity': + weight_mat_trans = self.weight_mat + else: + raise NotImplementedError + + A = weight_mat_trans + + if self.NormSet != 'none': + if self.NormSet == 'L1' and self.WeightTrans != 'softmax': + A = F.normalize(A, p=1, dim=-1) + elif self.NormSet == 'L2': + A = F.normalize(A, p=2, dim=-1) + + A = self.dropout(A) + new_x = A @ values + new_x = self.out_proj(new_x) + + x = x + self.dropout(new_x) + x = self.norm1(x) + + y = self.dropout(self.activation(self.conv1(x.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + output = self.norm2(x + y) + + return output, None + + +class Encoder_ori(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None, one_output=False, CKA_flag=False): + super(Encoder_ori, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + self.one_output = one_output + self.CKA_flag = CKA_flag + if self.CKA_flag: + print('CKA is enabled...') + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, nvars, D] + attns = [] + X0 = None # to make Pycharm happy + layer_len = len(self.attn_layers) + for i, attn_layer in enumerate(self.attn_layers): + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if not self.training and self.CKA_flag and layer_len > 1: + if i == 0: + X0 = x + + if i == layer_len - 1 and random.uniform(0, 1) < 1e-1: + CudaCKA1 = CudaCKA(device=x.device) + cka_value = CudaCKA1.linear_CKA(X0.flatten(0, 1)[:1000], x.flatten(0, 1)[:1000]) + print(f'CKA: \t{cka_value:.3f}') + + if isinstance(x, tuple) or isinstance(x, List): + x = x[0] + + if self.norm is not None: + x = self.norm(x) + + if self.one_output: + return x + else: + return x, attns + + +class FlattenHead(nn.Module): + def __init__(self, nf, target_window, head_dropout=0.0): + super().__init__() + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + +class refine_module(nn.Module): + def __init__(self, configs): + super().__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in + self.e_layers = configs.second_e_layers + self.d_model = configs.d_model + self.d_ff = configs.d_ff + self.dropout = configs.dropout + self.n_heads = configs.n_heads + self.activation = configs.activation + self.patch_len = configs.temp_patch_len2 + self.stride = configs.temp_stride2 + + self.encoder = Encoder_ori( + [ + EncoderLayer( + AttentionLayer( + FullAttention(mask_flag=False, attention_dropout=self.dropout, + output_attention=False, token_num=None, imp_mode=False, + ij_mat_flag=False, num_heads=self.n_heads, plot_mat_flag=False), + d_model=self.d_model, n_heads=self.n_heads), + d_model=self.d_model, + d_ff=self.d_ff, + dropout=self.dropout, + activation=self.activation + ) for _ in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model) + ) + + self.revin_layer = RevIN(self.enc_in, affine=True) + # self.revin_layer2 = RevIN(self.enc_in, affine=True) + + self.patch_embedding = PatchEmbedding( + d_model=self.d_model, patch_len=self.patch_len, stride=self.stride, padding=self.stride, + dropout=self.dropout) + self.patch_embedding2 = PatchEmbedding( + d_model=self.d_model, patch_len=self.patch_len, stride=self.stride, padding=self.stride, + dropout=self.dropout) + + # flatten head + self.pred_token_num = int((self.pred_len - self.patch_len) / self.stride + 2) + self.head_nf = self.d_model * self.pred_token_num + self.head = FlattenHead(nf=self.head_nf, target_window=self.pred_len, + head_dropout=configs.dropout) + # self.dropout_layer = nn.Dropout(configs.dropout) + + print('refine_module is used') + + def forward(self, x, pred): + # x, pred: [batch, len, vars] + assert x.ndim == 3 and pred.ndim == 3 and x.shape[-1] == self.enc_in and pred.shape[-1] == self.enc_in + + xy_concat = torch.concat([x, pred], dim=1) + # normalize xy_concat + xy_concat = self.revin_layer(xy_concat, mode='norm') + pred = xy_concat[:, -self.pred_len:, :] + + # use another revin_layer + # pred = self.revin_layer2(pred, mode='norm') + + # patch embedding: return [b*n, token_num, d_model] + xy_embed, _ = self.patch_embedding(xy_concat.transpose(-1, -2)) + pred = xy_embed[:, -self.pred_token_num:, :] + # pred, _ = self.patch_embedding2(pred.transpose(-1, -2)) + + # encoder [b*n, token_num, d_model] + pred_refine_feat, _ = self.encoder(x=(pred, xy_embed)) + pred_refine_feat = torch.reshape( + pred_refine_feat, (-1, self.enc_in, pred_refine_feat.shape[-2], pred_refine_feat.shape[-1])) + pred_refine_feat = pred_refine_feat.transpose(-1, -2) + + # Decoder + pred_refine = self.head(pred_refine_feat) # z: [bs x nvars x pred_len] + pred_refine = pred_refine.transpose(-1, -2) + + pred_refine = self.revin_layer(pred_refine, mode='denorm') + + return pred_refine + + +class fix_mask_with_neighbor(nn.Module): + def __init__(self, enc_in, kernel_size=3): + super().__init__() + self.enc_in = enc_in + self.kernel_size = kernel_size + self.alpha = nn.Parameter(torch.tensor(-3.0)) + # different channels do not mix + self.conv1d = nn.Conv1d(self.enc_in, self.enc_in, groups=self.enc_in, kernel_size=self.kernel_size, + padding='same', padding_mode='zeros', bias=True) # padding_mode: zeros circular + + def forward(self, x, mask=None): + if mask is None: + return x + B, N, D = x.shape + if D == self.enc_in: + x = x.transpose(-1, -2) + B, N, D = x.shape + assert N == self.enc_in, 'N!=self.enc_in in fix_mask_with_neighbor class...' + alpha = F.sigmoid(self.alpha) + x2 = alpha * self.conv1d(x) + (1 - alpha) * x + if mask.shape[1] != N: + mask = mask.transpose(-1, -2) + x = torch.where(mask, x2, x) + return x + + +def swin_output_update2(dec_seq_i, revin_layer, mask, x_ori, dec_seq_inter, dec_seq_inter2): + # dec_seq_i: [b, l, n] + # update dec_seq_i, revin_layer, dec_seq_inter, dec_seq_inter2 when Swin_output is enabled + if dec_seq_i is not None: + if revin_layer is not None: + dec_seq_i = revin_layer(dec_seq_i, mode='denorm') + dec_seq_inter.append(dec_seq_i) + + dec_seq_i[~mask] = x_ori[~mask] # calibration + dec_seq_inter2.append(dec_seq_i) + if revin_layer is not None: + dec_seq_i = revin_layer(dec_seq_i, mode='norm', mask=None) # new revin_layer + return dec_seq_i + + +class Encoder(nn.Module): + # the general class + def __init__(self, attn_layers, temp_attn_layers=None, temp_attn_params=None, conv_layers=None, norm_layer=None, + hierarchyFlag=False, patch_ln=False, imp_mode=False, Swin_after_patch=0, Swin_after_iTrans=0, + swin_layers=None, swin_layers2=None, Patch_CI=True, neighbor_fix=False, swin_first=False): + # , temp_token_num=None, channel_token_num=None + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + self.temp_attn_layers = nn.ModuleList(temp_attn_layers) if temp_attn_layers is not None else None + self.temp_attn_params = temp_attn_params + self.imp_mode = imp_mode + self.hierarchyFlag = False if self.imp_mode else hierarchyFlag + + self.Swin_after_patch = Swin_after_patch + self.Swin_after_iTrans = Swin_after_iTrans + + self.seq_len = temp_attn_params.seq_len + self.d_model = temp_attn_params.d_model + self.temp_token_num = temp_attn_params.temp_token_num + self.enc_in = temp_attn_params.enc_in + + self.swin_layers = swin_layers + self.swin_layers2 = swin_layers2 + + self.Patch_CI = Patch_CI + self.neighbor_fix = neighbor_fix + self.swin_first = swin_first + + self.act_layer = nn.GELU() + # self.act_layer = nn.ReLU() + + print('Encoder is initialized...') + + self.projector3 = nn.Linear(self.seq_len, self.d_model, bias=True) \ + if self.Swin_after_patch else nn.Identity() + + if self.imp_mode and self.Swin_after_iTrans and self.swin_layers2 is not None: + self.projector3_last = nn.Linear(self.d_model, self.seq_len, bias=True) + self.norm_post_proj3_last = nn.LayerNorm(self.d_model) + self.projector4_last = nn.Linear(self.seq_len, self.d_model, bias=True) + + self.dropout = nn.Dropout(p=0.1) + if temp_attn_layers is not None: + self.len1 = len1 = len(temp_attn_layers) + print('Temporal attention is initialized in Encoder...') + self.patch_ln = patch_ln # or len(attn_layers) == 0 + if self.patch_ln: + print('LayerNorm in PatchTST is used...') + + if self.neighbor_fix: + print('Class fix_mask_with_neighbor is used...') + self.fix_layer = fix_mask_with_neighbor(self.enc_in) + + if self.Patch_CI: + self.projector1 = nn.Linear(self.temp_attn_params.temp_patch_len, self.d_model) + + # patch -- swin + self.projector2 = nn.Linear(self.d_model * self.temp_token_num, + self.seq_len if self.Swin_after_patch + else self.d_model) + else: + self.projector1 = nn.Linear(self.temp_attn_params.temp_patch_len * self.enc_in, self.d_model) + # [b, n'*d_model] --> [b,n*len] + self.projector2 = nn.Linear(self.d_model * self.temp_token_num, + self.seq_len * self.enc_in if self.Swin_after_patch + else self.d_model * self.enc_in) + + if self.imp_mode and not self.Swin_after_patch and len(self.attn_layers) == 0: + self.proj2_d2l = nn.Linear(self.d_model, self.seq_len) # for check + + # for shortcut + # if self.Swin_after_patch and self.len1 >= 1 and self.swin_first: + # self.proj2_l2l_patch2seq = nn.Linear(self.seq_len, self.seq_len) + # self.alpha = nn.Parameter(torch.tensor(-3.0)) + + pe_mat = torch.zeros(1, self.temp_token_num, self.d_model) + self.positional_encoding = nn.Parameter(pe_mat) + + self.norm0 = nn.LayerNorm(self.d_model) + self.norm1 = nn.LayerNorm(self.d_model) + + # hierarchy mode + if self.hierarchyFlag and len1 > 1: + self.proj = nn.ModuleList([nn.Linear(self.d_model * 2, self.d_model, + bias=True) for _ in range(len1 - 1)] + + [nn.Linear(self.d_model * len1, + self.d_model, + bias=True)]) + self.norm2 = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(len1)]) + else: + self.projector2 = nn.Linear(self.d_model, self.seq_len, bias=True) + + if len(self.attn_layers): + self.norm_for_i = nn.LayerNorm(self.d_model) + + self.Pi_flag = len(self.attn_layers) > 0 and self.temp_attn_layers is not None + + if self.Pi_flag: + self.projector_shortcut = nn.Linear(self.seq_len, self.d_model, bias=True) + self.pi_weight = nn.Parameter(torch.zeros(2)) + self.pi_weight2 = nn.Parameter(torch.zeros(2)) + self.tau = nn.Parameter(torch.ones(1) * -5) + self.tau2 = nn.Parameter(torch.ones(1) * -5) + + def forward(self, x, attn_mask=None, tau=None, delta=None, temp_token_weight=None, ch_token_weight=None, + revin_layer=None, mask=None, x_ori=None, dec_seq_inter=None, dec_seq_inter2=None, swin_output=0): + # x [B, N, D] L means channel here + B, N, D = x.shape + input_ori = x + attns = [] + Swin_inter_output = x.transpose(-1, -2) + if self.conv_layers is not None: + # not used in this project + for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + # delta only work at the first layer + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + x0 = None + # patch tst + if self.temp_attn_layers is not None: + if self.neighbor_fix: + x = self.fix_layer(x, mask) + + if swin_output: + Swin_inter_output = swin_output_update2(x.transpose(-1, -2), revin_layer, mask, x_ori, + dec_seq_inter, dec_seq_inter2) + + if self.Pi_flag: + x0 = x + + if self.Patch_CI: + # [b,n,d] --> [b,n,n',p] --> [b',n',p] + assert D > self.temp_attn_params.temp_patch_len + rem = (D - self.temp_attn_params.temp_patch_len) % self.temp_attn_params.temp_stride + if rem != 0: + x = F.pad(x, pad=[0, self.temp_attn_params.temp_stride - rem]) + x = x.unfold(dimension=-1, size=self.temp_attn_params.temp_patch_len, + step=self.temp_attn_params.temp_stride).flatten(start_dim=0, end_dim=1) + else: + # [b,n,d] --> [b,n,n',p] --> [b,n',n, p] --> [b,n',n*p] + x = x.unfold(dimension=-1, size=self.temp_attn_params.temp_patch_len, + step=self.temp_attn_params.temp_stride).transpose(1, 2).flatten(start_dim=-2) + + # [b',n',p] --> [b',n',d_model] + # self.projector1 has been modified according to self.Patch_CI + x = self.projector1(x) + self.positional_encoding + x = self.norm0(x) + + x_list = [] + # x0 = x + B2, L2, D2 = x.shape + + for i, temp_attn_layer in enumerate(self.temp_attn_layers): + if temp_token_weight is not None: + assert x.shape[1] == temp_token_weight.shape[1], 'Please check temp_token_weight.' + + x, _ = temp_attn_layer(x, token_weight=temp_token_weight) + if i == 0 and self.hierarchyFlag and self.len1 > 1: + x_list.append(x) + if self.hierarchyFlag and x.shape[1] < L2: + # x2 = x.repeat_interleave(2 ** i, dim=1) + x2 = x.repeat_interleave(math.ceil(L2 / x.shape[1]), dim=1) + x2 = x2[:, :L2, :] + x_list.append(x2) + + if self.hierarchyFlag and self.len1 > 1 and x.shape[1] > 2 and i < self.len1 - 1: + # 240905; token num should > 2 + # token num --> 1/2 + B3, L3, D3 = x.shape + # print(f'x.shape: {x.shape}') + if L3 % 2 == 1: + x = torch.cat([x, x[:, [-1], :]], dim=1) + x2 = x.reshape(B3, -1, 2, D3).flatten(start_dim=-2) + # print(f'x2.shape: {x2.shape}') + # print(f'self.proj[i]: {self.proj[i]}') + x2 = self.proj[i](x2) + x = self.act_layer(x2) + # x = self.norm2[i](x2) + + # temp_token_weight; for imputation + temp_token_weight = hier_half_token_weight(temp_token_weight) + # print('Token_weight num reduced to ', temp_token_weight.shape[1], '. ') + # output + if self.hierarchyFlag and self.len1 > 1: + # cat and project + # x2 = torch.cat(x_list, dim=-1) + # x = self.norm2[-1](self.proj[-1](x2)) + x = torch.sum(torch.stack(x_list), dim=0) + # x = self.norm2[-1](x) + x = self.act_layer(x) + + # temp_attn_layers could be empty; Swin can still be used without temp_attn_layers + if self.imp_mode and self.Swin_after_patch and self.swin_layers is not None: + if self.temp_attn_layers is not None: + # self.Patch_CI: [b',n',d_model] --> [b',n'*d_model] --> [b',seq_len] --> [b,n,l] + # not self.Patch_CI: [b,n',d_model] --> [b,n'*d_model] --> [b,n*seq_len] --> [b,n,l] + x = self.projector2(x.flatten(start_dim=-2)).view(B, N, -1) + + # shortcut + # if self.swin_first: + # alpha = F.sigmoid(self.alpha) + # x = self.proj2_l2l_patch2seq(x) + alpha * input_ori + else: + # [b,n,d] --> [b,n,l] + x = self.projector2(x) + + # seq_len + # Swin_inter_output has to be [b,l,n] + Swin_inter_output = x.transpose(-1, -2) + + # revin layer + if swin_output: + Swin_inter_output = swin_output_update2(Swin_inter_output, revin_layer, mask, x_ori, + dec_seq_inter, dec_seq_inter2) + if isinstance(self.swin_layers, nn.ModuleList): + for layer in self.swin_layers: + Swin_inter_output, _ = layer((Swin_inter_output, mask)) + else: + Swin_inter_output, _ = self.swin_layers((Swin_inter_output, mask)) + + # revin layer + if swin_output: + Swin_inter_output = swin_output_update2(Swin_inter_output, revin_layer, mask, x_ori, + dec_seq_inter, dec_seq_inter2) + + x = Swin_inter_output.transpose(-1, -2) + + if len(self.attn_layers): + # [b,n,l] --> [b,n,dim] + x = self.projector3(x) + x = self.norm_for_i(x) + else: + if self.temp_attn_layers is not None: + + # if self.Patch_CI: [b',n',d_model] --> [b',n'*d_model] --> [b',d_model] --> [b,n,dim] + # else: [b, n',d_model] --> [b,n'*d_model] --> [b,n*dim] --> [b,n,dim] + # x = x + x0 + x = self.projector2(x.flatten(start_dim=-2)).view(B, N, -1) + + if self.patch_ln: + x = self.norm1(x) + else: + x = self.act_layer(x) + + if len(self.attn_layers) == 0 and self.imp_mode: + # also check this intermediate result; 240702 + # [b, n, dim] --> [b, n, l] + # print('Check this....') + Swin_inter_output = self.proj2_d2l(x) + Swin_inter_output.transpose(-1, -2) + Swin_inter_output = Swin_inter_output.transpose(-1, -2) + # revin layer + if swin_output: + Swin_inter_output = swin_output_update2(Swin_inter_output, revin_layer, mask, x_ori, + dec_seq_inter, dec_seq_inter2) + + if self.Pi_flag: + # x0 = x + + pi_weight = F.softmax(self.pi_weight / F.softplus(self.tau)) + x = x0 = x * pi_weight[0] + self.projector_shortcut(x0) * pi_weight[1] + # if not self.training and random.uniform(0, 1) < 1e-3: + # print(f'pi_weight: {pi_weight}') + + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta, token_weight=ch_token_weight) + # attns.append(attn) + + # 240925 + if self.Pi_flag: + pi_weight2 = F.softmax(self.pi_weight2 / F.softplus(self.tau2)) + x = x * pi_weight2[0] + x0 * pi_weight2[1] + + # pi_weight2 = F.softplus(self.pi_weight2) + # x = x + x0 * pi_weight2[1] + + # x = x0 + + # if not self.training and random.uniform(0, 1) < 1e-3: + # print(f'pi_weight2: {pi_weight2}') + + if self.imp_mode and self.Swin_after_iTrans and self.swin_layers2 is not None: + + # [b,n,dim] --> [b,n,l] --> [b,l,n] + Swin_inter_output = self.projector3_last(self.norm_post_proj3_last(x)).transpose(-1, -2) + + # revin layer + if swin_output: + Swin_inter_output = swin_output_update2(Swin_inter_output, revin_layer, mask, x_ori, + dec_seq_inter, dec_seq_inter2) + + # [b,n,l] --> [b,l,n] --> [b,n,l] + if isinstance(self.swin_layers2, nn.ModuleList): + for layer in self.swin_layers2: + Swin_inter_output, _ = layer((Swin_inter_output, mask)) + else: + Swin_inter_output, _ = self.swin_layers2((Swin_inter_output, mask)) + + # [b,l,n] --> [b,n,l] --> [b,n,dim] + # x = self.projector4_last(Swin_inter_output.transpose(-1, -2)) + + # keep it is + x = Swin_inter_output.transpose(-1, -2) + + if self.norm is not None and x.shape[-1] == self.d_model: + x = self.norm(x) + + return x, attns, Swin_inter_output.transpose(-1, -2) # Swin_inter_output not updated for Swin_after_iTrans + + +class DecoderLayer(nn.Module): + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask, + tau=tau, delta=None + )[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask, + tau=tau, delta=delta + )[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x diff --git a/ts_benchmark/baselines/olinear/layers/__init__.py b/ts_benchmark/baselines/olinear/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ts_benchmark/baselines/olinear/layers/newLinear.py b/ts_benchmark/baselines/olinear/layers/newLinear.py new file mode 100644 index 00000000..671cd6f6 --- /dev/null +++ b/ts_benchmark/baselines/olinear/layers/newLinear.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class newLinear(nn.Module): + + def __init__(self, input_dim, output_dim, bias=False): + super(newLinear, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + + self.weight_mat = nn.Parameter(torch.randn(self.output_dim, self.input_dim)) + + if self.bias: + self.bias_weight = nn.Parameter(torch.zeros(1, self.output_dim)) + + def forward(self, x): + x_shape = x.shape + assert x_shape[-1] == self.input_dim + + x_2d = x.reshape(-1, self.input_dim) + + # output_dim, input_dim + weight_mat = F.normalize(F.softplus(self.weight_mat), p=1, dim=-1) + # output_dim, -1 + output = weight_mat @ (x_2d.transpose(-1, -2)) + # -1, output_dim + output = output.transpose(-1, -2) + + if self.bias: + output = output + self.bias_weight + + new_shape = x_shape[:-1] + (self.output_dim,) + + return output.reshape(new_shape).contiguous() diff --git a/ts_benchmark/baselines/olinear/models/__init__.py b/ts_benchmark/baselines/olinear/models/__init__.py new file mode 100644 index 00000000..7bb9fede --- /dev/null +++ b/ts_benchmark/baselines/olinear/models/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["Model"] + +from ts_benchmark.baselines.olinear.models.olinear_model import Model diff --git a/ts_benchmark/baselines/olinear/models/olinear_model.py b/ts_benchmark/baselines/olinear/models/olinear_model.py new file mode 100644 index 00000000..0feac1d2 --- /dev/null +++ b/ts_benchmark/baselines/olinear/models/olinear_model.py @@ -0,0 +1,204 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..layers.RevIN import RevIN +from ..layers.Transformer_EncDec import Encoder_ori, EncoderLayer, LinearEncoder, LinearEncoder_Multihead +from ..layers.SelfAttention_Family import AttentionLayer, EnhancedAttention + +import sys + + +class Model(nn.Module): + def __init__(self, configs): + super(Model, self).__init__() + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in # channels + self.seq_len = configs.seq_len + self.hidden_size = self.d_model = configs.d_model # hidden_size + self.d_ff = configs.d_ff # d_ff + + self.Q_chan_indep = configs.Q_chan_indep + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + q_mat_path = self._resolve_path(configs.Q_MAT_file if self.Q_chan_indep else configs.q_mat_file, configs.root_path) + q_out_mat_path = self._resolve_path(configs.Q_OUT_MAT_file if self.Q_chan_indep else configs.q_out_mat_file, configs.root_path) + + self.Q_mat = self._load_q_matrix( + matrix_path=q_mat_path, + is_channel_indep=self.Q_chan_indep, + expected_first_dim=(self.enc_in if self.Q_chan_indep else self.seq_len), + square_size=self.seq_len, + matrix_name="Q", + device=device, + ) + + self.Q_out_mat = self._load_q_matrix( + matrix_path=q_out_mat_path, + is_channel_indep=self.Q_chan_indep, + expected_first_dim=(self.enc_in if self.Q_chan_indep else self.pred_len), + square_size=self.pred_len, + matrix_name="Q_out", + device=device, + ) + + self.patch_len = configs.temp_patch_len + self.stride = configs.temp_stride + + # self.channel_independence = configs.channel_independence + self.embed_size = configs.embed_size # embed_size + self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) + + self.fc = nn.Sequential( + nn.Linear(self.pred_len * self.embed_size, self.d_ff), + nn.GELU(), + nn.Linear(self.d_ff, self.pred_len) + ) + + # for final input and output + self.revin_layer = RevIN(self.enc_in, affine=True) + self.dropout = nn.Dropout(configs.dropout) + + # ############# transformer related ######### + self.encoder = Encoder_ori( + [ + LinearEncoder( + d_model=configs.d_model, d_ff=configs.d_ff, CovMat=None, + dropout=configs.dropout, activation=configs.activation, token_num=self.enc_in, + ) for _ in range(configs.e_layers) + ], + norm_layer=nn.LayerNorm(configs.d_model), + one_output=True, + CKA_flag=configs.CKA_flag + ) + self.ortho_trans = nn.Sequential( + nn.Linear(self.seq_len * self.embed_size, self.d_model), + self.encoder, + nn.Linear(self.d_model, self.pred_len * self.embed_size) + ) + + # learnable delta + self.delta1 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.seq_len)) + self.delta2 = nn.Parameter(torch.zeros(1, self.enc_in, 1, self.pred_len)) + + @staticmethod + def _resolve_path(file_path, root_path): + if os.path.isabs(file_path): + return file_path + if os.path.isfile(file_path): + return file_path + return os.path.abspath(os.path.join(root_path, file_path)) + + @staticmethod + def _load_q_matrix( + matrix_path, + is_channel_indep, + expected_first_dim, + square_size, + matrix_name, + device, + ): + if os.path.isfile(matrix_path): + matrix = np.load(matrix_path) + else: + # Keep training runnable without external Q files by using identity transforms. + if is_channel_indep: + matrix = np.stack([np.eye(square_size, dtype=np.float32) for _ in range(expected_first_dim)], axis=0) + else: + matrix = np.eye(square_size, dtype=np.float32) + + tensor = torch.from_numpy(matrix).to(torch.float32).to(device) + + expected_ndim = 3 if is_channel_indep else 2 + if tensor.ndim != expected_ndim: + raise ValueError( + f"{matrix_name} matrix ndim mismatch: expected {expected_ndim}, got {tensor.ndim}. " + f"Path: {matrix_path}" + ) + + if tensor.shape[0] != expected_first_dim: + raise ValueError( + f"{matrix_name} matrix first dim mismatch: expected {expected_first_dim}, got {tensor.shape[0]}. " + f"Path: {matrix_path}" + ) + + if tensor.shape[-1] != square_size: + raise ValueError( + f"{matrix_name} matrix last dim mismatch: expected {square_size}, got {tensor.shape[-1]}. " + f"Path: {matrix_path}" + ) + + return tensor + + # dimension extension + def tokenEmb(self, x, embeddings): + if self.embed_size <= 1: + return x.transpose(-1, -2).unsqueeze(-1) + # x: [B, T, N] --> [B, N, T] + x = x.transpose(-1, -2) + x = x.unsqueeze(-1) + # B*N*T*1 x 1*D = B*N*T*D + return x * embeddings + + def Fre_Trans(self, x): + # [B, N, T, D] + B, N, T, D = x.shape + assert T == self.seq_len + # [B, N, D, T] + x = x.transpose(-1, -2) + + # orthogonal transformation + # [B, N, D, T] + if self.Q_chan_indep: + x_trans = torch.einsum('bndt,ntv->bndv', x, self.Q_mat.transpose(-1, -2)) + else: + x_trans = torch.einsum('bndt,tv->bndv', x, self.Q_mat.transpose(-1, -2)) + self.delta1 + # added on 25/1/30 + # x_trans = F.gelu(x_trans) + # [B, N, D, T] + assert x_trans.shape[-1] == self.seq_len + + # ########## transformer #### + x_trans = self.ortho_trans(x_trans.flatten(-2)).reshape(B, N, D, self.pred_len) + + # [B, N, D, tau]; orthogonal transformation + if self.Q_chan_indep: + x = torch.einsum('bndt,ntv->bndv', x_trans, self.Q_out_mat) + else: + x = torch.einsum('bndt,tv->bndv', x_trans, self.Q_out_mat) + self.delta2 + # added on 25/1/30 + # x = F.gelu(x) + + # [B, N, tau, D] + x = x.transpose(-1, -2) + return x + + def forward(self, x, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + # x: [Batch, Input length, Channel] + B, T, N = x.shape + + # revin norm + x = self.revin_layer(x, mode='norm') + x_ori = x + + # ########### frequency (high-level) part ########## + # input fre fine-tuning + # [B, T, N] + # embedding x: [B, N, T, D] + x = self.tokenEmb(x_ori, self.embeddings) + # [B, N, tau, D] + x = self.Fre_Trans(x) + + # linear + # [B, N, tau*D] --> [B, N, dim] --> [B, N, tau] --> [B, tau, N] + out = self.fc(x.flatten(-2)).transpose(-1, -2) + + # dropout + out = self.dropout(out) + + # revin denorm + out = self.revin_layer(out, mode='denorm') + + return out \ No newline at end of file diff --git a/ts_benchmark/baselines/olinear/olinear.py b/ts_benchmark/baselines/olinear/olinear.py new file mode 100644 index 00000000..a02e0bee --- /dev/null +++ b/ts_benchmark/baselines/olinear/olinear.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +import os + +import torch.nn as nn +from torch import optim + +from ts_benchmark.baselines.deep_forecasting_model_base import DeepForecastingModelBase + +from .models.olinear_model import Model as OLinearModel + +MODEL_HYPER_PARAMS = { + "d_model": 256, + "d_ff": 512, + "e_layers": 2, + "dropout": 0.1, + "embed_size": 1, + "temp_patch_len": 1, + "temp_stride": 1, + "activation": "gelu", + "CKA_flag": False, + "Q_chan_indep": False, + "q_mat_file": "dataset/ILI_Q.npy", + "q_out_mat_file": "dataset/ILI_Q_out.npy", + "root_path": "./", + "lr": 0.001, +} + + +class OLinear(DeepForecastingModelBase): + def __init__(self, **kwargs): + if "root_path" not in kwargs: + kwargs["root_path"] = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + super(OLinear, self).__init__(MODEL_HYPER_PARAMS, **kwargs) + + @property + def model_name(self): + return "OLinear" + + def _init_criterion_and_optimizer(self): + criterion = nn.MSELoss() + optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr) + return criterion, optimizer + + def _init_model(self): + return OLinearModel(self.config) + + def _process(self, input, target, input_mark, target_mark): + outputs = self.model(input) + return {"output": outputs} \ No newline at end of file diff --git a/ts_benchmark/baselines/olinear/utils/CKA.py b/ts_benchmark/baselines/olinear/utils/CKA.py new file mode 100644 index 00000000..ec9b1bc9 --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/CKA.py @@ -0,0 +1,92 @@ +# inspired by +# https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment/blob/master/CKA.py + +import math +import torch +import numpy as np + +class CKA(object): + def __init__(self): + pass + + def centering(self, K): + n = K.shape[0] + unit = np.ones([n, n]) + I = np.eye(n) + H = I - unit / n + return np.dot(np.dot(H, K), H) + + def rbf(self, X, sigma=None): + GX = np.dot(X, X.T) + KX = np.diag(GX) - GX + (np.diag(GX) - GX).T + if sigma is None: + mdist = np.median(KX[KX != 0]) + sigma = math.sqrt(mdist) + KX *= - 0.5 / (sigma * sigma) + KX = np.exp(KX) + return KX + + def kernel_HSIC(self, X, Y, sigma): + return np.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) + + def linear_HSIC(self, X, Y): + L_X = X @ X.T + L_Y = Y @ Y.T + return np.sum(self.centering(L_X) * self.centering(L_Y)) + + def linear_CKA(self, X, Y): + hsic = self.linear_HSIC(X, Y) + var1 = np.sqrt(self.linear_HSIC(X, X)) + var2 = np.sqrt(self.linear_HSIC(Y, Y)) + + return hsic / (var1 * var2) + + def kernel_CKA(self, X, Y, sigma=None): + hsic = self.kernel_HSIC(X, Y, sigma) + var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) + var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) + + return hsic / (var1 * var2) + + +class CudaCKA(object): + def __init__(self, device): + self.device = device + + def centering(self, K): + n = K.shape[0] + unit = torch.ones([n, n], device=self.device) + I = torch.eye(n, device=self.device) + H = I - unit / n + return torch.matmul(torch.matmul(H, K), H) + + def rbf(self, X, sigma=None): + GX = torch.matmul(X, X.T) + KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T + if sigma is None: + mdist = torch.median(KX[KX != 0]) + sigma = math.sqrt(mdist) + KX *= - 0.5 / (sigma * sigma) + KX = torch.exp(KX) + return KX + + def kernel_HSIC(self, X, Y, sigma): + return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) + + def linear_HSIC(self, X, Y): + L_X = torch.matmul(X, X.T) + L_Y = torch.matmul(Y, Y.T) + return torch.sum(self.centering(L_X) * self.centering(L_Y)) + + def linear_CKA(self, X, Y): + hsic = self.linear_HSIC(X, Y) + var1 = torch.sqrt(self.linear_HSIC(X, X)) + var2 = torch.sqrt(self.linear_HSIC(Y, Y)) + + return hsic / (var1 * var2) + + def kernel_CKA(self, X, Y, sigma=None): + hsic = self.kernel_HSIC(X, Y, sigma) + var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma)) + var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma)) + return hsic / (var1 * var2) diff --git a/ts_benchmark/baselines/olinear/utils/__init__.py b/ts_benchmark/baselines/olinear/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ts_benchmark/baselines/olinear/utils/losses.py b/ts_benchmark/baselines/olinear/utils/losses.py new file mode 100644 index 00000000..1cfbf4bb --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/losses.py @@ -0,0 +1,103 @@ +# This source code is provided for the purposes of scientific reproducibility +# under the following limited license from Element AI Inc. The code is an +# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis +# expansion analysis for interpretable time series forecasting, +# https://arxiv.org/abs/1905.10437). The copyright to the source code is +# licensed under the Creative Commons - Attribution-NonCommercial 4.0 +# International license (CC BY-NC 4.0): +# https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether +# for the benefit of third parties or internally in production) requires an +# explicit license. The subject-matter of the N-BEATS model and associated +# materials are the property of Element AI Inc. and may be subject to patent +# protection. No license to patents is granted hereunder (whether express or +# implied). Copyright © 2020 Element AI Inc. All rights reserved. + +""" +Loss functions for PyTorch. +""" + +import torch as t +import torch.nn as nn +import numpy as np +import pdb + + +def divide_no_nan(a, b): + """ + a/b where the resulted NaN or Inf are replaced by 0. + """ + result = a / b + result[result != result] = .0 + result[result == np.inf] = .0 + return result + + +class mape_loss(nn.Module): + def __init__(self, reduction: bool = True): + super(mape_loss, self).__init__() + self.reduction = reduction + + def forward(self, insample: t.Tensor, freq: int, + forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: + """ + MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error + + :param forecast: Forecast values. Shape: batch, time + :param target: Target values. Shape: batch, time + :param mask: 0/1 mask. Shape: batch, time + :return: Loss value + """ + weights = divide_no_nan(mask, target) + if self.reduction: + return t.mean(t.abs((forecast - target) * weights)) + else: + return t.mean(t.abs((forecast - target) * weights), dim=[0, 2], keepdim=True) + + +class smape_loss(nn.Module): + def __init__(self, reduction: bool = True): + super(smape_loss, self).__init__() + self.reduction = reduction + + def forward(self, insample: t.Tensor, freq: int, + forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: + """ + sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) + + :param forecast: Forecast values. Shape: batch, time + :param target: Target values. Shape: batch, time + :param mask: 0/1 mask. Shape: batch, time + :return: Loss value + """ + if self.reduction: + return 200 * t.mean(divide_no_nan(t.abs(forecast - target), + t.abs(forecast.data) + t.abs(target.data)) * mask) + else: + return 200 * t.mean(divide_no_nan(t.abs(forecast - target), + t.abs(forecast.data) + t.abs(target.data)) * mask, + dim=[0, 2], keepdim=True) + + +class mase_loss(nn.Module): + def __init__(self, reduction: bool = True): + super(mase_loss, self).__init__() + self.reduction = reduction + + def forward(self, insample: t.Tensor, freq: int, + forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: + """ + MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf + + :param insample: Insample values. Shape: batch, time_i + :param freq: Frequency value + :param forecast: Forecast values. Shape: batch, time_o + :param target: Target values. Shape: batch, time_o + :param mask: 0/1 mask. Shape: batch, time_o + :return: Loss value + """ + masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) + masked_masep_inv = divide_no_nan(mask, masep[:, None]) + if self.reduction: + return t.mean(t.abs(target - forecast) * masked_masep_inv) + else: + return t.mean(t.abs(target - forecast) * masked_masep_inv, dim=[0, 2], keepdim=True) diff --git a/ts_benchmark/baselines/olinear/utils/masking.py b/ts_benchmark/baselines/olinear/utils/masking.py new file mode 100644 index 00000000..a19cbf63 --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/masking.py @@ -0,0 +1,26 @@ +import torch + + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + + +class ProbMask(): + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask diff --git a/ts_benchmark/baselines/olinear/utils/metrics.py b/ts_benchmark/baselines/olinear/utils/metrics.py new file mode 100644 index 00000000..8e5afb17 --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/metrics.py @@ -0,0 +1,128 @@ +import numpy as np + + +def RSE(pred, true): + return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) + + +def CORR(pred, true): + u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) + d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) + return (u / d).mean(-1) + + +def MAE(pred, true): + return np.mean(np.abs(pred - true)) + + +def MSE(pred, true): + return np.mean((pred - true) ** 2) + + +def RMSE(pred, true): + return np.sqrt(MSE(pred, true)) + + +def MAPE(pred, true): + return np.mean(np.abs((pred - true) / true)) + + +def MSPE(pred, true): + return np.mean(np.square((pred - true) / true)) + + +def calculate_r2_pearson_robust(x, y): + # check shape + # print(f'y.shape: {y.shape}') + if x.shape != y.shape: + raise ValueError("The shape of x and y must be the same!") + + assert x.ndim >= 3, 'Error in x.shape...' + + flattened_x = x.reshape(-1, x.shape[-2], x.shape[-1]) + flattened_y = y.reshape(-1, y.shape[-2], y.shape[-1]) + + # there was a bug here!!!! + mean_x = np.mean(flattened_x, axis=1, keepdims=True) + mean_y = np.mean(flattened_y, axis=1, keepdims=True) + + # correlation coefficient + numerator = np.sum((flattened_x - mean_x) * (flattened_y - mean_y), axis=1) + denominator = np.sqrt(np.sum((flattened_x - mean_x) ** 2, axis=1) * np.sum((flattened_y - mean_y) ** 2, axis=1)) + # avoid potential zero-divide error + denominator[denominator < 1e-5] = np.nan + pearson_correlation = numerator / denominator + + # R2 + total_variance = np.sum((flattened_y - mean_y) ** 2, axis=1) + residuals = flattened_y - flattened_x + residual_sum_of_squares = np.sum(residuals ** 2, axis=1) + # avoid potential zero-divide error + total_variance[total_variance < 1e-5] = np.nan + r2 = 1 - (residual_sum_of_squares / total_variance) + + # mean + mean_pearson = np.nanmean(pearson_correlation) + mean_r2 = np.nanmean(r2) + + return mean_r2, mean_pearson + + +def calculate_mase(y_pred, y_true, y_naive=None): + """ + MASE(Mean Absolute Scaled Error)。 + y_true: (b, l, n) + y_pred: (b, l, n) + y_naive: (b, l, n) + """ + + assert y_pred.ndim >= 3, 'Error in y_pred.shape...' + + if y_naive is None: + y_naive = y_true + + if y_true.shape != y_pred.shape or y_true.shape != y_naive.shape: + raise ValueError("The input shape must be the same.") + + # reshape to [batch, pred_len, channel] + y_true_flat = y_true.reshape(-1, y_true.shape[-2], y_true.shape[-1]) + y_pred_flat = y_pred.reshape(-1, y_pred.shape[-2], y_pred.shape[-1]) + y_naive_flat = y_naive.reshape(-1, y_naive.shape[-2], y_naive.shape[-1]) + + # MAE: [batch, channel] + mae_model = np.mean(np.abs(y_true_flat - y_pred_flat), axis=1) + + # naive MAE [batch, channel] + mae_naive = np.mean(np.abs(y_true_flat[:, 1:, :] - y_naive_flat[:, :-1, :]), axis=1) + + # avoid potential error + mae_naive[mae_naive < 1e-5] = np.nan + + # MASE + mase = np.nanmean(mae_model / mae_naive) + + return mase + + +def metric(pred, true): + if np.any(np.isnan(true)): + mask = ~np.isnan(true) + pred = pred[mask] + true = true[mask] + + mae = MAE(pred, true) + mse = MSE(pred, true) + + # print(f'pred.shape[-2]: {pred.shape[-2]}') + + if pred.ndim > 1 and true.ndim > 1: + r2, pear = calculate_r2_pearson_robust(pred, true) + mase = calculate_mase(pred, true) + else: + r2, pear, mase = 0, 0, 0 + + rmse = RMSE(pred, true) + mape = MAPE(pred, true) + mspe = MSPE(pred, true) + + return mae, mse, rmse, mape, mspe, r2, pear, mase diff --git a/ts_benchmark/baselines/olinear/utils/timefeatures.py b/ts_benchmark/baselines/olinear/utils/timefeatures.py new file mode 100644 index 00000000..7c129729 --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/timefeatures.py @@ -0,0 +1,148 @@ +# From: gluonts/src/gluonts/time_feature/_base.py +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Returns a list of time features that will be appropriate for the given frequency string. + Parameters + ---------- + freq_str + Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + + +def time_features(dates, freq='h'): + return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) diff --git a/ts_benchmark/baselines/olinear/utils/tools.py b/ts_benchmark/baselines/olinear/utils/tools.py new file mode 100644 index 00000000..1e600c0b --- /dev/null +++ b/ts_benchmark/baselines/olinear/utils/tools.py @@ -0,0 +1,892 @@ +import os +import shutil + +# import functorch.dim +import numpy as np +import torch +import matplotlib.pyplot as plt +import pandas as pd +import math +import torch.nn.functional as F + +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +import time +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import random +import seaborn as sns +from einops import rearrange + +plt.switch_backend('agg') + + +def adjust_learning_rate(optimizer, epoch, args, scheduler=None, printout=True): + # lr = args.learning_rate * (0.2 ** (epoch // 2)) + lr_adjust = {} + if args.lradj == 'type1': + lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} + elif args.lradj == 'type2': + lr_adjust = { + 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, + 10: 5e-7, 15: 1e-7, 20: 5e-8 + } + elif args.lradj == 'type3': + lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.9 ** ((epoch - 3) // 1))} + elif args.lradj == 'constant': + lr_adjust = {epoch: args.learning_rate} + elif args.lradj == 'TST': + assert scheduler is not None + lr_adjust = {epoch: scheduler.get_last_lr()[0]} + elif args.lradj in ['cosine', 'card']: + # warmup-cosine + min_lr = 0 + warmup_epochs = 0 + lr = (min_lr + (args.learning_rate - min_lr) * 0.5 * + (1. + math.cos(math.pi * (epoch - warmup_epochs) / (args.train_epochs - warmup_epochs)))) + lr_adjust = {epoch: lr} + + if epoch in lr_adjust.keys(): + lr = lr_adjust[epoch] + for param_group in optimizer.param_groups: + param_group['lr'] = lr + if printout: + print('Updating learning rate to {}'.format(lr)) + + +class EarlyStopping: + def __init__(self, patience=7, verbose=False, delta=0, save_every_epoch=False): + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.save_every_epoch = save_every_epoch + + def __call__(self, val_loss, model, path, epoch=None): + if np.isnan(val_loss): + self.early_stop = True + return + score = -val_loss + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, path, epoch) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, path, epoch) + self.counter = 0 + self.early_stop = False + + def save_checkpoint(self, val_loss, model, path, epoch=None): + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + file_path = os.path.join(path, 'checkpoint.pth') + torch.save(model.state_dict(), file_path) + + # output checkpoint size + file_size = os.path.getsize(file_path) + file_size = convert_size(file_size) + print(f"The size of checkpoint is {file_size}.") + + # delete txt files + delete_txt_files_in_folder(path) + file_path = os.path.join(path, f'Epoch_{epoch}.txt') + # Create the file with the name "epoch_{i}.txt" + with open(file_path, 'w') as file: + file.write(f'Current Epoch: {epoch}') + if self.save_every_epoch: + if epoch: + shutil.copy(os.path.join(path, 'checkpoint.pth'), os.path.join(path, f'checkpoint_epoch_{epoch:d}' + f'_val_loss_{val_loss:.5f}.pth')) + else: + shutil.copy(os.path.join(path, 'checkpoint.pth'), os.path.join(path, f'checkpoint_val_loss_' + f'{val_loss:.5f}.pth')) + self.val_loss_min = val_loss + + +def convert_size(size_bytes): + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if size_bytes < 1024: + return f"{size_bytes:.2f}{unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f}PB" + + +def delete_txt_files_in_folder(path): + # 遍历路径下的所有文件,并删除以 .txt 结尾的文件 + [os.remove(os.path.join(path, f)) for f in os.listdir(path) if f.endswith('.txt')] + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +class StandardScaler(): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def transform(self, data): + return (data - self.mean) / self.std + + def inverse_transform(self, data): + return (data * self.std) + self.mean + + +def write_into_xls(excel_name, mat, columns=None): + file_extension = os.path.splitext(excel_name)[1] + + if file_extension != ".xls" and file_extension != ".xlsx": + raise ValueError('excel_name is not right in write_into_xls') + + folder_name = os.path.dirname(excel_name) + if folder_name: + os.makedirs(folder_name, exist_ok=True) + + if isinstance(mat, np.ndarray) and mat.ndim > 2: + mat = mat.reshape(-1, mat.shape[-1]) + mat = mat[:1000] + if columns is not None: + dataframe = pd.DataFrame(mat, columns=columns) + else: + dataframe = pd.DataFrame(mat) + # print(dataframe) + # print(excel_name) + dataframe.to_excel(excel_name, index=False) + + +def visual(true, preds=None, name='./pic/test.pdf', imp=False): + """ + Results visualization + """ + folder_name = os.path.dirname(name) + if folder_name: + os.makedirs(folder_name, exist_ok=True) + label2 = 'Imputation' if imp else 'Prediction' + + if not isinstance(true, np.ndarray): + true = true.numpy() + if not isinstance(preds, np.ndarray): + preds = preds.numpy() + + plt.figure() + plt.plot(true, label='Ground Truth', linestyle='--', linewidth=2) + if preds is not None: + plt.plot(preds, label=label2, linewidth=2) + plt.legend() + plt.grid(linestyle=':', color='lightgray') + plt.savefig(name, bbox_inches='tight') + + +def adjustment(gt, pred): + anomaly_state = False + for i in range(len(gt)): + if gt[i] == 1 and pred[i] == 1 and not anomaly_state: + anomaly_state = True + for j in range(i, 0, -1): + if gt[j] == 0: + break + else: + if pred[j] == 0: + pred[j] = 1 + for j in range(i, len(gt)): + if gt[j] == 0: + break + else: + if pred[j] == 0: + pred[j] = 1 + elif gt[i] == 0: + anomaly_state = False + if anomaly_state: + pred[i] = 1 + return gt, pred + + +def cal_accuracy(y_pred, y_true): + return np.mean(y_pred == y_true) + + +def find_most_recently_modified_subfolder(base_dir, file_name='checkpoint.pth', contain_str=''): + most_recent_time = 0 + most_recent_folder = None + most_recent_subfolder = None + + if isinstance(contain_str, list): + subdirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and + os.path.isfile(os.path.join(base_dir, d, file_name)) and all([cstr in d for cstr in contain_str])] + else: + subdirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and + os.path.isfile(os.path.join(base_dir, d, file_name)) and contain_str in d] + + # if not subdirs: + # raise ValueError('No such folder found!!! ') + + for subdir in subdirs: + folder_path = os.path.join(base_dir, subdir) + current_time = os.path.getmtime(folder_path) + + if current_time > most_recent_time: + most_recent_time = current_time + most_recent_folder = folder_path + most_recent_subfolder = subdir + + return most_recent_folder, most_recent_subfolder + + +def compare_prefix_before_third_underscore(str1, str2, num=3): + if str1 is None or str2 is None: + return False + prefix1 = ''.join(str1.split("_", num)[:num]) + prefix2 = ''.join(str2.split("_", num)[:num]) + + are_prefixes_equal = prefix1 == prefix2 + + return are_prefixes_equal + + +def compute_gradient_norm(model): + total_norm = 0.0 + for param in model.parameters(): + if param.requires_grad and param.grad is not None: + param_norm = param.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + # elif param.requires_grad and param.grad is None: + # print('\t param.grad is None...') + total_norm = total_norm ** 0.5 + return total_norm + + +def is_not_empty_or_nan(a): + if isinstance(a, list): + if not a: + return False + if any(isinstance(i, (float, np.float32, np.float64)) and np.isnan(i) for i in a): + return False + elif isinstance(a, torch.Tensor): + if a.numel() == 0: + return False + if torch.isnan(a).any(): + return False + else: + if isinstance(a, (float, np.float32, np.float64)) and np.isnan(a): + return False + + return True + + +def compute_uncert(mask, patch_len=16, temp_stride=8, temporal=True, channel_num=7, softmax=0, tau=1.0, tau2=0.5, + Patch_CI=True, eps=1e-5): + # mask: [b,t,n] + # return [b,token_num] + mask = ~mask + assert channel_num == mask.shape[-1] + if temporal: + # [b,t,n] --> [b,token_num,n,patch_len] --> [b,token_num,n] + token_uncer_weight = mask.unfold(dimension=1, size=patch_len, step=temp_stride).sum(dim=-1) + if Patch_CI: + # [b*n, token_num] + token_uncer_weight = token_uncer_weight.sum(dim=-1).repeat_interleave(repeats=channel_num, dim=0) + else: + # [b, token_num] + token_uncer_weight = token_uncer_weight.sum(dim=-1) + else: + token_uncer_weight = mask.sum(dim=1) + # float + token_uncer_weight = token_uncer_weight.to(dtype=torch.float) + + tau = F.softplus(torch.tensor(tau)) if tau <= 0 else tau + tau2 = F.softplus(torch.tensor(tau2)) if tau2 <= 0 else tau2 + + if softmax > 0: + # softmax + token_uncer_weight = F.softmax(token_uncer_weight / tau, dim=-1) + elif softmax == 0: + # pow + token_uncer_weight = F.normalize(token_uncer_weight.pow(tau2), p=1, dim=-1) + else: + # F.normalize + token_uncer_weight = F.normalize(token_uncer_weight, p=1, dim=-1) + + return token_uncer_weight.clamp(min=eps) + + +def hier_half_token_weight(token_weight, ratio=2): + if token_weight is None: + return None + # temp_token_weight_time: [b, token_num] + B, N = token_weight.shape + if N % ratio != 0: + tmp = ratio - N % ratio + token_weight = torch.cat([token_weight, token_weight[:, -tmp:]], dim=-1) + token_weight = token_weight.reshape(B, -1, ratio).sum(dim=-1) + return token_weight + + +def cosine_distance(tensor1, tensor2, keepdims=False): + assert tensor1.shape == tensor2.shape, "Both tensors must have the same shape in cosine_distance" + # F.cosine_similarity + cosine_sim = F.cosine_similarity(tensor1, tensor2, dim=-1) + # 1 - cosine_sim + cosine_dist = 1 - cosine_sim + + if keepdims: + return cosine_dist.unsqueeze(-1) + else: + return cosine_dist + + +def euclidean_distance(tensor1, tensor2, keepdims=False): + assert tensor1.shape == tensor2.shape, "Both tensors must have the same shape in euclidean_distance" + diff = tensor1 - tensor2 + squared_diff = diff ** 2 + euclidean_dist = torch.sqrt(squared_diff.sum(-1)) + if keepdims: + return euclidean_dist.unsqueeze(-1) + else: + return euclidean_dist + + +def get_eval_feat(layer, tensor): + # tensor: [b, l, n] + # feat: [b, d_model] + + # [b,l,n] --> [n,b,d_model] --> [b,d_model] + feat = layer(tensor.permute(2, 0, 1)).sum(dim=0) + return feat + + +def undo_unfold(inp, length, stride, fft_flag=False): + # [b,n,stride,period] --> [b,l,n] + B, N, num, period = inp.shape + if fft_flag: + assert num == length // stride, (f'num:{num}, length:{length}, stride:{stride}. inp.shape: {inp.shape}. ' + f'Please check the inputs of undo_unfold().') + else: + assert num == (length - period) // stride + 1, 'Please check the inputs of undo_unfold().' + + if stride == period or fft_flag: + reconstructed = inp.flatten(start_dim=2) + return reconstructed.transpose(-1, -2) + + reconstructed = torch.zeros(B, N, length, device=inp.device) + count_overlap = torch.zeros_like(reconstructed) + + for i in range(num): + start = i * stride + end = start + period + reconstructed[:, :, start:end] += inp[:, :, i, :] + count_overlap[:, :, start:end] += 1 + + # average + mask = count_overlap > 0 + reconstructed[mask] /= count_overlap[mask] + + return reconstructed.transpose(-1, -2) + + +def send_email(subject='Python Notification', body='Program complete!', to_email=r'mail@mail.com', + from_email=r'mail@mail.com', password='xxxxxxxxxx', mail_host='xxxx.com', + mail_port=465): + # Create the message + + message = MIMEMultipart() + message['From'] = from_email + message['To'] = to_email + message['Subject'] = subject + message.attach(MIMEText(body, 'plain', 'utf-8')) # utf-8 for compatibility + + try: + # Connect to the SMTP server using SSL (port 465) + with smtplib.SMTP_SSL(mail_host, mail_port) as server: + # Login and send the email + server.login(from_email, password) + server.send_message(message) + print("Email sent successfully!") + except Exception as e: + print(f"Failed to send email: {e}") + + +def create_sub_diagonal_matrix(n, value=1, offset=0): + if abs(offset) >= n: + return None + vec = torch.ones(n - abs(offset)) * value + return torch.diag(vec, diagonal=offset) + + +def plot_mat(mat, str_cat='series_2D', str0='tmp', save_folder='./results'): + if not isinstance(mat, np.ndarray): + mat = mat.detach().cpu().numpy() + if not os.path.exists(save_folder): + os.makedirs(save_folder, exist_ok=True) + + # fig, axs = plt.subplots(1, 1) + # plt.imshow(mat, cmap='viridis', interpolation='nearest', vmin=0.0, vmax=1.0) # viridis hot + # plt.colorbar() + + plt.figure(figsize=(8, 8)) + sns.heatmap(mat, annot=False, cmap='coolwarm', square=True, cbar=True) + plt.xticks([]) # 去除x轴刻度 + plt.yticks([]) # 去除y轴刻度 + timestamp = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) + plt.savefig(os.path.join(save_folder, f'{str_cat}_{str0}-{timestamp}.pdf')) + plt.show() + # save to excel + excel_name = os.path.join(save_folder, f'{str_cat}_{str0}-{timestamp}.xlsx') + write_into_xls(excel_name, mat) + # save to npy + np.save(os.path.join(save_folder, f'{str_cat}_{str0}-{timestamp}.npy'), mat) + + +def create_sin_pos_embed(max_len, d_model): + pe = torch.zeros(max_len, d_model).float() + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + + # [1, max_len, d_model] + return pe + + +def var2tuple2(x, num=2): + num = int(num) + if isinstance(x, tuple): + if len(x) == num: + return x + elif len(x) > num: + return x[:num] + else: + return x + (x[-1],) * (num - len(x)) + return (x,) * num + + +def create_swin_relative_index(window_size): + # check + window_size = var2tuple2(window_size) + assert all(i > 0 for i in window_size) + + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +def get_relative_coords_table(window_size, h_times=torch.tensor(2.0), ws_scalar=torch.tensor(5.0), + ws_scalar2=None, pow_para=None, pow_mode=False): + assert isinstance(window_size, tuple) and len(window_size) == 2 and ws_scalar > 0, f"ws_scalar:{ws_scalar}" + # get relative_coords_table + relative_coords_h = torch.arange(-(window_size[0] - 1), window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(window_size[1] - 1), window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if torch.is_tensor(h_times): + relative_coords_table = relative_coords_table.to(h_times.device) + + if window_size[0] > 1: + relative_coords_table[:, :, :, 0] /= (window_size[0] - 1) + if window_size[1] > 1: + relative_coords_table[:, :, :, 1] /= (window_size[1] - 1) + + if pow_mode: + relative_coords_table *= ws_scalar + relative_coords_table = torch.sign(relative_coords_table) * torch.exp( + torch.abs(relative_coords_table) - 1) / (torch.exp(ws_scalar) - 1) + + if torch.any(torch.isnan(relative_coords_table)): + print('\t relative_coords_table is nan. please check...', ws_scalar) + print("\t relative_coords_table.shape: ", relative_coords_table.shape) + else: + # if ws_scalar2 is None: + relative_coords_table *= ws_scalar # normalize to -ws_scalar, ws_scalar + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1) / torch.log2(ws_scalar) # log2 log1p + # else: + # tmp = relative_coords_table[:, :, :, 1].clone() + # tmp = tmp * ws_scalar + # tmp2 = relative_coords_table[:, :, :, 0].clone() + # tmp2 = tmp2 * ws_scalar2 + # + # relative_coords_table[:, :, :, 1] = torch.sign(tmp) * torch.log2( + # torch.abs(tmp) + 1.0) / torch.log2(ws_scalar) + # relative_coords_table[:, :, :, 0] = torch.sign(tmp2) * torch.log2( + # torch.abs(tmp2) + 1.0) / torch.log2(ws_scalar2) + + relative_coords_table[:, :, :, 0] *= h_times + return relative_coords_table + + +def create_swin_relative_index_1d(window_size, period=None): + window_size = to_2tuple(window_size) + period = period or window_size[1] + if torch.is_tensor(period): + rel_pos = (torch.arange(window_size[1]).view(1, -1).to(period.device) + + torch.arange(window_size[0]).view(-1, 1).to(period.device) * period).view(1, -1) + else: + rel_pos = (torch.arange(window_size[1]).view(1, -1) + + torch.arange(window_size[0]).view(-1, 1) * period).view(1, -1) + rel_pos = rel_pos - rel_pos.transpose(-1, -2) + return rel_pos + + +def norm_rel_pos_1d(relative_position_index, ws_scalar): + if not torch.is_tensor(ws_scalar): + ws_scalar = torch.tensor(ws_scalar) + # one way + relative_position_index_norm = (relative_position_index / torch.max(relative_position_index).item() + * ws_scalar) # normalize to -8, 8 + relative_position_index_norm = torch.sign(relative_position_index_norm) * torch.log2( + torch.abs(relative_position_index_norm) + 1.0) / torch.log2(ws_scalar) + + # more simple + # relative_position_index_norm = (torch.sign(relative_position_index) * torch.log1p( + # torch.abs(relative_position_index) / torch.max(relative_position_index).item() * ws_scalar) + # / torch.log1p(ws_scalar)) + return relative_position_index_norm + + +def compute_weights(alpha, length, stages=None, multiple_flag=True): + assert alpha <= 0 + if alpha == 0: + weights = torch.ones(length) + return weights + stage_num = 1 + rem = 0 + if stages is not None: + # assert (length + 1) % stages == 0 or length % stages == 0 + stage_num = (length + 1) // stages + rem = length + 1 - stage_num * stages + + weights = torch.tensor([i ** alpha for i in range(length + 1, 0, -1)]) + # weights2 = torch.tensor([i ** (alpha / 2) for i in range(length + 1, 0, -1)]) + + # iTransformer + if multiple_flag and stages is not None: + # on SDA now + slices = list(range(stage_num - 1, length, stage_num)) + if rem > 0: + slices = [a + i + 1 if i < rem else a + rem for i, a in enumerate(slices)] + weights[slices] = torch.minimum(weights[slices] * 1.5, weights[-2]) + # weights[slices] = weights2[slices] + + # remove the first element + weights = weights[:length] + + return weights + + +def roll_without_cycle(x, shifts, dims): + if not isinstance(shifts, (tuple, list)): + shifts = (shifts,) + if not isinstance(dims, (tuple, list)): + dims = (dims,) + + assert len(shifts) == len(dims), "shifts and dims must have the same length" + + shifted_x = torch.roll(x, shifts, dims) + + for shift, dim in zip(shifts, dims): + zeros_slices = [slice(None)] * x.ndim + if shift == 0: + continue + if shift > 0: + zeros_slices[dim] = slice(0, shift) + else: + zeros_slices[dim] = slice(shift, None) + + shifted_x[tuple(zeros_slices)] = 0 + + return shifted_x + + +def forward_fill(x, mask): + b, l, n = x.size() + # x = x.clone() + mask = mask.clone() + + padding_positions = (mask == 1).nonzero(as_tuple=True) + + for batch_index, length_index, feature_index in zip(*padding_positions): + # search backwards + for prev_length_index in range(length_index - 1, -1, -1): + if mask[batch_index, prev_length_index, feature_index] == 0: + x[batch_index, length_index, feature_index] = x[batch_index, prev_length_index, feature_index] + mask[batch_index, length_index, feature_index] = 0 + break + + padding_positions = (mask == 1).nonzero(as_tuple=True) + + for batch_index, length_index, feature_index in zip(*padding_positions): + # search forwards + for prev_length_index in range(length_index + 1, l, 1): + if mask[batch_index, prev_length_index, feature_index] == 0: + x[batch_index, length_index, feature_index] = x[batch_index, prev_length_index, feature_index] + mask[batch_index, length_index, feature_index] = 0 + break + + return x, mask + + +def closest_divisor(a, b): + # a=10,b=3; --> return 2 + if a % b == 0: + return b + + left = b - 1 + right = b + 1 + + while left > 0: + if a % left == 0: + return left + left -= 1 + + while right <= a: + if a % right == 0: + return right + right += 1 + + return None + + +def adapt_win(seq_len, period): + H, W = math.ceil(seq_len / period), period + scalar = 5 + max_hw = 7 + if H <= W: + w = min(W // 2, max_hw) + h = min(scalar ** 2 // w, H, w) + else: + h = min(H // 2, max_hw) + w = min(scalar ** 2 // h, W, h) + return h, w + + +def cross_correlation_fft(x, tau=None, circular_shift=False, first_row_shift=0): + # input: b,l,h,w; output: [h, 2tau+1] + b, l, m, n = x.shape + if m == 1: + return x, [first_row_shift, ] + + fft_size = n if circular_shift else 2 * n - 1 + + tau = tau or fft_size // 2 + tau = min(tau, fft_size // 2) + + if first_row_shift != 0: + x[:, :, 0, :] = torch.roll(x[:, :, 0, :], shifts=first_row_shift, dims=-1) + + if not circular_shift: + x = F.pad(x, (0, fft_size - n)) + + x_fft = torch.fft.fft(x, dim=-1) + + first_row_fft = x_fft[:, :, 0, :].unsqueeze(2) + + cross_corr_fft = x_fft * torch.conj(first_row_fft) + + cross_corr = torch.fft.ifft(cross_corr_fft, dim=-1).real + + n_middle = (n - 1) // 2 if circular_shift else n - 1 + cross_corr = torch.roll(cross_corr, shifts=n_middle, dims=-1) + + cross_corr = cross_corr[:, :, :, max(n_middle - tau, 0):n_middle + tau + 1] # / n + + # mean; [m, 2tau+1] + cross_corr = cross_corr.flatten(0, 1).mean(0) + + # max delay + delay_vec = min(tau, n_middle) - cross_corr.max(-1)[1] + delay_vec[0] = first_row_shift + + return cross_corr, delay_vec + + +def cyclic_shift_per_row(x, vec): + b, h, w, c = x.shape + assert len(vec) == h, f"len(vec){len(vec)} should be equal to h{h}..." + + for i in range(h): + shift_amount = vec[i].item() if torch.is_tensor(vec) else vec[i] + x[:, i, :, :] = torch.roll(x[:, i, :, :], shifts=shift_amount, dims=1) + + return x + + +def find_period_multiple_k_ori(x, k=1): + """ + Find the period of the signal x, where the period is a multiple of k. + x is expected to be of shape (B,T,C) where C is the number of channels and T is the length of each channel. + """ + B, T, C = x.shape + + len_fft = math.ceil(T / k) * k + # Compute the FFT of the input + X = torch.fft.rfft(x, n=len_fft, dim=1) + frequency_list = abs(X).mean(0).mean(-1) + frequency_list[0:2] = 0 # period cannot be 1 + top_fre = torch.argmax(frequency_list) + top_fre = top_fre.detach().cpu().numpy() + + max_period = len_fft // int(top_fre) + + max_period = round(max_period / k) * k if max_period > k else max_period + + max_period = min(max(max_period, 2), T // 2) + + return int(max_period) + + +def find_period_multiple_k(x, k=1, harmonic=False): + """ + Find the period of the signal x, where the period is a multiple of k. + x is expected to be of shape (B,T,C) where C is the number of channels and T is the length of each channel. + """ + B, T, C = x.shape + + len_fft = math.ceil(T / k) * k + # Compute the FFT of the input + X = torch.fft.rfft(x, n=len_fft, dim=1) + frequency_list = abs(X).mean(0).mean(-1) + frequency_list[0:2] = 0 # period cannot be 1 + + if not harmonic: + top_fre = torch.argmax(frequency_list) + max_period = len_fft // int(top_fre) + max_period = round(max_period / k) * k if max_period > k else max_period + max_period = min(max(max_period, 2), T // 2) + + harm_period = 5 + else: + # _, top_list = torch.topk(frequency_list, k=2) + # top_list = top_list.detach().cpu().numpy() + # top_list = list(top_list) + # + # top_fre, sub_fre = min(top_list), max(top_list) + # + # max_period = len_fft // int(top_fre) + # max_period = round(max_period / k) * k if max_period > k else max_period + # max_period = int(min(max(max_period, 2), T // 2)) + # + # harm_period = len_fft // int(sub_fre) + # harm_period = int(min(max(harm_period, 1), max_period)) + + # another implementation + top_fre = torch.argmax(frequency_list) + max_period = len_fft // int(top_fre) + # max_period = round(max_period / k) * k if max_period > k else max_period + max_period = int(min(max(max_period, 2), T // 2)) + + sub_fre_list = frequency_list[top_fre + 1:] + sub_fre = torch.argmax(sub_fre_list) + top_fre + 1 + + if frequency_list[sub_fre] > frequency_list[top_fre] * 0.95: + harm_period = len_fft // int(sub_fre) + harm_period = int(min(max(harm_period, 1), max_period)) + else: + harm_period = 5 + + return max_period, harm_period + + +def compute_harm_fre(x, period, win_size=5): + """ + compute harmonic frequency under period + x is expected to be of shape (B,L,N) + """ + B, L, N = x.shape + rem = L % period + if rem != 0: + x = F.pad(x, pad=[0, 0, 0, period - rem]) + L = x.shape[1] + + # [_, period] + x = x.reshape(B, L // period, period, N).transpose(-1, -2).flatten(0, -2) + + x_fre = torch.fft.rfft(x, dim=-1).abs().mean(0) + x_fre[0:period // win_size] = 0 + + top_fre = torch.argmax(x_fre) + + harm_period = period // int(top_fre) + + return harm_period + + +def create_block_missing(input_size, mask_rate=0.1, block_length=(3, 10), device='cpu'): + # (batch, len, channels) + batch, length, channels = input_size + + mask = torch.ones(input_size, dtype=torch.float32, device=device) + + total_elements = batch * length * channels + total_mask_elements = int(total_elements * mask_rate) + masked_elements = 0 + + while masked_elements < total_mask_elements: + b = random.randint(0, batch - 1) + c = random.randint(0, channels - 1) + block_len = random.randint(block_length[0], block_length[1]) + + if masked_elements + block_len > total_mask_elements: + block_len = total_mask_elements - masked_elements + + start = random.randint(0, length - block_len) + + mask[b, start:start + block_len, c] = 0 + + masked_elements += block_len + + return mask + + +def apply_difference(data, n=1): + """ + Apply differencing to the data. + :param data: Input data [batch, length, channel] + :param n: Order of differencing + :return: Differenced data and the last original data point for each series + """ + for i in range(n): + data[..., 1:, :] = data[..., 1:, :] - data[..., :-1, :] + return data + + +def moore_penrose_iter_pinv(x, iters=6): + device = x.device + + abs_x = torch.abs(x) + col = abs_x.sum(dim=-1) + row = abs_x.sum(dim=-2) + z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) + + I = torch.eye(x.shape[-1], device=device) + I = rearrange(I, 'i j -> () i j') + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) + + return z diff --git a/ts_benchmark/hpo/__init__.py b/ts_benchmark/hpo/__init__.py new file mode 100644 index 00000000..e13ffff2 --- /dev/null +++ b/ts_benchmark/hpo/__init__.py @@ -0,0 +1,2 @@ +from .optuna_search import run_optuna_search +from .search_space import sample_params \ No newline at end of file diff --git a/ts_benchmark/hpo/optuna_search.py b/ts_benchmark/hpo/optuna_search.py new file mode 100644 index 00000000..913c0a85 --- /dev/null +++ b/ts_benchmark/hpo/optuna_search.py @@ -0,0 +1,61 @@ +import json +import os +from typing import List +import optuna +from ts_benchmark.pipeline import pipeline +from .search_space import sample_params + + +def evaluate_params(params, config_data, data_name_list, model_name, strategy_args): + # 在此处进行评估 + data_config = config_data["data_config"] + model_config = config_data["model_config"] + evaluation_config = config_data["evaluation_config"] + + model_config["models"] = [{ + "adapter": None, + "model_name": model_name, + "model_hyper_params": params + }] + + # 进行模型训练并评估 + try: + log_files = pipeline(data_config, model_config, evaluation_config) + except Exception as e: + print(f"Error: {e}") + return float("inf") + + # 假设我们从日志文件中提取损失值(这里我们假设返回0.0作为占位符) + return 0.0 # 在这里替换为实际计算的验证损失值 + + +def run_optuna_search(config_path: str, data_name_list: List[str], model_name: str, save_path: str, n_trials: int = 10, + seed: int = None): + # 加载配置文件 + with open(config_path, "r") as f: + config_data = json.load(f) + + study = optuna.create_study(direction="minimize", study_name="hyperparameter_optimization") + study.optimize( + lambda trial: evaluate_params(sample_params(model_name, trial), config_data, data_name_list, model_name, {}), + n_trials=n_trials) + + # 保存最优超参数 + best_params = study.best_params + best_value = study.best_value + + best_params_json = { + "model_name": model_name, + "series_name": data_name_list[0], # 假设只有一个数据集 + "objective": "val_loss", # 假设优化目标是 val_loss + "best_value": best_value, + "best_params": best_params + } + + if not os.path.exists(save_path): + os.makedirs(save_path) + + with open(os.path.join(save_path, f"{model_name}_{data_name_list[0]}_best_params.json"), "w") as f: + json.dump(best_params_json, f, indent=2) + + return best_params_json \ No newline at end of file diff --git a/ts_benchmark/hpo/search_space.py b/ts_benchmark/hpo/search_space.py new file mode 100644 index 00000000..7280a473 --- /dev/null +++ b/ts_benchmark/hpo/search_space.py @@ -0,0 +1,13 @@ +import optuna + +def sample_params(model_name: str, trial: optuna.trial.Trial) -> dict: + if model_name == 'olinear.OLinear': + return { + "lr": trial.suggest_loguniform("lr", 1e-5, 1e-2), + "dropout": trial.suggest_float("dropout", 0.0, 0.5), + "d_model": trial.suggest_categorical("d_model", [64, 128, 256, 512]), + "d_ff": trial.suggest_categorical("d_ff", [128, 256, 512, 1024]), + "e_layers": trial.suggest_int("e_layers", 1, 3) + } + else: + raise NotImplementedError(f"Model {model_name} not implemented in search space.") \ No newline at end of file