Skip to content

Commit fec0b86

Browse files
committed
tests: data transformers
1 parent a8ee9f7 commit fec0b86

3 files changed

Lines changed: 182 additions & 0 deletions

File tree

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Dummy data transformer implementation for testing.
5+
6+
This module provides a minimal concrete implementation of IDataTransformer
7+
used strictly to test the default behaviors of the abstract base class.
8+
"""
9+
10+
__author__ = "Danil Totmyanin"
11+
__copyright__ = "Copyright (c) 2026 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
14+
from typing import Any
15+
16+
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
17+
from pysatl_cpd.core.data_transformers.idata_transformer import IDataTransformer
18+
19+
20+
class DummyTransformer(IDataTransformer[Any, Any]):
21+
"""
22+
Minimal concrete implementation of IDataTransformer.
23+
24+
Used for testing the default behaviors of the abstract base class,
25+
such as the default `name` and `__hash__` properties.
26+
"""
27+
28+
def transform(self, provider: DataProvider[Any]) -> DataProvider[Any]:
29+
"""
30+
Dummy implementation that just returns the input provider.
31+
32+
Parameters
33+
----------
34+
provider : DataProvider[Any]
35+
The source data provider.
36+
37+
Returns
38+
-------
39+
DataProvider[Any]
40+
The unmodified input provider.
41+
"""
42+
return provider
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Tests for Data Transformers.
5+
6+
Covers IDataTransformer base class properties and ColumnsSelectorTransformer logic.
7+
"""
8+
9+
__author__ = "Danil Totmyanin"
10+
__copyright__ = "Copyright (c) 2026 PySATL project"
11+
__license__ = "SPDX-License-Identifier: MIT"
12+
13+
import numpy as np
14+
import pytest
15+
16+
from pysatl_cpd.core.data_providers.numpy_data_provider import (
17+
NDArrayMultivariateProvider,
18+
NDArrayUnivariateProvider,
19+
)
20+
from pysatl_cpd.core.data_transformers.columns_selector_transformer import (
21+
ColumnsSelectorTransformer,
22+
)
23+
24+
25+
class TestColumnsSelectorTransformer:
26+
"""Tests for ColumnsSelectorTransformer logic and naming."""
27+
28+
def test_name_single_column(self) -> None:
29+
"""Transformer name should be formatted as 'Col_X' for a single int."""
30+
transformer = ColumnsSelectorTransformer(columns=2)
31+
assert transformer.name == "Col_2"
32+
33+
def test_name_multiple_columns(self) -> None:
34+
"""Transformer name should be formatted as 'Cols_X_Y' for a list of ints."""
35+
transformer = ColumnsSelectorTransformer(columns=[0, 2, 3])
36+
assert transformer.name == "Cols_0_2_3"
37+
38+
def test_transform_int_to_univariate(self) -> None:
39+
"""Selecting a single int column should yield a Univariate provider."""
40+
data: np.ndarray = np.array(
41+
[
42+
[1.0, 2.0, 3.0],
43+
[4.0, 5.0, 6.0],
44+
[7.0, 8.0, 9.0],
45+
]
46+
)
47+
provider = NDArrayMultivariateProvider(data=data, name="test_data")
48+
transformer = ColumnsSelectorTransformer(columns=1)
49+
50+
result_provider = transformer.transform(provider)
51+
52+
# Check type and name
53+
assert isinstance(result_provider, NDArrayUnivariateProvider)
54+
assert result_provider.name == "test_data_Col_1"
55+
56+
# Check extracted data (column index 1 -> [2.0, 5.0, 8.0])
57+
result_data: list[float] = list(result_provider)
58+
np.testing.assert_array_equal(result_data, [2.0, 5.0, 8.0])
59+
60+
def test_transform_list_to_multivariate(self) -> None:
61+
"""Selecting a list of columns should yield a Multivariate provider."""
62+
data: np.ndarray = np.array(
63+
[
64+
[1.0, 2.0, 3.0, 4.0],
65+
[5.0, 6.0, 7.0, 8.0],
66+
]
67+
)
68+
provider = NDArrayMultivariateProvider(data=data, name="multidataset")
69+
transformer = ColumnsSelectorTransformer(columns=[0, 3])
70+
71+
result_provider = transformer.transform(provider)
72+
73+
# Check type and name
74+
assert isinstance(result_provider, NDArrayMultivariateProvider)
75+
assert result_provider.name == "multidataset_Cols_0_3"
76+
77+
# Check extracted data (columns 0 and 3)
78+
result_data: list[np.ndarray] = list(result_provider)
79+
expected_data: list[np.ndarray] = [
80+
np.array([1.0, 4.0]),
81+
np.array([5.0, 8.0]),
82+
]
83+
84+
assert len(result_data) == 2
85+
np.testing.assert_array_equal(result_data[0], expected_data[0])
86+
np.testing.assert_array_equal(result_data[1], expected_data[1])
87+
88+
def test_transform_raises_value_error_on_1d_data(self) -> None:
89+
"""Attempting to select columns from 1D data should raise ValueError."""
90+
data: np.ndarray = np.array([1.0, 2.0, 3.0])
91+
provider = NDArrayUnivariateProvider(data=data, name="1d_data")
92+
transformer = ColumnsSelectorTransformer(columns=0)
93+
94+
expected_msg = "ColumnsSelectorTransformer expects 2D data, got 1D data from provider '1d_data'."
95+
with pytest.raises(ValueError, match=expected_msg):
96+
transformer.transform(provider) # type: ignore[arg-type]
97+
98+
def test_transform_raises_index_error_on_out_of_bounds(self) -> None:
99+
"""Passing an out-of-bounds column index should propagate an IndexError from NumPy."""
100+
data: np.ndarray = np.array(
101+
[
102+
[1.0, 2.0],
103+
[3.0, 4.0],
104+
]
105+
)
106+
provider = NDArrayMultivariateProvider(data=data, name="data")
107+
108+
# Array only has columns 0 and 1, index 5 is out of bounds
109+
transformer = ColumnsSelectorTransformer(columns=5)
110+
111+
with pytest.raises(IndexError):
112+
transformer.transform(provider)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Tests for Data Transformers.
5+
6+
Covers IDataTransformer base class properties and ColumnsSelectorTransformer logic.
7+
"""
8+
9+
__author__ = "Danil Totmyanin"
10+
__copyright__ = "Copyright (c) 2026 PySATL project"
11+
__license__ = "SPDX-License-Identifier: MIT"
12+
13+
14+
from tests.mocks.core.data_transformers.simple import DummyTransformer
15+
16+
17+
class TestIDataTransformer:
18+
"""Tests for the abstract IDataTransformer base class default behaviors."""
19+
20+
def test_default_name_is_class_name(self) -> None:
21+
"""The default name property should return the class name."""
22+
transformer = DummyTransformer()
23+
assert transformer.name == "DummyTransformer"
24+
25+
def test_default_hash_is_hash_of_name(self) -> None:
26+
"""The default hash should be equal to the hash of the transformer's name."""
27+
transformer = DummyTransformer()
28+
assert hash(transformer) == hash("DummyTransformer")

0 commit comments

Comments
 (0)