Skip to content

Commit a831042

Browse files
committed
Revert "permutation"
This reverts commit 00ee4fe.
1 parent 00ee4fe commit a831042

8 files changed

Lines changed: 16 additions & 224 deletions

File tree

vortex-tensor/src/encodings/turboquant/array.rs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ pub struct TurboQuantMetadata {
3535
/// Whether QJL correction children are present.
3636
#[prost(bool, tag = "3")]
3737
pub has_qjl: bool,
38-
/// Whether a pre-SRHT permutation is stored (for non-power-of-2 dims).
39-
#[prost(bool, tag = "4")]
40-
pub has_permutation: bool,
4138
}
4239

4340
/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased
@@ -80,11 +77,10 @@ pub(crate) enum Slot {
8077
QjlSigns = 4,
8178
QjlResidualNorms = 5,
8279
QjlRotationSigns = 6,
83-
Permutation = 7,
8480
}
8581

8682
impl Slot {
87-
pub(crate) const COUNT: usize = 8;
83+
pub(crate) const COUNT: usize = 7;
8884

8985
pub(crate) fn name(self) -> &'static str {
9086
match self {
@@ -95,7 +91,6 @@ impl Slot {
9591
Self::QjlSigns => "qjl_signs",
9692
Self::QjlResidualNorms => "qjl_residual_norms",
9793
Self::QjlRotationSigns => "qjl_rotation_signs",
98-
Self::Permutation => "permutation",
9994
}
10095
}
10196

@@ -108,7 +103,6 @@ impl Slot {
108103
4 => Self::QjlSigns,
109104
5 => Self::QjlResidualNorms,
110105
6 => Self::QjlRotationSigns,
111-
7 => Self::Permutation,
112106
_ => vortex_error::vortex_panic!("invalid slot index {idx}"),
113107
}
114108
}
@@ -126,9 +120,6 @@ impl Slot {
126120
/// - 4: `qjl_signs` — `FixedSizeListArray<u8>` (num_rows * padded_dim, 1-bit)
127121
/// - 5: `qjl_residual_norms` — `PrimitiveArray<f32>` (one per row)
128122
/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation)
129-
///
130-
/// Optional permutation slot (None for power-of-2 dims):
131-
/// - 7: `permutation` — `BitPackedArray<u16>` (padded_dim, ceil(log2(padded_dim))-bit)
132123
#[derive(Clone, Debug)]
133124
pub struct TurboQuantArray {
134125
pub(crate) dtype: DType,
@@ -256,11 +247,6 @@ impl TurboQuantArray {
256247
})
257248
}
258249

259-
/// The optional pre-SRHT permutation (for non-power-of-2 dims).
260-
pub fn permutation(&self) -> Option<&ArrayRef> {
261-
self.slots[Slot::Permutation as usize].as_ref()
262-
}
263-
264250
/// Set the QJL correction fields on this array.
265251
pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) {
266252
self.slots[Slot::QjlSigns as usize] = Some(qjl.signs);

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ struct MseQuantizationResult {
8080
all_indices: BufferMut<u8>,
8181
norms: BufferMut<f32>,
8282
padded_dim: usize,
83-
/// Random permutation for non-power-of-2 dims (shared by MSE and QJL).
84-
perm: Option<Vec<u16>>,
8583
}
8684

8785
/// Core quantization: extract f32 elements, build rotation, normalize/rotate/quantize all rows.
@@ -93,17 +91,9 @@ fn turboquant_quantize_core(
9391
let dimension = fsl.list_size() as usize;
9492
let num_rows = fsl.len();
9593

96-
let mut rotation = RotationMatrix::try_new(seed, dimension)?;
94+
let rotation = RotationMatrix::try_new(seed, dimension)?;
9795
let padded_dim = rotation.padded_dim();
9896

99-
// For non-power-of-2 dims, generate a random permutation to scatter
100-
// zero-padded entries uniformly before the SRHT.
101-
let perm = (dimension < padded_dim)
102-
.then(|| RotationMatrix::gen_permutation(seed.wrapping_add(42), padded_dim));
103-
if let Some(ref p) = perm {
104-
rotation = rotation.with_permutation(p.clone());
105-
}
106-
10797
let f32_elements = extract_f32_elements(fsl)?;
10898

10999
let centroids = get_centroids(padded_dim as u32, bit_width)?;
@@ -142,7 +132,6 @@ fn turboquant_quantize_core(
142132
all_indices,
143133
norms,
144134
padded_dim,
145-
perm,
146135
})
147136
}
148137

