Skip to content

Commit b8aea0f

Browse files
committed
feat: add SegmentAggregationMetric
1 parent f13336b commit b8aea0f

2 files changed

Lines changed: 124 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: ascii -*-
2+
"""
3+
Module for computing aggregated metrics over specific dataset transitions (bisegments).
4+
"""
5+
6+
__author__ = "Your Name"
7+
__copyright__ = "Copyright (c) 2026 PySATL project"
8+
__license__ = "SPDX-License-Identifier: MIT"
9+
10+
from collections.abc import Sequence
11+
from typing import Any, cast
12+
13+
from pysatl_cpd.benchmark.metrics.aggregation_metric import AggregationMetric
14+
from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric
15+
from pysatl_cpd.core.data_providers.dataset import PandasLabeledDataProvider, SegmentFilter
16+
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
17+
18+
19+
class SegmentAggregationMetric[TraceT: OnlineDetectionTrace[Any], ResultInT, ResultOutT](
20+
MultipleRunMetric[TraceT, PandasLabeledDataProvider, dict[str, ResultOutT]]
21+
):
22+
"""
23+
Evaluates an aggregation metric exclusively on specific transition types (bisegments).
24+
25+
This metric slices both the input data providers and their corresponding
26+
detection traces based on user-provided transition filters. It then groups
27+
these slices by transition type and computes the underlying base metric for
28+
each group independently.
29+
30+
Parameters
31+
----------
32+
base_agg_metric : AggregationMetric[TraceT, PandasLabeledDataProvider, ResultInT, ResultOutT]
33+
The underlying metric to compute (e.g., F1Metric, MeanDelayMetric) for each group.
34+
transition_filters : dict[str, SegmentFilter]
35+
A mapping where keys are human-readable transition names (e.g., 'A -> B')
36+
and values are callable predicates that filter bisegments.
37+
"""
38+
39+
def __init__(
40+
self,
41+
base_agg_metric: AggregationMetric[TraceT, PandasLabeledDataProvider, ResultInT, ResultOutT],
42+
transition_filters: dict[str, SegmentFilter],
43+
) -> None:
44+
self._base_agg_metric = base_agg_metric
45+
self._transition_filters = transition_filters
46+
47+
@property
48+
def base_agg_metric(self) -> AggregationMetric[TraceT, PandasLabeledDataProvider, ResultInT, ResultOutT]:
49+
"""
50+
Returns the underlying aggregation metric instance.
51+
"""
52+
53+
return self._base_agg_metric
54+
55+
def evaluate(self, runs: Sequence[tuple[TraceT, PandasLabeledDataProvider]]) -> dict[str, ResultOutT]:
56+
"""
57+
Evaluate the metric grouped by segment transitions.
58+
59+
Parameters
60+
----------
61+
runs : Sequence[tuple[TraceT, PandasLabeledDataProvider]]
62+
The full benchmark execution results.
63+
64+
Returns
65+
-------
66+
dict[str, Rout]
67+
A dictionary mapping the transition name to the computed metric result.
68+
If a transition filter matches no segments, it is omitted from the output.
69+
"""
70+
71+
grouped_runs: dict[str, list[tuple[TraceT, PandasLabeledDataProvider]]] = {
72+
name: [] for name in self._transition_filters
73+
}
74+
75+
for trace, provider in runs:
76+
for trans_name, filter_fn in self._transition_filters.items():
77+
sub_providers = provider.query_bisegments(filter_fn)
78+
sub_indices = provider.query_bisegments_indexes(filter_fn)
79+
80+
for sub_prov, (g_start, _, g_end) in zip(sub_providers, sub_indices, strict=False):
81+
sub_trace = cast(TraceT, trace.slice(g_start, g_end))
82+
grouped_runs[trans_name].append((sub_trace, sub_prov))
83+
84+
results: dict[str, ResultOutT] = {}
85+
for trans_name, sub_runs in grouped_runs.items():
86+
if sub_runs:
87+
results[trans_name] = self._base_agg_metric.evaluate(sub_runs)
88+
89+
return results

pysatl_cpd/core/online/online_detection_trace.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,41 @@ class OnlineDetectionTrace[StateT: OnlineAlgorithmState](DetectionTrace):
155155
learning_periods: list[tuple[int, int]] = field(default_factory=list)
156156
algorithm_states: list[StateT | None]
157157

158+
def slice(self, start: int, end: int) -> "OnlineDetectionTrace[StateT]":
159+
"""
160+
Create a new trace representing a slice of the current trace [start, end] (inclusive).
161+
Automatically recalculates all relative indices (change points, periods).
162+
"""
163+
new_df = self.detection_function[start : end + 1].copy()
164+
new_pt = self.processing_time[start : end + 1].copy()
165+
166+
new_states = self.algorithm_states[start : end + 1] if self.algorithm_states else []
167+
168+
def shift_points(pts: Sequence[int]) -> list[int]:
169+
return [p - start for p in pts if start <= p <= end]
170+
171+
def shift_periods(periods: list[tuple[int, int]]) -> list[tuple[int, int]]:
172+
res = []
173+
for p_start, p_end in periods:
174+
if p_end < start or p_start > end:
175+
continue
176+
res.append((max(0, p_start - start), min(end - start, p_end - start)))
177+
return res
178+
179+
return type(self)(
180+
algorithm_name=self.algorithm_name,
181+
configuration_hash=self.configuration_hash,
182+
threshold=self.threshold,
183+
detected_change_points=shift_points(self.detected_change_points),
184+
forced_change_points=shift_points(self.forced_change_points),
185+
signal_change_points=shift_points(self.signal_change_points),
186+
detection_function=new_df,
187+
processing_time=new_pt,
188+
algorithm_states=new_states,
189+
skip_periods=shift_periods(self.skip_periods),
190+
learning_periods=shift_periods(self.learning_periods),
191+
)
192+
158193
@classmethod
159194
def from_run(
160195
cls,

0 commit comments

Comments
 (0)