Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions crates/sats/src/algebraic_value/ser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::bsatn::decode;
use crate::de::DeserializeSeed;
use crate::ser::{self, ForwardNamedToSeqProduct, Serialize};
use crate::{i256, u256};
use crate::{AlgebraicType, AlgebraicValue, ArrayValue, F32, F64};
use crate::{i256, u256, WithTypespace};
use crate::{AlgebraicValue, ArrayValue, F32, F64};
use core::convert::Infallible;
use core::mem::MaybeUninit;
use core::ptr;
Expand Down Expand Up @@ -81,18 +83,25 @@ impl ser::Serializer for ValueSerializer {
value.serialize(self).map(|v| AlgebraicValue::sum(tag, v))
}

unsafe fn serialize_bsatn(self, ty: &AlgebraicType, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
let res = AlgebraicValue::decode(ty, &mut bsatn);
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
where
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
let res = decode(ty, &mut bsatn);
// SAFETY: Caller promised that `res.is_ok()`.
Ok(unsafe { res.unwrap_unchecked() })
let val = unsafe { res.unwrap_unchecked() };
Ok(val.into())
}

unsafe fn serialize_bsatn_in_chunks<'a, I: Iterator<Item = &'a [u8]>>(
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Iterator<Item = &'a [u8]>>(
self,
ty: &crate::AlgebraicType,
ty: &Ty,
total_bsatn_len: usize,
chunks: I,
) -> Result<Self::Ok, Self::Error> {
) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// SAFETY: Caller promised `total_bsatn_len == chunks.map(|c| c.len()).sum() <= isize::MAX`.
unsafe {
concat_byte_chunks_buf(total_bsatn_len, chunks, |bsatn| {
Expand Down
10 changes: 7 additions & 3 deletions crates/sats/src/algebraic_value_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Hash for ArrayValue {

type HR = Result<(), DecodeError>;

pub fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
match ty {
AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de),
Expand Down Expand Up @@ -166,7 +166,11 @@ fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deseri
}

/// Hashes every elem in the BSATN-encoded array value.
fn hash_bsatn_array<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
pub fn hash_bsatn_array<'a>(
state: &mut impl Hasher,
ty: &AlgebraicType,
de: Deserializer<'_, impl BufReader<'a>>,
) -> HR {
// The BSATN is length-prefixed.
// `Hash for &[T]` also does length-prefixing.
match ty {
Expand Down Expand Up @@ -236,9 +240,9 @@ fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>(

#[cfg(test)]
mod tests {
use super::hash_bsatn;
use crate::{
bsatn::{to_vec, Deserializer},
hash_bsatn,
proptest::generate_typed_value,
AlgebraicType, AlgebraicValue,
};
Expand Down
21 changes: 14 additions & 7 deletions crates/sats/src/bsatn/ser.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::buffer::BufWriter;
use crate::de::DeserializeSeed;
use crate::ser::{self, Error, ForwardNamedToSeqProduct, SerializeArray, SerializeSeqProduct};
use crate::AlgebraicValue;
use crate::{i256, u256};
use crate::{AlgebraicValue, WithTypespace};
use core::fmt;

/// Defines the BSATN serialization data format.
Expand Down Expand Up @@ -159,20 +160,26 @@ impl<W: BufWriter> ser::Serializer for Serializer<'_, W> {
value.serialize(self)
}

unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
debug_assert!(AlgebraicValue::decode(ty, &mut { bsatn }).is_ok());
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
where
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
debug_assert!(crate::bsatn::decode(ty, &mut { bsatn }).is_ok());
self.writer.put_slice(bsatn);
Ok(())
}

unsafe fn serialize_bsatn_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Clone + Iterator<Item = &'a [u8]>>(
self,
ty: &crate::AlgebraicType,
ty: &Ty,
total_bsatn_len: usize,
bsatn: I,
) -> Result<Self::Ok, Self::Error> {
) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
debug_assert!(total_bsatn_len <= isize::MAX as usize);
debug_assert!(AlgebraicValue::decode(ty, &mut &*concat_bytes_slow(total_bsatn_len, bsatn.clone())).is_ok());
debug_assert!(crate::bsatn::decode(ty, &mut &*concat_bytes_slow(total_bsatn_len, bsatn.clone())).is_ok());

for chunk in bsatn {
self.writer.put_slice(chunk);
Expand Down
51 changes: 28 additions & 23 deletions crates/sats/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::{
},
i256, impl_deserialize, impl_serialize,
sum_type::{OPTION_NONE_TAG, OPTION_SOME_TAG},
u256, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, SumType, SumTypeVariant,
SumValue, WithTypespace,
u256, AlgebraicType, AlgebraicValue, ArrayType, ProductType, ProductTypeElement, ProductValue, SumType,
SumTypeVariant, SumValue, WithTypespace,
};
use core::ops::{Index, Mul};
use core::{mem, ops::Deref};
Expand Down Expand Up @@ -199,8 +199,8 @@ impl AlgebraicTypeLayout {
// but we don't care to avoid that and optimize right now,
// as this is only executed during upgrade / migration,
// and that doesn't need to be fast right now.
let old = AlgebraicTypeLayout::from(old.deref().clone());
let new = AlgebraicTypeLayout::from(new.deref().clone());
let old = AlgebraicTypeLayout::from(old.elem_ty.deref().clone());
let new = AlgebraicTypeLayout::from(new.elem_ty.deref().clone());
old.is_compatible_with(&new)
}
(Self::VarLen(VarLenType::String), Self::VarLen(VarLenType::String)) => true,
Expand Down Expand Up @@ -515,11 +515,11 @@ impl HasLayout for PrimitiveType {
pub enum VarLenType {
/// The string type corresponds to `AlgebraicType::String`.
String,
/// An array type. The whole outer `AlgebraicType` is stored here.
/// An array type. The inner `AlgebraicType` is stored here.
///
/// Storing the whole `AlgebraicType` here allows us to directly call BSATN ser/de,
/// and to report type errors.
Array(Box<AlgebraicType>),
/// Previously, the outer type, i.e., `AlgebraicType::Array` was stored.
/// However, this is both more inefficient and bug prone.
Array(ArrayType),
}

#[cfg(feature = "memory-usage")]
Expand Down Expand Up @@ -554,7 +554,7 @@ impl From<AlgebraicType> for AlgebraicTypeLayout {
AlgebraicType::Product(prod) => AlgebraicTypeLayout::Product(prod.into()),

AlgebraicType::String => AlgebraicTypeLayout::VarLen(VarLenType::String),
AlgebraicType::Array(_) => AlgebraicTypeLayout::VarLen(VarLenType::Array(Box::new(ty))),
AlgebraicType::Array(array) => AlgebraicTypeLayout::VarLen(VarLenType::Array(array)),

AlgebraicType::Bool => AlgebraicTypeLayout::Bool,
AlgebraicType::I8 => AlgebraicTypeLayout::I8,
Expand Down Expand Up @@ -690,19 +690,11 @@ impl AlgebraicTypeLayout {
/// It is intended for use in error paths, where performance is a secondary concern.
pub fn algebraic_type(&self) -> AlgebraicType {
match self {
AlgebraicTypeLayout::Primitive(prim) => prim.algebraic_type(),
AlgebraicTypeLayout::VarLen(var_len) => var_len.algebraic_type(),
AlgebraicTypeLayout::Product(prod) => AlgebraicType::Product(prod.view().product_type()),
AlgebraicTypeLayout::Sum(sum) => AlgebraicType::Sum(sum.sum_type()),
}
}
}

impl VarLenType {
fn algebraic_type(&self) -> AlgebraicType {
match self {
VarLenType::String => AlgebraicType::String,
VarLenType::Array(ty) => ty.as_ref().clone(),
Self::Primitive(prim) => prim.algebraic_type(),
Self::VarLen(VarLenType::String) => AlgebraicType::String,
Self::VarLen(VarLenType::Array(array)) => AlgebraicType::Array(array.clone()),
Self::Product(prod) => AlgebraicType::Product(prod.view().product_type()),
Self::Sum(sum) => AlgebraicType::Sum(sum.sum_type()),
}
}
}
Expand Down Expand Up @@ -828,7 +820,9 @@ impl<'de> DeserializeSeed<'de> for &AlgebraicTypeLayout {
AlgebraicTypeLayout::Primitive(PrimitiveType::U256) => u256::deserialize(de).map(Into::into),
AlgebraicTypeLayout::Primitive(PrimitiveType::F32) => f32::deserialize(de).map(Into::into),
AlgebraicTypeLayout::Primitive(PrimitiveType::F64) => f64::deserialize(de).map(Into::into),
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => WithTypespace::empty(&**ty).deserialize(de),
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
WithTypespace::empty(ty).deserialize(de).map(AlgebraicValue::Array)
}
AlgebraicTypeLayout::VarLen(VarLenType::String) => <Box<str>>::deserialize(de).map(Into::into),
}
}
Expand Down Expand Up @@ -1124,4 +1118,15 @@ mod test {
}
}
}

#[test]
fn infinite_recursion_in_is_compatible_with_with_array_type() {
let ty = AlgebraicTypeLayout::from(AlgebraicType::array(AlgebraicType::U64));
// This would previously cause an infinite recursion / stack overflow
// due the setup where `AlgebraicTypeLayout::VarLen(Array(x))` stored
// `x = Box::new(AlgebraicType::Array(elem_ty))`.
// The method `AlgebraicTypeLayout::is_compatible_with` was not setup to handle that.
// To avoid such bugs in the future, `x` is now `elem_ty` instead.
assert!(ty.is_compatible_with(&ty));
}
}
2 changes: 1 addition & 1 deletion crates/sats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub use crate as sats;
pub use algebraic_type::AlgebraicType;
pub use algebraic_type_ref::AlgebraicTypeRef;
pub use algebraic_value::{i256, u256, AlgebraicValue, F32, F64};
pub use algebraic_value_hash::hash_bsatn;
pub use algebraic_value_hash::hash_bsatn_array;
pub use array_type::ArrayType;
pub use array_value::ArrayValue;
pub use product_type::ProductType;
Expand Down
17 changes: 12 additions & 5 deletions crates/sats/src/satn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::de::DeserializeSeed;
use crate::time_duration::TimeDuration;
use crate::timestamp::Timestamp;
use crate::{i256, u256};
use crate::{i256, u256, AlgebraicValue, WithTypespace};
use crate::{ser, ProductType, ProductTypeElement};
use core::fmt;
use core::fmt::Write as _;
Expand Down Expand Up @@ -706,17 +707,23 @@ impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> {
self.fmt.serialize_variant(tag, name, value)
}

unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// SAFETY: Forward caller requirements of this method to that we are calling.
unsafe { self.fmt.serialize_bsatn(ty, bsatn) }
}

unsafe fn serialize_bsatn_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator<Item = &'c [u8]>>(
self,
ty: &crate::AlgebraicType,
ty: &Ty,
total_bsatn_len: usize,
bsatn: I,
) -> Result<Self::Ok, Self::Error> {
) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// SAFETY: Forward caller requirements of this method to that we are calling.
unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) }
}
Expand Down
24 changes: 17 additions & 7 deletions crates/sats/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ mod impls;
#[cfg(feature = "serde")]
pub mod serde;

use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, AlgebraicType};
use crate::de::DeserializeSeed;
use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter};
use crate::{AlgebraicValue, WithTypespace};
use core::marker::PhantomData;
use core::{convert::Infallible, fmt};
use ethnum::{i256, u256};
Expand Down Expand Up @@ -130,9 +132,13 @@ pub trait Serializer: Sized {
///
/// # Safety
///
/// - `AlgebraicValue::decode(ty, &mut bsatn).is_ok()`.
/// - `decode(ty, &mut bsatn).is_ok()`.
/// That is, `bsatn` encodes a valid element of `ty`.
unsafe fn serialize_bsatn(self, ty: &AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
/// It's up to the caller to arrange `Ty` such that this holds.
unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
where
for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// TODO(Centril): Consider instead deserializing the `bsatn` through a
// deserializer that serializes into `self` directly.

Expand Down Expand Up @@ -168,14 +174,18 @@ pub trait Serializer: Sized {
///
/// - `total_bsatn_len == bsatn.map(|c| c.len()).sum() <= isize::MAX`
/// - Let `buf` be defined as above, i.e., the bytes of `bsatn` concatenated.
/// Then `AlgebraicValue::decode(ty, &mut buf).is_ok()`.
/// Then `decode(ty, &mut buf).is_ok()`.
/// That is, `buf` encodes a valid element of `ty`.
unsafe fn serialize_bsatn_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
/// It's up to the caller to arrange `Ty` such that this holds.
unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Clone + Iterator<Item = &'a [u8]>>(
self,
ty: &AlgebraicType,
ty: &Ty,
total_bsatn_len: usize,
bsatn: I,
) -> Result<Self::Ok, Self::Error> {
) -> Result<Self::Ok, Self::Error>
where
for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
{
// TODO(Centril): Unlike above, in this case we must at minimum concatenate `bsatn`
// before we can do the piping mentioned above, but that's better than
// serializing to `AlgebraicValue` first, so consider that.
Expand Down
8 changes: 4 additions & 4 deletions crates/table/src/bflatn_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use spacetimedb_sats::{
align_to, AlgebraicTypeLayout, HasLayout as _, ProductTypeLayoutView, RowTypeLayout, SumTypeLayout, VarLenType,
},
ser::{SerializeNamedProduct, Serializer},
u256, AlgebraicType,
u256, ArrayType,
};

/// Serializes the row in `page` where the fixed part starts at `fixed_offset`
Expand Down Expand Up @@ -243,7 +243,7 @@ pub(crate) unsafe fn serialize_value<S: Serializer>(
}
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
// SAFETY: `value` was valid at `ty` and `VarLenRef`s won't be dangling.
unsafe { serialize_bsatn(ser, bytes, page, blob_store, curr_offset, ty) }
unsafe { serialize_array(ser, bytes, page, blob_store, curr_offset, ty) }
}
}
}
Expand Down Expand Up @@ -285,13 +285,13 @@ unsafe fn serialize_string<S: Serializer>(
}
}

