Skip to content

Commit 39c6903

Browse files
committed
feat: add threshold policy for NoResetBenchmark
1 parent e982f55 commit 39c6903

2 files changed

Lines changed: 656 additions & 9 deletions

File tree

Lines changed: 240 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,276 @@
1+
# pysatl_cpd/benchmark/noreset/threshold_policy.py
2+
3+
"""
4+
Threshold policies for signal extraction in NoReset benchmark.
5+
6+
This module provides the ThresholdPolicy protocol and two concrete
7+
implementations: PointBasedPolicy and EventBasedPolicy.
8+
"""
9+
10+
__author__ = "Danil Totmyanin"
11+
__copyright__ = "Copyright (c) 2026 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
114
from collections.abc import Sequence
2-
from typing import Protocol, runtime_checkable
15+
from typing import Protocol, cast, runtime_checkable
16+
17+
import numpy as np
318

419
from pysatl_cpd.core.typedefs import UnivariateNumericArray
520

621

722
@runtime_checkable
823
class ThresholdPolicy(Protocol):
24+
"""
25+
Protocol for signal extraction from a detection function.
26+
27+
Implementations define how to convert a raw detection function array
28+
into a list of signal indices given a threshold and known change points.
29+
"""
30+
931
def apply(
1032
self,
1133
detection_function: UnivariateNumericArray,
1234
threshold: float,
13-
change_points: Sequence[int], # true, 1-based
14-
) -> list[int]: ... # 1-based signal indices
35+
change_points: Sequence[int],
36+
) -> list[int]:
37+
"""
38+
Extract signal indices from the detection function.
39+
40+
Parameters
41+
----------
42+
detection_function : UnivariateNumericArray
43+
Array of detection statistic values, one per time step.
44+
threshold : float
45+
Detection threshold.
46+
change_points : Sequence[int]
47+
True change point indices (1-based). Used by some policies
48+
to define delay windows.
49+
50+
Returns
51+
-------
52+
list[int]
53+
1-based indices where signals were detected.
54+
"""
55+
...
1556

1657

1758
class PointBasedPolicy:
59+
"""
60+
Signal extraction policy based on point-wise threshold comparison.
61+
62+
Any position where the detection function satisfies the threshold
63+
condition is considered a signal. The change_points argument is
64+
accepted for interface compatibility but is ignored.
65+
66+
Parameters
67+
----------
68+
strict : bool, default=True
69+
If True, signal condition is detection_function > threshold.
70+
If False, signal condition is detection_function >= threshold.
71+
"""
72+
1873
def __init__(self, strict: bool = True) -> None:
19-
return
74+
self.strict = strict
75+
76+
@staticmethod
77+
def _exceeds(arr: np.ndarray, threshold: float, strict: bool) -> np.ndarray:
78+
"""
79+
Check whether array values exceed threshold.
80+
81+
Parameters
82+
----------
83+
arr : np.ndarray
84+
Array of values to check.
85+
threshold : float
86+
Threshold value.
87+
strict : bool
88+
If True, uses strict inequality (>).
89+
If False, uses non-strict inequality (>=).
90+
91+
Returns
92+
-------
93+
np.ndarray
94+
Boolean array.
95+
"""
96+
return arr > threshold if strict else arr >= threshold
2097

2198
def apply(
2299
self,
23100
detection_function: UnivariateNumericArray,
24101
threshold: float,
25-
change_points: Sequence[int], # true, 1-based
102+
change_points: Sequence[int],
26103
) -> list[int]:
27-
raise NotImplementedError("Method `apply` is not implemented yet.")
104+
"""
105+
Return 1-based indices where detection function exceeds threshold.
106+
107+
Parameters
108+
----------
109+
detection_function : UnivariateNumericArray
110+
Array of detection statistic values.
111+
threshold : float
112+
Detection threshold.
113+
change_points : Sequence[int]
114+
Ignored. Present for interface compatibility.
115+
116+
Returns
117+
-------
118+
list[int]
119+
Sorted list of 1-based signal indices.
120+
"""
121+
if len(detection_function) == 0:
122+
return []
123+
124+
res = (np.where(self._exceeds(detection_function, threshold, self.strict))[0] + 1).tolist()
125+
return cast(list[int], res)
28126

29127

