Skip to content

Commit 4296b3c

Browse files
yadavay-amzncloud-fan
authored andcommitted
[SPARK-57220][SQL] Extend block-chunked segment-tree window frame to shrinking frames
### What changes were proposed in this pull request? Extends `SegmentTreeWindowFunctionFrame` (introduced in [SPARK-56546](https://issues.apache.org/jira/browse/SPARK-56546) for sliding aggregates) to also handle **shrinking** frames of the form `... ROWS/RANGE BETWEEN ` *lower* ` AND UNBOUNDED FOLLOWING`. The class is parameterized with `ubound: Option[BoundOrdering]` (`None` = shrinking, `Some(ub)` = sliding) and a `fallbackFactory` for the small-partition path so the same machinery (build, spill via `TaskMemoryManager`, eligibility allowlist, SQLMetrics) serves both shapes. The dispatcher in `WindowEvaluatorFactoryBase` gains a shrinking-frame branch that consults the existing `eligibleForSegTree` gate and, on success, builds the unified frame with `ubound = None`. ### Why are the changes needed? The legacy `UnboundedFollowingWindowFunctionFrame` recomputes the suffix aggregate from scratch for every output row — O(n · (n - 1) / 2). Its own scaladoc acknowledges this (`WindowFunctionFrame.scala:636`): > This is a very expensive operator to use, O(n * (n - 1) / 2), because we need to maintain a buffer and must do full recalculation after each row. The segment tree built by SPARK-56546 already supports arbitrary `[lower, upper)` queries; routing shrinking frames into it is purely a dispatch + parameter change. Shrinking frames are common in retention / cohort / "remaining-lifetime" analytics. For partitions of 100K+ rows the legacy O(N²) path is infeasible. ### Does this PR introduce _any_ user-facing change? No. - Same opt-in conf: `spark.sql.window.segmentTree.enabled` (default `false`). - Same eligibility allowlist (DeclarativeAggregate with `mergeExpressions`, no FILTER, no DISTINCT). - Same `minPartitionRows` fallback. The fallback type is now shape-dependent: `SlidingWindowFunctionFrame` for moving frames, `UnboundedFollowingWindowFunctionFrame` for shrinking frames. - No analyzer / SQL grammar / plan-shape changes. ### How was this patch tested? New `UnboundedFollowingSegmentTreeSuite` mirrors `SegmentTreeWindowFunctionSuite`'s structure with oracle-vs-naive equivalence over ROWS/RANGE frames, NULL/NaN, multi-aggregate, type coverage, and fallback paths. All existing window suites still pass with the unified rewrite. Benchmark — `UnboundedFollowingWindowBenchmark` on Linux x86_64 (Intel Xeon Platinum 8259CL 2.50GHz, OpenJDK 17.0.19+10-LTS), single-partition `SUM(v) OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)`: | N | naive (best) | segtree (best) | speedup | |------|-------------:|---------------:|--------:| | 5K | 620 ms | 73 ms | 8.5× | | 10K | 2,471 ms | 110 ms | 22.5× | | 25K | 14,259 ms | 119 ms | 119.3× | | 50K | 57,022 ms | 181 ms | 314.2× | | 100K | (~4 min) | 269 ms | — | | 200K | (~16 min) | 480 ms | — | Naive is clean O(N²); segtree is sub-linear. Full results checked in at `sql/core/benchmarks/UnboundedFollowingWindowBenchmark-results.txt`. ### Was this patch authored or co-authored using generative AI tooling? Yes. Authored with assistance from Claude (Anthropic). Closes #56291 from yadavay-amzn/SPARK-57220. Authored-by: Anupam Yadav <anupamya@amazon.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 89d5f18 commit 4296b3c

6 files changed

Lines changed: 906 additions & 63 deletions

