diff --git a/pypots/anomaly_detection/__init__.py b/pypots/anomaly_detection/__init__.py index 1e058a29..6376d1cc 100644 --- a/pypots/anomaly_detection/__init__.py +++ b/pypots/anomaly_detection/__init__.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause from .autoformer import Autoformer +from .dcdetector import DCdetector from .dlinear import DLinear from .imputeformer import ImputeFormer from .patchtst import PatchTST @@ -29,6 +30,7 @@ __all__ = [ "Autoformer", + "DCdetector", "SAITS", "TEFN", "ImputeFormer", diff --git a/pypots/anomaly_detection/dcdetector/__init__.py b/pypots/anomaly_detection/dcdetector/__init__.py new file mode 100644 index 00000000..9720acae --- /dev/null +++ b/pypots/anomaly_detection/dcdetector/__init__.py @@ -0,0 +1,19 @@ +""" +The implementation of DCdetector for the partially-observed time-series anomaly detection task. + +Refer to the paper +`Yiyuan Yang, Chaoli Zhang, Tian Zhou, Qingsong Wen, and Liang Sun. +"DCdetector: Dual Attention Contrastive Representation Learning for Time Series Anomaly Detection". +In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 2023. +`_ + +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +from .model import DCdetector + +__all__ = [ + "DCdetector", +] diff --git a/pypots/anomaly_detection/dcdetector/core.py b/pypots/anomaly_detection/dcdetector/core.py new file mode 100644 index 00000000..1765cb51 --- /dev/null +++ b/pypots/anomaly_detection/dcdetector/core.py @@ -0,0 +1,200 @@ +""" +The core model of DCdetector for the anomaly detection task. + +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...nn.modules import ModelCore +from ...nn.modules.dcdetector import BackboneDCdetector +from ...nn.modules.loss import Criterion + + +def _kl_loss(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """Asymmetric (unnormalized) KL divergence. + + Parameters + ---------- + p : torch.Tensor, shape [B, H, L, L] + q : torch.Tensor, shape [B, H, L, L] + + Returns + ------- + torch.Tensor, shape [B, L] + Mean KL divergence per batch sample and time step. + + """ + res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001)) + return torch.mean(torch.sum(res, dim=-1), dim=1) + + +def _normalize_prior(prior: torch.Tensor, n_steps: int) -> torch.Tensor: + """Normalise a prior attention map to sum to 1 along the last dimension. + + Parameters + ---------- + prior : torch.Tensor, shape [B, H, L, L] + Raw (unnormalized) prior attention map. + + n_steps : int + Length of the time-step dimension (L). + + Returns + ------- + torch.Tensor, shape [B, H, L, L] + Row-normalized prior map. + + """ + return prior / torch.unsqueeze(torch.sum(prior, dim=-1), dim=-1).repeat(1, 1, 1, n_steps) + + +class _DCdetector(ModelCore): + """The core PyTorch model of DCdetector. + + This module wraps :class:`BackboneDCdetector` and adds the minimax + contrastive loss used during training/validation. + + The training objective is a minimax game between two attention views: + *series* (patch-wise, inter-patch attention) and *prior* (in-patch, + intra-patch attention). The loss ``prior_loss - series_loss`` is + minimised, which encourages the two views to maximally disagree on + anomalous patterns. + + Parameters + ---------- + n_steps : + Number of time steps in each input window. + + n_features : + Number of input features. + + patch_sizes : + List of patch sizes for multi-scale patching. + + d_model : + Model embedding dimension. + + n_heads : + Number of attention heads. + + e_layers : + Number of encoder layers. + + dropout : + Dropout rate. + + training_loss : + Loss criterion (used only for its ``lower_better`` attribute; the + actual training loss is the DCdetector contrastive loss). + + validation_metric : + Validation metric (same remark as ``training_loss``). + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + patch_sizes: list, + d_model: int, + n_heads: int, + e_layers: int, + dropout: float, + training_loss: Criterion, + validation_metric: Criterion, + ): + super().__init__() + + self.n_steps = n_steps + self.patch_sizes = patch_sizes + + self.training_loss = training_loss + if validation_metric.__class__.__name__ == "Criterion": + # in this case, we need validation_metric.lower_better in _train_model() so only pass Criterion() + # we use training_loss as validation_metric for concrete calculation process + self.validation_metric = self.training_loss + else: + self.validation_metric = validation_metric + + self.backbone = BackboneDCdetector( + n_steps=n_steps, + n_features=n_features, + patch_sizes=patch_sizes, + d_model=d_model, + n_heads=n_heads, + e_layers=e_layers, + dropout=dropout, + ) + + def forward( + self, + inputs: dict, + calc_criterion: bool = False, + ) -> dict: + """Forward pass. + + Parameters + ---------- + inputs : dict + Must contain key ``"X"`` with shape ``[B, L, M]``. + + calc_criterion : bool + If True, the contrastive loss is added to the returned dict as + ``"loss"`` (training mode) or ``"metric"`` (evaluation mode). + + Returns + ------- + dict with keys: + + - ``"series"`` – list of patch-wise attention tensors, + each of shape ``[B, H, L, L]``. + - ``"prior"`` – list of in-patch attention tensors, + each of shape ``[B, H, L, L]``. + - ``"loss"`` / ``"metric"`` – scalar contrastive loss + (only when ``calc_criterion=True``). + + """ + X = inputs["X"] + + series, prior = self.backbone(X) + + results = { + "series": series, + "prior": prior, + } + + if calc_criterion: + series_loss = 0.0 + prior_loss = 0.0 + + for u in range(len(prior)): + # Normalise prior so it sums to 1 along the last dimension + prior_norm = _normalize_prior(prior[u], self.n_steps) + + # Symmetric KL between series and normalised prior + series_loss += torch.mean( + _kl_loss(series[u], prior_norm.detach()) + ) + torch.mean(_kl_loss(prior_norm.detach(), series[u])) + + # Symmetric KL in the opposite direction (minimax partner) + prior_loss += torch.mean( + _kl_loss(prior_norm, series[u].detach()) + ) + torch.mean(_kl_loss(series[u].detach(), prior_norm)) + + series_loss = series_loss / len(prior) + prior_loss = prior_loss / len(prior) + + # Minimax training objective: minimise prior_loss - series_loss + loss = prior_loss - series_loss + + if self.training: + results["loss"] = loss + else: + results["metric"] = loss + + return results diff --git a/pypots/anomaly_detection/dcdetector/model.py b/pypots/anomaly_detection/dcdetector/model.py new file mode 100644 index 00000000..61b6205c --- /dev/null +++ b/pypots/anomaly_detection/dcdetector/model.py @@ -0,0 +1,344 @@ +""" +The implementation of DCdetector for the partially-observed time-series anomaly detection task. + +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from ..base import BaseNNDetector +from ...data.checking import key_in_data_set +from ...data.dataset.base import BaseDataset +from ...imputation.saits.data import DatasetForSAITS +from ...nn.functional import autocast +from ...nn.modules.loss import Criterion, MAE, MSE +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger +from .core import _DCdetector, _kl_loss, _normalize_prior + + +class DCdetector(BaseNNDetector): + """The PyTorch implementation of the DCdetector model :cite:`yang2023dcdetector` + for the anomaly detection task. + + DCdetector learns dual attention representations (patch-wise and in-patch) + via a minimax contrastive objective. Anomaly scores are derived from the + KL divergence between the two attention views: time steps where the two + views disagree most are flagged as anomalous. + + Parameters + ---------- + n_steps : int + The number of time steps in the time-series data sample. + Must be divisible by every value in ``patch_sizes``. + + n_features : int + The number of features in the time-series data sample. + + anomaly_rate : float + The estimated anomaly rate in the dataset, within the range (0, 1). + Used for thresholding. + + patch_sizes : list of int + Patch sizes for multi-scale patching (e.g. ``[3, 5, 7]``). + Each value must divide ``n_steps`` evenly. + + d_model : int + Dimension of the model embeddings. + + n_heads : int + Number of attention heads. + + e_layers : int + Number of encoder layers. + + dropout : float, optional + Dropout rate. Default is 0. + + batch_size : int, optional + Number of samples per training batch. Default is 32. + + epochs : int, optional + Maximum number of training epochs. Default is 100. + + patience : int or None, optional + Early-stopping patience. Disabled if None. Default is None. + + training_loss : Criterion or type, optional + Loss criterion. Used for its ``lower_better`` attribute; the actual + training loss is the DCdetector contrastive loss. Defaults to MAE. + + validation_metric : Criterion or type, optional + Validation metric. Same remark as ``training_loss``. Defaults to MSE. + + optimizer : Optimizer or type, optional + Optimizer for training. Defaults to Adam. + + num_workers : int, optional + Number of DataLoader worker processes. Default is 0. + + device : str, torch.device, or list, optional + Device(s) for model training and inference. + + saving_path : str, optional + Directory to save model checkpoints. No saving if None. + + model_saving_strategy : str or None, optional + Checkpoint saving strategy: one of ``{None, "best", "better", "all"}``. + + verbose : bool, optional + Whether to print training progress. Default is True. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + anomaly_rate: float, + patch_sizes: list, + d_model: int, + n_heads: int, + e_layers: int, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, + training_loss: Union[Criterion, type] = MAE, + validation_metric: Union[Criterion, type] = MSE, + optimizer: Union[Optimizer, type] = Adam, + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + anomaly_rate=anomaly_rate, + training_loss=training_loss, + validation_metric=validation_metric, + batch_size=batch_size, + epochs=epochs, + patience=patience, + num_workers=num_workers, + device=device, + saving_path=saving_path, + model_saving_strategy=model_saving_strategy, + verbose=verbose, + ) + + # Validate that n_steps is divisible by each patch size + for ps in patch_sizes: + assert n_steps % ps == 0, ( + f"n_steps ({n_steps}) must be divisible by each patch_size, " + f"but {ps} does not divide {n_steps} evenly." + ) + + self.n_steps = n_steps + self.n_features = n_features + self.patch_sizes = patch_sizes + self.d_model = d_model + self.n_heads = n_heads + self.e_layers = e_layers + self.dropout = dropout + + self.model = _DCdetector( + n_steps=n_steps, + n_features=n_features, + patch_sizes=patch_sizes, + d_model=d_model, + n_heads=n_heads, + e_layers=e_layers, + dropout=dropout, + training_loss=self.training_loss, + validation_metric=self.validation_metric, + ) + + self._send_model_to_given_device() + self._print_model_size() + + if isinstance(optimizer, Optimizer): + self.optimizer = optimizer + else: + self.optimizer = optimizer() + assert isinstance(self.optimizer, Optimizer) + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + """Prepare a training batch.""" + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + return { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + def _assemble_input_for_validating(self, data: list) -> dict: + """Prepare a validation batch (same as training).""" + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + """Prepare an inference batch.""" + indices, X, missing_mask = self._send_data_to_given_device(data) + + return { + "X": X, + "missing_mask": missing_mask, + } + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + """Train the model. + + Parameters + ---------- + train_set : dict or str + Training dataset. + + val_set : dict or str, optional + Validation dataset. Must contain ``"X_ori"``. + + file_type : str, optional + File type for lazy-loading. Default is ``"hdf5"``. + + """ + self.train_set = train_set + + train_dataset = DatasetForSAITS( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + train_dataloader = DataLoader( + train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + val_dataloader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_dataset = DatasetForSAITS( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + self._train_model(train_dataloader, val_dataloader) + self.model.load_state_dict(self.best_model_dict) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + @torch.no_grad() + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + **kwargs, + ) -> dict: + """Detect anomalies in the test set. + + Anomaly scores are computed as the per-time-step softmax of the + combined KL divergence between the patch-wise and in-patch attention + maps, scaled by a temperature of 50 (as in the original paper). + The threshold is determined from the training set distribution using + ``anomaly_rate``. + + Parameters + ---------- + test_set : dict or str + Test dataset. + + file_type : str, optional + File type for lazy-loading. Default is ``"hdf5"``. + + Returns + ------- + dict + Contains key ``"anomaly_detection"`` with a 1-D binary array of + length ``n_test_samples * n_steps``. + + """ + self.model.eval() + temperature = 50 + + def _build_dataloader(dataset_arg): + ds = BaseDataset( + dataset_arg, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + return DataLoader( + ds, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def _score_dataloader(dataloader): + """Return per-time-step anomaly scores, shape [N, L].""" + score_collector = [] + for data in dataloader: + inputs = self._assemble_input_for_testing(data) + with autocast(enabled=self.amp_enabled): + results = self.model(inputs) + + series = results["series"] + prior = results["prior"] + + series_loss = None + prior_loss = None + + for u in range(len(prior)): + prior_norm = _normalize_prior(prior[u], self.n_steps) + + kl_sp = _kl_loss(series[u], prior_norm.detach()) * temperature + kl_ps = _kl_loss(prior_norm, series[u].detach()) * temperature + + if series_loss is None: + series_loss = kl_sp + prior_loss = kl_ps + else: + series_loss = series_loss + kl_sp + prior_loss = prior_loss + kl_ps + + # metric: [B, L] — softmax over the time-step dimension + metric = torch.softmax((-series_loss - prior_loss), dim=-1) + score_collector.append(metric.detach().cpu().numpy()) + + return np.concatenate(score_collector, axis=0) # [N, L] + + train_scores = _score_dataloader(_build_dataloader(self.train_set)).reshape(-1) + test_scores = _score_dataloader(_build_dataloader(test_set)).reshape(-1) + + combined = np.concatenate([train_scores, test_scores], axis=0) + threshold = np.percentile(combined, 100 - self.anomaly_rate * 100) + logger.info(f"Threshold: {threshold}") + + anomaly_pred = (test_scores > threshold).astype(int) + + return {"anomaly_detection": anomaly_pred} diff --git a/pypots/nn/modules/dcdetector/__init__.py b/pypots/nn/modules/dcdetector/__init__.py new file mode 100644 index 00000000..9f0522ff --- /dev/null +++ b/pypots/nn/modules/dcdetector/__init__.py @@ -0,0 +1,23 @@ +""" +The package including the modules of DCdetector. + +Refer to the paper +`Yiyuan Yang, Chaoli Zhang, Tian Zhou, Qingsong Wen, and Liang Sun. +"DCdetector: Dual Attention Contrastive Representation Learning for Time Series Anomaly Detection". +In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 2023. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/DAMO-DI-ML/KDD2023-DCdetector + +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +from .layers import BackboneDCdetector + +__all__ = [ + "BackboneDCdetector", +] diff --git a/pypots/nn/modules/dcdetector/layers.py b/pypots/nn/modules/dcdetector/layers.py new file mode 100644 index 00000000..febc9acb --- /dev/null +++ b/pypots/nn/modules/dcdetector/layers.py @@ -0,0 +1,523 @@ +""" +PyTorch modules/layers for DCdetector. + +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +import math + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat + +from ..revin import RevIN + + +class DCdetectorTokenEmbedding(nn.Module): + """Token embedding for DCdetector using a 1D circular convolution.""" + + def __init__(self, c_in: int, d_model: int): + super().__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + bias=False, + ) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, L, C] -> permute -> [B, C, L] -> conv -> [B, d_model, L] -> transpose -> [B, L, d_model] + return self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + + +class DCdetectorPositionalEncoding(nn.Module): + """Fixed sinusoidal positional encoding for DCdetector.""" + + def __init__(self, d_model: int, max_len: int = 5000): + super().__init__() + pe = torch.zeros(max_len, d_model).float() + pe.requires_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pe[:, : x.size(1)] + + +class DCdetectorDataEmbedding(nn.Module): + """Data embedding for DCdetector: token embedding + positional encoding + dropout.""" + + def __init__(self, c_in: int, d_model: int, dropout: float = 0.05): + super().__init__() + self.value_embedding = DCdetectorTokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = DCdetectorPositionalEncoding(d_model=d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x) + + +class DACStructure(nn.Module): + """Dual Attention Contrastive (DAC) structure. + + Computes two attention views for contrastive representation learning: + + - **Patch-wise (series)**: inter-patch attention that captures long-range dependencies + across patch positions. + - **In-patch (prior)**: intra-patch attention that captures local dependencies + within each patch. + + Both views are upsampled back to the full window size and averaged over the channel + dimension to produce per-time-step attention matrices of shape ``[B, H, L, L]``. + + Parameters + ---------- + win_size : + The full window size (equal to n_steps). + + patch_sizes : + List of patch sizes used for multi-scale patching. + + n_features : + Number of input features (channels). + + mask_flag : + Whether to apply a causal mask. Default is False. + + scale : + Scaling factor for attention scores. If None, uses ``1/sqrt(d_k)``. + + attention_dropout : + Dropout rate applied to attention weights. + + output_attention : + Whether to return attention maps. + + """ + + def __init__( + self, + win_size: int, + patch_sizes: list, + n_features: int, + mask_flag: bool = False, + scale=None, + attention_dropout: float = 0.05, + output_attention: bool = True, + ): + super().__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.window_size = win_size + self.patch_sizes = patch_sizes + self.n_features = n_features + + def forward( + self, + queries_patch_size: torch.Tensor, + queries_patch_num: torch.Tensor, + keys_patch_size: torch.Tensor, + keys_patch_num: torch.Tensor, + values: torch.Tensor, + patch_index: int, + attn_mask=None, + ): + """ + Parameters + ---------- + queries_patch_size : torch.Tensor, shape [B*C, patch_num, H, d_k] + Queries for the patch-wise attention. + + queries_patch_num : torch.Tensor, shape [B*C, patch_size, H, d_k] + Queries for the in-patch attention. + + keys_patch_size : torch.Tensor, shape [B*C, patch_num, H, d_k] + Keys for the patch-wise attention. + + keys_patch_num : torch.Tensor, shape [B*C, patch_size, H, d_k] + Keys for the in-patch attention. + + values : torch.Tensor, shape [B, L, H, d_v] + Values (not used directly, kept for interface consistency). + + patch_index : int + Index into self.patch_sizes for the current scale. + + attn_mask : + Optional attention mask (not currently applied). + + Returns + ------- + series_patch_size : torch.Tensor, shape [B, H, L, L] + Patch-wise attention map upsampled to the full window size. + + series_patch_num : torch.Tensor, shape [B, H, L, L] + In-patch attention map upsampled to the full window size. + + """ + patchsize = self.patch_sizes[patch_index] + patch_num = self.window_size // patchsize + + # ---- Patch-wise representation (inter-patch attention) ---- + # queries_patch_size: [B*C, patch_num, H, d_k] + B, L, H, E = queries_patch_size.shape + scale_patch_size = self.scale or 1.0 / math.sqrt(E) + scores_patch_size = torch.einsum("blhe,bshe->bhls", queries_patch_size, keys_patch_size) + attn_patch_size = scale_patch_size * scores_patch_size + # series_patch_size: [B*C, H, patch_num, patch_num] + series_patch_size = self.dropout(torch.softmax(attn_patch_size, dim=-1)) + + # ---- In-patch representation (intra-patch attention) ---- + # queries_patch_num: [B*C, patch_size, H, d_k] + B, L, H, E = queries_patch_num.shape + scale_patch_num = self.scale or 1.0 / math.sqrt(E) + scores_patch_num = torch.einsum("blhe,bshe->bhls", queries_patch_num, keys_patch_num) + attn_patch_num = scale_patch_num * scores_patch_num + # series_patch_num: [B*C, H, patch_size, patch_size] + series_patch_num = self.dropout(torch.softmax(attn_patch_num, dim=-1)) + + # ---- Upsample both maps to full window size [B*C, H, L, L] ---- + # Repeat each attention value patchsize times in both spatial dims + series_patch_size = repeat( + series_patch_size, + "b l m n -> b l (m repeat_m) (n repeat_n)", + repeat_m=patchsize, + repeat_n=patchsize, + ) + # Tile the in-patch attention patch_num times in both spatial dims + series_patch_num = series_patch_num.repeat(1, 1, patch_num, patch_num) + + # ---- Average over the channel dimension ---- + # [B*C, H, L, L] -> [B, H, L, L] + series_patch_size = reduce( + series_patch_size, "(b reduce_b) l m n -> b l m n", "mean", reduce_b=self.n_features + ) + series_patch_num = reduce( + series_patch_num, "(b reduce_b) l m n -> b l m n", "mean", reduce_b=self.n_features + ) + + return series_patch_size, series_patch_num + + +class DCdetectorAttentionLayer(nn.Module): + """Attention layer wrapping the :class:`DACStructure`. + + Projects input embeddings to queries/keys/values and delegates to the + inner dual-attention module. + + Parameters + ---------- + attention : + The inner :class:`DACStructure` module. + + d_model : + Model dimension. + + patch_sizes : + List of patch sizes used for multi-scale patching. + + n_features : + Number of input features (channels). + + n_heads : + Number of attention heads. + + win_size : + Full window size (equal to n_steps). + + d_keys : + Dimension of keys/queries per head. Defaults to ``d_model // n_heads``. + + d_values : + Dimension of values per head. Defaults to ``d_model // n_heads``. + + """ + + def __init__( + self, + attention: DACStructure, + d_model: int, + patch_sizes: list, + n_features: int, + n_heads: int, + win_size: int, + d_keys: int = None, + d_values: int = None, + ): + super().__init__() + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.norm = nn.LayerNorm(d_model) + self.inner_attention = attention + self.patch_sizes = patch_sizes + self.n_features = n_features + self.window_size = win_size + self.n_heads = n_heads + + self.patch_query_projection = nn.Linear(d_model, d_keys * n_heads) + self.patch_key_projection = nn.Linear(d_model, d_keys * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + + def forward( + self, + x_patch_size: torch.Tensor, + x_patch_num: torch.Tensor, + x_ori: torch.Tensor, + patch_index: int, + attn_mask=None, + ): + """ + Parameters + ---------- + x_patch_size : torch.Tensor, shape [B*C, patch_num, d_model] + Patch-size-based embedding of the input. + + x_patch_num : torch.Tensor, shape [B*C, patch_size, d_model] + Patch-num-based embedding of the input. + + x_ori : torch.Tensor, shape [B, L, d_model] + Window-level embedding of the full input. + + patch_index : int + Index into patch_sizes for the current scale. + + attn_mask : + Optional attention mask. + + Returns + ------- + series : torch.Tensor, shape [B, H, L, L] + prior : torch.Tensor, shape [B, H, L, L] + + """ + H = self.n_heads + + # ---- Patch-size branch ---- + B, L, _ = x_patch_size.shape + queries_patch_size = self.patch_query_projection(x_patch_size).view(B, L, H, -1) + keys_patch_size = self.patch_key_projection(x_patch_size).view(B, L, H, -1) + + # ---- Patch-num branch ---- + B, L, _ = x_patch_num.shape + queries_patch_num = self.patch_query_projection(x_patch_num).view(B, L, H, -1) + keys_patch_num = self.patch_key_projection(x_patch_num).view(B, L, H, -1) + + # ---- Values from window-level embedding ---- + B, L, _ = x_ori.shape + values = self.value_projection(x_ori).view(B, L, H, -1) + + series, prior = self.inner_attention( + queries_patch_size, + queries_patch_num, + keys_patch_size, + keys_patch_num, + values, + patch_index, + attn_mask, + ) + return series, prior + + +class DCdetectorEncoder(nn.Module): + """Stack of :class:`DCdetectorAttentionLayer` modules. + + Parameters + ---------- + attn_layers : + List of attention layers to stack. + + norm_layer : + Optional normalization layer applied after all attention layers. + + """ + + def __init__(self, attn_layers: list, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + + def forward( + self, + x_patch_size: torch.Tensor, + x_patch_num: torch.Tensor, + x_ori: torch.Tensor, + patch_index: int, + attn_mask=None, + ): + """ + Returns + ------- + series_list : list of torch.Tensor + One attention map per encoder layer, each of shape [B, H, L, L]. + + prior_list : list of torch.Tensor + One attention map per encoder layer, each of shape [B, H, L, L]. + + """ + series_list = [] + prior_list = [] + for attn_layer in self.attn_layers: + series, prior = attn_layer(x_patch_size, x_patch_num, x_ori, patch_index, attn_mask) + series_list.append(series) + prior_list.append(prior) + return series_list, prior_list + + +class BackboneDCdetector(nn.Module): + """Backbone of the DCdetector model. + + Implements the multi-scale dual-attention contrastive architecture from + :cite:`yang2023dcdetector`. For each patch size, the input is split into + two complementary patch views (patch-wise and in-patch), embedded, and then + passed through a shared encoder that returns the dual attention maps. + RevIN normalization is applied to the input before patching. + + Parameters + ---------- + n_steps : + The number of time steps (window size). Must be divisible by every + element of ``patch_sizes``. + + n_features : + The number of input features (channels). + + patch_sizes : + List of patch sizes for multi-scale patching (e.g. ``[3, 5, 7]``). + Each value must divide ``n_steps`` evenly. + + d_model : + Dimension of the model embeddings. + + n_heads : + Number of attention heads. + + e_layers : + Number of encoder layers. + + dropout : + Dropout rate applied in embeddings and attention. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + patch_sizes: list, + d_model: int, + n_heads: int, + e_layers: int, + dropout: float, + ): + super().__init__() + self.patch_sizes = patch_sizes + self.n_features = n_features + self.n_steps = n_steps + + # Reversible Instance Normalization + self.revin = RevIN(n_features) + + # Per-patch-size embeddings (two views per scale) + self.embedding_patch_size = nn.ModuleList() + self.embedding_patch_num = nn.ModuleList() + for patchsize in patch_sizes: + self.embedding_patch_size.append( + DCdetectorDataEmbedding(patchsize, d_model, dropout) + ) + self.embedding_patch_num.append( + DCdetectorDataEmbedding(n_steps // patchsize, d_model, dropout) + ) + + # Window-level embedding for the value branch + self.embedding_window_size = DCdetectorDataEmbedding(n_features, d_model, dropout) + + # Shared encoder + self.encoder = DCdetectorEncoder( + [ + DCdetectorAttentionLayer( + DACStructure( + n_steps, + patch_sizes, + n_features, + mask_flag=False, + attention_dropout=dropout, + output_attention=True, + ), + d_model, + patch_sizes, + n_features, + n_heads, + n_steps, + ) + for _ in range(e_layers) + ], + norm_layer=nn.LayerNorm(d_model), + ) + + def forward(self, x: torch.Tensor): + """ + Parameters + ---------- + x : torch.Tensor, shape [B, L, M] + Input time-series data (B=batch, L=n_steps, M=n_features). + + Returns + ------- + series_list : list of torch.Tensor + Flattened list of patch-wise attention maps (one per scale × layer). + Each tensor has shape ``[B, H, L, L]``. + + prior_list : list of torch.Tensor + Flattened list of in-patch attention maps (one per scale × layer). + Each tensor has shape ``[B, H, L, L]``. + + """ + # RevIN normalization + x = self.revin(x, mode="norm") + + # Window-level embedding (used as values in the attention layer) + x_ori = self.embedding_window_size(x) + + series_patch_mean = [] + prior_patch_mean = [] + + for patch_index, patchsize in enumerate(self.patch_sizes): + # ---- Patch-size view: [B, L, M] -> [B, M, L] -> [B*M, L//p, p] ---- + x_patch_size = rearrange(x, "b l m -> b m l") + x_patch_size = rearrange(x_patch_size, "b m (n p) -> (b m) n p", p=patchsize) + x_patch_size = self.embedding_patch_size[patch_index](x_patch_size) + + # ---- Patch-num view: [B, L, M] -> [B, M, L] -> [B*M, p, L//p] ---- + x_patch_num = rearrange(x, "b l m -> b m l") + x_patch_num = rearrange(x_patch_num, "b m (p n) -> (b m) p n", p=patchsize) + x_patch_num = self.embedding_patch_num[patch_index](x_patch_num) + + series, prior = self.encoder(x_patch_size, x_patch_num, x_ori, patch_index) + series_patch_mean.append(series) + prior_patch_mean.append(prior) + + # Flatten the nested lists (each element was a list of e_layers tensors) + series_list = [item for sublist in series_patch_mean for item in sublist] + prior_list = [item for sublist in prior_patch_mean for item in sublist] + + return series_list, prior_list diff --git a/tests/anomaly_detection/dcdetector.py b/tests/anomaly_detection/dcdetector.py new file mode 100644 index 00000000..51750325 --- /dev/null +++ b/tests/anomaly_detection/dcdetector.py @@ -0,0 +1,127 @@ +""" +Test cases for DCdetector anomaly detection model. +""" + +# Created by Yiyuan Yang +# License: BSD-3-Clause + +import os.path +import unittest + +import pytest + +from pypots.anomaly_detection import DCdetector +from pypots.nn.functional import calc_acc, calc_precision_recall_f1 +from pypots.optim import Adam +from pypots.utils.logging import logger +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_ANOMALY_DETECTION, + check_tb_and_model_checkpoints_existence, +) + + +class TestDCdetector(unittest.TestCase): + logger.info("Running tests for an anomaly detection model DCdetector...") + + # Define where to save logs and models + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_ANOMALY_DETECTION, "DCdetector") + model_save_name = "saved_dcdetector_model.pypots" + + # Instantiate a custom Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # Initialize the DCdetector anomaly detection model. + # n_steps=8 (N_STEPS + N_PRED_STEPS = 6+2) must be divisible by each patch size. + dcdetector = DCdetector( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + anomaly_rate=DATA["anomaly_rate"], + patch_sizes=[2, 4], + d_model=16, + n_heads=2, + e_layers=2, + dropout=0.1, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="anomaly-detection-dcdetector") + def test_0_fit(self): + """Test training the model on in-memory data.""" + self.dcdetector.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="anomaly-detection-dcdetector") + def test_1_detect(self): + """Test anomaly detection and evaluate accuracy, precision, recall, and F1.""" + anomaly_detection_results = self.dcdetector.predict(TEST_SET) + anomaly_labels = TEST_SET["anomaly_y"].flatten() + + accuracy = calc_acc( + anomaly_detection_results["anomaly_detection"], + anomaly_labels, + ) + precision, recall, f1 = calc_precision_recall_f1( + anomaly_detection_results["anomaly_detection"], + anomaly_labels, + ) + logger.info( + f"DCdetector Accuracy: {accuracy}, F1: {f1}, Precision: {precision}, Recall: {recall}" + ) + + @pytest.mark.xdist_group(name="anomaly-detection-dcdetector") + def test_2_parameters(self): + """Check key parameters are initialized correctly after training.""" + assert hasattr(self.dcdetector, "model") and self.dcdetector.model is not None + assert hasattr(self.dcdetector, "optimizer") and self.dcdetector.optimizer is not None + assert hasattr(self.dcdetector, "best_loss") + self.assertNotEqual(self.dcdetector.best_loss, float("inf")) + assert ( + hasattr(self.dcdetector, "best_model_dict") + and self.dcdetector.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="anomaly-detection-dcdetector") + def test_3_saving_path(self): + """Test model saving and loading functionality.""" + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" + check_tb_and_model_checkpoints_existence(self.dcdetector) + + # Save model to disk and test loading + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.dcdetector.save(saved_model_path) + self.dcdetector.load(saved_model_path) + + @pytest.mark.xdist_group(name="anomaly-detection-dcdetector") + def test_4_lazy_loading(self): + """Test training and prediction with lazy loading from HDF5 files.""" + self.dcdetector.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + anomaly_detection_results = self.dcdetector.predict(GENERAL_H5_TEST_SET_PATH) + anomaly_labels = TEST_SET["anomaly_y"].flatten() + + accuracy = calc_acc( + anomaly_detection_results["anomaly_detection"], + anomaly_labels, + ) + precision, recall, f1 = calc_precision_recall_f1( + anomaly_detection_results["anomaly_detection"], + anomaly_labels, + ) + logger.info( + f"Lazy-loading DCdetector Accuracy: {accuracy}, F1: {f1}, " + f"Precision: {precision}, Recall: {recall}" + ) + + +if __name__ == "__main__": + unittest.main()