@@ -36,7 +36,7 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new
3636/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3737/// quantization levels for the coordinate distribution after random rotation in
3838/// `dimension`-dimensional space.
39- pub fn get_centroids ( dimension : u32 , bit_width : u8 ) -> VortexResult < Buffer < f32 > > {
39+ pub fn compute_or_get_centroids ( dimension : u32 , bit_width : u8 ) -> VortexResult < Buffer < f32 > > {
4040 vortex_ensure ! (
4141 ( 1 ..=MAX_BIT_WIDTH ) . contains( & bit_width) ,
4242 "TurboQuant bit_width must be 1-{}, got {bit_width}" ,
@@ -239,7 +239,7 @@ mod tests {
239239 #[ case] bits : u8 ,
240240 #[ case] expected : usize ,
241241 ) -> VortexResult < ( ) > {
242- let centroids = get_centroids ( dim, bits) ?;
242+ let centroids = compute_or_get_centroids ( dim, bits) ?;
243243 assert_eq ! ( centroids. len( ) , expected) ;
244244 Ok ( ( ) )
245245 }
@@ -251,7 +251,7 @@ mod tests {
251251 #[ case( 128 , 4 ) ]
252252 #[ case( 768 , 2 ) ]
253253 fn centroids_are_sorted ( #[ case] dim : u32 , #[ case] bits : u8 ) -> VortexResult < ( ) > {
254- let centroids = get_centroids ( dim, bits) ?;
254+ let centroids = compute_or_get_centroids ( dim, bits) ?;
255255 for window in centroids. windows ( 2 ) {
256256 assert ! (
257257 window[ 0 ] < window[ 1 ] ,
@@ -268,7 +268,7 @@ mod tests {
268268 #[ case( 256 , 2 ) ]
269269 #[ case( 768 , 2 ) ]
270270 fn centroids_are_symmetric ( #[ case] dim : u32 , #[ case] bits : u8 ) -> VortexResult < ( ) > {
271- let centroids = get_centroids ( dim, bits) ?;
271+ let centroids = compute_or_get_centroids ( dim, bits) ?;
272272 let count = centroids. len ( ) ;
273273 for idx in 0 ..count / 2 {
274274 let diff = ( centroids[ idx] + centroids[ count - 1 - idx] ) . abs ( ) ;
@@ -287,7 +287,7 @@ mod tests {
287287 #[ case( 128 , 1 ) ]
288288 #[ case( 128 , 4 ) ]
289289 fn centroids_within_bounds ( #[ case] dim : u32 , #[ case] bits : u8 ) -> VortexResult < ( ) > {
290- let centroids = get_centroids ( dim, bits) ?;
290+ let centroids = compute_or_get_centroids ( dim, bits) ?;
291291 for & val in centroids. iter ( ) {
292292 assert ! (
293293 ( -1.0 ..=1.0 ) . contains( & val) ,
@@ -299,15 +299,15 @@ mod tests {
299299
300300 #[ test]
301301 fn centroids_cached ( ) -> VortexResult < ( ) > {
302- let c1 = get_centroids ( 128 , 2 ) ?;
303- let c2 = get_centroids ( 128 , 2 ) ?;
302+ let c1 = compute_or_get_centroids ( 128 , 2 ) ?;
303+ let c2 = compute_or_get_centroids ( 128 , 2 ) ?;
304304 assert_eq ! ( c1, c2) ;
305305 Ok ( ( ) )
306306 }
307307
308308 #[ test]
309309 fn find_nearest_basic ( ) -> VortexResult < ( ) > {
310- let centroids = get_centroids ( 128 , 2 ) ?;
310+ let centroids = compute_or_get_centroids ( 128 , 2 ) ?;
311311 let boundaries = compute_centroid_boundaries ( & centroids) ;
312312 assert_eq ! ( find_nearest_centroid( -1.0 , & boundaries) , 0 ) ;
313313
@@ -324,9 +324,9 @@ mod tests {
324324
325325 #[ test]
326326 fn rejects_invalid_params ( ) {
327- assert ! ( get_centroids ( 128 , 0 ) . is_err( ) ) ;
328- assert ! ( get_centroids ( 128 , 9 ) . is_err( ) ) ;
329- assert ! ( get_centroids ( 1 , 2 ) . is_err( ) ) ;
330- assert ! ( get_centroids ( 127 , 2 ) . is_err( ) ) ;
327+ assert ! ( compute_or_get_centroids ( 128 , 0 ) . is_err( ) ) ;
328+ assert ! ( compute_or_get_centroids ( 128 , 9 ) . is_err( ) ) ;
329+ assert ! ( compute_or_get_centroids ( 1 , 2 ) . is_err( ) ) ;
330+ assert ! ( compute_or_get_centroids ( 127 , 2 ) . is_err( ) ) ;
331331 }
332332}
0 commit comments