@@ -145,44 +145,44 @@ pub enum DispatchPlan {
145145
146146/// A fused plan: stages, source buffers and shared-memory.
147147///
148- /// The kernel runs in two phases:
148+ /// Stages are stored in kernel execution order. There are two phases:
149149///
150- /// 1. `smem_stages` run first and decode their **entire** output into
151- /// shared memory (e.g. all dict values, all run-end endpoints). This data
152- /// stays resident for the output stage to index into.
153- /// 2. `output_stage` then iterates over the input in tiles of
154- /// `SMEM_TILE_SIZE` (1024) elements, decoding each tile into a scratch
155- /// region of shared memory, applying scalar ops (which may reference the
156- /// smem_stages data), and writing the result to global memory.
150+ /// 1. All stages except the last run first and decode their output
151+ /// into shared memory (e.g. all dict values, all run-end endpoints).
152+ /// This data stays resident for the output stage to index into.
153+ ///
154+ /// 2. The last stage (the output stage) iterates over the input in tiles
155+ /// of `SMEM_TILE_SIZE` (1024) elements, decoding each tile into a
156+ /// scratch region of shared memory, applying scalar ops (which may
157+ /// reference data from the earlier stages), and writing the result to
158+ /// global memory.
157159///
158160/// # Shared memory allocation
159161///
160162/// Total shared memory = (`smem_cursor` + `SMEM_TILE_SIZE`) × `elem_bytes`.
161163///
162164/// This is sufficient because:
163165///
164- /// - `smem_stages` only originate from dict (values) and run-end
165- /// (ends, values). `push_smem_stage` reserves
166- /// the full auxiliary data length in `smem_cursor`, so each stage's source
167- /// op has room to decode the complete input.
166+ /// - Earlier stages only originate from dict (values) and run-end (ends,
167+ /// values). `push_smem_stage` reserves the full auxiliary data length in
168+ /// `smem_cursor`, so each stage's source op has room to decode the complete
169+ /// input.
168170///
169- /// - `output_stage` tiles at `SMEM_TILE_SIZE` (1024 elements), so its
170- /// source op never writes more than 1024 elements into the scratch region,
171- /// even though each block is responsible for `ELEMENTS_PER_BLOCK` (2048)
172- /// output elements — it processes them in two passes through the scratch.
171+ /// - The output stage (last) tiles at `SMEM_TILE_SIZE` (1024 elements),
172+ /// so its source op never writes more than 1024 elements into the
173+ /// scratch region, even though each block is responsible for
174+ /// `ELEMENTS_PER_BLOCK` (2048) output elements — it processes them in
175+ /// two passes through the scratch.
173176///
174177/// Note: `BITUNPACK` writes full FastLanes blocks (1024 elements), which can
175178/// exceed `stage.len` by up to 1023 elements. This overflow is absorbed by
176179/// the scratch region (`SMEM_TILE_SIZE` ≥ `FL_CHUNK_SIZE`).
177180pub struct FusedPlan {
178- /// Stages that decode fully into shared memory before the output stage
179- /// runs. Their data stays resident so the output stage can reference it.
180- smem_stages : Vec < ( Stage , SmemOffset , StageLen ) > ,
181- /// The root stage that produces the final output. Iterates in tiles of
182- /// `SMEM_TILE_SIZE`, writing each to global memory. `None` only during
183- /// construction; always `Some` after [`build`](Self::build).
184- output_stage : Option < ( Stage , SmemOffset , StageLen ) > ,
185- /// Shared memory elements reserved by `smem_stages` (fully decoded).
181+ /// Stages in kernel execution order. All stages except the last decode
182+ /// fully into persistent shared memory; the final stage produces the
183+ /// output.
184+ stages : Vec < ( Stage , SmemOffset , StageLen ) > ,
185+ /// Shared memory elements reserved by the preceding (non-output) stages.
186186 smem_cursor : SmemOffset ,
187187 /// Source buffers. `None` entries are placeholder slots for pending subtrees,
188188 /// filled by [`materialize_with_subtrees`] before device copy.
@@ -246,26 +246,25 @@ impl FusedPlan {
246246
247247 let mut pending_subtrees: Vec < ArrayRef > = Vec :: new ( ) ;
248248 let mut plan = Self {
249- smem_stages : Vec :: new ( ) ,
250- output_stage : None ,
249+ stages : Vec :: new ( ) ,
251250 smem_cursor : SmemOffset :: from ( 0u32 ) ,
252251 source_buffers : Vec :: new ( ) ,
253252 elem_bytes,
254253 } ;
255254
256255 let len = array. len ( ) as u32 ;
257256 let output = plan. walk ( array. clone ( ) , & mut pending_subtrees) ?;
258- plan. output_stage = Some ( ( output, plan. smem_cursor , len) ) ;
257+ plan. stages . push ( ( output, plan. smem_cursor , len) ) ;
259258
260259 Ok ( ( plan, pending_subtrees) )
261260 }
262261
263262 /// Shared memory bytes needed to launch this plan.
264263 ///
265- /// `smem_cursor` covers the fully-decoded smem_stages (dict values,
266- /// run-end endpoints ). `SMEM_TILE_SIZE` covers the output stage's
267- /// scratch region — the output stage processes `ELEMENTS_PER_BLOCK`
268- /// (2048) elements per block by tiling through this 1024-element window.
264+ /// `smem_cursor` covers the preceding fully-decoded stages (dict values,
265+ /// run-end ). `SMEM_TILE_SIZE` covers the output stage's scratch region —
266+ /// the output stage processes `ELEMENTS_PER_BLOCK` (2048) elements per
267+ /// block by tiling through this 1024-element window.
269268 fn shared_mem_bytes ( & self ) -> u32 {
270269 ( self . smem_cursor + SMEM_TILE_SIZE ) * self . elem_bytes
271270 }
@@ -311,9 +310,8 @@ impl FusedPlan {
311310 } ;
312311
313312 let stages: Vec < MaterializedStage > = self
314- . smem_stages
313+ . stages
315314 . iter ( )
316- . chain ( self . output_stage . iter ( ) )
317315 . map ( |( stage, smem_offset, len) | {
318316 MaterializedStage :: new (
319317 resolve_ptr ( stage) ,
@@ -592,7 +590,7 @@ impl FusedPlan {
592590 /// stage runs. Returns the shared memory offset where the data starts.
593591 fn push_smem_stage ( & mut self , spec : Stage , len : u32 ) -> u32 {
594592 let smem_offset = self . smem_cursor ;
595- self . smem_stages . push ( ( spec, smem_offset, len) ) ;
593+ self . stages . push ( ( spec, smem_offset, len) ) ;
596594 self . smem_cursor += len;
597595 smem_offset
598596 }
0 commit comments