30128
class EventBasedPolicy:
129+
"""
130+
Signal extraction policy based on rising-edge detection with delay windows.
131+
132+
In normal (edge) mode, a signal is produced only when the detection
133+
function crosses the threshold from below (rising edge). Inside delay
134+
windows [true_cp, true_cp + max_delay] (1-based, inclusive), the policy
135+
switches to point-based mode to correctly capture detection delay.
136+
137+
The previous value used for edge detection (prev) is tracked continuously,
138+
including values inside delay windows (variant A). This means that if the
139+
detection function is above threshold at the end of a window, the first
140+
element after the window will not produce an edge signal.
141+
142+
For the first element, prev is treated as -inf (always below threshold).
143+
144+
Parameters
145+
----------
146+
max_delay : int
147+
Maximum allowable detection delay. Defines the right boundary of
148+
the delay window as true_cp + max_delay (inclusive). Must be >= 0.
149+
strict_edge : bool, default=True
150+
If True, rising edge condition requires detection_function > threshold.
151+
If False, condition is detection_function >= threshold.
152+
prev is always checked with strict inequality (prev < threshold).
153+
strict_point : bool, default=True
154+
If True, point-based condition in delay window is
155+
detection_function > threshold.
156+
If False, condition is detection_function >= threshold.
157+
158+
Raises
159+
------
160+
ValueError
161+
If max_delay is negative.
162+
"""
163+
31164
def __init__(
32165
self,
33166
max_delay: int,
34167
strict_edge: bool = True,
35168
strict_point: bool = True,
36169
) -> None:
37-
return
170+
if max_delay < 0:
171+
raise ValueError(f"max_delay must be non-negative, got {max_delay}")
172+
self.max_delay = max_delay
173+
self.strict_edge = strict_edge
174+
self.strict_point = strict_point
175+
176+
@staticmethod
177+
def _exceeds(arr: np.ndarray, threshold: float, strict: bool) -> np.ndarray:
178+
"""
179+
Check whether array values exceed threshold.
180+
181+
Parameters
182+
----------
183+
arr : np.ndarray
184+
Array of values to check.
185+
threshold : float
186+
Threshold value.
187+
strict : bool
188+
If True, uses strict inequality (>).
189+
If False, uses non-strict inequality (>=).
190+
191+
Returns
192+
-------
193+
np.ndarray
194+
Boolean array.
195+
"""
196+
return arr > threshold if strict else arr >= threshold
197+
198+
def _build_window_mask(
199+
self,
200+
length: int,
201+
change_points: Sequence[int],
202+
) -> np.ndarray:
203+
"""
204+
Build a boolean mask indicating which 0-based indices are in delay windows.
205+
206+
Uses cumsum trick for fully vectorized computation over change points.
207+
208+
Parameters
209+
----------
210+
length : int
211+
Length of the detection function array.
212+
change_points : Sequence[int]
213+
True change point indices (1-based).
214+
215+
Returns
216+
-------
217+
np.ndarray
218+
Boolean array of shape (length,) where True means the position
219+
is inside a delay window.
220+
"""
221+
if not change_points:
222+
return np.zeros(length, dtype=bool)
223+
224+
lefts = np.clip(np.array(change_points, dtype=int) - 1, 0, length - 1)
225+
rights = np.clip(lefts + self.max_delay, 0, length - 1)
226+
227+
marker = np.zeros(length + 1, dtype=int)
228+
np.add.at(marker, lefts, 1)
229+
np.add.at(marker, rights + 1, -1)
230+
return np.cumsum(marker)[:length] > 0
38231

39232
def apply(
40233
self,
41234
detection_function: UnivariateNumericArray,
42235
threshold: float,
43-
change_points: Sequence[int], # true, 1-based
236+
change_points: Sequence[int],
44237
) -> list[int]:
45-
raise NotImplementedError("Method `apply` is not implemented yet.")
238+
"""
239+
Extract signal indices using rising-edge detection with delay windows.
240+
241+
Fully vectorized implementation using numpy masks.
242+
243+
Parameters
244+
----------
245+
detection_function : UnivariateNumericArray
246+
Array of detection statistic values.
247+
threshold : float
248+
Detection threshold.
249+
change_points : Sequence[int]
250+
True change point indices (1-based). Used to define delay windows
251+
where point-based mode is applied.
252+
253+
Returns
254+
-------
255+
list[int]
256+
Sorted list of 1-based signal indices.
257+
"""
258+
n = len(detection_function)
259+
if n == 0:
260+
return []
261+
262+
window_mask = self._build_window_mask(n, change_points)
263+
264+
# prev[i] = df[i-1], prev[0] = -inf
265+
prev = np.empty(n, dtype=detection_function.dtype)
266+
prev[0] = float("-inf")
267+
prev[1:] = detection_function[:-1]
268+
269+
# edge signals: rising edge outside windows
270+
edge = (prev < threshold) & self._exceeds(detection_function, threshold, self.strict_edge) & ~window_mask
271+
272+
# point signals: threshold exceeded inside windows
273+
point = self._exceeds(detection_function, threshold, self.strict_point) & window_mask
274+
275+
res = (np.where(edge | point)[0] + 1).tolist()
276+
return cast(list[int], res)

0 commit comments

Comments
 (0)