unsafe fn serialize_bsatn<S: Serializer>(
unsafe fn serialize_array<S: Serializer>(
ser: S,
bytes: &Bytes,
page: &Page,
blob_store: &dyn BlobStore,
curr_offset: CurrOffset<'_>,
ty: &AlgebraicType,
ty: &ArrayType,
) -> Result<S::Ok, S::Error> {
// SAFETY: `value` was valid at and aligned for `ty`.
// These `ty` store a `vlr: VarLenRef` as their fixed value.
Expand Down
4 changes: 3 additions & 1 deletion crates/table/src/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ unsafe fn hash_value(
}
}
AlgebraicTypeLayout::VarLen(VarLenType::Array(ty)) => {
let ty = &ty.elem_ty;

// SAFETY: `value` was valid at and aligned for `ty`.
// These `ty` store a `vlr: VarLenRef` as their value,
// so the range is valid and properly aligned for `VarLenRef`.
Expand All @@ -168,7 +170,7 @@ unsafe fn hash_value(
unsafe {
run_vlo_bytes(page, bytes, blob_store, curr_offset, |mut bsatn| {
let de = Deserializer::new(&mut bsatn);
spacetimedb_sats::hash_bsatn(hasher, ty, de).unwrap();
spacetimedb_sats::hash_bsatn_array(hasher, ty, de).unwrap();
});
}
}
Expand Down
Loading