Skip to content

Commit 47e4fe3

Browse files
author
ci bot
committed
Merge branch 'aarthy/monitor' into 'enterprise'
fix(monitors): prevent overconfident prediction bounds See merge request dkinternal/testgen/dataops-testgen!417
2 parents 388129f + 67a1afb commit 47e4fe3

4 files changed

Lines changed: 25 additions & 98 deletions

File tree

testgen/commands/test_thresholds_prediction.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,19 @@ def run(self) -> None:
110110
)
111111
test_prediction.extend([lower, upper, staleness, prediction])
112112
else:
113-
functional_table_type = group["functional_table_type"].iloc[0]
114-
is_cumulative = bool(
115-
functional_table_type and str(functional_table_type).startswith("cumulative")
116-
)
117113
lower, upper, prediction = compute_sarimax_threshold(
118114
history,
119115
sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium,
120116
min_lookback=self.test_suite.predict_min_lookback or 1,
121117
exclude_weekends=self.test_suite.predict_exclude_weekends,
122118
holiday_codes=self.test_suite.holiday_codes_list,
123119
schedule_tz=self.tz,
124-
is_cumulative=is_cumulative,
125120
)
121+
if test_type == "Volume_Trend":
122+
if lower is not None:
123+
lower = max(lower, 0.0)
124+
if upper is not None:
125+
upper = max(upper, 0.0)
126126
test_prediction.extend([lower, upper, None, prediction])
127127

