Skip to content

Commit c180cdf

Browse files
danieldkDaniël de Kok
authored andcommitted
Storage: allow lookup of multiple embeddings at the same time
1 parent c056986 commit c180cdf

5 files changed

Lines changed: 233 additions & 62 deletions

File tree

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ default = ["memmap"]
3939

4040
[dev-dependencies]
4141
approx = "0.3"
42-
maplit = "1"
43-
lazy_static = "1"
4442
criterion = "0.3"
43+
lazy_static = "1"
44+
maplit = "1"
45+
tempfile = "3"
4546

4647
[[bench]]
4748
name = "array"

src/chunks/storage/array.rs

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::io::{Read, Seek, SeekFrom, Write};
33
use std::mem::size_of;
44

55
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6-
use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Ix1};
6+
use ndarray::{Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Ix1};
77

88
use super::{Storage, StorageView, StorageViewMut};
99
use crate::chunks::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
@@ -24,7 +24,7 @@ mod mmap {
2424
use byteorder::ByteOrder;
2525
use byteorder::{LittleEndian, ReadBytesExt};
2626
use memmap::{Mmap, MmapOptions};
27-
use ndarray::{ArrayView2, CowArray, Ix1};
27+
use ndarray::{Array2, ArrayView2, Axis, CowArray, Ix1};
2828
use ndarray::{Dimension, Ix2};
2929

3030
#[cfg(target_endian = "little")]
@@ -65,6 +65,27 @@ mod mmap {
6565
CowArray::from(embedding)
6666
}
6767

68+
#[allow(clippy::let_and_return)]
69+
fn embeddings(&self, indices: &[usize]) -> Array2<f32> {
70+
#[allow(clippy::cast_ptr_alignment,unused_mut)]
71+
let embeddings =
72+
// Alignment is ok, padding guarantees that the pointer is at
73+
// a multiple of 4.
74+
unsafe { ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32) };
75+
76+
#[allow(unused_mut)]
77+
let mut selected_embeddings = embeddings.select(Axis(0), indices);
78+
79+
#[cfg(target_endian = "big")]
80+
LittleEndian::from_slice_f32(
81+
selected_embeddings
82+
.as_slice_mut()
83+
.expect("Cannot borrow matrix as mutable slice"),
84+
);
85+
86+
selected_embeddings
87+
}
88+
6889
fn shape(&self) -> (usize, usize) {
6990
self.shape.into_pattern()
7091
}
@@ -245,6 +266,10 @@ impl Storage for NdArray {
245266
CowArray::from(self.inner.row(idx))
246267
}
247268

269+
fn embeddings(&self, indices: &[usize]) -> Array2<f32> {
270+
self.inner.select(Axis(0), indices)
271+
}
272+
248273
fn shape(&self) -> (usize, usize) {
249274
self.inner.dim()
250275
}
@@ -321,7 +346,7 @@ mod tests {
321346
use ndarray::Array2;
322347

323348
use crate::chunks::io::{ReadChunk, WriteChunk};
324-
use crate::chunks::storage::{NdArray, StorageView};
349+
use crate::chunks::storage::{NdArray, Storage, StorageView};
325350

326351
const N_ROWS: usize = 100;
327352
const N_COLS: usize = 100;
@@ -342,6 +367,19 @@ mod tests {
342367
read.read_u64::<LittleEndian>().unwrap()
343368
}
344369

370+
#[test]
371+
fn embeddings_returns_expected_embeddings() {
372+
const CHECK_INDICES: &[usize] = &[0, 50, 99, 0];
373+
374+
let check_arr = test_ndarray();
375+
376+
let embeddings = check_arr.embeddings(CHECK_INDICES);
377+
378+
for (embedding, &idx) in embeddings.outer_iter().zip(CHECK_INDICES) {
379+
assert_eq!(embedding, check_arr.embedding(idx));
380+
}
381+
}
382+
345383
#[test]
346384
fn ndarray_correct_chunk_size() {
347385
let check_arr = test_ndarray();
@@ -365,4 +403,36 @@ mod tests {
365403
let arr = NdArray::read_chunk(&mut cursor).unwrap();
366404
assert_eq!(arr.view(), check_arr.view());
367405
}
406+
407+
#[cfg(feature = "memmap")]
408+
mod mmap {
409+
use std::io::{BufReader, BufWriter, Seek, SeekFrom};
410+
411+
use tempfile::tempfile;
412+
413+
use super::test_ndarray;
414+
use crate::chunks::io::{MmapChunk, WriteChunk};
415+
use crate::chunks::storage::{MmapArray, Storage};
416+
417+
fn test_mmap_array() -> MmapArray {
418+
let array = test_ndarray();
419+
let mut tmp = tempfile().unwrap();
420+
array.write_chunk(&mut BufWriter::new(&mut tmp)).unwrap();
421+
tmp.seek(SeekFrom::Start(0)).unwrap();
422+
MmapArray::mmap_chunk(&mut BufReader::new(tmp)).unwrap()
423+
}
424+
425+
#[test]
426+
fn embeddings_returns_expected_embeddings() {
427+
const CHECK_INDICES: &[usize] = &[0, 50, 99, 0];
428+
429+
let check_arr = test_mmap_array();
430+
431+
let embeddings = check_arr.embeddings(CHECK_INDICES);
432+
433+
for (embedding, &idx) in embeddings.outer_iter().zip(CHECK_INDICES) {
434+
assert_eq!(embedding, check_arr.embedding(idx));
435+
}
436+
}
437+
}
368438
}

src/chunks/storage/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Embedding matrix representations.
22
3-
use ndarray::{ArrayView2, ArrayViewMut2, CowArray, Ix1};
3+
use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Ix1};
44