@@ -177,23 +166,15 @@ fn build_turboquant_mse(
177166

178167
let rotation_signs = bitpack_rotation_signs(&core.rotation)?;
179168

180-
let mut array = TurboQuantArray::try_new_mse(
169+
TurboQuantArray::try_new_mse(
181170
fsl.dtype().clone(),
182171
codes,
183172
norms_array,
184173
centroids_array,
185174
rotation_signs,
186175
dimension,
187176
bit_width,
188-
)?;
189-
190-
// Store permutation for non-power-of-2 dims.
191-
if let Some(ref perm) = core.perm {
192-
array.slots[crate::encodings::turboquant::array::Slot::Permutation as usize] =
193-
Some(bitpack_permutation(perm)?);
194-
}
195-
196-
Ok(array)
177+
)
197178
}
198179

199180
/// Encode a FixedSizeListArray into a MSE-only `TurboQuantArray`.
@@ -266,16 +247,7 @@ pub fn turboquant_encode_qjl(
266247

267248
// QJL uses a different rotation than the MSE stage to ensure statistical
268249
// independence between the quantization noise and the sign projection.
269-
// The same permutation is shared: it's a property of the padded embedding
270-
// space, not of the rotation itself.
271-
let qjl_rotation = {
272-
let rot = RotationMatrix::try_new(seed.wrapping_add(25), dim)?;
273-
if let Some(ref p) = core.perm {
274-
rot.with_permutation(p.clone())
275-
} else {
276-
rot
277-
}
278-
};
250+
let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?;
279251

280252
let num_rows = fsl.len();
281253
let mut residual_norms_buf = BufferMut::<f32>::with_capacity(num_rows);
@@ -378,12 +350,3 @@ fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult<ArrayRef> {
378350
let prim = PrimitiveArray::new::<u8>(buf.freeze(), Validity::NonNullable);
379351
Ok(bitpack_encode(&prim, 1, None)?.into_array())
380352
}
381-
382-
/// Bitpack a permutation of u16 indices for efficient storage.
383-
fn bitpack_permutation(perm: &[u16]) -> VortexResult<ArrayRef> {
384-
let mut buf = BufferMut::<u16>::with_capacity(perm.len());
385-
buf.extend_from_slice(perm);
386-
let prim = PrimitiveArray::new::<u16>(buf.freeze(), Validity::NonNullable);
387-
let bit_width = (perm.len() as f64).log2().ceil() as u8;
388-
Ok(bitpack_encode(&prim, bit_width, None)?.into_array())
389-
}

vortex-tensor/src/encodings/turboquant/compute/slice.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use vortex_array::arrays::slice::SliceReduce;
99
use vortex_error::VortexResult;
1010

1111
use crate::encodings::turboquant::array::QjlCorrection;
12-
use crate::encodings::turboquant::array::Slot;
1312
use crate::encodings::turboquant::array::TurboQuant;
1413
use crate::encodings::turboquant::array::TurboQuantArray;
1514

@@ -41,8 +40,6 @@ impl SliceReduce for TurboQuant {
4140
if let Some(qjl) = sliced_qjl {
4241
result.set_qjl(qjl);
4342
}
44-
// Permutation is shared (not per-row), clone unchanged.
45-
result.slots[Slot::Permutation as usize] = array.permutation().cloned();
4643

4744
Ok(Some(result.into_array()))
4845
}

vortex-tensor/src/encodings/turboquant/compute/take.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use vortex_array::arrays::dict::TakeExecute;
99
use vortex_error::VortexResult;
1010

1111
use crate::encodings::turboquant::array::QjlCorrection;
12-
use crate::encodings::turboquant::array::Slot;
1312
use crate::encodings::turboquant::array::TurboQuant;
1413
use crate::encodings::turboquant::array::TurboQuantArray;
1514

@@ -46,8 +45,6 @@ impl TakeExecute for TurboQuant {
4645
if let Some(qjl) = taken_qjl {
4746
result.set_qjl(qjl);
4847
}
49-
// Permutation is shared (not per-row), clone unchanged.
50-
result.slots[Slot::Permutation as usize] = array.permutation().cloned();
5148

5249
Ok(Some(result.into_array()))
5350
}

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,14 @@ pub fn execute_decompress(
5252
let centroids_prim = array.centroids().clone().execute::<PrimitiveArray>(ctx)?;
5353
let centroids = centroids_prim.as_slice::<f32>();
5454

55-
// Unpack optional permutation (for non-power-of-2 dims).
56-
let perm: Option<Vec<u16>> = array
57-
.permutation()
58-
.map(|arr| {
59-
let prim = arr.clone().execute::<PrimitiveArray>(ctx)?;
60-
Ok::<_, vortex_error::VortexError>(prim.as_slice::<u16>().to_vec())
61-
})
62-
.transpose()?;
63-
6455
// FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values,
6556
// then we expand to u32 XOR masks once (amortized over all rows). This enables
6657
// branchless XOR-based sign application in the per-row SRHT hot loop.
6758
let signs_prim = array
6859
.rotation_signs()
6960
.clone()
7061
.execute::<PrimitiveArray>(ctx)?;
71-
let rotation = {
72-
let rot = RotationMatrix::from_u8_slice(signs_prim.as_slice::<u8>(), dim)?;
73-
if let Some(ref p) = perm {
74-
rot.with_permutation(p.clone())
75-
} else {
76-
rot
77-
}
78-
};
62+
let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::<u8>(), dim)?;
7963

8064
// Unpack codes from FixedSizeListArray → flat u8 elements.
8165
let codes_fsl = array.codes().clone().execute::<FixedSizeListArray>(ctx)?;
@@ -129,14 +113,7 @@ pub fn execute_decompress(
129113
let residual_norms = residual_norms_prim.as_slice::<f32>();
130114

131115
let qjl_rot_signs_prim = qjl.rotation_signs.clone().execute::<PrimitiveArray>(ctx)?;
132-
let qjl_rot = {
133-
let rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::<u8>(), dim)?;
134-
if let Some(ref p) = perm {
135-
rot.with_permutation(p.clone())
136-
} else {
137-
rot
138-
}
139-
};
116+
let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::<u8>(), dim)?;
140117

141118
let qjl_scale = qjl_correction_scale(padded_dim);
142119
let mse_elements = mse_output.as_ref();

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,9 @@ mod tests {
415415

416416
// For power-of-2 dims, QJL bias should be small (< 0.15).
417417
// For non-power-of-2 dims (e.g., 768 padded to 1024), the bias is
418-
// larger due to distributional mismatch: the zero-padded vector has
419-
// fewer effective nonzero terms per SRHT coordinate, changing the
420-
// kurtosis. The pre-SRHT permutation helps with butterfly alignment
421-
// but does not fully resolve this; dimension-aware centroids would.
418+
// inherently larger because the SRHT centroids are optimized for the
419+
// padded dimension's coordinate distribution, which differs from the
420+
// actual distribution of a zero-padded lower-dimensional vector.
422421
let threshold = if dim.is_power_of_two() { 0.15 } else { 0.25 };
423422
assert!(
424423
mean_rel_error.abs() < threshold,

0 commit comments

Comments
 (0)