Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use marrow::view::{BytesView, PrimitiveView};
use serde::de::Visitor;

use crate::internal::{
error::{set_default, try_, Context, ContextSupport, Error, ErrorKind, Result},
error::{set_default, try_, Context, ContextSupport, Result},
utils::{array_view_ext::ViewAccess, Offset},
};

use super::{
enums_as_string_impl::EnumAccess, integer_deserializer::Integer,
random_access_deserializer::RandomAccessDeserializer, utils::bitset_all_set,
random_access_deserializer::RandomAccessDeserializer,
};

pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> {
Expand All @@ -19,15 +19,6 @@ pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> {

impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> {
pub fn new(path: String, keys: PrimitiveView<'a, K>, values: BytesView<'a, V>) -> Result<Self> {
if let Some(validity) = &values.validity {
if !bitset_all_set(validity, values.offsets.len() - 1)? {
return Err(Error::new(
ErrorKind::NullabilityViolation { field: None },
"Null for non-nullable type: dictionaries do not support nullable values"
.into(),
));
}
}
Ok(Self {
path,
keys: keys.clone(),
Expand All @@ -53,7 +44,11 @@ impl<'de, K: Integer, V: Offset> RandomAccessDeserializer<'de>
for DictionaryDeserializer<'de, K, V>
{
fn is_some(&self, idx: usize) -> Result<bool> {
self.keys.is_some(idx)
if !self.keys.is_some(idx)? {
return Ok(false);
}
let key: usize = self.keys.get_required(idx)?.into_i64()?.try_into()?;
ViewAccess::<[u8]>::is_some(&self.values, key)
}

fn deserialize_any_some<VV: Visitor<'de>>(&self, visitor: VV, idx: usize) -> Result<VV::Value> {
Expand Down
6 changes: 1 addition & 5 deletions serde_arrow/src/internal/deserialization/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ use serde::{

use crate::internal::{
error::{fail, Error, Result},
utils::array_ext::{all_set_buffer, get_bit_buffer},
utils::array_ext::get_bit_buffer,
};

pub fn bitset_is_set(set: &BitsWithOffset<'_>, idx: usize) -> Result<bool> {
get_bit_buffer(set.data, set.offset, idx)
}

pub fn bitset_all_set(set: &BitsWithOffset<'_>, len: usize) -> Result<bool> {
all_set_buffer(set.data, set.offset, set.offset + len)
}

pub struct U8Deserializer(pub u8);

macro_rules! unimplemented {
Expand Down
75 changes: 0 additions & 75 deletions serde_arrow/src/internal/utils/array_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,38 +534,6 @@ pub fn get_bit_buffer(data: &[u8], offset: usize, idx: usize) -> Result<bool> {
Ok(byte & flag == flag)
}

/// True if all bits in the `start_bit..end_bit` range are set.
pub fn all_set_buffer(data: &[u8], start_bit: usize, end_bit: usize) -> Result<bool> {
if end_bit > data.len() * 8 {
fail!("Invalid access in bitset");
}

let mut current = start_bit;

while current < end_bit && (current % 8) != 0 {
if !get_bit_buffer(data, 0, current)? {
return Ok(false);
}
current += 1;
}

while current.saturating_add(8) < end_bit {
if data[current / 8] != 0xFF {
return Ok(false);
}
current += 8;
}

while current < end_bit {
if !get_bit_buffer(data, 0, current)? {
return Ok(false);
}
current += 1;
}

Ok(true)
}

#[test]
fn test_set_bit_buffer() {
let mut buffer = vec![];
Expand All @@ -587,46 +555,3 @@ fn test_set_bit_buffer() {
set_bit_buffer(&mut buffer, 4, false);
assert_eq!(buffer, vec![0b_0010_0001, 0b_0000_0000, 0b_0000_0100]);
}

#[test]
fn test_all_set_buffer() {
assert!(all_set_buffer(&[0b_0000_0001], 0, 1).unwrap());
assert!(!all_set_buffer(&[0b_0000_0001], 0, 2).unwrap());
assert!(all_set_buffer(&[0b_1000_0000], 7, 8).unwrap());
assert!(!all_set_buffer(&[0b_1000_0000], 6, 8).unwrap());

assert!(all_set_buffer(&[0b_1111_1111], 0, 8).unwrap());
assert!(!all_set_buffer(&[0b_1110_1111], 0, 8).unwrap());
assert!(all_set_buffer(&[0b_1110_1111], 0, 4).unwrap());
assert!(all_set_buffer(&[0b_1110_1111], 5, 8).unwrap());
assert!(!all_set_buffer(&[0b_1110_1111], 4, 5).unwrap());

assert!(all_set_buffer(&[0, 0b_1111_1111], 8, 16).unwrap());
assert!(!all_set_buffer(&[0, 0b_1110_1111], 8, 16).unwrap());
assert!(all_set_buffer(&[0, 0b_1110_1111], 8, 12).unwrap());
assert!(all_set_buffer(&[0, 0b_1110_1111], 13, 16).unwrap());
assert!(!all_set_buffer(&[0, 0b_1110_1111], 12, 13).unwrap());

assert!(all_set_buffer(&[0, 0b_1111_1111, 0], 8, 16).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111], 8, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111], 7, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1110], 8, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_0111, 0b_1111_1111], 8, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 24).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 23).unwrap());

assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 23, 24).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 22, 23).unwrap());

assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1111], 8, 32).unwrap());
assert!(!all_set_buffer(&[0, 0b_0111_1111, 0b_1111_1111, 0b_1111_1111], 8, 32).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1110], 8, 32).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1110], 8, 24).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_0111_1111], 8, 32).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_0111_1111], 8, 31).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 8, 32).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 8, 20).unwrap());
assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 21, 32).unwrap());
assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 20, 21).unwrap());
}
44 changes: 40 additions & 4 deletions serde_arrow/src/test_with_arrow/impls/arrow_dictionary.rs
Comment thread
chmp marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,21 @@ mod construction {
values: Box::new(to_array(DataType::Utf8, true, [None::<&str>, None])),
});

assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err());
let deserializer =
ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap();

assert_eq!(
Option::<String>::deserialize(deserializer.at(0)).unwrap(),
None
);
assert_eq!(
Option::<String>::deserialize(deserializer.at(1)).unwrap(),
None
);
assert_eq!(
Option::<String>::deserialize(deserializer.at(2)).unwrap(),
None
);
}

#[test]
Expand All @@ -196,11 +210,25 @@ mod construction {
)),
});

assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err());
let deserializer =
ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap();

assert_eq!(
Option::<String>::deserialize(deserializer.at(0)).unwrap(),
None
);
assert_eq!(
Option::<String>::deserialize(deserializer.at(1)).unwrap(),
None
);
assert_eq!(
Option::<String>::deserialize(deserializer.at(2)).unwrap(),
Some(String::from("foo")),
);
}

#[test]
fn some_null_values_v2() {
fn some_null_values_out_of_range() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this test is really required anymore. It seems to be testing that if the dictionary contains a null but isn't used, it still errors out.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Not sure what I was thinking here. I would probably remove the test.

let array = Array::Dictionary(DictionaryArray {
keys: Box::new(to_array(DataType::Int8, false, [1, 1, 0])),
values: Box::new(to_array(
Expand All @@ -224,6 +252,14 @@ mod construction {
)),
});

assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err());
let deserializer =
ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap();

for i in 0..3 {
assert_eq!(
Option::<String>::deserialize(deserializer.at(i)).unwrap(),
Some(String::from("1")),
);
}
}
}
Loading