@@ -136,6 +136,38 @@ def median_test(ts, level=4, filt_len=7, quantiles=(0.005, 0.095), copy=True):
136136 as_anomaly = True ,
137137 )
138138
139+ def choose_npartitions (
140+ nrows ,
141+ * ,
142+ window ,
143+ small_series_threshold = None ,
144+ min_partition_size = None ,
145+ target_partition_size = 5000 ,
146+ ):
147+ if nrows <= 0 :
148+ raise ValueError ("nrows must be positive" )
149+ if window <= 0 :
150+ raise ValueError ("window must be positive" )
151+ if target_partition_size <= 0 :
152+ raise ValueError ("target_partition_size must be positive" )
153+
154+ if min_partition_size is None :
155+ min_partition_size = max (4 * window , 128 )
156+
157+ if small_series_threshold is None :
158+ small_series_threshold = min_partition_size
159+
160+ if small_series_threshold <= 0 :
161+ raise ValueError ("small_series_threshold must be positive" )
162+ if min_partition_size <= 0 :
163+ raise ValueError ("min_partition_size must be positive" )
164+
165+ if nrows <= small_series_threshold :
166+ return 1
167+
168+ proposed = max (1 , nrows // target_partition_size )
169+ max_safe = max (1 , nrows // min_partition_size )
170+ return max (1 , min (proposed , max_safe ))
139171
140172def median_test_oneside (
141173 ts ,
@@ -154,13 +186,23 @@ def median_test_oneside(
154186 vals = ts [::- 1 ]
155187 else :
156188 vals = ts
157- vals = to_dataframe (vals )
158- vals .columns = ["ts" ]
189+
190+ if isinstance (vals , pd .Series ):
191+ vals = vals .to_frame (name = "ts" )
192+ elif isinstance (vals , pd .DataFrame ):
193+ vals = vals .copy ()
194+ if vals .shape [1 ] != 1 :
195+ raise ValueError ("median_test_oneside expects single-column input" )
196+ vals .columns = ["ts" ]
197+ else :
198+ raise TypeError ("Expected pandas Series or DataFrame" )
159199
160200 vals ["z" ] = vals .ts .diff ()
161201 min_periods = kappa * 2 - 1
162202
163- dds = dd .from_pandas (vals , npartitions = 50 )
203+ window = kappa * 2
204+ npartitions = choose_npartitions (len (vals ), window = window )
205+ dds = dd .from_pandas (vals , npartitions = npartitions )
164206 dds ["my" ] = dds ["ts" ].shift ().rolling (kappa * 2 , min_periods = min_periods ).median ()
165207 dds ["mz" ] = dds .z .shift ().rolling (kappa * 2 , min_periods = min_periods ).median ()
166208 dds ["pred" ] = dds .my + kappa * dds .mz
@@ -299,10 +341,7 @@ def median_test_twoside(
299341 warnings .filterwarnings ("ignore" )
300342
301343 vals = ts_out .to_numpy ()
302- # if ts_out.ndim == 1:
303- # filt = medfilt(vals,filt_len)
304- # else:
305- # filt = np.apply_along_axis(medfilt,0,vals,filt_len)
344+
306345
307346 def mseq (flen ):
308347 halflen = flen // 2
@@ -312,7 +351,10 @@ def mseq(flen):
312351
313352 medseq = mseq (filt_len )
314353
315- dds = dd .from_pandas (ts_out , npartitions = 1 )
354+ kappa = filt_len // 2
355+ window = kappa * 2
356+ npartitions = choose_npartitions (len (vals ), window = window )
357+ dds = dd .from_pandas (ts_out , npartitions = npartitions )
316358 filt = (
317359 dds .rolling (filt_len , center = True )
318360 .apply (lambda x : np .nanmedian (x [medseq ]), raw = True , engine = "numba" )
0 commit comments