Skip to content

Commit 5bb4589

Browse files
committed
feat: add ColumnsSelectorTransformer
1 parent e8fb007 commit 5bb4589

4 files changed

Lines changed: 101 additions & 11 deletions

File tree

pysatl_cpd/benchmark/core/benchmark_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from typing import Any
88

9-
__author__ = "PySATL contributors"
9+
__author__ = "Danil Totmyanin"
1010
__copyright__ = "Copyright (c) 2026 PySATL project"
1111
__license__ = "SPDX-License-Identifier: MIT"
1212

pysatl_cpd/core/algorithm_entry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
Container for benchmark algorithm execution entries.
44
"""
55

6-
__author__ = "PySATL contributors"
6+
__author__ = "Danil Totmyanin"
77
__copyright__ = "Copyright (c) 2026 PySATL project"
88
__license__ = "SPDX-License-Identifier: MIT"
99

1010
from collections.abc import Sequence
1111
from dataclasses import dataclass
12+
from typing import Any
1213

1314
from pysatl_cpd.core.data_transformers.idata_transformer import IDataTransformer
1415
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm, OnlineAlgorithmConfiguration, OnlineAlgorithmState
@@ -36,7 +37,7 @@ class AlgorithmEntry[DataT, ConfigT: OnlineAlgorithmConfiguration, StateT: Onlin
3637

3738
algorithm: OnlineAlgorithm[DataT, ConfigT, StateT]
3839
thresholds: Sequence[float]
39-
transformer: IDataTransformer | None = None
40+
transformer: IDataTransformer[Any, Any] | None = None
4041

4142
@property
4243
def full_name(self) -> str:
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Columns Selector Transformer Implementation.
5+
6+
This module provides a transformer that allows selecting specific columns
7+
from multivariate time series data.
8+
"""
9+
10+
__author__ = "Danil Totmyanin"
11+
__copyright__ = "Copyright (c) 2026 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
14+
import numpy as np
15+
16+
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
17+
from pysatl_cpd.core.data_providers.numpy_data_provider import (
18+
NDArrayMultivariateProvider,
19+
NDArrayUnivariateProvider,
20+
)
21+
from pysatl_cpd.core.data_transformers.idata_transformer import IDataTransformer
22+
23+
24+
class ColumnsSelectorTransformer(IDataTransformer[np.ndarray, np.ndarray | float]):
25+
"""
26+
Transformer for selecting specific columns from multivariate data.
27+
28+
If a single integer index is provided, it transforms multivariate data
29+
into univariate data. If a list of indices is provided, it returns
30+
multivariate data containing only the specified columns.
31+
32+
Parameters
33+
----------
34+
columns : list[int] or int
35+
Indices of columns to select from the input multivariate array.
36+
"""
37+
38+
def __init__(self, columns: list[int] | int) -> None:
39+
self.cols = columns
40+
41+
@property
42+
def name(self) -> str:
43+
"""
44+
Return a unique name including selected column indices.
45+
46+
Returns
47+
-------
48+
str
49+
Formatted name like 'Col_0' or 'Cols_0_2_3'.
50+
"""
51+
if isinstance(self.cols, int):
52+
return f"Col_{self.cols}"
53+
cols_str = "_".join(map(str, self.cols))
54+
return f"Cols_{cols_str}"
55+
56+
def transform(self, provider: DataProvider[np.ndarray]) -> DataProvider[np.ndarray | float]:
57+
"""
58+
Extract selected columns and wrap into a new NumPy data provider.
59+
60+
Parameters
61+
----------
62+
provider : DataProvider[np.ndarray]
63+
Multivariate data provider yielding 1-D NumPy arrays.
64+
65+
Returns
66+
-------
67+
DataProvider[Any]
68+
NDArrayUnivariateProvider if `columns` is int,
69+
NDArrayMultivariateProvider if `columns` is list[int].
70+
71+
Raises
72+
------
73+
ValueError
74+
If the data provided by the source is not 2-dimensional.
75+
"""
76+
raw_nd_data = np.array(list(provider))
77+
78+
if raw_nd_data.ndim < 2:
79+
raise ValueError(
80+
f"ColumnsSelectorTransformer expects 2D data, "
81+
f"got {raw_nd_data.ndim}D data from provider '{provider.name}'."
82+
)
83+
84+
cols_data = raw_nd_data[:, self.cols]
85+
86+
new_provider_name = f"{provider.name}_{self.name}"
87+
88+
if isinstance(self.cols, int):
89+
return NDArrayUnivariateProvider(data=cols_data, name=new_provider_name)
90+
91+
return NDArrayMultivariateProvider(data=cols_data, name=new_provider_name)

pysatl_cpd/core/data_transformers/idata_transformer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
feeding data into change-point detection algorithms.
88
"""
99

10-
__author__ = "PySATL contributors"
10+
__author__ = "Danil Totmyanin"
1111
__copyright__ = "Copyright (c) 2026 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

1414
from abc import ABC, abstractmethod
15-
from typing import Any
1615

1716
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
1817

1918

20-
class IDataTransformer(ABC):
19+
class IDataTransformer[DataInT, DataOutT](ABC):
2120
"""
2221
Abstract base class for data transformers.
2322
@@ -27,24 +26,23 @@ class IDataTransformer(ABC):
2726
"""
2827

2928
@abstractmethod
30-
def transform(self, provider: DataProvider[Any]) -> DataProvider[Any]:
29+
def transform(self, provider: DataProvider[DataInT]) -> DataProvider[DataOutT]:
3130
"""
3231
Apply transformation to the given data provider.
3332
3433
Parameters
3534
----------
36-
provider : DataProvider[Any]
35+
provider : DataProvider[DataInT]
3736
The source data provider.
3837
3938
Returns
4039
-------
41-
DataProvider[Any]
40+
DataProvider[DataOut]
4241
A new data provider yielding transformed observations.
4342
"""
4443
raise NotImplementedError
4544

4645
@property
47-
@abstractmethod
4846
def name(self) -> str:
4947
"""
5048
Return the human-readable name of the transformer.
@@ -54,7 +52,7 @@ def name(self) -> str:
5452
str
5553
Transformer identifier used for logging and caching.
5654
"""
57-
raise NotImplementedError
55+
return type(self).__name__
5856

5957
def __hash__(self) -> int:
6058
"""

0 commit comments

Comments
 (0)