Skip to content

Commit 2624d73

Browse files
committed
refactor: ruff
1 parent 328f1e3 commit 2624d73

2 files changed

Lines changed: 45 additions & 30 deletions

File tree

examples/noreset_shewhart.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# 1. Dataset generation
2525
# ---------------------------------------------------------------------------
2626

27+
2728
def generate_dataset(
2829
n: int,
2930
series_length: int = 200,
@@ -47,11 +48,13 @@ def generate_dataset(
4748

4849
df = pd.DataFrame({"value": data, "segment": segments})
4950

50-
seg_info = pd.DataFrame({
51-
"segment": [0, 1],
52-
"start": [0, change_point],
53-
"end": [change_point - 1, series_length - 1],
54-
})
51+
seg_info = pd.DataFrame(
52+
{
53+
"segment": [0, 1],
54+
"start": [0, change_point],
55+
"end": [change_point - 1, series_length - 1],
56+
}
57+
)
5558

5659
provider = PandasLabeledDataProvider(
5760
dataset=df,
@@ -79,11 +82,13 @@ def generate_arl_dataset(
7982
data = rng.normal(mu, sigma, size=series_length)
8083

8184
df = pd.DataFrame({"value": data, "segment": 0})
82-
seg_info = pd.DataFrame({
83-
"segment": [0],
84-
"start": [0],
85-
"end": [series_length - 1],
86-
})
85+
seg_info = pd.DataFrame(
86+
{
87+
"segment": [0],
88+
"start": [0],
89+
"end": [series_length - 1],
90+
}
91+
)
8792

8893
provider = PandasLabeledDataProvider(
8994
dataset=df,
@@ -100,6 +105,7 @@ def generate_arl_dataset(
100105
# 2. Main benchmark
101106
# ---------------------------------------------------------------------------
102107

108+
103109
def main() -> None:
104110
# --- Parameters ---
105111
N_SERIES = 25
@@ -118,16 +124,26 @@ def main() -> None:
118124

119125
# --- Generate datasets ---
120126
providers = generate_dataset(
121-
n=N_SERIES, series_length=SERIES_LENGTH, change_point=CHANGE_POINT,
122-
mu_before=MU_BEFORE, mu_after=MU_AFTER, sigma=SIGMA, seed=42,
127+
n=N_SERIES,
128+
series_length=SERIES_LENGTH,
129+
change_point=CHANGE_POINT,
130+
mu_before=MU_BEFORE,
131+
mu_after=MU_AFTER,
132+
sigma=SIGMA,
133+
seed=42,
123134
)
124135
arl_providers = generate_arl_dataset(
125-
n=N_SERIES, series_length=SERIES_LENGTH,
126-
mu=MU_BEFORE, sigma=SIGMA, seed=42,
136+
n=N_SERIES,
137+
series_length=SERIES_LENGTH,
138+
mu=MU_BEFORE,
139+
sigma=SIGMA,
140+
seed=42,
127141
)
128142

129143
print(f"Algorithm: ShewhartControlChart(learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
130-
print(f"Dataset (NoReset): {N_SERIES} series, length={SERIES_LENGTH}, cp={CHANGE_POINT}, shift={MU_AFTER - MU_BEFORE:.1f}σ")
144+
print(
145+
f"Dataset (NoReset): {N_SERIES} series, length={SERIES_LENGTH}, cp={CHANGE_POINT}, shift={MU_AFTER - MU_BEFORE:.1f}σ"
146+
)
131147
print(f"Dataset (ARL): {N_SERIES} series, length={SERIES_LENGTH}, no change points")
132148
print(f"Error margin: {ERROR_MARGIN}")
133149
print("-" * 115)
@@ -139,9 +155,7 @@ def main() -> None:
139155
solver = OnlineCpdSolver()
140156

141157
entry = OnlineBenchmarkEntry(
142-
algorithm=algorithm,
143-
thresholds=LinspaceThresholds(start=0, stop=7, num=30),
144-
entry_name="Shewhart"
158+
algorithm=algorithm, thresholds=LinspaceThresholds(start=0, stop=7, num=30), entry_name="Shewhart"
145159
)
146160

147161
# ==========================================
@@ -172,7 +186,7 @@ def main() -> None:
172186
# ==========================================
173187
runner_arl = NoResetBenchmark(
174188
solver=solver,
175-
policy=PointBasedPolicy(strict=True), # Быстрая поточечная экстракция
189+
policy=PointBasedPolicy(strict=True), # Быстрая поточечная экстракция
176190
metrics={"arl": ARLMetric()},
177191
dump_dir="benchmark_cache/arl",
178192
verbose=True,

pysatl_cpd/benchmark/noreset/noreset_benchmark_runner.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
"""
32
NoReset benchmark runner implementation.
43
@@ -19,28 +18,27 @@
1918
import pandas as pd
2019
from tqdm.auto import tqdm
2120

22-
from pysatl_cpd.analysis.labeled_data import LabeledData
2321
from pysatl_cpd.benchmark.core.benchmark_executor import BenchmarkExecutor
2422
from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric
2523
from pysatl_cpd.benchmark.noreset.noreset_detection_trace import NoResetDetectionTrace
2624
from pysatl_cpd.benchmark.noreset.threshold_policy import ThresholdPolicy
2725
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
28-
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
29-
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
30-
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
31-
3226
from pysatl_cpd.core.data_providers.dataset import (
3327
AnnotationFilter,
3428
Dataset,
3529
PandasLabeledDataProvider,
3630
SegmentFilter,
3731
)
32+
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
33+
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
34+
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
3835

3936

4037
class ThresholdRange(Protocol):
4138
"""Protocol for generating a sequence of thresholds."""
42-
def get_thresholds(self) -> list[float]:
43-
...
39+
40+
def get_thresholds(self) -> list[float]: ...
41+
4442

4543
@dataclasses.dataclass
4644
class ManualThresholds(ThresholdRange):
@@ -49,6 +47,7 @@ class ManualThresholds(ThresholdRange):
4947
def get_thresholds(self) -> list[float]:
5048
return self.thresholds
5149

50+
5251
@dataclasses.dataclass
5352
class LinspaceThresholds(ThresholdRange):
5453
start: float
@@ -58,17 +57,19 @@ class LinspaceThresholds(ThresholdRange):
5857
def get_thresholds(self) -> list[float]:
5958
return np.linspace(self.start, self.stop, self.num).tolist()
6059

60+
6161
class DataTransformer(Protocol):
6262
"""Protocol for transforming data providers before running the algorithm."""
63-
def transform(self, provider: PandasLabeledDataProvider) -> PandasLabeledDataProvider:
64-
...
63+
64+
def transform(self, provider: PandasLabeledDataProvider) -> PandasLabeledDataProvider: ...
6565

6666

6767
@dataclasses.dataclass
6868
class OnlineBenchmarkEntry:
6969
"""
7070
Configuration entry for running an online algorithm in the benchmark.
7171
"""
72+
7273
algorithm: OnlineAlgorithm
7374
thresholds: ThresholdRange
7475
data_transformer: DataTransformer | None = None
@@ -127,7 +128,7 @@ def run(
127128
Mapping of entry_name to a DataFrame containing thresholds and metrics.
128129
"""
129130
if not providers:
130-
return {entry.entry_name : pd.DataFrame() for entry in entries}
131+
return {entry.entry_name: pd.DataFrame() for entry in entries}
131132

132133
inf_entries: list[AlgorithmEntry[Any, Any, Any]] = []
133134
for entry in entries:

0 commit comments

Comments
 (0)