File tree

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
================================================================================================
2+
Section A - SUM (non-invertible suffix)
3+
================================================================================================
4+
5+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
6+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
7+
SUM shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
8+
------------------------------------------------------------------------------------------------------------------------
9+
SUM naive (master O(N^2)) 2471 2495 14 0.0 241298.5 1.0X
10+
SUM segtree 110 115 4 0.1 10744.6 22.5X
11+
12+
13+
================================================================================================
14+
Section A - MIN
15+
================================================================================================
16+
17+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
18+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
19+
MIN shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
20+
------------------------------------------------------------------------------------------------------------------------
21+
MIN naive (master O(N^2)) 2417 2438 23 0.0 236035.8 1.0X
22+
MIN segtree 215 219 5 0.0 21015.3 11.2X
23+
24+
25+
================================================================================================
26+
Section A - MAX
27+
================================================================================================
28+
29+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
30+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
31+
MAX shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
32+
------------------------------------------------------------------------------------------------------------------------
33+
MAX naive (master O(N^2)) 2396 2401 5 0.0 233937.5 1.0X
34+
MAX segtree 228 229 1 0.0 22259.2 10.5X
35+
36+
37+
================================================================================================
38+
Section A - COUNT
39+
================================================================================================
40+
41+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
42+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
43+
COUNT shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
44+
------------------------------------------------------------------------------------------------------------------------
45+
COUNT naive (master O(N^2)) 2203 2222 16 0.0 215139.0 1.0X
46+
COUNT segtree 80 88 9 0.1 7846.1 27.4X
47+
48+
49+
================================================================================================
50+
Section A - AVG (multi-buffer)
51+
================================================================================================
52+
53+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
54+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
55+
AVG shrinking frame, N=10K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
56+
------------------------------------------------------------------------------------------------------------------------
57+
AVG naive (master O(N^2)) 2886 2900 18 0.0 281837.8 1.0X
58+
AVG segtree 84 86 4 0.1 8165.1 34.5X
59+
60+
61+
================================================================================================
62+
Section B - N=5K
63+
================================================================================================
64+
65+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
66+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
67+
SUM shrinking frame, N=5K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
68+
------------------------------------------------------------------------------------------------------------------------
69+
SUM naive (master O(N^2)) N=5K 620 628 7 0.0 121170.2 1.0X
70+
SUM segtree N=5K 73 74 1 0.1 14302.8 8.5X
71+
72+
73+
================================================================================================
74+
Section B - N=25K (stress)
75+
================================================================================================
76+
77+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
78+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
79+
SUM shrinking frame, N=25K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
80+
------------------------------------------------------------------------------------------------------------------------
81+
SUM naive (master O(N^2)) N=25K 14259 14341 108 0.0 556977.9 1.0X
82+
SUM segtree N=25K 119 120 0 0.2 4667.1 119.3X
83+
84+
85+
================================================================================================
86+
Section B - N=50K (stress, last naive run)
87+
================================================================================================
88+
89+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
90+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
91+
SUM shrinking frame, N=50K rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
92+
------------------------------------------------------------------------------------------------------------------------
93+
SUM naive (master O(N^2)) N=50K 57022 57659 987 0.0 1113704.1 1.0X
94+
SUM segtree N=50K 181 182 1 0.3 3544.3 314.2X
95+
96+
97+
================================================================================================
98+
Section B - N=100K (segtree-only, stress)
99+
================================================================================================
100+
101+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
102+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
103+
SUM shrinking frame, N=100K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
104+
-------------------------------------------------------------------------------------------------------------------------------
105+
SUM segtree N=100K 269 270 2 0.4 2627.9 1.0X
106+
107+
108+
================================================================================================
109+
Section B - N=200K (segtree-only, stress)
110+
================================================================================================
111+
112+
OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 5.10.255-254.1008.amzn2int.x86_64
113+
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
114+
SUM shrinking frame, N=200K rows (segtree-only): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
115+
-------------------------------------------------------------------------------------------------------------------------------
116+
SUM segtree N=200K 480 481 1 0.4 2343.7 1.0X
117+
118+