55
mod array;
66
#[cfg(feature = "memmap")]
@@ -23,6 +23,9 @@ pub use self::wrappers::{StorageViewWrap, StorageWrap};
2323
pub trait Storage {
2424
fn embedding(&self, idx: usize) -> CowArray<f32, Ix1>;
2525

26+
/// Retrieve multiple embeddings.
27+
fn embeddings(&self, indices: &[usize]) -> Array2<f32>;
28+
2629
fn shape(&self) -> (usize, usize);
2730
}
2831

src/chunks/storage/quantized.rs

Lines changed: 131 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::io::{Read, Seek, SeekFrom, Write};
33
use std::mem::size_of;
44

55
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6-
use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, CowArray, IntoDimension, Ix1};
6+
use ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, Axis, CowArray, IntoDimension, Ix1};
77
use rand::{RngCore, SeedableRng};
88
use rand_xorshift::XorShiftRng;
99
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
@@ -257,6 +257,21 @@ impl Storage for QuantizedArray {
257257
CowArray::from(reconstructed)
258258
}
259259

260+
fn embeddings(&self, indices: &[usize]) -> Array2<f32> {
261+
let quantized_select = self.quantized_embeddings.select(Axis(0), indices);
262+
let mut reconstructed = self.quantizer.reconstruct_batch(quantized_select);
263+
264+
if let Some(ref norms) = self.norms {
265+
let norms_select = norms.select(Axis(0), indices);
266+
reconstructed *= &norms_select
267+
.view()
268+
.into_shape((norms_select.len(), 1))
269+
.unwrap();
270+
}
271+
272+
reconstructed
273+
}
274+
260275
fn shape(&self) -> (usize, usize) {
261276
(
262277
self.quantized_embeddings.nrows(),
@@ -451,7 +466,7 @@ mod mmap {
451466
use std::io::{BufReader, Seek, SeekFrom, Write};
452467

453468
use memmap::{Mmap, MmapOptions};
454-
use ndarray::{Array1, ArrayView2, CowArray, Ix1};
469+
use ndarray::{Array1, Array2, ArrayView2, Axis, CowArray, Ix1};
455470
use reductive::pq::{QuantizeVector, ReconstructVector, PQ};
456471

457472
use super::{PQRead, QuantizedArray, Storage};
@@ -530,6 +545,23 @@ mod mmap {
530545
CowArray::from(reconstructed)
531546
}
532547

548+
fn embeddings(&self, indices: &[usize]) -> Array2<f32> {
549+
let quantized = unsafe { self.quantized_embeddings() };
550+
551+
let quantized_select = quantized.select(Axis(0), indices);
552+
let mut reconstructed = self.quantizer.reconstruct_batch(quantized_select);
553+
554+
if let Some(ref norms) = self.norms {
555+
let norms_select = norms.select(Axis(0), indices);
556+
reconstructed *= &norms_select
557+
.view()
558+
.into_shape((norms_select.len(), 1))
559+
.unwrap();
560+
}
561+
562+
reconstructed
563+
}
564+
533565
fn shape(&self) -> (usize, usize) {
534566
(
535567
self.quantized_embeddings.len() / self.quantizer.quantized_len(),
@@ -654,6 +686,19 @@ mod tests {
654686
}
655687
}
656688

689+
#[test]
690+
fn embeddings_returns_expected_embeddings() {
691+
const CHECK_INDICES: &[usize] = &[0, 50, 99, 0];
692+
693+
let check_arr = test_quantized_array(true);
694+
695+
let embeddings = check_arr.embeddings(CHECK_INDICES);
696+
697+
for (embedding, &idx) in embeddings.outer_iter().zip(CHECK_INDICES) {
698+
assert_eq!(embedding, check_arr.embedding(idx));
699+
}
700+
}
701+
657702
#[test]
658703
fn quantized_array_correct_chunk_size() {
659704
let check_arr = test_quantized_array(false);
@@ -700,64 +745,96 @@ mod tests {
700745
storage_eq(&quantized, &reconstructed);
701746
}
702747

703-
#[test]
704748
#[cfg(feature = "memmap")]
705-
fn mmap_quantized_array() {
706-
use crate::chunks::io::MmapChunk;
707-
use crate::chunks::storage::MmapQuantizedArray;
708-
use std::fs::File;
709-
use std::io::BufReader;
710-
711-
let mut storage_read =
712-
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
713-
let check_arr = QuantizedArray::read_chunk(&mut storage_read).unwrap();
714-
715-
// Memory map matrix.
716-
storage_read.seek(SeekFrom::Start(0)).unwrap();
717-
let arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
718-
719-
// Check
720-
storage_eq(&arr, &check_arr);
721-
}
749+
mod mmap {
750+
use std::io::{BufReader, BufWriter, Cursor, Seek, SeekFrom};
751+
752+
use tempfile::tempfile;
753+
754+
use super::{storage_eq, test_quantized_array};
755+
use crate::chunks::io::{MmapChunk, ReadChunk, WriteChunk};
756+
use crate::chunks::storage::quantized::Reconstruct;
757+
use crate::chunks::storage::{MmapQuantizedArray, QuantizedArray, Storage};
758+
759+
fn test_mmap_quantized_array(norms: bool) -> MmapQuantizedArray {
760+
let quantized = test_quantized_array(norms);
761+
let mut tmp = tempfile().unwrap();
762+
quantized
763+
.write_chunk(&mut BufWriter::new(&mut tmp))
764+
.unwrap();
765+
tmp.seek(SeekFrom::Start(0)).unwrap();
766+
MmapQuantizedArray::mmap_chunk(&mut BufReader::new(tmp)).unwrap()
767+
}
722768

723-
#[test]
724-
#[cfg(feature = "memmap")]
725-
fn reconstruct_mmap_quantized_array() {
726-
use std::fs::File;
727-
use std::io::BufReader;
769+
#[test]
770+
fn embeddings_returns_expected_embeddings() {
771+
const CHECK_INDICES: &[usize] = &[0, 50, 99, 0];
728772

729-
use crate::chunks::io::MmapChunk;
730-
use crate::chunks::storage::MmapQuantizedArray;
773+
let check_arr = test_mmap_quantized_array(true);
731774

732-
let mut storage_read =
733-
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
734-
let quantized = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
735-
let reconstructed = quantized.reconstruct();
736-
storage_eq(&quantized, &reconstructed);
737-
}
775+
let embeddings = check_arr.embeddings(CHECK_INDICES);
738776

739-
#[test]
740-
#[cfg(feature = "memmap")]
741-
fn write_mmap_quantized_array() {
742-
use crate::chunks::io::MmapChunk;
743-
use crate::chunks::storage::MmapQuantizedArray;
744-
use std::fs::File;
745-
use std::io::BufReader;
746-
747-
// Memory map matrix.
748-
let mut storage_read =
749-
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
750-
let check_arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
751-
752-
// Write matrix
753-
let mut cursor = Cursor::new(Vec::new());
754-
check_arr.write_chunk(&mut cursor).unwrap();
777+
for (embedding, &idx) in embeddings.outer_iter().zip(CHECK_INDICES) {
778+
assert_eq!(embedding, check_arr.embedding(idx));
779+
}
780+
}
755781

756-
// Read using non-mmap'ed reader.
757-
cursor.seek(SeekFrom::Start(0)).unwrap();
758-
let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
782+
#[test]
783+
fn reconstruct_mmap_quantized_array() {
784+
use std::fs::File;
785+
use std::io::BufReader;
786+
787+
use crate::chunks::io::MmapChunk;
788+
use crate::chunks::storage::MmapQuantizedArray;
759789

760-
// Check
761-
storage_eq(&arr, &check_arr);
790+
let mut storage_read =
791+
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
792+
let quantized = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
793+
let reconstructed = quantized.reconstruct();
794+
storage_eq(&quantized, &reconstructed);
795+
}
796+
797+
#[test]
798+
fn mmap_quantized_array() {
799+
use crate::chunks::io::MmapChunk;
800+
use crate::chunks::storage::MmapQuantizedArray;
801+
use std::fs::File;
802+
use std::io::BufReader;
803+
804+
let mut storage_read =
805+
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
806+
let check_arr = QuantizedArray::read_chunk(&mut storage_read).unwrap();
807+
808+
// Memory map matrix.
809+
storage_read.seek(SeekFrom::Start(0)).unwrap();
810+
let arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
811+
812+
// Check
813+
storage_eq(&arr, &check_arr);
814+
}
815+
816+
#[test]
817+
fn write_mmap_quantized_array() {
818+
use crate::chunks::io::MmapChunk;
819+
use crate::chunks::storage::MmapQuantizedArray;
820+
use std::fs::File;
821+
use std::io::BufReader;
822+
823+
// Memory map matrix.
824+
let mut storage_read =
825+
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
826+
let check_arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
827+
828+
// Write matrix
829+
let mut cursor = Cursor::new(Vec::new());
830+
check_arr.write_chunk(&mut cursor).unwrap();
831+
832+
// Read using non-mmap'ed reader.
833+
cursor.seek(SeekFrom::Start(0)).unwrap();
834+
let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
835+
836+
// Check
837+
storage_eq(&arr, &check_arr);
838+
}
762839
}
763840
}

0 commit comments

Comments
 (0)