128128
prediction_results.append(test_prediction)
@@ -263,13 +263,10 @@ def compute_sarimax_threshold(
263263
exclude_weekends: bool = False,
264264
holiday_codes: list[str] | None = None,
265265
schedule_tz: str | None = None,
266-
is_cumulative: bool = False,
267266
) -> tuple[float | None, float | None, str | None]:
268267
"""Compute SARIMAX-based thresholds for the next forecast point.
269268
270269
Returns (lower, upper, forecast_json) or (None, None, None) if insufficient data.
271-
For cumulative tables, the lower tolerance is floored at the last observed value
272-
so that any decrease in row count is detected as an anomaly.
273270
"""
274271
if len(history) < min_lookback:
275272
return None, None, None
@@ -299,12 +296,7 @@ def compute_sarimax_threshold(
299296

300297
if pd.isna(lower_tolerance) or pd.isna(upper_tolerance):
301298
return None, None, None
302-
303-
lower_tolerance = float(lower_tolerance)
304-
if is_cumulative:
305-
last_observed = float(history["result_signal"].iloc[-1])
306-
lower_tolerance = max(lower_tolerance, last_observed)
307-
308-
return lower_tolerance, float(upper_tolerance), forecast.to_json()
299+
else:
300+
return float(lower_tolerance), float(upper_tolerance), forecast.to_json()
309301
except NotEnoughData:
310302
return None, None, None

testgen/common/time_series_service.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33

44
import holidays
5+
import numpy as np
56
import pandas as pd
67
from statsmodels.tsa.statespace.sarimax import SARIMAX
78

@@ -94,7 +95,21 @@ def get_exog_flags(index: pd.DatetimeIndex) -> pd.DataFrame:
9495

9596
results = pd.DataFrame(index=forecast_index)
9697
results["mean"] = forecast.predicted_mean
97-
results["se"] = forecast.var_pred_mean ** 0.5
98+
99+
# SE estimation: take the max of three sources to prevent overconfident bounds.
100+
# 1. Model SE (var_pred_mean): can be artificially small when AR/MA nearly cancel
101+
# 2. Residual SE: the model's actual 1-step prediction errors (after Kalman burn-in)
102+
# 3. Raw diff SE: std of first-differences of the original data — captures inherent
103+
# point-to-point variability that the model may underestimate
104+
model_se = forecast.var_pred_mean ** 0.5
105+
order_sum = model.k_ar + model.k_diff + model.k_ma
106+
burn_in = max(order_sum, 3)
107+
usable_residuals = fitted_model.resid.iloc[burn_in:]
108+
resid_se = usable_residuals.std() if len(usable_residuals) >= 5 else 0.0
109+
raw_diffs = np.diff(history.iloc[:, 0].values)
110+
raw_diff_se = np.std(raw_diffs, ddof=1) if len(raw_diffs) > 1 else 0.0
111+
results["se"] = np.maximum(model_se, max(resid_se, raw_diff_se))
112+
98113
return results
99114

100115

testgen/template/prediction/get_historical_test_results.sql

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ WITH filtered_defs AS (
22
-- Filter definitions first to minimize join surface area
33
SELECT id,
44
test_suite_id,
5-
table_groups_id,
65
schema_name,
76
table_name,
87
column_name,
@@ -18,13 +17,8 @@ SELECT r.test_definition_id,
1817
CASE
1918
WHEN r.result_signal ~ '^-?[0-9]*\.?[0-9]+$' THEN r.result_signal::NUMERIC
2019
ELSE NULL
21-
END AS result_signal,
22-
dtc.functional_table_type
20+
END AS result_signal
2321
FROM test_results r
2422
JOIN filtered_defs d ON d.id = r.test_definition_id
25-
LEFT JOIN data_table_chars dtc
26-
ON dtc.table_groups_id = d.table_groups_id
27-
AND dtc.schema_name = d.schema_name
28-
AND dtc.table_name = d.table_name
2923
WHERE r.test_suite_id = :TEST_SUITE_ID
3024
ORDER BY r.test_time;

tests/unit/common/test_time_series_service.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import pytest
66

7-
from testgen.commands.test_thresholds_prediction import compute_freshness_threshold, compute_sarimax_threshold
7+
from testgen.commands.test_thresholds_prediction import compute_freshness_threshold
88
from testgen.common.freshness_service import (
99
MIN_FRESHNESS_GAPS,
1010
FreshnessThreshold,
@@ -634,77 +634,3 @@ def test_without_exclusions_timezone_has_no_effect(self):
634634
forecast_with_tz = get_sarimax_forecast(history, num_forecast=3, exclude_weekends=False, tz="America/New_York")
635635

636636
pd.testing.assert_frame_equal(forecast_no_tz, forecast_with_tz)
637-
638-
639-
class Test_ComputeSarimaxThreshold_CumulativeFloor:
640-
"""Tests for the cumulative table floor constraint in compute_sarimax_threshold."""
641-
642-
@staticmethod
643-
def _make_monotonic_history(n_days: int = 30, start_value: int = 1000, daily_growth: int = 100) -> pd.DataFrame:
644-
"""Create a monotonically increasing row count history (cumulative table)."""
645-
dates = pd.date_range("2026-01-01", periods=n_days, freq="1D")
646-
values = [start_value + i * daily_growth for i in range(n_days)]
647-
return pd.DataFrame({"result_signal": values}, index=dates)
648-
649-
def test_cumulative_floors_lower_at_last_observed(self):
650-
history = self._make_monotonic_history(n_days=30, start_value=1000, daily_growth=100)
651-
last_observed = float(history["result_signal"].iloc[-1])
652-
653-
lower, upper, prediction = compute_sarimax_threshold(
654-
history, PredictSensitivity.medium, is_cumulative=True,
655-
)
656-
657-
assert lower is not None
658-
assert upper is not None
659-
assert prediction is not None
660-
assert lower >= last_observed
661-
662-
def test_non_cumulative_allows_lower_below_last_observed(self):
663-
# With high variance, SARIMAX lower bound can drop below last observed
664-
rng = np.random.default_rng(42)
665-
dates = pd.date_range("2026-01-01", periods=30, freq="1D")
666-
# Trending up but with large noise — lower bound should be below last value
667-
values = [1000 + i * 50 + rng.normal(0, 200) for i in range(30)]
668-
history = pd.DataFrame({"result_signal": values}, index=dates)
669-
last_observed = float(history["result_signal"].iloc[-1])
670-
671-
lower, upper, prediction = compute_sarimax_threshold(
672-
history, PredictSensitivity.low, is_cumulative=False,
673-
)
674-
675-
assert lower is not None
676-
# With low sensitivity (z=-3.0) and high noise, lower should be below last value
677-
# This is the behavior we're protecting against with the cumulative floor
678-
assert lower < last_observed
679-
680-
def test_cumulative_does_not_affect_upper_tolerance(self):
681-
history = self._make_monotonic_history(n_days=30)
682-
683-
_, upper_cumulative, _ = compute_sarimax_threshold(
684-
history, PredictSensitivity.medium, is_cumulative=True,
685-
)
686-
_, upper_normal, _ = compute_sarimax_threshold(
687-
history, PredictSensitivity.medium, is_cumulative=False,
688-
)
689-
690-
assert upper_cumulative == upper_normal
691-
692-
def test_cumulative_with_insufficient_data_returns_none(self):
693-
history = self._make_monotonic_history(n_days=2)
694-
695-
lower, upper, prediction = compute_sarimax_threshold(
696-
history, PredictSensitivity.medium, min_lookback=5, is_cumulative=True,
697-
)
698-
699-
assert lower is None
700-
assert upper is None
701-
assert prediction is None
702-
703-
def test_cumulative_default_is_false(self):
704-
history = self._make_monotonic_history(n_days=30)
705-
706-
# Without is_cumulative param, should behave as non-cumulative
707-
lower_default, _, _ = compute_sarimax_threshold(history, PredictSensitivity.medium)
708-
lower_explicit, _, _ = compute_sarimax_threshold(history, PredictSensitivity.medium, is_cumulative=False)
709-
710-
assert lower_default == lower_explicit

0 commit comments

Comments
 (0)