Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions crates/transport/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,70 @@ macro_rules! impl_copy_codec {
};
}

// The Component Model canonical ABI mandates a single canonical `NaN`
// representation for floating point values. Encoding canonicalizes `NaN`s to
// match; decoding is lenient and accepts any `NaN` representation.
//
// See `canonicalize_nan{32,64}` in
// <https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py>.
const CANONICAL_NAN_F32: u32 = 0x7fc0_0000;
const CANONICAL_NAN_F64: u64 = 0x7ff8_0000_0000_0000;

/// Defines a floating-point codec that canonicalizes `NaN` values on encode to
/// match the Component Model canonical ABI, delegating the actual byte encoding
/// and decoding to the wrapped `wasm-tokio` codec.
macro_rules! impl_canonical_nan_codec {
($name:ident, $inner:ty, $t:ty, $canon:expr) => {
#[doc = concat!("Canonicalizes `NaN`s on encode, wrapping [`", stringify!($inner), "`].")]
#[derive(Debug, Default)]
pub struct $name($inner);

impl tokio_util::codec::Encoder<$t> for $name {
type Error = std::io::Error;

fn encode(&mut self, item: $t, dst: &mut BytesMut) -> Result<(), Self::Error> {
let item = if item.is_nan() {
<$t>::from_bits($canon)
} else {
item
};
self.0.encode(item, dst)
}
}

impl tokio_util::codec::Encoder<&$t> for $name {
type Error = std::io::Error;

fn encode(&mut self, item: &$t, dst: &mut BytesMut) -> Result<(), Self::Error> {
tokio_util::codec::Encoder::<$t>::encode(self, *item, dst)
}
}

impl tokio_util::codec::Encoder<&&$t> for $name {
type Error = std::io::Error;

fn encode(&mut self, item: &&$t, dst: &mut BytesMut) -> Result<(), Self::Error> {
tokio_util::codec::Encoder::<$t>::encode(self, **item, dst)
}
}

impl tokio_util::codec::Decoder for $name {
type Item = $t;
type Error = std::io::Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<$t>, Self::Error> {
self.0.decode(src)
}
}

impl_deferred_sync!($name);
impl_deferred_sync!(CoreVecDecoder<$name>);
};
}

impl_canonical_nan_codec!(CanonicalNanF32Codec, F32Codec, f32, CANONICAL_NAN_F32);
impl_canonical_nan_codec!(CanonicalNanF64Codec, F64Codec, f64, CANONICAL_NAN_F64);

impl_copy_codec!(bool, BoolCodec);
impl_copy_codec!(i8, S8Codec);
impl_copy_codec!(i16, S16Codec);
Expand All @@ -939,8 +1003,8 @@ impl_copy_codec!(i32, S32Codec);
impl_copy_codec!(u32, U32Codec);
impl_copy_codec!(i64, S64Codec);
impl_copy_codec!(u64, U64Codec);
impl_copy_codec!(f32, F32Codec);
impl_copy_codec!(f64, F64Codec);
impl_copy_codec!(f32, CanonicalNanF32Codec);
impl_copy_codec!(f64, CanonicalNanF64Codec);
impl_copy_codec!(char, Utf8Codec);

impl<T> Encode<T> for u8 {
Expand Down Expand Up @@ -2295,4 +2359,43 @@ mod tests {
assert_eq!(buf.as_ref(), b"\x42\x42");
Ok(())
}

#[test]
fn canonical_nan_f32() {
let mut enc = <f32 as Encode<NoopStream>>::Encoder::default();

// A non-canonical (e.g. signalling) `NaN` is canonicalized on encode.
let mut buf = BytesMut::new();
enc.encode(f32::from_bits(0x7f80_0001), &mut buf).unwrap();
assert_eq!(buf.as_ref(), CANONICAL_NAN_F32.to_le_bytes());

// A negative `NaN` is canonicalized to the (positive) canonical `NaN`.
let mut buf = BytesMut::new();
enc.encode(f32::from_bits(0xffc0_0000), &mut buf).unwrap();
assert_eq!(buf.as_ref(), CANONICAL_NAN_F32.to_le_bytes());

// Non-`NaN` values are encoded unchanged.
let mut buf = BytesMut::new();
enc.encode(1.5_f32, &mut buf).unwrap();
assert_eq!(buf.as_ref(), 1.5_f32.to_bits().to_le_bytes());
}

#[test]
fn canonical_nan_f64() {
let mut enc = <f64 as Encode<NoopStream>>::Encoder::default();

let mut buf = BytesMut::new();
enc.encode(f64::from_bits(0x7ff0_0000_0000_0001), &mut buf)
.unwrap();
assert_eq!(buf.as_ref(), CANONICAL_NAN_F64.to_le_bytes());

let mut buf = BytesMut::new();
enc.encode(f64::from_bits(0xfff8_0000_0000_0000), &mut buf)
.unwrap();
assert_eq!(buf.as_ref(), CANONICAL_NAN_F64.to_le_bytes());

let mut buf = BytesMut::new();
enc.encode(1.5_f64, &mut buf).unwrap();
assert_eq!(buf.as_ref(), 1.5_f64.to_bits().to_le_bytes());
}
}
Loading