Skip to content

Commit a8ee9f7

Browse files
committed
tests: data transformers logic in BenchmarkExecutor
1 parent cb51df2 commit a8ee9f7

2 files changed

Lines changed: 226 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Mock data transformer implementations for testing.
5+
6+
This module provides mock implementations of IDataTransformer used for testing
7+
the transformation pipeline in benchmark execution and algorithm evaluation.
8+
"""
9+
10+
__author__ = "Danil Totmyanin"
11+
__copyright__ = "Copyright (c) 2026 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
14+
from typing import Any
15+
16+
from pysatl_cpd.analysis.labeled_data import LabeledData
17+
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
18+
from pysatl_cpd.core.data_transformers.idata_transformer import IDataTransformer
19+
20+
21+
class MockDataTransformer(IDataTransformer[float, float]):
22+
"""
23+
Mock data transformer for testing benchmark execution.
24+
25+
This transformer adds a specified constant value to every observation
26+
in the dataset and keeps track of how many times the `transform` method
27+
was applied. It wraps the transformed data back into a `LabeledData` instance.
28+
29+
Parameters
30+
----------
31+
name : str, default="MockTransform"
32+
The string identifier for the transformer.
33+
add_value : float, default=1.0
34+
The numeric value to add to each observation.
35+
"""
36+
37+
def __init__(self, name: str = "MockTransform", add_value: float = 1.0) -> None:
38+
self._name = name
39+
self.add_value = add_value
40+
self.call_count = 0
41+
42+
@property
43+
def name(self) -> str:
44+
"""
45+
Return the name of the mock transformer.
46+
47+
Returns
48+
-------
49+
str
50+
The identifier of this transformer instance.
51+
"""
52+
return self._name
53+
54+
def __hash__(self) -> int:
55+
"""
56+
Return a hash based on the transformer's properties.
57+
58+
Used to uniquely identify the pipeline configuration in the cache.
59+
60+
Returns
61+
-------
62+
int
63+
Hash value representing the transformer configuration.
64+
"""
65+
return hash((self._name, self.add_value))
66+
67+
def transform(self, provider: DataProvider[float]) -> DataProvider[float]:
68+
"""
69+
Transform the data by adding a constant value to each element.
70+
71+
Parameters
72+
----------
73+
provider : DataProvider[float]
74+
The original data provider.
75+
76+
Returns
77+
-------
78+
DataProvider[float]
79+
A new `LabeledData` instance containing the transformed values.
80+
"""
81+
self.call_count += 1
82+
83+
# Transform data
84+
new_data: list[float] = [float(x) + self.add_value for x in provider]
85+
86+
# Preserve change points if the provider has them
87+
change_points: Any = getattr(provider, "change_points", getattr(provider, "change_point", []))
88+
89+
return LabeledData(raw_data=new_data, change_points=change_points, name=f"{provider.name}_{self.name}")

tests/unit/benchmark/core/test_benchmark_executor.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
2626
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
2727
from tests.mocks.algorithms.online.simple import MockOnlineAlgorithm
28+
from tests.mocks.core.data_transformers.data_transformer import MockDataTransformer
2829

2930

3031
def _make_provider(
@@ -519,3 +520,139 @@ def test_multiple_thresholds_create_separate_pickle_files(self, tmp_path: Path)
519520
with open(registry_path, encoding="utf-8") as f:
520521
rows: list[dict[str, str]] = list(csv.DictReader(f))
521522
assert len(rows) == 3
523+
524+
525+
# ---------------------------------------------------------------------------
526+
# 6. Data Transformers
527+
# ---------------------------------------------------------------------------
528+
class TestBenchmarkExecutorTransformers:
529+
"""Tests for the DataTransformer integration in BenchmarkExecutor."""
530+
531+
def test_transformer_modifies_data_passed_to_algorithm(self) -> None:
532+
"""Executor should pass transformed data, not raw data, to the solver."""
533+
algo = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
534+
transformer = MockDataTransformer(name="T1", add_value=5.0)
535+
entry = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=transformer)
536+
537+
# Original provider with zeros
538+
provider: LabeledData[float] = LabeledData(raw_data=[0.0, 0.0, 0.0], change_points=[], name="data")
539+
solver: OnlineCpdSolver = OnlineCpdSolver()
540+
541+
executor: BenchmarkExecutor[float] = BenchmarkExecutor(
542+
entries=[entry],
543+
providers=[provider],
544+
solver=solver,
545+
)
546+
executor.execute()
547+
548+
# The algorithm should have received [5.0, 5.0, 5.0]
549+
history: list[float] = algo.get_call_history()
550+
assert history == [5.0, 5.0, 5.0]
551+
552+
def test_record_metadata_uses_transformer_name_and_hash(self) -> None:
553+
"""Benchmark record should inherit the full name and hash from the Entry."""
554+
algo = MockOnlineAlgorithm[float](name="BaseAlgo", return_sequence=[0.0])
555+
transformer = MockDataTransformer(name="MyTF", add_value=1.0)
556+
entry = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=transformer)
557+
558+
provider: LabeledData[float] = _make_provider(3, name="d1")
559+
solver: OnlineCpdSolver = OnlineCpdSolver()
560+
561+
executor: BenchmarkExecutor[float] = BenchmarkExecutor(
562+
entries=[entry],
563+
providers=[provider],
564+
solver=solver,
565+
)
566+
results = executor.execute()
567+
record: BenchmarkRecord = results[0][0]
568+
569+
# Name should be combined
570+
assert record.algorithm == "BaseAlgo_MyTF"
571+
assert record.algorithm == entry.full_name
572+
573+
# Hash should match the entry's composite hash
574+
assert record.configuration_hash == entry.full_hash
575+
576+
def test_caching_separates_different_transformers(self, tmp_path: Path) -> None:
577+
"""Using the same algorithm but different transformers should create separate cache records."""
578+
algo = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
579+
580+
entry_clean = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=None)
581+
entry_transformed = AlgorithmEntry(
582+
algorithm=algo, thresholds=[1.0], transformer=MockDataTransformer(name="T1", add_value=2.0)
583+
)
584+
585+
provider: LabeledData[float] = _make_provider(3, name="data")
586+
solver: OnlineCpdSolver = OnlineCpdSolver()
587+
588+
executor: BenchmarkExecutor[float] = BenchmarkExecutor(
589+
entries=[entry_clean, entry_transformed],
590+
providers=[provider],
591+
solver=solver,
592+
dump_dir=tmp_path,
593+
)
594+
executor.execute()
595+
596+
# Should produce two distinct pickle files
597+
pkl_files: list[Path] = list(tmp_path.glob("*.pkl"))
598+
assert len(pkl_files) == 2
599+
600+
# Names of the files should reflect the different algorithm representations
601+
file_names: str = " ".join(f.name for f in pkl_files)
602+
assert "A_" in file_names
603+
assert "A_T1_" in file_names
604+
605+
def test_transformer_is_called_even_on_cache_hit(self, tmp_path: Path) -> None:
606+
"""Transformer should be applied before checking cache, incrementing its call count."""
607+
transformer = MockDataTransformer(name="T1", add_value=1.0)
608+
provider: LabeledData[float] = _make_provider(3, name="data")
609+
solver: OnlineCpdSolver = OnlineCpdSolver()
610+
611+
# First run to populate cache
612+
algo1 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
613+
entry1 = AlgorithmEntry(algorithm=algo1, thresholds=[1.0], transformer=transformer)
614+
exec1: BenchmarkExecutor[float] = BenchmarkExecutor([entry1], [provider], solver, tmp_path)
615+
exec1.execute()
616+
617+
assert transformer.call_count == 1
618+
assert len(algo1.get_call_history()) == 3
619+
620+
# Second run with cache hit (using a fresh algorithm instance to verify it doesn't run)
621+
algo2 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
622+
entry2 = AlgorithmEntry(algorithm=algo2, thresholds=[1.0], transformer=transformer)
623+
exec2: BenchmarkExecutor[float] = BenchmarkExecutor([entry2], [provider], solver, tmp_path)
624+
exec2.execute()
625+
626+
# Transformer is still called during iteration
627+
assert transformer.call_count == 2
628+
629+
# But the solver/algorithm was skipped due to cache hit
630+
assert len(algo2.get_call_history()) == 0
631+
632+
def test_multiple_entries_mixed_transformers(self) -> None:
633+
"""Executor should properly route data when processing mixed transformer configurations."""
634+
algo1 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
635+
algo2 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
636+
algo3 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0])
637+
638+
entry_none = AlgorithmEntry(algorithm=algo1, thresholds=[1.0])
639+
entry_t1 = AlgorithmEntry(algorithm=algo2, thresholds=[1.0], transformer=MockDataTransformer("T1", 10.0))
640+
entry_t2 = AlgorithmEntry(algorithm=algo3, thresholds=[1.0], transformer=MockDataTransformer("T2", 20.0))
641+
642+
# Provider yields [1.0, 1.0]
643+
provider: LabeledData[float] = _make_provider(2, name="data")
644+
solver: OnlineCpdSolver = OnlineCpdSolver()
645+
646+
executor: BenchmarkExecutor[float] = BenchmarkExecutor(
647+
entries=[entry_none, entry_t1, entry_t2],
648+
providers=[provider],
649+
solver=solver,
650+
)
651+
results = executor.execute()
652+
653+
assert len(results) == 3
654+
655+
# Verify specific algorithm histories to ensure they received correct streams
656+
assert algo1.get_call_history() == [1.0, 1.0] # No transformation
657+
assert algo2.get_call_history() == [11.0, 11.0] # 1.0 + 10.0
658+
assert algo3.get_call_history() == [21.0, 21.0] # 1.0 + 20.0

0 commit comments

Comments
 (0)