sql/core/src/main/scala/org/apache/spark/sql/execution/window/SegmentTreeWindowFunctionFrame.scala

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@ import org.apache.spark.sql.execution.metric.SQLMetric
2727
import org.apache.spark.sql.internal.SQLConf
2828

2929
/**
30-
* Moving-frame window function frame backed by [[WindowSegmentTree]]. Produces
31-
* the same outputs as [[SlidingWindowFunctionFrame]] for RowFrame or
32-
* single-column RangeFrame moving frames whose aggregates are all
33-
* [[DeclarativeAggregate]] with no FILTER/DISTINCT. For partitions below
34-
* `spark.sql.window.segmentTree.minPartitionRows`, delegates to a wrapped
35-
* [[SlidingWindowFunctionFrame]]. Under RANGE, two forward-only cursors
36-
* (`lowerIter` / `upperIter`) advance the bounds in O(n) total; the segtree
37-
* answers `[lowerBound, upperBound)` in O(log n).
30+
* Window function frame backed by [[WindowSegmentTree]]. Handles two frame
31+
* shapes:
32+
* - **Sliding** (`ubound = Some(...)`): both edges move; mirrors
33+
* [[SlidingWindowFunctionFrame]]. O(N log W) total.
34+
* - **Shrinking** (`ubound = None`): upper edge pinned to partition end
35+
* (`BETWEEN <lower> AND UNBOUNDED FOLLOWING`); replaces
36+
* [[UnboundedFollowingWindowFunctionFrame]]'s O(N^2) full recompute with
37+
* O(N log N).
38+
*
39+
* Eligibility, build, spill, and memory accounting are identical for both
40+
* shapes; only the per-row cursor logic differs (admit+drop for sliding,
41+
* drop-only for shrinking).
42+
*
43+
* For partitions below `spark.sql.window.segmentTree.minPartitionRows`,
44+
* delegates to a frame produced by `fallbackFactory`.
3845
*
3946
* @note Not thread-safe.
4047
*/
@@ -45,7 +52,8 @@ private[window] final class SegmentTreeWindowFunctionFrame(
4552
inputSchema: Seq[Attribute],
4653
frameType: FrameType,
4754
lbound: BoundOrdering,
48-
ubound: BoundOrdering,
55+
ubound: Option[BoundOrdering],
56+
fallbackFactory: () => WindowFunctionFrame,
4957
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
5058
conf: SQLConf,
5159
maxCachedBlocks: Option[Int],
@@ -57,16 +65,18 @@ private[window] final class SegmentTreeWindowFunctionFrame(
5765
require(frameType == RowFrame || frameType == RangeFrame,
5866
s"SegmentTreeWindowFunctionFrame supports RowFrame or RangeFrame, got $frameType")
5967

60-
private[this] var fallback: SlidingWindowFunctionFrame = _
68+
// True when this is a shrinking-frame (UnboundedFollowing) instance.
69+
// Shorthand to avoid repeated `ubound.isEmpty` reads in hot loops.
70+
private[this] val shrinking: Boolean = ubound.isEmpty
71+
72+
private[this] var fallback: WindowFunctionFrame = _
6173
private[this] var tree: WindowSegmentTree = _
6274

6375
/**
64-
* Allocate a fresh fallback sliding-window frame. Called lazily from
65-
* `prepare()` on the small-partition path. Factored out for testability
66-
* (subclasses can inject a throwing fallback for prepare-failure tests).
76+
* Allocate a fresh fallback frame via `fallbackFactory`. Called lazily
77+
* from `prepare()` on the small-partition path.
6778
*/
68-
private[window] def newFallback(): SlidingWindowFunctionFrame =
69-
new SlidingWindowFunctionFrame(target, processor, lbound, ubound)
79+
private[window] def newFallback(): WindowFunctionFrame = fallbackFactory()
7080

7181
/** Test hook: whether the fallback frame has been lazily allocated. */
7282
private[window] def fallbackAllocated: Boolean = fallback != null
@@ -100,8 +110,11 @@ private[window] final class SegmentTreeWindowFunctionFrame(
100110

101111
/**
102112
* Runtime dispatch flag: when `true`, `write()`, `currentLowerBound()`, and
103-
* `currentUpperBound()` delegate to the wrapped [[SlidingWindowFunctionFrame]]
104-
* (small-partition path). Set by `prepare()` based on partition size vs.
113+
* `currentUpperBound()` delegate to the wrapped fallback frame produced by
114+
* `fallbackFactory` (small-partition path). The fallback type is shape-
115+
* dependent: [[SlidingWindowFunctionFrame]] for moving frames and
116+
* [[UnboundedFollowingWindowFunctionFrame]] for shrinking frames. Set by
117+
* `prepare()` based on partition size vs.
105118
* `spark.sql.window.segmentTree.minPartitionRows`.
106119
*/
107120
private[window] var fallbackUsed: Boolean = false
@@ -155,19 +168,31 @@ private[window] final class SegmentTreeWindowFunctionFrame(
155168
// Count only on the successful segtree path: if `tree.build` throws,
156169
// the counter is not bumped.
157170
numSegmentTreeFrames.foreach(_ += 1)
158-
frameType match {
159-
case RowFrame =>
160-
boundIter = rows.generateIterator()
161-
nextRow = WindowFunctionFrame.getNextOrNull(boundIter)
162-
case RangeFrame =>
163-
lowerIter = rows.generateIterator()
164-
upperIter = rows.generateIterator()
165-
// Pre-seed cursor heads so `RangeBoundOrdering.compare` never
166-
// dereferences null on round 0. Either may be null if `rows` is
167-
// empty; the advance loops' `!= null` / `< upperBound` guards
168-
// handle that.
169-
lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter)
170-
upperRow = WindowFunctionFrame.getNextOrNull(upperIter)
171+
if (shrinking) {
172+
// Upper bound pinned to partition end; never moves.
173+
upperBound = tree.size
174+
frameType match {
175+
case RowFrame =>
176+
// RowFrame lower-bound advance is pure index arithmetic; no iterator.
177+
case RangeFrame =>
178+
lowerIter = rows.generateIterator()
179+
lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter)
180+
}
181+
} else {
182+
frameType match {
183+
case RowFrame =>
184+
boundIter = rows.generateIterator()
185+
nextRow = WindowFunctionFrame.getNextOrNull(boundIter)
186+
case RangeFrame =>
187+
lowerIter = rows.generateIterator()
188+
upperIter = rows.generateIterator()
189+
// Pre-seed cursor heads so `RangeBoundOrdering.compare` never
190+
// dereferences null on round 0. Either may be null if `rows` is
191+
// empty; the advance loops' `!= null` / `< upperBound` guards
192+
// handle that.
193+
lowerRow = WindowFunctionFrame.getNextOrNull(lowerIter)
194+
upperRow = WindowFunctionFrame.getNextOrNull(upperIter)
195+
}
171196
}
172197
}
173198

@@ -196,27 +221,42 @@ private[window] final class SegmentTreeWindowFunctionFrame(
196221
}
197222
}
198223

199-
// `writeRow`/`writeRange` mirror the `(lowerBound, upperBound)` monotone
200-
// cursor invariant of `SlidingWindowFunctionFrame.write`, but run
201-
// admit-then-drop (no buffer to maintain) instead of drop-then-admit.
202-
// Any future fix to Sliding's boundary semantics must be mirrored here;
203-
// equivalence is guarded by `SegmentTreeWindowFunctionSuite` flag-on/off
204-
// tests (`checkRangeEquivalence`, `feature flag off ...`, fallback tests)
205-
// which compare against the Sliding baseline.
224+
// `writeRow`/`writeRange` maintain the `(lowerBound, upperBound)` monotone
225+
// cursor invariant for both sliding and shrinking frame shapes:
226+
//
227+
// - Sliding (`ubound.isDefined`, mirrors `SlidingWindowFunctionFrame.write`):
228+
// run admit-then-drop (no buffer to maintain) instead of drop-then-admit.
229+
// The admit loop below (`if (!shrinking)`) extends `upperBound`; the drop
230+
// loop advances `lowerBound`. Any future fix to Sliding's boundary
231+
// semantics must be mirrored here; equivalence is guarded by
232+
// `SegmentTreeWindowFunctionSuite` flag-on/off tests
233+
// (`checkRangeEquivalence`, `feature flag off ...`, fallback tests)
234+
// against the Sliding baseline.
235+
//
236+
// - Shrinking (`ubound.isEmpty`, upper is `tree.size`): drop-only. The admit
237+
// loop is skipped; only `lowerBound` advances each step. Equivalence is
238+
// guarded by `UnboundedFollowingSegmentTreeSuite` against the
239+
// `UnboundedFollowingWindowFunctionFrame` baseline.
240+
//
241+
// In both shapes, the segtree's `query(lowerBound, upperBound, ...)` is
242+
// re-issued only when `boundsChanged` is true.
206243
private def writeRow(index: Int, current: InternalRow): Unit = {
207244
var boundsChanged = index == 0
208245

209-
// admit loop: extend upperBound; if a candidate is already below the
210-
// lower bound, advance lowerBound in lock-step to preserve invariant
211-
// (0 <= lowerBound <= upperBound <= tree.size).
212-
while (nextRow != null &&
213-
ubound.compare(nextRow, upperBound, current, index) <= 0) {
214-
if (lbound.compare(nextRow, lowerBound, current, index) < 0) {
215-
lowerBound += 1
246+
if (!shrinking) {
247+
val ub = ubound.get
248+
// admit loop: extend upperBound; if a candidate is already below the
249+
// lower bound, advance lowerBound in lock-step to preserve invariant
250+
// (0 <= lowerBound <= upperBound <= tree.size).
251+
while (nextRow != null &&
252+
ub.compare(nextRow, upperBound, current, index) <= 0) {
253+
if (lbound.compare(nextRow, lowerBound, current, index) < 0) {
254+
lowerBound += 1
255+
}
256+
nextRow = WindowFunctionFrame.getNextOrNull(boundIter)
257+
upperBound += 1
258+
boundsChanged = true
216259
}
217-
nextRow = WindowFunctionFrame.getNextOrNull(boundIter)
218-
upperBound += 1
219-
boundsChanged = true
220260
}
221261
// drop loop: advance lowerBound to the frame's left edge. RowFrame's
222262
// `lbound.compare` is pure index arithmetic so the input row is unread;
@@ -235,13 +275,16 @@ private[window] final class SegmentTreeWindowFunctionFrame(
235275
private def writeRange(index: Int, current: InternalRow): Unit = {
236276
var boundsChanged = index == 0
237277

238-
// admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index
239-
// arguments; we pass `upperBound` for API symmetry with RowBoundOrdering.
240-
while (upperRow != null &&
241-
ubound.compare(upperRow, upperBound, current, index) <= 0) {
242-
upperBound += 1
243-
upperRow = WindowFunctionFrame.getNextOrNull(upperIter)
244-
boundsChanged = true
278+
if (!shrinking) {
279+
val ub = ubound.get
280+
// admit loop (upper edge). `RangeBoundOrdering.compare` ignores its index
281+
// arguments; we pass `upperBound` for API symmetry with RowBoundOrdering.
282+
while (upperRow != null &&
283+
ub.compare(upperRow, upperBound, current, index) <= 0) {
284+
upperBound += 1
285+
upperRow = WindowFunctionFrame.getNextOrNull(upperIter)
286+
boundsChanged = true
287+
}
245288
}
246289

247290
// drop loop (lower edge): strict `< 0`, guarded by

0 commit comments

Comments
 (0)