Skip to content

Commit 30e87e7

Browse files
committed
chore: minor cleanup
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 786d6b7 commit 30e87e7

2 files changed

Lines changed: 33 additions & 34 deletions

File tree

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -145,44 +145,44 @@ pub enum DispatchPlan {
145145

146146
/// A fused plan: stages, source buffers and shared-memory.
147147
///
148-
/// The kernel runs in two phases:
148+
/// Stages are stored in kernel execution order. There are two phases:
149149
///
150-
/// 1. `smem_stages` run first and decode their **entire** output into
151-
/// shared memory (e.g. all dict values, all run-end endpoints). This data
152-
/// stays resident for the output stage to index into.
153-
/// 2. `output_stage` then iterates over the input in tiles of
154-
/// `SMEM_TILE_SIZE` (1024) elements, decoding each tile into a scratch
155-
/// region of shared memory, applying scalar ops (which may reference the
156-
/// smem_stages data), and writing the result to global memory.
150+
/// 1. All stages except the last run first and decode their output
151+
/// into shared memory (e.g. all dict values, all run-end endpoints).
152+
/// This data stays resident for the output stage to index into.
153+
///
154+
/// 2. The last stage (the output stage) iterates over the input in tiles
155+
/// of `SMEM_TILE_SIZE` (1024) elements, decoding each tile into a
156+
/// scratch region of shared memory, applying scalar ops (which may
157+
/// reference data from the earlier stages), and writing the result to
158+
/// global memory.
157159
///
158160
/// # Shared memory allocation
159161
///
160162
/// Total shared memory = (`smem_cursor` + `SMEM_TILE_SIZE`) × `elem_bytes`.
161163
///
162164
/// This is sufficient because:
163165
///
164-
/// - `smem_stages` only originate from dict (values) and run-end
165-
/// (ends, values). `push_smem_stage` reserves
166-
/// the full auxiliary data length in `smem_cursor`, so each stage's source
167-
/// op has room to decode the complete input.
166+
/// - Earlier stages only originate from dict (values) and run-end (ends,
167+
/// values). `push_smem_stage` reserves the full auxiliary data length in
168+
/// `smem_cursor`, so each stage's source op has room to decode the complete
169+
/// input.
168170
///
169-
/// - `output_stage` tiles at `SMEM_TILE_SIZE` (1024 elements), so its
170-
/// source op never writes more than 1024 elements into the scratch region,
171-
/// even though each block is responsible for `ELEMENTS_PER_BLOCK` (2048)
172-
/// output elements — it processes them in two passes through the scratch.
171+
/// - The output stage (last) tiles at `SMEM_TILE_SIZE` (1024 elements),
172+
/// so its source op never writes more than 1024 elements into the
173+
/// scratch region, even though each block is responsible for
174+
/// `ELEMENTS_PER_BLOCK` (2048) output elements — it processes them in
175+
/// two passes through the scratch.
173176
///
174177
/// Note: `BITUNPACK` writes full FastLanes blocks (1024 elements), which can
175178
/// exceed `stage.len` by up to 1023 elements. This overflow is absorbed by
176179
/// the scratch region (`SMEM_TILE_SIZE` ≥ `FL_CHUNK_SIZE`).
177180
pub struct FusedPlan {
178-
/// Stages that decode fully into shared memory before the output stage
179-
/// runs. Their data stays resident so the output stage can reference it.
180-
smem_stages: Vec<(Stage, SmemOffset, StageLen)>,
181-
/// The root stage that produces the final output. Iterates in tiles of
182-
/// `SMEM_TILE_SIZE`, writing each to global memory. `None` only during
183-
/// construction; always `Some` after [`build`](Self::build).
184-
output_stage: Option<(Stage, SmemOffset, StageLen)>,
185-
/// Shared memory elements reserved by `smem_stages` (fully decoded).
181+
/// Stages in kernel execution order. All stages except the last decode
182+
/// fully into persistent shared memory; the final stage produces the
183+
/// output.
184+
stages: Vec<(Stage, SmemOffset, StageLen)>,
185+
/// Shared memory elements reserved by the preceding (non-output) stages.
186186
smem_cursor: SmemOffset,
187187
/// Source buffers. `None` entries are placeholder slots for pending subtrees,
188188
/// filled by [`materialize_with_subtrees`] before device copy.
@@ -246,26 +246,25 @@ impl FusedPlan {
246246

247247
let mut pending_subtrees: Vec<ArrayRef> = Vec::new();
248248
let mut plan = Self {
249-
smem_stages: Vec::new(),
250-
output_stage: None,
249+
stages: Vec::new(),
251250
smem_cursor: SmemOffset::from(0u32),
252251
source_buffers: Vec::new(),
253252
elem_bytes,
254253
};
255254

256255
let len = array.len() as u32;
257256
let output = plan.walk(array.clone(), &mut pending_subtrees)?;
258-
plan.output_stage = Some((output, plan.smem_cursor, len));
257+
plan.stages.push((output, plan.smem_cursor, len));
259258

260259
Ok((plan, pending_subtrees))
261260
}
262261

263262
/// Shared memory bytes needed to launch this plan.
264263
///
265-
/// `smem_cursor` covers the fully-decoded smem_stages (dict values,
266-
/// run-end endpoints). `SMEM_TILE_SIZE` covers the output stage's
267-
/// scratch region — the output stage processes `ELEMENTS_PER_BLOCK`
268-
/// (2048) elements per block by tiling through this 1024-element window.
264+
/// `smem_cursor` covers the preceding fully-decoded stages (dict values,
265+
/// run-end ). `SMEM_TILE_SIZE` covers the output stage's scratch region —
266+
/// the output stage processes `ELEMENTS_PER_BLOCK` (2048) elements per
267+
/// block by tiling through this 1024-element window.
269268
fn shared_mem_bytes(&self) -> u32 {
270269
(self.smem_cursor + SMEM_TILE_SIZE) * self.elem_bytes
271270
}
@@ -311,9 +310,8 @@ impl FusedPlan {
311310
};
312311

313312
let stages: Vec<MaterializedStage> = self
314-
.smem_stages
313+
.stages
315314
.iter()
316-
.chain(self.output_stage.iter())
317315
.map(|(stage, smem_offset, len)| {
318316
MaterializedStage::new(
319317
resolve_ptr(stage),
@@ -592,7 +590,7 @@ impl FusedPlan {
592590
/// stage runs. Returns the shared memory offset where the data starts.
593591
fn push_smem_stage(&mut self, spec: Stage, len: u32) -> u32 {
594592
let smem_offset = self.smem_cursor;
595-
self.smem_stages.push((spec, smem_offset, len));
593+
self.stages.push((spec, smem_offset, len));
596594
self.smem_cursor += len;
597595
smem_offset
598596
}

vortex-cuda/src/hybrid_dispatch/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use crate::executor::CudaExecutionCtx;
6363
///
6464
/// Returns `Ok(Canonical)` on success. Returns `Err` when the array
6565
/// cannot be handled (non-primitive output dtype, no registered kernel).
66+
#[allow(clippy::cognitive_complexity)]
6667
pub async fn try_gpu_dispatch(
6768
array: &ArrayRef,
6869
ctx: &mut CudaExecutionCtx,

0 commit comments

Comments
 (0)