Skip to content

Commit 8954186

Browse files
EliEli
authored andcommitted
Dask partitions counted heuristically.
1 parent 26441b8 commit 8954186

1 file changed

Lines changed: 50 additions & 8 deletions

File tree

vtools/functions/error_detect.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,38 @@ def median_test(ts, level=4, filt_len=7, quantiles=(0.005, 0.095), copy=True):
136136
as_anomaly=True,
137137
)
138138

139+
def choose_npartitions(
140+
nrows,
141+
*,
142+
window,
143+
small_series_threshold=None,
144+
min_partition_size=None,
145+
target_partition_size=5000,
146+
):
147+
if nrows <= 0:
148+
raise ValueError("nrows must be positive")
149+
if window <= 0:
150+
raise ValueError("window must be positive")
151+
if target_partition_size <= 0:
152+
raise ValueError("target_partition_size must be positive")
153+
154+
if min_partition_size is None:
155+
min_partition_size = max(4 * window, 128)
156+
157+
if small_series_threshold is None:
158+
small_series_threshold = min_partition_size
159+
160+
if small_series_threshold <= 0:
161+
raise ValueError("small_series_threshold must be positive")
162+
if min_partition_size <= 0:
163+
raise ValueError("min_partition_size must be positive")
164+
165+
if nrows <= small_series_threshold:
166+
return 1
167+
168+
proposed = max(1, nrows // target_partition_size)
169+
max_safe = max(1, nrows // min_partition_size)
170+
return max(1, min(proposed, max_safe))
139171

140172
def median_test_oneside(
141173
ts,
@@ -154,13 +186,23 @@ def median_test_oneside(
154186
vals = ts[::-1]
155187
else:
156188
vals = ts
157-
vals = to_dataframe(vals)
158-
vals.columns = ["ts"]
189+
190+
if isinstance(vals, pd.Series):
191+
vals = vals.to_frame(name="ts")
192+
elif isinstance(vals, pd.DataFrame):
193+
vals = vals.copy()
194+
if vals.shape[1] != 1:
195+
raise ValueError("median_test_oneside expects single-column input")
196+
vals.columns = ["ts"]
197+
else:
198+
raise TypeError("Expected pandas Series or DataFrame")
159199

160200
vals["z"] = vals.ts.diff()
161201
min_periods = kappa * 2 - 1
162202

163-
dds = dd.from_pandas(vals, npartitions=50)
203+
window = kappa * 2
204+
npartitions = choose_npartitions(len(vals), window=window)
205+
dds = dd.from_pandas(vals, npartitions=npartitions)
164206
dds["my"] = dds["ts"].shift().rolling(kappa * 2, min_periods=min_periods).median()
165207
dds["mz"] = dds.z.shift().rolling(kappa * 2, min_periods=min_periods).median()
166208
dds["pred"] = dds.my + kappa * dds.mz
@@ -299,10 +341,7 @@ def median_test_twoside(
299341
warnings.filterwarnings("ignore")
300342

301343
vals = ts_out.to_numpy()
302-
# if ts_out.ndim == 1:
303-
# filt = medfilt(vals,filt_len)
304-
# else:
305-
# filt = np.apply_along_axis(medfilt,0,vals,filt_len)
344+
306345

307346
def mseq(flen):
308347
halflen = flen // 2
@@ -312,7 +351,10 @@ def mseq(flen):
312351

313352
medseq = mseq(filt_len)
314353

315-
dds = dd.from_pandas(ts_out, npartitions=1)
354+
kappa = filt_len // 2
355+
window = kappa * 2
356+
npartitions = choose_npartitions(len(vals), window=window)
357+
dds = dd.from_pandas(ts_out, npartitions=npartitions)
316358
filt = (
317359
dds.rolling(filt_len, center=True)
318360
.apply(lambda x: np.nanmedian(x[medseq]), raw=True, engine="numba")

0 commit comments

Comments
 (0)