Skip to content

Commit 40df862

Browse files
committed
refactor: tie patches to owning ops, consolidate tests into rstest
- Move patches from loose Stage fields to op-associated storage: source_patches for SourceOp, (ScalarOp, Option<Patches>) tuples - Consolidate 7 patch tests into 3 rstest parametrized groups with unsliced/sliced/large-offset cases - Fix walk_for scalar_ops push for new tuple type Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent c6700ae commit 40df862

2 files changed

Lines changed: 109 additions & 205 deletions

File tree

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 72 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ impl MaterializedPlan {
498498

499499
#[cfg(test)]
500500
mod tests {
501+
use std::ops::Range;
501502
use std::sync::Arc;
502503

503504
use cudarc::driver::DevicePtr;
@@ -2121,10 +2122,34 @@ mod tests {
21212122
Ok(())
21222123
}
21232124

2125+
/// Empty nullable array should preserve nullability.
2126+
#[crate::test]
2127+
async fn test_empty_nullable_array() -> VortexResult<()> {
2128+
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2129+
2130+
let array = PrimitiveArray::new(Buffer::<u32>::empty(), Validity::AllValid);
2131+
let result = try_gpu_dispatch(&array.into_array(), &mut cuda_ctx).await?;
2132+
let prim = result.into_primitive();
2133+
assert_eq!(prim.len(), 0);
2134+
assert_eq!(prim.validity()?.nullability(), Nullability::Nullable);
2135+
Ok(())
2136+
}
2137+
2138+
// ---------------------------------------------------------------
2139+
// Patch tests — fused dynamic dispatch with exception values
2140+
// ---------------------------------------------------------------
2141+
2142+
#[rstest]
2143+
#[case::unsliced(3000, None)]
2144+
#[case::mid_slice(5000, Some(500..3500))]
2145+
#[case::start_slice(5000, Some(0..1000))]
2146+
#[case::chunk_aligned(5000, Some(1024..3000))]
21242147
#[crate::test]
2125-
fn test_sliced_bitpacked_with_patches() -> VortexResult<()> {
2148+
fn test_bitpacked_with_patches(
2149+
#[case] len: usize,
2150+
#[case] slice_range: Option<Range<usize>>,
2151+
) -> VortexResult<()> {
21262152
let bit_width: u8 = 4;
2127-
let len = 5000usize;
21282153
let max_val = (1u32 << bit_width) - 1;
21292154
let values: Vec<u32> = (0..len)
21302155
.map(|i| {
@@ -2144,12 +2169,15 @@ mod tests {
21442169
)?;
21452170
assert!(bp.patches().is_some(), "expected patches");
21462171

2147-
// Slice crossing chunk boundaries.
2148-
let sliced = bp.into_array().slice(500..3500)?;
2149-
let expected: Vec<u32> = values[500..3500].to_vec();
2172+
let (array, expected) = if let Some(range) = slice_range {
2173+
let sliced = bp.into_array().slice(range.clone())?;
2174+
(sliced, values[range].to_vec())
2175+
} else {
2176+
(bp.into_array(), values)
2177+
};
21502178

21512179
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2152-
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
2180+
let plan = dispatch_plan(&array, &cuda_ctx)?;
21532181
let actual = run_dynamic_dispatch_plan(
21542182
&cuda_ctx,
21552183
expected.len(),
@@ -2160,10 +2188,15 @@ mod tests {
21602188
Ok(())
21612189
}
21622190

2191+
#[rstest]
2192+
#[case::unsliced(3000, None)]
2193+
#[case::mid_slice(5000, Some(500..3500))]
21632194
#[crate::test]
2164-
fn test_sliced_for_bitpacked_with_patches() -> VortexResult<()> {
2195+
fn test_for_bitpacked_with_patches(
2196+
#[case] len: usize,
2197+
#[case] slice_range: Option<Range<usize>>,
2198+
) -> VortexResult<()> {
21652199
let bit_width: u8 = 6;
2166-
let len = 5000usize;
21672200
let reference = 42u32;
21682201
let max_val = (1u32 << bit_width) - 1;
21692202
let residuals: Vec<u32> = (0..len)
@@ -2186,11 +2219,15 @@ mod tests {
21862219
assert!(bp.patches().is_some(), "expected patches");
21872220
let for_arr = FoR::try_new(bp.into_array(), Scalar::from(reference))?;
21882221

2189-
let sliced = for_arr.into_array().slice(500..3500)?;
2190-
let expected: Vec<u32> = all_values[500..3500].to_vec();
2222+
let (array, expected) = if let Some(range) = slice_range {
2223+
let sliced = for_arr.into_array().slice(range.clone())?;
2224+
(sliced, all_values[range].to_vec())
2225+
} else {
2226+
(for_arr.into_array(), all_values)
2227+
};
21912228

21922229
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2193-
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
2230+
let plan = dispatch_plan(&array, &cuda_ctx)?;
21942231
let actual = run_dynamic_dispatch_plan(
21952232
&cuda_ctx,
21962233
expected.len(),
@@ -2201,81 +2238,49 @@ mod tests {
22012238
Ok(())
22022239
}
22032240

2241+
#[rstest]
2242+
#[case::unsliced(2000, None)]
2243+
#[case::mid_slice(5000, Some(100..4000))]
2244+
#[case::large_offset(5000, Some(1500..4500))]
22042245
#[crate::test]
2205-
fn test_sliced_alp_with_patches() -> VortexResult<()> {
2206-
let len = 5000usize;
2246+
fn test_alp_with_patches(
2247+
#[case] len: usize,
2248+
#[case] slice_range: Option<Range<usize>>,
2249+
) -> VortexResult<()> {
22072250
let mut values: Vec<f32> = (0..len).map(|i| (i as f32) * 1.1).collect();
2251+
// Insert exception values that ALP can't encode.
22082252
values[0] = 99.9;
22092253
values[500] = std::f32::consts::PI;
22102254
values[1024] = std::f32::consts::E;
2211-
values[2048] = std::f32::consts::LN_2;
2212-
values[3333] = std::f32::consts::SQRT_2;
2213-
2214-
let float_prim = PrimitiveArray::new(Buffer::from(values.clone()), NonNullable);
2215-
let encoded = alp_encode(
2216-
float_prim.as_view(),
2217-
None,
2218-
&mut LEGACY_SESSION.create_execution_ctx(),
2219-
)?
2220-
.into_array();
2221-
2222-
let sliced = encoded.slice(100..4000)?;
2223-
2224-
// Decode on CPU as ground truth (accounts for ALP precision loss + patches).
2225-
let cpu_decoded = sliced
2226-
.clone()
2227-
.execute::<PrimitiveArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
2228-
let expected: Vec<f32> = cpu_decoded.as_slice::<f32>().to_vec();
2229-
2230-
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2231-
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
2232-
let actual = run_dispatch_plan_f32(
2233-
&cuda_ctx,
2234-
expected.len(),
2235-
&plan.dispatch_plan,
2236-
plan.shared_mem_bytes,
2237-
)?;
2238-
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
2239-
assert!(
2240-
a.to_bits() == e.to_bits(),
2241-
"mismatch at index {i} (original index {}): gpu={a} cpu={e} (bits: {:#010x} vs {:#010x})",
2242-
i + 100,
2243-
a.to_bits(),
2244-
e.to_bits(),
2245-
);
2255+
if len > 2048 {
2256+
values[2048] = std::f32::consts::LN_2;
2257+
}
2258+
if len > 3333 {
2259+
values[3333] = std::f32::consts::SQRT_2;
22462260
}
2247-
Ok(())
2248-
}
2249-
2250-
#[crate::test]
2251-
fn test_sliced_alp_with_patches_large_offset() -> VortexResult<()> {
2252-
let len = 5000usize;
2253-
let mut values: Vec<f32> = (0..len).map(|i| (i as f32) * 1.1).collect();
2254-
// Place patches across multiple chunks, including after position 1024.
2255-
values[0] = 99.9;
2256-
values[500] = std::f32::consts::PI;
2257-
values[1024] = std::f32::consts::E;
2258-
values[2048] = std::f32::consts::LN_2;
2259-
values[3333] = std::f32::consts::SQRT_2;
22602261

2261-
let float_prim = PrimitiveArray::new(Buffer::from(values.clone()), NonNullable);
2262+
let float_prim = PrimitiveArray::new(Buffer::from(values), NonNullable);
22622263
let encoded = alp_encode(
22632264
float_prim.as_view(),
22642265
None,
22652266
&mut LEGACY_SESSION.create_execution_ctx(),
22662267
)?
22672268
.into_array();
22682269

2269-
// Slice with offset >= 1024 to exercise the double-counted offset bug.
2270-
let sliced = encoded.slice(1500..4500)?;
2270+
let (array, base_offset) = if let Some(ref range) = slice_range {
2271+
(encoded.slice(range.clone())?, range.start)
2272+
} else {
2273+
(encoded, 0)
2274+
};
22712275

2272-
let cpu_decoded = sliced
2276+
// Decode on CPU as ground truth (accounts for ALP precision loss + patches).
2277+
let cpu_decoded = array
22732278
.clone()
22742279
.execute::<PrimitiveArray>(&mut LEGACY_SESSION.create_execution_ctx())?;
22752280
let expected: Vec<f32> = cpu_decoded.as_slice::<f32>().to_vec();
22762281

22772282
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2278-
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
2283+
let plan = dispatch_plan(&array, &cuda_ctx)?;
22792284
let actual = run_dispatch_plan_f32(
22802285
&cuda_ctx,
22812286
expected.len(),
@@ -2286,121 +2291,14 @@ mod tests {
22862291
assert!(
22872292
a.to_bits() == e.to_bits(),
22882293
"mismatch at index {i} (original index {}): gpu={a} cpu={e} (bits: {:#010x} vs {:#010x})",
2289-
i + 1500,
2294+
i + base_offset,
22902295
a.to_bits(),
22912296
e.to_bits(),
22922297
);
22932298
}
22942299
Ok(())
22952300
}
22962301

2297-
/// Empty nullable array should preserve nullability.
2298-
#[crate::test]
2299-
async fn test_empty_nullable_array() -> VortexResult<()> {
2300-
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2301-
2302-
let array = PrimitiveArray::new(Buffer::<u32>::empty(), Validity::AllValid);
2303-
let result = try_gpu_dispatch(&array.into_array(), &mut cuda_ctx).await?;
2304-
let prim = result.into_primitive();
2305-
assert_eq!(prim.len(), 0);
2306-
assert_eq!(prim.validity()?.nullability(), Nullability::Nullable);
2307-
Ok(())
2308-
}
2309-
2310-
// ---------------------------------------------------------------
2311-
// Patch tests — fused dynamic dispatch with exception values
2312-
// ---------------------------------------------------------------
2313-
2314-
#[crate::test]
2315-
fn test_bitpacked_with_patches() -> VortexResult<()> {
2316-
let bit_width: u8 = 4;
2317-
let len = 3000usize;
2318-
let max_val = (1u32 << bit_width) - 1;
2319-
let values: Vec<u32> = (0..len)
2320-
.map(|i| {
2321-
if i % 100 == 0 {
2322-
1000
2323-
} else {
2324-
(i as u32) % (max_val + 1)
2325-
}
2326-
})
2327-
.collect();
2328-
2329-
let prim = PrimitiveArray::new(Buffer::from(values.clone()), NonNullable);
2330-
let bp = BitPacked::encode(
2331-
&prim.into_array(),
2332-
bit_width,
2333-
&mut LEGACY_SESSION.create_execution_ctx(),
2334-
)?;
2335-
assert!(bp.patches().is_some(), "expected patches");
2336-
2337-
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2338-
let plan = dispatch_plan(&bp.into_array(), &cuda_ctx)?;
2339-
let actual =
2340-
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
2341-
assert_eq!(actual, values);
2342-
Ok(())
2343-
}
2344-
2345-
#[crate::test]
2346-
fn test_for_bitpacked_with_patches() -> VortexResult<()> {
2347-
let bit_width: u8 = 6;
2348-
let len = 3000usize;
2349-
let reference = 42u32;
2350-
let max_val = (1u32 << bit_width) - 1;
2351-
let residuals: Vec<u32> = (0..len)
2352-
.map(|i| {
2353-
if i % 200 == 0 {
2354-
500
2355-
} else {
2356-
(i as u32) % (max_val + 1)
2357-
}
2358-
})
2359-
.collect();
2360-
let expected: Vec<u32> = residuals.iter().map(|&v| v + reference).collect();
2361-
2362-
let prim = PrimitiveArray::new(Buffer::from(residuals), NonNullable);
2363-
let bp = BitPacked::encode(
2364-
&prim.into_array(),
2365-
bit_width,
2366-
&mut LEGACY_SESSION.create_execution_ctx(),
2367-
)?;
2368-
assert!(bp.patches().is_some(), "expected patches");
2369-
let for_arr = FoR::try_new(bp.into_array(), Scalar::from(reference))?;
2370-
2371-
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2372-
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
2373-
let actual =
2374-
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
2375-
assert_eq!(actual, expected);
2376-
Ok(())
2377-
}
2378-
2379-
#[crate::test]
2380-
fn test_alp_with_patches() -> VortexResult<()> {
2381-
let len = 2000usize;
2382-
let mut values: Vec<f32> = (0..len).map(|i| (i as f32) * 1.1).collect();
2383-
// Insert exception values that ALP can't encode.
2384-
values[0] = 99.9;
2385-
values[500] = std::f32::consts::PI;
2386-
values[1024] = std::f32::consts::E;
2387-
2388-
let float_prim = PrimitiveArray::new(Buffer::from(values.clone()), NonNullable);
2389-
let encoded = alp_encode(
2390-
float_prim.as_view(),
2391-
None,
2392-
&mut LEGACY_SESSION.create_execution_ctx(),
2393-
)?
2394-
.into_array();
2395-
2396-
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
2397-
let plan = dispatch_plan(&encoded, &cuda_ctx)?;
2398-
let actual =
2399-
run_dispatch_plan_f32(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
2400-
assert_eq!(actual, values);
2401-
Ok(())
2402-
}
2403-
24042302
// ---------------------------------------------------------------
24052303
// Additional patch tests — typed widths, edge cases, composites
24062304
// ---------------------------------------------------------------

0 commit comments

Comments
 (0)