@@ -77,9 +77,25 @@ impl SorfMatrix {
7777 /// round-major, block-major order, with each `u64` contributing 64 sign bits in
7878 /// least-significant-bit-first order.
7979 pub fn try_new ( seed : u64 , dimensions : usize , num_rounds : usize ) -> VortexResult < Self > {
80+ Self :: try_new_padded ( dimensions. next_power_of_two ( ) , num_rounds, seed)
81+ }
82+
83+ /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension.
84+ ///
85+ /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded
86+ /// logical dimension should call [`Self::try_new`] instead.
87+ pub ( crate ) fn try_new_padded (
88+ padded_dimensions : usize ,
89+ num_rounds : usize ,
90+ seed : u64 ,
91+ ) -> VortexResult < Self > {
8092 vortex_ensure ! ( num_rounds >= 1 , "num_rounds must be >= 1, got {num_rounds}" ) ;
93+ vortex_ensure ! (
94+ padded_dimensions. is_power_of_two( ) ,
95+ "padded_dimensions must be a power of two, got {padded_dimensions}"
96+ ) ;
8197
82- let padded_dim = dimensions . next_power_of_two ( ) ;
98+ let padded_dim = padded_dimensions ;
8399 let sign_masks = gen_sign_masks_from_seed ( seed, padded_dim, num_rounds) ;
84100
85101 // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers.
@@ -132,8 +148,7 @@ impl SorfMatrix {
132148 /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`.
133149 fn apply_srht ( & self , buf : & mut [ f32 ] ) {
134150 for round in 0 ..self . num_rounds {
135- let offset = round * self . padded_dim ;
136- apply_signs_xor ( buf, & self . sign_masks [ offset..offset + self . padded_dim ] ) ;
151+ self . apply_signs_xor ( buf, round) ;
137152 walsh_hadamard_transform ( buf) ;
138153 }
139154
@@ -148,14 +163,24 @@ impl SorfMatrix {
148163 fn apply_inverse_srht ( & self , buf : & mut [ f32 ] ) {
149164 for round in ( 0 ..self . num_rounds ) . rev ( ) {
150165 walsh_hadamard_transform ( buf) ;
151- let offset = round * self . padded_dim ;
152- apply_signs_xor ( buf, & self . sign_masks [ offset..offset + self . padded_dim ] ) ;
166+ self . apply_signs_xor ( buf, round) ;
153167 }
154168
155169 let norm = self . norm_factor ;
156170 buf. iter_mut ( ) . for_each ( |val| * val *= norm) ;
157171 }
158172
173+ /// Apply one round's sign masks via XOR on the IEEE 754 sign bit.
174+ ///
175+ /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
176+ /// multiplying each element by +/-1.0, but avoids FP dependency chains.
177+ fn apply_signs_xor ( & self , buf : & mut [ f32 ] , round : usize ) {
178+ let masks = & self . sign_masks [ round * self . padded_dim ..] [ ..self . padded_dim ] ;
179+ for ( val, & mask) in buf. iter_mut ( ) . zip ( masks. iter ( ) ) {
180+ * val = f32:: from_bits ( val. to_bits ( ) ^ mask) ;
181+ }
182+ }
183+
159184 /// Export the sign vectors as a flat `Vec<u8>` of 0/1 values in inverse application order
160185 /// `[D_k | ... | D₁]`.
161186 ///
@@ -263,16 +288,6 @@ fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 {
263288 }
264289}
265290
266- /// Apply sign masks via XOR on the IEEE 754 sign bit.
267- ///
268- /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
269- /// multiplying each element by +/-1.0, but avoids FP dependency chains.
270- fn apply_signs_xor ( buf : & mut [ f32 ] , masks : & [ u32 ] ) {
271- for ( val, & mask) in buf. iter_mut ( ) . zip ( masks. iter ( ) ) {
272- * val = f32:: from_bits ( val. to_bits ( ) ^ mask) ;
273- }
274- }
275-
276291/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative.
277292///
278293/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2`
@@ -327,14 +342,24 @@ mod tests {
327342 . collect ( )
328343 }
329344
345+ fn dim_to_usize ( dim : u32 ) -> usize {
346+ usize:: try_from ( dim) . unwrap ( )
347+ }
348+
349+ fn rounds_to_usize ( num_rounds : u8 ) -> usize {
350+ usize:: from ( num_rounds)
351+ }
352+
330353 #[ test]
331354 fn deterministic_from_seed ( ) -> VortexResult < ( ) > {
332- let r1 = SorfMatrix :: try_new ( 42 , 64 , 3 ) ?;
333- let r2 = SorfMatrix :: try_new ( 42 , 64 , 3 ) ?;
355+ let dim = dim_to_usize ( 64u32 ) ;
356+ let num_rounds = rounds_to_usize ( 3u8 ) ;
357+ let r1 = SorfMatrix :: try_new ( 42u64 , dim, num_rounds) ?;
358+ let r2 = SorfMatrix :: try_new ( 42u64 , dim, num_rounds) ?;
334359 let pd = r1. padded_dim ( ) ;
335360
336361 let mut input = vec ! [ 0.0f32 ; pd] ;
337- for i in 0 ..64 {
362+ for i in 0 ..dim {
338363 input[ i] = i as f32 ;
339364 }
340365 let mut out1 = vec ! [ 0.0f32 ; pd] ;
@@ -349,41 +374,58 @@ mod tests {
349374
350375 #[ test]
351376 fn export_inverse_signs_matches_golden_words ( ) -> VortexResult < ( ) > {
352- let rot = SorfMatrix :: try_new ( 42 , 64 , 2 ) ?;
377+ let dim = dim_to_usize ( 64u32 ) ;
378+ let num_rounds = rounds_to_usize ( 2u8 ) ;
379+ let seed = 42u64 ;
380+ let rot = SorfMatrix :: try_new ( seed, dim, num_rounds) ?;
381+ let padded_dim = rot. padded_dim ( ) ;
353382 let actual = rot. export_inverse_signs_u8 ( ) ;
354- let mut rng = SplitMix64 :: new ( 42 ) ;
383+ let mut rng = SplitMix64 :: new ( seed ) ;
355384 let round0_word = rng. next_u64 ( ) ;
356385 let round1_word = rng. next_u64 ( ) ;
357386
358- let mut expected = Vec :: with_capacity ( 128 ) ;
359- expected. extend ( unpack_sign_bits ( round1_word, 64 ) ) ;
360- expected. extend ( unpack_sign_bits ( round0_word, 64 ) ) ;
387+ let mut expected = Vec :: with_capacity ( num_rounds * padded_dim ) ;
388+ expected. extend ( unpack_sign_bits ( round1_word, padded_dim ) ) ;
389+ expected. extend ( unpack_sign_bits ( round0_word, padded_dim ) ) ;
361390
362391 assert_eq ! ( actual, expected) ;
363392 Ok ( ( ) )
364393 }
365394
366395 #[ test]
367396 fn one_word_generates_64_signs_lsb_first ( ) {
368- let masks = gen_sign_masks_from_seed ( 42 , 64 , 1 ) ;
369- assert_eq ! ( masks. len( ) , 64 ) ;
397+ let seed = 42u64 ;
398+ let padded_dim = dim_to_usize ( 64u32 ) ;
399+ let num_rounds = rounds_to_usize ( 1u8 ) ;
400+ let masks = gen_sign_masks_from_seed ( seed, padded_dim, num_rounds) ;
401+ assert_eq ! ( masks. len( ) , padded_dim) ;
370402
371- let mut rng = SplitMix64 :: new ( 42 ) ;
403+ let mut rng = SplitMix64 :: new ( seed ) ;
372404 let word = rng. next_u64 ( ) ;
373- let expected: Vec < _ > = ( 0 ..64 )
405+ let expected: Vec < _ > = ( 0 ..padded_dim )
374406 . map ( |bit_idx| sign_mask_from_word ( word, bit_idx) )
375407 . collect ( ) ;
376408 assert_eq ! ( masks, expected) ;
377409 }
378410
411+ #[ test]
412+ fn accepts_non_power_of_two_dimensions ( ) -> VortexResult < ( ) > {
413+ let rot = SorfMatrix :: try_new ( 42u64 , dim_to_usize ( 100u32 ) , rounds_to_usize ( 3u8 ) ) ?;
414+ assert_eq ! ( rot. padded_dim( ) , 128 ) ;
415+ Ok ( ( ) )
416+ }
417+
379418 #[ test]
380419 fn tail_block_uses_only_required_bits ( ) {
381- let masks = gen_sign_masks_from_seed ( 42 , 32 , 1 ) ;
382- assert_eq ! ( masks. len( ) , 32 ) ;
420+ let seed = 42u64 ;
421+ let padded_dim = dim_to_usize ( 32u32 ) ;
422+ let num_rounds = rounds_to_usize ( 1u8 ) ;
423+ let masks = gen_sign_masks_from_seed ( seed, padded_dim, num_rounds) ;
424+ assert_eq ! ( masks. len( ) , padded_dim) ;
383425
384- let mut rng = SplitMix64 :: new ( 42 ) ;
426+ let mut rng = SplitMix64 :: new ( seed ) ;
385427 let word = rng. next_u64 ( ) ;
386- let expected: Vec < _ > = ( 0 ..32 )
428+ let expected: Vec < _ > = ( 0 ..padded_dim )
387429 . map ( |bit_idx| sign_mask_from_word ( word, bit_idx) )
388430 . collect ( ) ;
389431 assert_eq ! ( masks, expected) ;
@@ -392,19 +434,21 @@ mod tests {
392434 /// Verify roundtrip is exact to f32 precision across many dimensions and round counts,
393435 /// including non-power-of-two dimensions that require padding.
394436 #[ rstest]
395- #[ case( 32 , 3 ) ]
396- #[ case( 64 , 3 ) ]
397- #[ case( 100 , 3 ) ]
398- #[ case( 128 , 1 ) ]
399- #[ case( 128 , 2 ) ]
400- #[ case( 128 , 3 ) ]
401- #[ case( 128 , 5 ) ]
402- #[ case( 256 , 3 ) ]
403- #[ case( 512 , 3 ) ]
404- #[ case( 768 , 3 ) ]
405- #[ case( 1024 , 3 ) ]
406- fn roundtrip_exact ( #[ case] dim : usize , #[ case] num_rounds : usize ) -> VortexResult < ( ) > {
407- let rot = SorfMatrix :: try_new ( 42 , dim, num_rounds) ?;
437+ #[ case( 32u32 , 3u8 ) ]
438+ #[ case( 64u32 , 3u8 ) ]
439+ #[ case( 100u32 , 3u8 ) ]
440+ #[ case( 128u32 , 1u8 ) ]
441+ #[ case( 128u32 , 2u8 ) ]
442+ #[ case( 128u32 , 3u8 ) ]
443+ #[ case( 128u32 , 5u8 ) ]
444+ #[ case( 256u32 , 3u8 ) ]
445+ #[ case( 512u32 , 3u8 ) ]
446+ #[ case( 768u32 , 3u8 ) ]
447+ #[ case( 1024u32 , 3u8 ) ]
448+ fn roundtrip_exact ( #[ case] dim : u32 , #[ case] num_rounds : u8 ) -> VortexResult < ( ) > {
449+ let dim = dim_to_usize ( dim) ;
450+ let num_rounds = rounds_to_usize ( num_rounds) ;
451+ let rot = SorfMatrix :: try_new ( 42u64 , dim, num_rounds) ?;
408452 let padded_dim = rot. padded_dim ( ) ;
409453
410454 let mut input = vec ! [ 0.0f32 ; padded_dim] ;
@@ -435,12 +479,14 @@ mod tests {
435479
436480 /// Verify norm preservation across dimensions and round counts.
437481 #[ rstest]
438- #[ case( 128 , 1 ) ]
439- #[ case( 128 , 3 ) ]
440- #[ case( 128 , 5 ) ]
441- #[ case( 768 , 3 ) ]
442- fn preserves_norm ( #[ case] dim : usize , #[ case] num_rounds : usize ) -> VortexResult < ( ) > {
443- let rot = SorfMatrix :: try_new ( 7 , dim, num_rounds) ?;
482+ #[ case( 128u32 , 1u8 ) ]
483+ #[ case( 128u32 , 3u8 ) ]
484+ #[ case( 128u32 , 5u8 ) ]
485+ #[ case( 768u32 , 3u8 ) ]
486+ fn preserves_norm ( #[ case] dim : u32 , #[ case] num_rounds : u8 ) -> VortexResult < ( ) > {
487+ let dim = dim_to_usize ( dim) ;
488+ let num_rounds = rounds_to_usize ( num_rounds) ;
489+ let rot = SorfMatrix :: try_new ( 42u64 , dim, num_rounds) ?;
444490 let padded_dim = rot. padded_dim ( ) ;
445491
446492 let mut input = vec ! [ 0.0f32 ; padded_dim] ;
@@ -465,16 +511,15 @@ mod tests {
465511
466512 /// Verify that export -> [`from_u8_slice`] produces identical transform output.
467513 #[ rstest]
468- #[ case( 64 , 3 ) ]
469- #[ case( 128 , 1 ) ]
470- #[ case( 128 , 3 ) ]
471- #[ case( 128 , 5 ) ]
472- #[ case( 768 , 3 ) ]
473- fn sign_export_import_roundtrip (
474- #[ case] dim : usize ,
475- #[ case] num_rounds : usize ,
476- ) -> VortexResult < ( ) > {
477- let rot = SorfMatrix :: try_new ( 42 , dim, num_rounds) ?;
514+ #[ case( 64u32 , 3u8 ) ]
515+ #[ case( 128u32 , 1u8 ) ]
516+ #[ case( 128u32 , 3u8 ) ]
517+ #[ case( 128u32 , 5u8 ) ]
518+ #[ case( 768u32 , 3u8 ) ]
519+ fn sign_export_import_roundtrip ( #[ case] dim : u32 , #[ case] num_rounds : u8 ) -> VortexResult < ( ) > {
520+ let dim = dim_to_usize ( dim) ;
521+ let num_rounds = rounds_to_usize ( num_rounds) ;
522+ let rot = SorfMatrix :: try_new ( 42u64 , dim, num_rounds) ?;
478523 let padded_dim = rot. padded_dim ( ) ;
479524
480525 let signs_u8 = rot. export_inverse_signs_u8 ( ) ;
0 commit comments