@@ -3,7 +3,7 @@ use std::io::{Read, Seek, SeekFrom, Write};
33use std:: mem:: size_of;
44
55use byteorder:: { LittleEndian , ReadBytesExt , WriteBytesExt } ;
6- use ndarray:: { Array , Array1 , Array2 , ArrayView1 , ArrayView2 , CowArray , IntoDimension , Ix1 } ;
6+ use ndarray:: { Array , Array1 , Array2 , ArrayView1 , ArrayView2 , Axis , CowArray , IntoDimension , Ix1 } ;
77use rand:: { RngCore , SeedableRng } ;
88use rand_xorshift:: XorShiftRng ;
99use reductive:: pq:: { QuantizeVector , ReconstructVector , TrainPQ , PQ } ;
@@ -257,6 +257,21 @@ impl Storage for QuantizedArray {
257257 CowArray :: from ( reconstructed)
258258 }
259259
260+ fn embeddings ( & self , indices : & [ usize ] ) -> Array2 < f32 > {
261+ let quantized_select = self . quantized_embeddings . select ( Axis ( 0 ) , indices) ;
262+ let mut reconstructed = self . quantizer . reconstruct_batch ( quantized_select) ;
263+
264+ if let Some ( ref norms) = self . norms {
265+ let norms_select = norms. select ( Axis ( 0 ) , indices) ;
266+ reconstructed *= & norms_select
267+ . view ( )
268+ . into_shape ( ( norms_select. len ( ) , 1 ) )
269+ . unwrap ( ) ;
270+ }
271+
272+ reconstructed
273+ }
274+
260275 fn shape ( & self ) -> ( usize , usize ) {
261276 (
262277 self . quantized_embeddings . nrows ( ) ,
@@ -451,7 +466,7 @@ mod mmap {
451466 use std:: io:: { BufReader , Seek , SeekFrom , Write } ;
452467
453468 use memmap:: { Mmap , MmapOptions } ;
454- use ndarray:: { Array1 , ArrayView2 , CowArray , Ix1 } ;
469+ use ndarray:: { Array1 , Array2 , ArrayView2 , Axis , CowArray , Ix1 } ;
455470 use reductive:: pq:: { QuantizeVector , ReconstructVector , PQ } ;
456471
457472 use super :: { PQRead , QuantizedArray , Storage } ;
@@ -530,6 +545,23 @@ mod mmap {
530545 CowArray :: from ( reconstructed)
531546 }
532547
548+ fn embeddings ( & self , indices : & [ usize ] ) -> Array2 < f32 > {
549+ let quantized = unsafe { self . quantized_embeddings ( ) } ;
550+
551+ let quantized_select = quantized. select ( Axis ( 0 ) , indices) ;
552+ let mut reconstructed = self . quantizer . reconstruct_batch ( quantized_select) ;
553+
554+ if let Some ( ref norms) = self . norms {
555+ let norms_select = norms. select ( Axis ( 0 ) , indices) ;
556+ reconstructed *= & norms_select
557+ . view ( )
558+ . into_shape ( ( norms_select. len ( ) , 1 ) )
559+ . unwrap ( ) ;
560+ }
561+
562+ reconstructed
563+ }
564+
533565 fn shape ( & self ) -> ( usize , usize ) {
534566 (
535567 self . quantized_embeddings . len ( ) / self . quantizer . quantized_len ( ) ,
@@ -654,6 +686,19 @@ mod tests {
654686 }
655687 }
656688
689+ #[ test]
690+ fn embeddings_returns_expected_embeddings ( ) {
691+ const CHECK_INDICES : & [ usize ] = & [ 0 , 50 , 99 , 0 ] ;
692+
693+ let check_arr = test_quantized_array ( true ) ;
694+
695+ let embeddings = check_arr. embeddings ( CHECK_INDICES ) ;
696+
697+ for ( embedding, & idx) in embeddings. outer_iter ( ) . zip ( CHECK_INDICES ) {
698+ assert_eq ! ( embedding, check_arr. embedding( idx) ) ;
699+ }
700+ }
701+
657702 #[ test]
658703 fn quantized_array_correct_chunk_size ( ) {
659704 let check_arr = test_quantized_array ( false ) ;
@@ -700,64 +745,96 @@ mod tests {
700745 storage_eq ( & quantized, & reconstructed) ;
701746 }
702747
703- #[ test]
704748 #[ cfg( feature = "memmap" ) ]
705- fn mmap_quantized_array ( ) {
706- use crate :: chunks:: io:: MmapChunk ;
707- use crate :: chunks:: storage:: MmapQuantizedArray ;
708- use std:: fs:: File ;
709- use std:: io:: BufReader ;
710-
711- let mut storage_read =
712- BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
713- let check_arr = QuantizedArray :: read_chunk ( & mut storage_read) . unwrap ( ) ;
714-
715- // Memory map matrix.
716- storage_read. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
717- let arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
718-
719- // Check
720- storage_eq ( & arr, & check_arr) ;
721- }
749+ mod mmap {
750+ use std:: io:: { BufReader , BufWriter , Cursor , Seek , SeekFrom } ;
751+
752+ use tempfile:: tempfile;
753+
754+ use super :: { storage_eq, test_quantized_array} ;
755+ use crate :: chunks:: io:: { MmapChunk , ReadChunk , WriteChunk } ;
756+ use crate :: chunks:: storage:: quantized:: Reconstruct ;
757+ use crate :: chunks:: storage:: { MmapQuantizedArray , QuantizedArray , Storage } ;
758+
759+ fn test_mmap_quantized_array ( norms : bool ) -> MmapQuantizedArray {
760+ let quantized = test_quantized_array ( norms) ;
761+ let mut tmp = tempfile ( ) . unwrap ( ) ;
762+ quantized
763+ . write_chunk ( & mut BufWriter :: new ( & mut tmp) )
764+ . unwrap ( ) ;
765+ tmp. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
766+ MmapQuantizedArray :: mmap_chunk ( & mut BufReader :: new ( tmp) ) . unwrap ( )
767+ }
722768
723- #[ test]
724- #[ cfg( feature = "memmap" ) ]
725- fn reconstruct_mmap_quantized_array ( ) {
726- use std:: fs:: File ;
727- use std:: io:: BufReader ;
769+ #[ test]
770+ fn embeddings_returns_expected_embeddings ( ) {
771+ const CHECK_INDICES : & [ usize ] = & [ 0 , 50 , 99 , 0 ] ;
728772
729- use crate :: chunks:: io:: MmapChunk ;
730- use crate :: chunks:: storage:: MmapQuantizedArray ;
773+ let check_arr = test_mmap_quantized_array ( true ) ;
731774
732- let mut storage_read =
733- BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
734- let quantized = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
735- let reconstructed = quantized. reconstruct ( ) ;
736- storage_eq ( & quantized, & reconstructed) ;
737- }
775+ let embeddings = check_arr. embeddings ( CHECK_INDICES ) ;
738776
739- #[ test]
740- #[ cfg( feature = "memmap" ) ]
741- fn write_mmap_quantized_array ( ) {
742- use crate :: chunks:: io:: MmapChunk ;
743- use crate :: chunks:: storage:: MmapQuantizedArray ;
744- use std:: fs:: File ;
745- use std:: io:: BufReader ;
746-
747- // Memory map matrix.
748- let mut storage_read =
749- BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
750- let check_arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
751-
752- // Write matrix
753- let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
754- check_arr. write_chunk ( & mut cursor) . unwrap ( ) ;
777+ for ( embedding, & idx) in embeddings. outer_iter ( ) . zip ( CHECK_INDICES ) {
778+ assert_eq ! ( embedding, check_arr. embedding( idx) ) ;
779+ }
780+ }
755781
756- // Read using non-mmap'ed reader.
757- cursor. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
758- let arr = QuantizedArray :: read_chunk ( & mut cursor) . unwrap ( ) ;
782+ #[ test]
783+ fn reconstruct_mmap_quantized_array ( ) {
784+ use std:: fs:: File ;
785+ use std:: io:: BufReader ;
786+
787+ use crate :: chunks:: io:: MmapChunk ;
788+ use crate :: chunks:: storage:: MmapQuantizedArray ;
759789
760- // Check
761- storage_eq ( & arr, & check_arr) ;
790+ let mut storage_read =
791+ BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
792+ let quantized = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
793+ let reconstructed = quantized. reconstruct ( ) ;
794+ storage_eq ( & quantized, & reconstructed) ;
795+ }
796+
797+ #[ test]
798+ fn mmap_quantized_array ( ) {
799+ use crate :: chunks:: io:: MmapChunk ;
800+ use crate :: chunks:: storage:: MmapQuantizedArray ;
801+ use std:: fs:: File ;
802+ use std:: io:: BufReader ;
803+
804+ let mut storage_read =
805+ BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
806+ let check_arr = QuantizedArray :: read_chunk ( & mut storage_read) . unwrap ( ) ;
807+
808+ // Memory map matrix.
809+ storage_read. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
810+ let arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
811+
812+ // Check
813+ storage_eq ( & arr, & check_arr) ;
814+ }
815+
816+ #[ test]
817+ fn write_mmap_quantized_array ( ) {
818+ use crate :: chunks:: io:: MmapChunk ;
819+ use crate :: chunks:: storage:: MmapQuantizedArray ;
820+ use std:: fs:: File ;
821+ use std:: io:: BufReader ;
822+
823+ // Memory map matrix.
824+ let mut storage_read =
825+ BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
826+ let check_arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
827+
828+ // Write matrix
829+ let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
830+ check_arr. write_chunk ( & mut cursor) . unwrap ( ) ;
831+
832+ // Read using non-mmap'ed reader.
833+ cursor. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
834+ let arr = QuantizedArray :: read_chunk ( & mut cursor) . unwrap ( ) ;
835+
836+ // Check
837+ storage_eq ( & arr, & check_arr) ;
838+ }
762839 }
763840}
0 commit comments