diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index 820e2000..3c6158b1 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -2,13 +2,13 @@ use marrow::view::{BytesView, PrimitiveView}; use serde::de::Visitor; use crate::internal::{ - error::{set_default, try_, Context, ContextSupport, Error, ErrorKind, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{array_view_ext::ViewAccess, Offset}, }; use super::{ enums_as_string_impl::EnumAccess, integer_deserializer::Integer, - random_access_deserializer::RandomAccessDeserializer, utils::bitset_all_set, + random_access_deserializer::RandomAccessDeserializer, }; pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { @@ -19,15 +19,6 @@ pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { pub fn new(path: String, keys: PrimitiveView<'a, K>, values: BytesView<'a, V>) -> Result { - if let Some(validity) = &values.validity { - if !bitset_all_set(validity, values.offsets.len() - 1)? { - return Err(Error::new( - ErrorKind::NullabilityViolation { field: None }, - "Null for non-nullable type: dictionaries do not support nullable values" - .into(), - )); - } - } Ok(Self { path, keys: keys.clone(), @@ -53,7 +44,11 @@ impl<'de, K: Integer, V: Offset> RandomAccessDeserializer<'de> for DictionaryDeserializer<'de, K, V> { fn is_some(&self, idx: usize) -> Result { - self.keys.is_some(idx) + if !self.keys.is_some(idx)? { + return Ok(false); + } + let key: usize = self.keys.get_required(idx)?.into_i64()?.try_into()?; + ViewAccess::<[u8]>::is_some(&self.values, key) } fn deserialize_any_some>(&self, visitor: VV, idx: usize) -> Result { diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index 0feb4b0e..7e00a1fa 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -6,17 +6,13 @@ use serde::{ use crate::internal::{ error::{fail, Error, Result}, - utils::array_ext::{all_set_buffer, get_bit_buffer}, + utils::array_ext::get_bit_buffer, }; pub fn bitset_is_set(set: &BitsWithOffset<'_>, idx: usize) -> Result { get_bit_buffer(set.data, set.offset, idx) } -pub fn bitset_all_set(set: &BitsWithOffset<'_>, len: usize) -> Result { - all_set_buffer(set.data, set.offset, set.offset + len) -} - pub struct U8Deserializer(pub u8); macro_rules! unimplemented { diff --git a/serde_arrow/src/internal/utils/array_ext.rs b/serde_arrow/src/internal/utils/array_ext.rs index 85add671..89893b9b 100644 --- a/serde_arrow/src/internal/utils/array_ext.rs +++ b/serde_arrow/src/internal/utils/array_ext.rs @@ -534,38 +534,6 @@ pub fn get_bit_buffer(data: &[u8], offset: usize, idx: usize) -> Result { Ok(byte & flag == flag) } -/// True if all bits in the `start_bit..end_bit` range are set. -pub fn all_set_buffer(data: &[u8], start_bit: usize, end_bit: usize) -> Result { - if end_bit > data.len() * 8 { - fail!("Invalid access in bitset"); - } - - let mut current = start_bit; - - while current < end_bit && (current % 8) != 0 { - if !get_bit_buffer(data, 0, current)? { - return Ok(false); - } - current += 1; - } - - while current.saturating_add(8) < end_bit { - if data[current / 8] != 0xFF { - return Ok(false); - } - current += 8; - } - - while current < end_bit { - if !get_bit_buffer(data, 0, current)? { - return Ok(false); - } - current += 1; - } - - Ok(true) -} - #[test] fn test_set_bit_buffer() { let mut buffer = vec![]; @@ -587,46 +555,3 @@ fn test_set_bit_buffer() { set_bit_buffer(&mut buffer, 4, false); assert_eq!(buffer, vec![0b_0010_0001, 0b_0000_0000, 0b_0000_0100]); } - -#[test] -fn test_all_set_buffer() { - assert!(all_set_buffer(&[0b_0000_0001], 0, 1).unwrap()); - assert!(!all_set_buffer(&[0b_0000_0001], 0, 2).unwrap()); - assert!(all_set_buffer(&[0b_1000_0000], 7, 8).unwrap()); - assert!(!all_set_buffer(&[0b_1000_0000], 6, 8).unwrap()); - - assert!(all_set_buffer(&[0b_1111_1111], 0, 8).unwrap()); - assert!(!all_set_buffer(&[0b_1110_1111], 0, 8).unwrap()); - assert!(all_set_buffer(&[0b_1110_1111], 0, 4).unwrap()); - assert!(all_set_buffer(&[0b_1110_1111], 5, 8).unwrap()); - assert!(!all_set_buffer(&[0b_1110_1111], 4, 5).unwrap()); - - assert!(all_set_buffer(&[0, 0b_1111_1111], 8, 16).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1110_1111], 8, 16).unwrap()); - assert!(all_set_buffer(&[0, 0b_1110_1111], 8, 12).unwrap()); - assert!(all_set_buffer(&[0, 0b_1110_1111], 13, 16).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1110_1111], 12, 13).unwrap()); - - assert!(all_set_buffer(&[0, 0b_1111_1111, 0], 8, 16).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111], 8, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111], 7, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1110], 8, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_0111, 0b_1111_1111], 8, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 24).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 8, 23).unwrap()); - - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 23, 24).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_0111_1111], 22, 23).unwrap()); - - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1111], 8, 32).unwrap()); - assert!(!all_set_buffer(&[0, 0b_0111_1111, 0b_1111_1111, 0b_1111_1111], 8, 32).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1110], 8, 32).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_1111_1110], 8, 24).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_0111_1111], 8, 32).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1111_1111, 0b_0111_1111], 8, 31).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 8, 32).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 8, 20).unwrap()); - assert!(all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 21, 32).unwrap()); - assert!(!all_set_buffer(&[0, 0b_1111_1111, 0b_1110_1111, 0b_1111_1111], 20, 21).unwrap()); -} diff --git a/serde_arrow/src/test_with_arrow/impls/arrow_dictionary.rs b/serde_arrow/src/test_with_arrow/impls/arrow_dictionary.rs index c8b794fb..0b73dee0 100644 --- a/serde_arrow/src/test_with_arrow/impls/arrow_dictionary.rs +++ b/serde_arrow/src/test_with_arrow/impls/arrow_dictionary.rs @@ -182,7 +182,21 @@ mod construction { values: Box::new(to_array(DataType::Utf8, true, [None::<&str>, None])), }); - assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err()); + let deserializer = + ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap(); + + assert_eq!( + Option::::deserialize(deserializer.at(0)).unwrap(), + None + ); + assert_eq!( + Option::::deserialize(deserializer.at(1)).unwrap(), + None + ); + assert_eq!( + Option::::deserialize(deserializer.at(2)).unwrap(), + None + ); } #[test] @@ -196,11 +210,25 @@ mod construction { )), }); - assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err()); + let deserializer = + ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap(); + + assert_eq!( + Option::::deserialize(deserializer.at(0)).unwrap(), + None + ); + assert_eq!( + Option::::deserialize(deserializer.at(1)).unwrap(), + None + ); + assert_eq!( + Option::::deserialize(deserializer.at(2)).unwrap(), + Some(String::from("foo")), + ); } #[test] - fn some_null_values_v2() { + fn some_null_values_out_of_range() { let array = Array::Dictionary(DictionaryArray { keys: Box::new(to_array(DataType::Int8, false, [1, 1, 0])), values: Box::new(to_array( @@ -224,6 +252,14 @@ mod construction { )), }); - assert!(ArrayDeserializer::new(String::from("$"), None, array.as_view()).is_err()); + let deserializer = + ArrayDeserializer::new(String::from("$"), None, array.as_view()).unwrap(); + + for i in 0..3 { + assert_eq!( + Option::::deserialize(deserializer.at(i)).unwrap(), + Some(String::from("1")), + ); + } } }