|
12 | 12 | //! For dimensions that are not powers of 2, the input is zero-padded to the |
13 | 13 | //! next power of 2 before the transform and truncated afterward. |
14 | 14 |
|
| 15 | +use rand::RngExt; |
15 | 16 | use rand::SeedableRng; |
16 | 17 | use rand::rngs::StdRng; |
17 | 18 | use vortex_array::arrays::BoolArray; |
@@ -212,42 +213,8 @@ impl RotationMatrix { |
212 | 213 | /// contains `3 * padded_dim` bits in inverse-application order `[D₃ | D₂ | D₁]`. |
213 | 214 | /// Convention: bit set (1) = +1, bit unset (0) = -1 (negate). |
214 | 215 | /// |
215 | | -/// Applies: H → D₃ → H → D₂ → H → D₁ → scale |
216 | | -#[inline] |
217 | | -pub fn apply_inverse_srht_from_bits( |
218 | | - buf: &mut [f32], |
219 | | - signs_bytes: &[u8], |
220 | | - padded_dim: usize, |
221 | | - norm_factor: f32, |
222 | | -) { |
223 | | - debug_assert!(padded_dim.is_power_of_two()); |
224 | | - debug_assert_eq!(buf.len(), padded_dim); |
225 | | - |
226 | | - for round in 0..3 { |
227 | | - walsh_hadamard_transform(buf); |
228 | | - apply_signs_from_bits(buf, signs_bytes, round * padded_dim); |
229 | | - } |
230 | | - |
231 | | - for val in buf.iter_mut() { |
232 | | - *val *= norm_factor; |
233 | | - } |
234 | | -} |
235 | | - |
236 | | -/// Element-wise negate coordinates where the sign bit is unset (0 = -1). |
237 | | -#[inline] |
238 | | -fn apply_signs_from_bits(buf: &mut [f32], signs_bytes: &[u8], bit_offset: usize) { |
239 | | - for (j, val) in buf.iter_mut().enumerate() { |
240 | | - let idx = bit_offset + j; |
241 | | - let is_positive = (signs_bytes[idx / 8] >> (idx % 8)) & 1 == 1; |
242 | | - if !is_positive { |
243 | | - *val = -*val; |
244 | | - } |
245 | | - } |
246 | | -} |
247 | | - |
248 | 216 | /// Generate a vector of random ±1 signs. |
249 | 217 | fn gen_random_signs(rng: &mut StdRng, len: usize) -> Vec<f32> { |
250 | | - use rand::RngExt; |
251 | 218 | (0..len) |
252 | 219 | .map(|_| { |
253 | 220 | if rng.random_bool(0.5) { |
@@ -416,48 +383,6 @@ mod tests { |
416 | 383 | Ok(()) |
417 | 384 | } |
418 | 385 |
|
419 | | - /// Verify that the hot-path `apply_inverse_srht_from_bits` matches `inverse_rotate`. |
420 | | - #[rstest] |
421 | | - #[case(64)] |
422 | | - #[case(128)] |
423 | | - #[case(768)] |
424 | | - fn hot_path_matches_inverse_rotate(#[case] dim: usize) -> VortexResult<()> { |
425 | | - let rot = RotationMatrix::try_new(99, dim)?; |
426 | | - let padded_dim = rot.padded_dim(); |
427 | | - let norm_factor = rot.norm_factor(); |
428 | | - |
429 | | - let signs_array = rot.export_inverse_signs_bool_array(); |
430 | | - let bit_buf = signs_array.to_bit_buffer(); |
431 | | - let (_, _, raw_buf) = bit_buf.into_inner(); |
432 | | - |
433 | | - // Create some rotated input. |
434 | | - let mut input = vec![0.0f32; padded_dim]; |
435 | | - for i in 0..dim { |
436 | | - input[i] = (i as f32 + 1.0) * 0.01; |
437 | | - } |
438 | | - let mut rotated = vec![0.0f32; padded_dim]; |
439 | | - rot.rotate(&input, &mut rotated); |
440 | | - |
441 | | - // Inverse via the struct method. |
442 | | - let mut recovered1 = vec![0.0f32; padded_dim]; |
443 | | - rot.inverse_rotate(&rotated, &mut recovered1); |
444 | | - |
445 | | - // Inverse via the hot-path function. |
446 | | - let mut recovered2 = rotated.clone(); |
447 | | - apply_inverse_srht_from_bits(&mut recovered2, raw_buf.as_ref(), padded_dim, norm_factor); |
448 | | - |
449 | | - for i in 0..padded_dim { |
450 | | - assert!( |
451 | | - (recovered1[i] - recovered2[i]).abs() < 1e-10, |
452 | | - "Hot-path mismatch at {i}: {} vs {}", |
453 | | - recovered1[i], |
454 | | - recovered2[i] |
455 | | - ); |
456 | | - } |
457 | | - |
458 | | - Ok(()) |
459 | | - } |
460 | | - |
461 | 386 | #[test] |
462 | 387 | fn wht_basic() { |
463 | 388 | // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] |
|
0 commit comments