Skip to content

Commit 9687d36

Browse files
committed
tests: SegmentAggregationMetric
1 parent fec0b86 commit 9687d36

3 files changed

Lines changed: 217 additions & 0 deletions

File tree

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Mock pandas data provider for testing segment logic.
5+
"""
6+
7+
__author__ = "Danil Totmyanin"
8+
__copyright__ = "Copyright (c) 2026 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
11+
12+
from pysatl_cpd.core.data_providers.dataset import PandasLabeledDataProvider, SegmentFilter
13+
14+
15+
class MockPandasLabeledDataProvider(PandasLabeledDataProvider):
16+
"""
17+
Mock implementation of PandasLabeledDataProvider for testing segment slicing.
18+
19+
Bypasses pandas DataFrame initialization entirely and returns pre-configured
20+
bisegments and indices when queried.
21+
"""
22+
23+
def __init__(self, name: str = "MockPandasProvider") -> None:
24+
self._name = name
25+
self.mock_bisegments: list[PandasLabeledDataProvider] = []
26+
self.mock_indexes: list[tuple[int, int, int]] = []
27+
28+
@property
29+
def name(self) -> str:
30+
return self._name
31+
32+
def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list[PandasLabeledDataProvider]:
33+
"""Return pre-configured bisegments."""
34+
return self.mock_bisegments
35+
36+
def query_bisegments_indexes(self, filter_fn: SegmentFilter | None = None) -> list[tuple[int, int, int]]:
37+
"""Return pre-configured bisegment indices."""
38+
return self.mock_indexes

tests/mocks/core/online/online_detection_trace.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,14 @@ def __init__(self, detected_change_points: Sequence[int]):
4242
detection_function=np.array([]),
4343
algorithm_states=[],
4444
)
45+
46+
def slice(self, start: int, end: int) -> "MockOnlineDetectionTrace":
47+
"""
48+
Mock implementation of slice.
49+
50+
Returns a new MockOnlineDetectionTrace containing only the change points
51+
that fall within [start, end], shifted relative to `start`.
52+
"""
53+
54+
shifted_cps: list[int] = [cp - start for cp in self.detected_change_points if start <= cp <= end]
55+
return MockOnlineDetectionTrace(detected_change_points=shifted_cps)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Unit tests for SegmentAggregationMetric.
5+
6+
Verifies that the metric correctly slices traces and providers according to
7+
transition filters, groups them by transition name, and delegates evaluation
8+
to the base aggregation metric.
9+
"""
10+
11+
__author__ = "Danil Totmyanin"
12+
__copyright__ = "Copyright (c) 2026 PySATL project"
13+
__license__ = "SPDX-License-Identifier: MIT"
14+
15+
from collections.abc import Sequence
16+
17+
from pysatl_cpd.benchmark.metrics.segment_aggregation_metric import SegmentAggregationMetric
18+
from pysatl_cpd.core.data_providers.dataset import PandasLabeledDataProvider, SegmentFilter, SegmentInfo
19+
from tests.mocks.analysis.metrics.run_metric import MockRunMetric
20+
from tests.mocks.benchmark.metrics.aggregation_metric import MockAggregationMetric
21+
from tests.mocks.core.data_providers.pandas_provider import MockPandasLabeledDataProvider
22+
from tests.mocks.core.online.online_detection_trace import MockOnlineDetectionTrace
23+
24+
25+
def dummy_filter(pair: tuple[SegmentInfo, SegmentInfo]) -> bool:
26+
"""A dummy segment filter for testing."""
27+
return True
28+
29+
30+
class TestSegmentAggregationMetricInit:
31+
"""Tests for SegmentAggregationMetric initialization."""
32+
33+
def test_initialization_stores_properties(self) -> None:
34+
"""Metric should store the base metric and transition filters."""
35+
base_run_metric: MockRunMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = MockRunMetric([1.0])
36+
base_agg_metric: MockAggregationMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = (
37+
MockAggregationMetric(base_run_metric)
38+
)
39+
filters: dict[str, SegmentFilter] = {"A->B": dummy_filter}
40+
41+
metric: SegmentAggregationMetric[MockOnlineDetectionTrace, float, float] = SegmentAggregationMetric(
42+
base_agg_metric=base_agg_metric,
43+
transition_filters=filters,
44+
)
45+
46+
assert metric.base_agg_metric is base_agg_metric
47+
assert metric._transition_filters == filters
48+
49+
50+
class TestSegmentAggregationMetricEvaluate:
51+
"""Tests for the evaluate() method of SegmentAggregationMetric."""
52+
53+
def test_evaluate_empty_runs(self) -> None:
54+
"""Evaluating with an empty runs list should yield an empty result dict."""
55+
base_run_metric: MockRunMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = MockRunMetric([1.0])
56+
base_agg_metric: MockAggregationMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = (
57+
MockAggregationMetric(base_run_metric)
58+
)
59+
filters: dict[str, SegmentFilter] = {"A->B": dummy_filter}
60+
61+
metric: SegmentAggregationMetric[MockOnlineDetectionTrace, float, float] = SegmentAggregationMetric(
62+
base_agg_metric=base_agg_metric,
63+
transition_filters=filters,
64+
)
65+
66+
result: dict[str, float] = metric.evaluate([])
67+
68+
# If no runs provided, no sub_runs are created, so the result should be empty
69+
assert result == {}
70+
assert len(base_agg_metric.aggregate_calls) == 0
71+
72+
def test_evaluate_filters_with_no_matches_are_omitted(self) -> None:
73+
"""Filters that produce no bisegments should not appear in the final output."""
74+
base_run_metric: MockRunMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = MockRunMetric([1.0])
75+
base_agg_metric: MockAggregationMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = (
76+
MockAggregationMetric(base_run_metric)
77+
)
78+
filters: dict[str, SegmentFilter] = {"A->B": dummy_filter, "C->D": dummy_filter}
79+
80+
metric: SegmentAggregationMetric[MockOnlineDetectionTrace, float, float] = SegmentAggregationMetric(
81+
base_agg_metric=base_agg_metric,
82+
transition_filters=filters,
83+
)
84+
85+
trace = MockOnlineDetectionTrace(detected_change_points=[])
86+
provider = MockPandasLabeledDataProvider(name="MainProvider")
87+
88+
# We configure the provider to return nothing for any query
89+
provider.mock_bisegments = []
90+
provider.mock_indexes = []
91+
92+
runs: Sequence[tuple[MockOnlineDetectionTrace, PandasLabeledDataProvider]] = [(trace, provider)]
93+
94+
result: dict[str, float] = metric.evaluate(runs)
95+
96+
assert result == {}
97+
assert len(base_agg_metric.aggregate_calls) == 0
98+
99+
def test_evaluate_groups_and_delegates_correctly(self) -> None:
100+
"""
101+
Metric should slice traces, group by filter name, and call the base
102+
metric evaluate() with the correctly grouped sub-runs.
103+
"""
104+
# 1. Setup base metrics. Our mock aggregation metric just sums the results.
105+
# The base run metric returns 1.0 for every call.
106+
base_run_metric: MockRunMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = MockRunMetric([1.0])
107+
base_agg_metric: MockAggregationMetric[MockOnlineDetectionTrace, PandasLabeledDataProvider] = (
108+
MockAggregationMetric(base_run_metric)
109+
)
110+
111+
filters: dict[str, SegmentFilter] = {
112+
"A->B": dummy_filter,
113+
"C->D": dummy_filter,
114+
}
115+
116+
metric: SegmentAggregationMetric[MockOnlineDetectionTrace, float, float] = SegmentAggregationMetric(
117+
base_agg_metric=base_agg_metric,
118+
transition_filters=filters,
119+
)
120+
121+
# 2. Setup traces and providers
122+
main_trace = MockOnlineDetectionTrace(detected_change_points=[15, 45])
123+
main_provider = MockPandasLabeledDataProvider(name="MainProvider")
124+
125+
# Let's say query_bisegments returns two pieces:
126+
# First piece: index [10, 15, 20] (covers cp at 15)
127+
# Second piece: index [40, 45, 50] (covers cp at 45)
128+
sub_prov1 = MockPandasLabeledDataProvider(name="Sub1")
129+
sub_prov2 = MockPandasLabeledDataProvider(name="Sub2")
130+
131+
main_provider.mock_bisegments = [sub_prov1, sub_prov2]
132+
main_provider.mock_indexes = [(10, 15, 20), (40, 45, 50)]
133+
134+
runs: Sequence[tuple[MockOnlineDetectionTrace, PandasLabeledDataProvider]] = [(main_trace, main_provider)]
135+
136+
# 3. Execute
137+
result: dict[str, float] = metric.evaluate(runs)
138+
139+
# 4. Verify results
140+
# The provider is queried TWICE (once for 'A->B', once for 'C->D').
141+
# Each query returns 2 sub-providers.
142+
# So 'A->B' group gets 2 sub-runs, 'C->D' group gets 2 sub-runs.
143+
# Since base_run_metric returns 1.0 for each run, aggregate sum is 2.0 for each group.
144+
assert "A->B" in result
145+
assert "C->D" in result
146+
assert result["A->B"] == 2.0
147+
assert result["C->D"] == 2.0
148+
149+
# Verify that slicing happened correctly:
150+
# The run metric was called 4 times total (2 for 'A->B', 2 for 'C->D').
151+
assert len(base_run_metric.calls) == 4
152+
153+
# Let's inspect the first call: it should be sub_prov1 and a sliced trace.
154+
trace1_sliced, prov1_sliced = base_run_metric.calls[0]
155+
assert isinstance(trace1_sliced, MockOnlineDetectionTrace)
156+
assert trace1_sliced.algorithm_name == "MockOnlineAlgorithm"
157+
# The slice was [10, 20]. The original trace had [15, 45].
158+
# Sliced trace should have 15 shifted by 10 -> [5].
159+
assert trace1_sliced.detected_change_points == [5]
160+
assert prov1_sliced is sub_prov1
161+
162+
# Let's inspect the second call: it should be sub_prov2.
163+
trace2_sliced, prov2_sliced = base_run_metric.calls[1]
164+
assert isinstance(trace2_sliced, MockOnlineDetectionTrace)
165+
# The slice was [40, 50]. The original trace had [15, 45].
166+
# Sliced trace should have 45 shifted by 40 -> [5].
167+
assert trace2_sliced.detected_change_points == [5]
168+
assert prov2_sliced is sub_prov2

0 commit comments

Comments
 (0)