|
25 | 25 | from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver |
26 | 26 | from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace |
27 | 27 | from tests.mocks.algorithms.online.simple import MockOnlineAlgorithm |
| 28 | +from tests.mocks.core.data_transformers.data_transformer import MockDataTransformer |
28 | 29 |
|
29 | 30 |
|
30 | 31 | def _make_provider( |
@@ -519,3 +520,139 @@ def test_multiple_thresholds_create_separate_pickle_files(self, tmp_path: Path) |
519 | 520 | with open(registry_path, encoding="utf-8") as f: |
520 | 521 | rows: list[dict[str, str]] = list(csv.DictReader(f)) |
521 | 522 | assert len(rows) == 3 |
| 523 | + |
| 524 | + |
| 525 | +# --------------------------------------------------------------------------- |
| 526 | +# 6. Data Transformers |
| 527 | +# --------------------------------------------------------------------------- |
| 528 | +class TestBenchmarkExecutorTransformers: |
| 529 | + """Tests for the DataTransformer integration in BenchmarkExecutor.""" |
| 530 | + |
| 531 | + def test_transformer_modifies_data_passed_to_algorithm(self) -> None: |
| 532 | + """Executor should pass transformed data, not raw data, to the solver.""" |
| 533 | + algo = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 534 | + transformer = MockDataTransformer(name="T1", add_value=5.0) |
| 535 | + entry = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=transformer) |
| 536 | + |
| 537 | + # Original provider with zeros |
| 538 | + provider: LabeledData[float] = LabeledData(raw_data=[0.0, 0.0, 0.0], change_points=[], name="data") |
| 539 | + solver: OnlineCpdSolver = OnlineCpdSolver() |
| 540 | + |
| 541 | + executor: BenchmarkExecutor[float] = BenchmarkExecutor( |
| 542 | + entries=[entry], |
| 543 | + providers=[provider], |
| 544 | + solver=solver, |
| 545 | + ) |
| 546 | + executor.execute() |
| 547 | + |
| 548 | + # The algorithm should have received [5.0, 5.0, 5.0] |
| 549 | + history: list[float] = algo.get_call_history() |
| 550 | + assert history == [5.0, 5.0, 5.0] |
| 551 | + |
| 552 | + def test_record_metadata_uses_transformer_name_and_hash(self) -> None: |
| 553 | + """Benchmark record should inherit the full name and hash from the Entry.""" |
| 554 | + algo = MockOnlineAlgorithm[float](name="BaseAlgo", return_sequence=[0.0]) |
| 555 | + transformer = MockDataTransformer(name="MyTF", add_value=1.0) |
| 556 | + entry = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=transformer) |
| 557 | + |
| 558 | + provider: LabeledData[float] = _make_provider(3, name="d1") |
| 559 | + solver: OnlineCpdSolver = OnlineCpdSolver() |
| 560 | + |
| 561 | + executor: BenchmarkExecutor[float] = BenchmarkExecutor( |
| 562 | + entries=[entry], |
| 563 | + providers=[provider], |
| 564 | + solver=solver, |
| 565 | + ) |
| 566 | + results = executor.execute() |
| 567 | + record: BenchmarkRecord = results[0][0] |
| 568 | + |
| 569 | + # Name should be combined |
| 570 | + assert record.algorithm == "BaseAlgo_MyTF" |
| 571 | + assert record.algorithm == entry.full_name |
| 572 | + |
| 573 | + # Hash should match the entry's composite hash |
| 574 | + assert record.configuration_hash == entry.full_hash |
| 575 | + |
| 576 | + def test_caching_separates_different_transformers(self, tmp_path: Path) -> None: |
| 577 | + """Using the same algorithm but different transformers should create separate cache records.""" |
| 578 | + algo = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 579 | + |
| 580 | + entry_clean = AlgorithmEntry(algorithm=algo, thresholds=[1.0], transformer=None) |
| 581 | + entry_transformed = AlgorithmEntry( |
| 582 | + algorithm=algo, thresholds=[1.0], transformer=MockDataTransformer(name="T1", add_value=2.0) |
| 583 | + ) |
| 584 | + |
| 585 | + provider: LabeledData[float] = _make_provider(3, name="data") |
| 586 | + solver: OnlineCpdSolver = OnlineCpdSolver() |
| 587 | + |
| 588 | + executor: BenchmarkExecutor[float] = BenchmarkExecutor( |
| 589 | + entries=[entry_clean, entry_transformed], |
| 590 | + providers=[provider], |
| 591 | + solver=solver, |
| 592 | + dump_dir=tmp_path, |
| 593 | + ) |
| 594 | + executor.execute() |
| 595 | + |
| 596 | + # Should produce two distinct pickle files |
| 597 | + pkl_files: list[Path] = list(tmp_path.glob("*.pkl")) |
| 598 | + assert len(pkl_files) == 2 |
| 599 | + |
| 600 | + # Names of the files should reflect the different algorithm representations |
| 601 | + file_names: str = " ".join(f.name for f in pkl_files) |
| 602 | + assert "A_" in file_names |
| 603 | + assert "A_T1_" in file_names |
| 604 | + |
| 605 | + def test_transformer_is_called_even_on_cache_hit(self, tmp_path: Path) -> None: |
| 606 | + """Transformer should be applied before checking cache, incrementing its call count.""" |
| 607 | + transformer = MockDataTransformer(name="T1", add_value=1.0) |
| 608 | + provider: LabeledData[float] = _make_provider(3, name="data") |
| 609 | + solver: OnlineCpdSolver = OnlineCpdSolver() |
| 610 | + |
| 611 | + # First run to populate cache |
| 612 | + algo1 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 613 | + entry1 = AlgorithmEntry(algorithm=algo1, thresholds=[1.0], transformer=transformer) |
| 614 | + exec1: BenchmarkExecutor[float] = BenchmarkExecutor([entry1], [provider], solver, tmp_path) |
| 615 | + exec1.execute() |
| 616 | + |
| 617 | + assert transformer.call_count == 1 |
| 618 | + assert len(algo1.get_call_history()) == 3 |
| 619 | + |
| 620 | + # Second run with cache hit (using a fresh algorithm instance to verify it doesn't run) |
| 621 | + algo2 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 622 | + entry2 = AlgorithmEntry(algorithm=algo2, thresholds=[1.0], transformer=transformer) |
| 623 | + exec2: BenchmarkExecutor[float] = BenchmarkExecutor([entry2], [provider], solver, tmp_path) |
| 624 | + exec2.execute() |
| 625 | + |
| 626 | + # Transformer is still called during iteration |
| 627 | + assert transformer.call_count == 2 |
| 628 | + |
| 629 | + # But the solver/algorithm was skipped due to cache hit |
| 630 | + assert len(algo2.get_call_history()) == 0 |
| 631 | + |
| 632 | + def test_multiple_entries_mixed_transformers(self) -> None: |
| 633 | + """Executor should properly route data when processing mixed transformer configurations.""" |
| 634 | + algo1 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 635 | + algo2 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 636 | + algo3 = MockOnlineAlgorithm[float](name="A", return_sequence=[0.0]) |
| 637 | + |
| 638 | + entry_none = AlgorithmEntry(algorithm=algo1, thresholds=[1.0]) |
| 639 | + entry_t1 = AlgorithmEntry(algorithm=algo2, thresholds=[1.0], transformer=MockDataTransformer("T1", 10.0)) |
| 640 | + entry_t2 = AlgorithmEntry(algorithm=algo3, thresholds=[1.0], transformer=MockDataTransformer("T2", 20.0)) |
| 641 | + |
| 642 | + # Provider yields [1.0, 1.0] |
| 643 | + provider: LabeledData[float] = _make_provider(2, name="data") |
| 644 | + solver: OnlineCpdSolver = OnlineCpdSolver() |
| 645 | + |
| 646 | + executor: BenchmarkExecutor[float] = BenchmarkExecutor( |
| 647 | + entries=[entry_none, entry_t1, entry_t2], |
| 648 | + providers=[provider], |
| 649 | + solver=solver, |
| 650 | + ) |
| 651 | + results = executor.execute() |
| 652 | + |
| 653 | + assert len(results) == 3 |
| 654 | + |
| 655 | + # Verify specific algorithm histories to ensure they received correct streams |
| 656 | + assert algo1.get_call_history() == [1.0, 1.0] # No transformation |
| 657 | + assert algo2.get_call_history() == [11.0, 11.0] # 1.0 + 10.0 |
| 658 | + assert algo3.get_call_history() == [21.0, 21.0] # 1.0 + 20.0 |
0 commit comments