Skip to content

Commit 4dcf5f8

Browse files
authored
SATS: make field_names/variant_names return iterator + add FieldNameVisitor::visit_seq (#2927)
1 parent 5770386 commit 4dcf5f8

7 files changed

Lines changed: 93 additions & 128 deletions

File tree

crates/bindings-macro/src/sats.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
347347
de_generics.params.insert(0, de_lt_param.into());
348348
let (de_impl_generics, _, de_where_clause) = de_generics.split_for_impl();
349349

350-
let (iter_n, iter_n2, iter_n3) = (0usize.., 0usize.., 0usize..);
350+
let (iter_n, iter_n2, iter_n3, iter_n4) = (0usize.., 0usize.., 0usize.., 0usize..);
351351

352352
match &ty.data {
353353
SatsTypeData::Product(fields) => {
@@ -443,8 +443,8 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
443443
impl #de_impl_generics #spacetimedb_lib::de::FieldNameVisitor<'de> for __ProductVisitor #ty_generics #de_where_clause {
444444
type Output = __ProductFieldIdent;
445445

446-
fn field_names(&self, names: &mut dyn #spacetimedb_lib::de::ValidNames) {
447-
names.extend::<&[&str]>(&[#(#field_strings),*])
446+
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
447+
[#(#field_strings),*].into_iter().map(Some)
448448
}
449449

450450
fn visit<__E: #spacetimedb_lib::de::Error>(self, name: &str) -> Result<Self::Output, __E> {
@@ -453,6 +453,13 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
453453
_ => Err(#spacetimedb_lib::de::Error::unknown_field_name(name, &self)),
454454
}
455455
}
456+
457+
fn visit_seq<__E: #spacetimedb_lib::de::Error>(self, index: usize, name: &str) -> Result<Self::Output, __E> {
458+
match index {
459+
#(#iter_n4 => Ok(__ProductFieldIdent::#field_names),)*
460+
_ => Err(#spacetimedb_lib::de::Error::unknown_field_name(name, &self)),
461+
}
462+
}
456463
}
457464

458465
#[allow(non_camel_case_types)]
@@ -516,11 +523,11 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
516523
#(#variant_idents,)*
517524
}
518525

519-
impl #de_impl_generics #spacetimedb_lib::de::VariantVisitor for __SumVisitor #ty_generics #de_where_clause {
526+
impl #de_impl_generics #spacetimedb_lib::de::VariantVisitor<'de> for __SumVisitor #ty_generics #de_where_clause {
520527
type Output = __Variant;
521528

522-
fn variant_names(&self, names: &mut dyn #spacetimedb_lib::de::ValidNames) {
523-
names.extend::<&[&str]>(&[#(#variant_names,)*])
529+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
530+
[#(#variant_names,)*].into_iter()
524531
}
525532

526533
fn visit_tag<E: #spacetimedb_lib::de::Error>(self, __tag: u8) -> Result<Self::Output, E> {

crates/sats/src/algebraic_value/de.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ impl SumAccess {
173173
}
174174
}
175175

176-
impl de::SumAccess<'_> for SumAccess {
176+
impl<'de> de::SumAccess<'de> for SumAccess {
177177
type Error = ValueDeserializeError;
178178

179179
type Variant = ValueDeserializer;
180180

181-
fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
181+
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
182182
let tag = visitor.visit_tag(self.sum.tag)?;
183183
let val = *self.sum.value;
184184
Ok((tag, ValueDeserializer { val }))
@@ -313,7 +313,7 @@ impl<'de> de::SumAccess<'de> for &'de SumAccess {
313313

314314
type Variant = &'de ValueDeserializer;
315315

316-
fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
316+
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
317317
let tag = visitor.visit_tag(self.sum.tag)?;
318318
Ok((tag, ValueDeserializer::from_ref(&self.sum.value)))
319319
}

crates/sats/src/bsatn/de.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ impl<'de, R: BufReader<'de>> SumAccess<'de> for Deserializer<'_, R> {
146146
type Error = DecodeError;
147147
type Variant = Self;
148148

149-
fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
149+
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
150150
let tag = self.reader.get_u8()?;
151151
visitor.visit_tag(tag).map(|variant| (variant, self))
152152
}

crates/sats/src/de.rs

Lines changed: 45 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ pub trait Error: Sized {
184184
ProductKind::Normal => "field",
185185
ProductKind::ReducerArgs => "reducer argument",
186186
};
187-
if let Some(one_of) = one_of_names(|n| expected.field_names(n)) {
187+
if let Some(one_of) = one_of_names(|| expected.field_names()) {
188188
Self::custom(format_args!("unknown {el_ty} `{field_name}`, expected {one_of}"))
189189
} else {
190190
Self::custom(format_args!("unknown {el_ty} `{field_name}`, there are no {el_ty}s"))
@@ -200,8 +200,8 @@ pub trait Error: Sized {
200200
}
201201

202202
/// The `name` is not that of a variant of the sum type.
203-
fn unknown_variant_name<T: VariantVisitor>(name: &str, expected: &T) -> Self {
204-
if let Some(one_of) = one_of_names(|n| expected.variant_names(n)) {
203+
fn unknown_variant_name<'de, T: VariantVisitor<'de>>(name: &str, expected: &T) -> Self {
204+
if let Some(one_of) = one_of_names(|| expected.variant_names().map(Some)) {
205205
Self::custom(format_args!("unknown variant `{name}`, expected {one_of}",))
206206
} else {
207207
Self::custom(format_args!("unknown variant `{name}`, there are no variants"))
@@ -358,39 +358,18 @@ pub trait FieldNameVisitor<'de> {
358358
ProductKind::Normal
359359
}
360360

361-
/// Provides the visitor the chance to add valid names into `names`.
362-
fn field_names(&self, names: &mut dyn ValidNames);
361+
/// Provides a list of valid field names.
362+
///
363+
/// Where `None` is yielded, this indicates a nameless field.
364+
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>>;
363365

366+
/// Deserializes the name of a field using `name`.
364367
fn visit<E: Error>(self, name: &str) -> Result<Self::Output, E>;
365-
}
366368

367-
/// A trait for types storing a set of valid names.
368-
pub trait ValidNames {
369-
/// Adds the name `s` to the set.
370-
fn push(&mut self, s: &str);
371-
372-
/// Runs the function `names` provided with `self` as the store
373-
/// and then returns back `self`.
374-
/// This method exists for convenience.
375-
fn run(mut self, names: &impl Fn(&mut dyn ValidNames)) -> Self
376-
where
377-
Self: Sized,
378-
{
379-
names(&mut self);
380-
self
381-
}
382-
}
383-
384-
impl dyn ValidNames + '_ {
385-
/// Adds the names in `iter` to the set.
386-
pub fn extend<I: IntoIterator>(&mut self, iter: I)
387-
where
388-
I::Item: AsRef<str>,
389-
{
390-
for name in iter {
391-
self.push(name.as_ref())
392-
}
393-
}
369+
/// Deserializes the name of a field using `index`.
370+
///
371+
/// The `name` is provided for error messages.
372+
fn visit_seq<E: Error>(self, index: usize, name: &str) -> Result<Self::Output, E>;
394373
}
395374

396375
/// A visitor walking through a [`Deserializer`] for sums.
@@ -442,17 +421,17 @@ pub trait SumAccess<'de> {
442421
/// The `visitor` is provided by the [`Deserializer`].
443422
/// This method is typically called from [`SumVisitor::visit_sum`]
444423
/// which will provide the [`V: VariantVisitor`](VariantVisitor).
445-
fn variant<V: VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error>;
424+
fn variant<V: VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error>;
446425
}
447426

448427
/// A visitor passed from [`SumVisitor`] to [`SumAccess::variant`]
449428
/// which the latter uses to decide what variant to deserialize.
450-
pub trait VariantVisitor {
429+
pub trait VariantVisitor<'de> {
451430
/// The result of identifying a variant, e.g., some index type.
452431
type Output;
453432

454-
/// Provides the visitor the chance to add valid names into `names`.
455-
fn variant_names(&self, names: &mut dyn ValidNames);
433+
/// Provides a list of variant names.
434+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str>;
456435

457436
/// Identify the variant based on `tag`.
458437
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E>;
@@ -669,71 +648,42 @@ impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicArrayVisitor<N> {
669648
}
670649
}
671650

672-
/// Provided a function `names` that is allowed to store a name into a valid set,
651+
/// Provided a list of names,
673652
/// returns a human readable list of all the names,
674653
/// or `None` in the case of an empty list of names.
675-
fn one_of_names(names: impl Fn(&mut dyn ValidNames)) -> Option<impl fmt::Display> {
676-
/// An implementation of `ValidNames` that just counts how many valid names are pushed into it.
677-
struct CountNames(usize);
678-
679-
impl ValidNames for CountNames {
680-
fn push(&mut self, _: &str) {
681-
self.0 += 1
682-
}
683-
}
654+
fn one_of_names<'a, I: Iterator<Item = Option<&'a str>>>(names: impl Fn() -> I) -> Option<impl fmt::Display> {
655+
// Count how many names there are.
656+
let count = names().count();
684657

685-
/// An implementation of `ValidNames` that provides a human friendly enumeration of names.
686-
struct OneOfNames<'a, 'b> {
687-
/// A `.push(_)` counter.
688-
index: usize,
689-
/// How many names there were.
690-
count: usize,
691-
/// Result of formatting thus far.
692-
f: Result<&'a mut fmt::Formatter<'b>, fmt::Error>,
693-
}
694-
695-
impl<'a, 'b> OneOfNames<'a, 'b> {
696-
fn new(count: usize, f: &'a mut fmt::Formatter<'b>) -> Self {
697-
Self {
698-
index: 0,
699-
count,
700-
f: Ok(f),
701-
}
702-
}
703-
}
704-
705-
impl ValidNames for OneOfNames<'_, '_> {
706-
fn push(&mut self, name: &str) {
707-
// This will give us, after all `.push()`es have been made, the following:
658+
// There was at least one name; render those names.
659+
(count != 0).then(move || {
660+
fmt_fn(move |f| {
661+
let mut anon_name = 0;
662+
// An example of what happens for names "foo", "bar", and "baz":
708663
//
709664
// count = 1 -> "`foo`"
710665
// = 2 -> "`foo` or `bar`"
711666
// > 2 -> "one of `foo`, `bar`, or `baz`"
712-
713-
let Ok(f) = &mut self.f else {
714-
return;
715-
};
716-
717-
self.index += 1;
718-
719-
if let Err(e) = match (self.count, self.index) {
720-
(1, _) => write!(f, "`{name}`"),
721-
(2, 1) => write!(f, "`{name}`"),
722-
(2, 2) => write!(f, "`or `{name}`"),
723-
(_, 1) => write!(f, "one of `{name}`"),
724-
(c, i) if i < c => write!(f, ", `{name}`"),
725-
(_, _) => write!(f, ", `, or {name}`"),
726-
} {
727-
self.f = Err(e);
667+
for (index, mut name) in names().enumerate() {
668+
let mut name_buf: String = String::new();
669+
let name = name.get_or_insert_with(|| {
670+
name_buf = format!("{anon_name}");
671+
anon_name += 1;
672+
&name_buf
673+
});
674+
match (count, index) {
675+
(1, _) => write!(f, "`{name}`"),
676+
(2, 1) => write!(f, "`{name}`"),
677+
(2, 2) => write!(f, "`or `{name}`"),
678+
(_, 1) => write!(f, "one of `{name}`"),
679+
(c, i) if i < c => write!(f, ", `{name}`"),
680+
(_, _) => write!(f, ", `, or {name}`"),
681+
}?;
728682
}
729-
}
730-
}
731683

732-
// Count how many names have been pushed.
733-
let count = CountNames(0).run(&names).0;
734-
735-
// There was at least one name; render those names.
736-
(count != 0).then(|| fmt_fn(move |fmt| OneOfNames::new(count, fmt).run(&names).f.map(drop)))
684+
Ok(())
685+
})
686+
})
737687
}
738688

739689
/// Deserializes `none` variant of an optional value.
@@ -752,11 +702,11 @@ impl<E: Error> Default for NoneAccess<E> {
752702
}
753703
}
754704

755-
impl<E: Error> SumAccess<'_> for NoneAccess<E> {
705+
impl<'de, E: Error> SumAccess<'de> for NoneAccess<E> {
756706
type Error = E;
757707
type Variant = Self;
758708

759-
fn variant<V: VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
709+
fn variant<V: VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
760710
visitor.visit_name("none").map(|var| (var, self))
761711
}
762712
}

crates/sats/src/de/impls.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,11 @@ impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor<T> {
211211
}
212212
}
213213

214-
impl<'de, T: Deserialize<'de>> VariantVisitor for OptionVisitor<T> {
214+
impl<'de, T: Deserialize<'de>> VariantVisitor<'de> for OptionVisitor<T> {
215215
type Output = bool;
216216

217-
fn variant_names(&self, names: &mut dyn super::ValidNames) {
218-
names.extend(["some", "none"])
217+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
218+
["some", "none"].into_iter()
219219
}
220220

221221
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
@@ -268,11 +268,11 @@ impl<'de, T: Deserialize<'de>, E: Deserialize<'de>> SumVisitor<'de> for ResultVi
268268
}
269269
}
270270

271-
impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor for ResultVisitor<T, U> {
271+
impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor<'de> for ResultVisitor<T, U> {
272272
type Output = ResultVariant;
273273

274-
fn variant_names(&self, names: &mut dyn super::ValidNames) {
275-
names.extend(["ok", "err"])
274+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
275+
["ok", "err"].into_iter()
276276
}
277277

278278
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
@@ -335,11 +335,11 @@ impl<'de, S: Copy + DeserializeSeed<'de>> SumVisitor<'de> for BoundVisitor<S> {
335335
}
336336
}
337337

338-
impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor for BoundVisitor<T> {
338+
impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor<'de> for BoundVisitor<T> {
339339
type Output = BoundVariant;
340340

341-
fn variant_names(&self, names: &mut dyn super::ValidNames) {
342-
names.extend(["included", "excluded", "unbounded"])
341+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
342+
["included", "excluded", "unbounded"].into_iter()
343343
}
344344

345345
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
@@ -420,12 +420,12 @@ impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> {
420420
}
421421
}
422422

423-
impl VariantVisitor for WithTypespace<'_, SumType> {
423+
impl VariantVisitor<'_> for WithTypespace<'_, SumType> {
424424
type Output = u8;
425425

426-
fn variant_names(&self, names: &mut dyn super::ValidNames) {
426+
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
427427
// Provide the names known from the `SumType`.
428-
names.extend(self.ty().variants.iter().filter_map(|v| v.name()))
428+
self.ty().variants.iter().filter_map(|v| v.name())
429429
}
430430

431431
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
@@ -643,8 +643,8 @@ impl FieldNameVisitor<'_> for TupleNameVisitor<'_> {
643643
// The index of the field name.
644644
type Output = usize;
645645

646-
fn field_names(&self, names: &mut dyn super::ValidNames) {
647-
names.extend(self.elems.iter().filter_map(|f| f.name()))
646+
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
647+
self.elems.iter().map(|f| f.name())
648648
}
649649

650650
fn kind(&self) -> ProductKind {
@@ -658,6 +658,14 @@ impl FieldNameVisitor<'_> for TupleNameVisitor<'_> {
658658
.position(|f| f.has_name(name))
659659
.ok_or_else(|| Error::unknown_field_name(name, &self))
660660
}
661+
662+
fn visit_seq<E: Error>(self, index: usize, name: &str) -> Result<Self::Output, E> {
663+
self.elems
664+
.get(index)
665+
.ok_or_else(|| Error::unknown_field_name(name, &self))?;
666+
667+
Ok(index)
668+
}
661669
}
662670

663671
impl_deserialize!([] spacetimedb_primitives::TableId, de => u32::deserialize(de).map(Self));

0 commit comments

Comments
 (0)