Skip to content

Commit fd57889

Browse files
Rust: merge structurally equal types in bindgen (#1468)
* merge equal types * fix * add test * Refactor type equality algorithm structure * Handle `alias == primitive` by juggling typedefs/`Type::Id` a bit more carefully. * Avoid using `_ => false` * Fix handling of `own`/`borrow` and `future`/`stream` to be more uniform like the rest of the algorithm. * Avoid special-casing resources, but for now consider them always not-equal. * Adjust CLI flag handling * Refactor how structurally-equal types are generated Leverage the recent support for nominal type IDs to remove some now-no-longer-necessary infrastructure. This additionally overrides the `define_type` method in the Rust generator to handle aliases at one location instead of in multiple locations. --------- Co-authored-by: Alex Crichton <alex@alexcrichton.com>
1 parent 7bebfd6 commit fd57889

File tree

9 files changed

+493
-24
lines changed

9 files changed

+493
-24
lines changed

crates/core/src/lib.rs

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -176,25 +176,32 @@ pub trait InterfaceGenerator<'a> {
176176
}
177177

178178
fn define_type(&mut self, name: &str, id: TypeId) {
179-
let ty = &self.resolve().types[id];
180-
match &ty.kind {
181-
TypeDefKind::Record(record) => self.type_record(id, name, record, &ty.docs),
182-
TypeDefKind::Resource => self.type_resource(id, name, &ty.docs),
183-
TypeDefKind::Flags(flags) => self.type_flags(id, name, flags, &ty.docs),
184-
TypeDefKind::Tuple(tuple) => self.type_tuple(id, name, tuple, &ty.docs),
185-
TypeDefKind::Enum(enum_) => self.type_enum(id, name, enum_, &ty.docs),
186-
TypeDefKind::Variant(variant) => self.type_variant(id, name, variant, &ty.docs),
187-
TypeDefKind::Option(t) => self.type_option(id, name, t, &ty.docs),
188-
TypeDefKind::Result(r) => self.type_result(id, name, r, &ty.docs),
189-
TypeDefKind::List(t) => self.type_list(id, name, t, &ty.docs),
190-
TypeDefKind::Type(t) => self.type_alias(id, name, t, &ty.docs),
191-
TypeDefKind::Future(t) => self.type_future(id, name, t, &ty.docs),
192-
TypeDefKind::Stream(t) => self.type_stream(id, name, t, &ty.docs),
193-
TypeDefKind::Handle(_) => panic!("handle types do not require definition"),
194-
TypeDefKind::FixedLengthList(..) => todo!(),
195-
TypeDefKind::Map(..) => todo!(),
196-
TypeDefKind::Unknown => unreachable!(),
197-
}
179+
define_type(self, name, id)
180+
}
181+
}
182+
183+
pub fn define_type<'a, T>(generator: &mut T, name: &str, id: TypeId)
184+
where
185+
T: InterfaceGenerator<'a> + ?Sized,
186+
{
187+
let ty = &generator.resolve().types[id];
188+
match &ty.kind {
189+
TypeDefKind::Record(record) => generator.type_record(id, name, record, &ty.docs),
190+
TypeDefKind::Resource => generator.type_resource(id, name, &ty.docs),
191+
TypeDefKind::Flags(flags) => generator.type_flags(id, name, flags, &ty.docs),
192+
TypeDefKind::Tuple(tuple) => generator.type_tuple(id, name, tuple, &ty.docs),
193+
TypeDefKind::Enum(enum_) => generator.type_enum(id, name, enum_, &ty.docs),
194+
TypeDefKind::Variant(variant) => generator.type_variant(id, name, variant, &ty.docs),
195+
TypeDefKind::Option(t) => generator.type_option(id, name, t, &ty.docs),
196+
TypeDefKind::Result(r) => generator.type_result(id, name, r, &ty.docs),
197+
TypeDefKind::List(t) => generator.type_list(id, name, t, &ty.docs),
198+
TypeDefKind::Type(t) => generator.type_alias(id, name, t, &ty.docs),
199+
TypeDefKind::Future(t) => generator.type_future(id, name, t, &ty.docs),
200+
TypeDefKind::Stream(t) => generator.type_stream(id, name, t, &ty.docs),
201+
TypeDefKind::Handle(_) => panic!("handle types do not require definition"),
202+
TypeDefKind::FixedLengthList(..) => todo!(),
203+
TypeDefKind::Map(..) => todo!(),
204+
TypeDefKind::Unknown => unreachable!(),
198205
}
199206
}
200207

crates/core/src/types.rs

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use wit_parser::*;
55
#[derive(Default)]
66
pub struct Types {
77
type_info: HashMap<TypeId, TypeInfo>,
8+
equal_types: UnionFind,
89
}
910

1011
#[derive(Default, Clone, Copy, Debug)]
@@ -93,6 +94,22 @@ impl Types {
9394
}
9495
}
9596
}
97+
pub fn collect_equal_types(&mut self, resolve: &Resolve) {
98+
for (i, (ty, _)) in resolve.types.iter().enumerate() {
99+
// TODO: we could define a hash function for TypeDefKind to prevent the inner loop.
100+
for (earlier, _) in resolve.types.iter().take(i) {
101+
if self.equal_types.find(ty) == self.equal_types.find(earlier) {
102+
continue;
103+
}
104+
// The correctness of is_structurally_equal relies on the fact
105+
// that resolve.types.iter() is in topological order.
106+
if self.is_structurally_equal(resolve, ty, earlier) {
107+
self.equal_types.union(ty, earlier);
108+
break;
109+
}
110+
}
111+
}
112+
}
96113

97114
fn type_info_func(&mut self, resolve: &Resolve, func: &Function, import: bool) {
98115
let mut live = LiveTypes::default();
@@ -233,4 +250,185 @@ impl Types {
233250
None => TypeInfo::default(),
234251
}
235252
}
253+
254+
fn is_structurally_equal(&mut self, resolve: &Resolve, a: TypeId, b: TypeId) -> bool {
255+
let a_def = &resolve.types[a].kind;
256+
let b_def = &resolve.types[b].kind;
257+
if self.equal_types.find(a) == self.equal_types.find(b) {
258+
return true;
259+
}
260+
match (a_def, b_def) {
261+
// Peel off typedef layers and continue recursing.
262+
(TypeDefKind::Type(a), _) => self.type_id_equal_to_type(resolve, b, a),
263+
(_, TypeDefKind::Type(b)) => self.type_id_equal_to_type(resolve, a, b),
264+
265+
(TypeDefKind::Record(ra), TypeDefKind::Record(rb)) => {
266+
ra.fields.len() == rb.fields.len()
267+
// Fields are ordered in WIT, so record {a: T, b: U} is different from {b: U, a: T}
268+
&& ra.fields.iter().zip(rb.fields.iter()).all(|(fa, fb)| {
269+
fa.name == fb.name && self.types_equal(resolve, &fa.ty, &fb.ty)
270+
})
271+
}
272+
(TypeDefKind::Record(_), _) => false,
273+
(TypeDefKind::Variant(va), TypeDefKind::Variant(vb)) => {
274+
va.cases.len() == vb.cases.len()
275+
&& va.cases.iter().zip(vb.cases.iter()).all(|(ca, cb)| {
276+
ca.name == cb.name && self.optional_types_equal(resolve, &ca.ty, &cb.ty)
277+
})
278+
}
279+
(TypeDefKind::Variant(_), _) => false,
280+
(TypeDefKind::Enum(ea), TypeDefKind::Enum(eb)) => {
281+
ea.cases.len() == eb.cases.len()
282+
&& ea
283+
.cases
284+
.iter()
285+
.zip(eb.cases.iter())
286+
.all(|(ca, cb)| ca.name == cb.name)
287+
}
288+
(TypeDefKind::Enum(_), _) => false,
289+
(TypeDefKind::Flags(fa), TypeDefKind::Flags(fb)) => {
290+
fa.flags.len() == fb.flags.len()
291+
&& fa
292+
.flags
293+
.iter()
294+
.zip(fb.flags.iter())
295+
.all(|(fa, fb)| fa.name == fb.name)
296+
}
297+
(TypeDefKind::Flags(_), _) => false,
298+
(TypeDefKind::Tuple(ta), TypeDefKind::Tuple(tb)) => {
299+
ta.types.len() == tb.types.len()
300+
&& ta
301+
.types
302+
.iter()
303+
.zip(tb.types.iter())
304+
.all(|(a, b)| self.types_equal(resolve, a, b))
305+
}
306+
(TypeDefKind::Tuple(_), _) => false,
307+
(TypeDefKind::List(la), TypeDefKind::List(lb)) => self.types_equal(resolve, la, lb),
308+
(TypeDefKind::List(_), _) => false,
309+
(TypeDefKind::FixedLengthList(ta, sa), TypeDefKind::FixedLengthList(tb, sb)) => {
310+
sa == sb && self.types_equal(resolve, ta, tb)
311+
}
312+
(TypeDefKind::FixedLengthList(..), _) => false,
313+
(TypeDefKind::Option(oa), TypeDefKind::Option(ob)) => self.types_equal(resolve, oa, ob),
314+
(TypeDefKind::Option(_), _) => false,
315+
(TypeDefKind::Result(ra), TypeDefKind::Result(rb)) => {
316+
self.optional_types_equal(resolve, &ra.ok, &rb.ok)
317+
&& self.optional_types_equal(resolve, &ra.err, &rb.err)
318+
}
319+
(TypeDefKind::Result(_), _) => false,
320+
(TypeDefKind::Map(ak, av), TypeDefKind::Map(bk, bv)) => {
321+
self.types_equal(resolve, ak, bk) && self.types_equal(resolve, av, bv)
322+
}
323+
(TypeDefKind::Map(..), _) => false,
324+
(TypeDefKind::Future(a), TypeDefKind::Future(b)) => {
325+
self.optional_types_equal(resolve, a, b)
326+
}
327+
(TypeDefKind::Future(..), _) => false,
328+
(TypeDefKind::Stream(a), TypeDefKind::Stream(b)) => {
329+
self.optional_types_equal(resolve, a, b)
330+
}
331+
(TypeDefKind::Stream(..), _) => false,
332+
(TypeDefKind::Handle(a), TypeDefKind::Handle(b)) => match (a, b) {
333+
(Handle::Own(a), Handle::Own(b)) | (Handle::Borrow(a), Handle::Borrow(b)) => {
334+
self.is_structurally_equal(resolve, *a, *b)
335+
}
336+
(Handle::Own(_) | Handle::Borrow(_), _) => false,
337+
},
338+
(TypeDefKind::Handle(_), _) => false,
339+
(TypeDefKind::Unknown, _) => unreachable!(),
340+
341+
// TODO: for now consider all resources not-equal to each other.
342+
// This is because the same type id can be used for both an imported
343+
// and exported resource where those should be distinct types.
344+
(TypeDefKind::Resource, _) => false,
345+
}
346+
}
347+
348+
fn types_equal(&mut self, resolve: &Resolve, a: &Type, b: &Type) -> bool {
349+
match (a, b) {
350+
// Peel off typedef layers and continue recursing.
351+
(Type::Id(a), b) => self.type_id_equal_to_type(resolve, *a, b),
352+
(a, Type::Id(b)) => self.type_id_equal_to_type(resolve, *b, a),
353+
354+
// When both a and b are primitives, they're only equal of
355+
// the primitives are the same.
356+
(
357+
Type::Bool
358+
| Type::U8
359+
| Type::S8
360+
| Type::U16
361+
| Type::S16
362+
| Type::U32
363+
| Type::S32
364+
| Type::U64
365+
| Type::S64
366+
| Type::F32
367+
| Type::F64
368+
| Type::Char
369+
| Type::String
370+
| Type::ErrorContext,
371+
_,
372+
) => a == b,
373+
}
374+
}
375+
376+
fn type_id_equal_to_type(&mut self, resolve: &Resolve, a: TypeId, b: &Type) -> bool {
377+
let ak = &resolve.types[a].kind;
378+
match (ak, b) {
379+
(TypeDefKind::Type(a), b) => self.types_equal(resolve, a, b),
380+
(_, Type::Id(b)) => self.is_structurally_equal(resolve, a, *b),
381+
382+
// Type `a` isn't a typedef, and type `b` is a primitive, so it's no
383+
// longer possible for them to be equal.
384+
_ => false,
385+
}
386+
}
387+
388+
fn optional_types_equal(
389+
&mut self,
390+
resolve: &Resolve,
391+
a: &Option<Type>,
392+
b: &Option<Type>,
393+
) -> bool {
394+
match (a, b) {
395+
(Some(a), Some(b)) => self.types_equal(resolve, a, b),
396+
(Some(_), None) | (None, Some(_)) => false,
397+
(None, None) => true,
398+
}
399+
}
400+
401+
pub fn get_representative_type(&mut self, id: TypeId) -> TypeId {
402+
self.equal_types.find(id)
403+
}
404+
}
405+
406+
#[derive(Default)]
407+
pub struct UnionFind {
408+
parent: HashMap<TypeId, TypeId>,
409+
}
410+
impl UnionFind {
411+
fn find(&mut self, id: TypeId) -> TypeId {
412+
// Path compression
413+
let parent = self.parent.get(&id).copied().unwrap_or(id);
414+
if parent != id {
415+
let root = self.find(parent);
416+
self.parent.insert(id, root);
417+
root
418+
} else {
419+
id
420+
}
421+
}
422+
fn union(&mut self, a: TypeId, b: TypeId) {
423+
let ra = self.find(a);
424+
let rb = self.find(b);
425+
if ra != rb {
426+
// Use smaller id as root for determinism
427+
if ra < rb {
428+
self.parent.insert(rb, ra);
429+
} else {
430+
self.parent.insert(ra, rb);
431+
}
432+
}
433+
}
236434
}

crates/rust/src/interface.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,6 +2577,16 @@ impl<'a> wit_bindgen_core::InterfaceGenerator<'a> for InterfaceGenerator<'a> {
25772577
self.print_typedef_record(id, record, docs);
25782578
}
25792579

2580+
fn define_type(&mut self, name: &str, id: TypeId) {
2581+
let equal = self.r#gen.types.get_representative_type(id);
2582+
if equal == id {
2583+
wit_bindgen_core::define_type(self, name, id)
2584+
} else {
2585+
let docs = &self.resolve.types[id].docs;
2586+
self.print_typedef_alias(id, &Type::Id(equal), &docs);
2587+
}
2588+
}
2589+
25802590
fn type_resource(&mut self, _id: TypeId, name: &str, docs: &Docs) {
25812591
self.rustdoc(docs);
25822592
let camel = to_upper_camel_case(name);

crates/rust/src/lib.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,18 @@ pub struct Opts {
274274
#[cfg_attr(feature = "clap", clap(flatten))]
275275
#[cfg_attr(feature = "serde", serde(flatten))]
276276
pub async_: AsyncFilterSet,
277+
278+
/// Find all structurally equal types and only generate one type definition for
279+
/// each equivalence class. Other types in the same class will be type aliases to the
280+
/// generated type. This avoids clone when converting between types that are
281+
/// structurally equal, which is useful when import and export the same interface.
282+
///
283+
/// Types containing resource, future, or stream are never considered equal.
284+
#[cfg_attr(
285+
feature = "clap",
286+
arg(long, require_equals = true, value_name = "true|false")
287+
)]
288+
pub merge_structurally_equal_types: Option<Option<bool>>,
277289
}
278290

279291
impl Opts {
@@ -283,6 +295,18 @@ impl Opts {
283295
r.opts = self;
284296
r
285297
}
298+
299+
fn merge_structurally_equal_types(&self) -> bool {
300+
const DEFAULT: bool = false;
301+
match self.merge_structurally_equal_types {
302+
// no option passed, use the default
303+
None => DEFAULT,
304+
// --merge-structurally-equal-types
305+
Some(None) => true,
306+
// --merge-structurally-equal-types=val
307+
Some(Some(val)) => val,
308+
}
309+
}
286310
}
287311

288312
impl RustWasm {
@@ -1082,6 +1106,9 @@ impl WorldGenerator for RustWasm {
10821106
"// * disable-run-ctors-once-workaround"
10831107
);
10841108
}
1109+
if self.opts.merge_structurally_equal_types() {
1110+
uwriteln!(self.src_preamble, "// * merge_structurally_equal_types");
1111+
}
10851112
if let Some(s) = &self.opts.export_macro_name {
10861113
uwriteln!(self.src_preamble, "// * export-macro-name: {s}");
10871114
}
@@ -1101,6 +1128,9 @@ impl WorldGenerator for RustWasm {
11011128
uwriteln!(self.src_preamble, "// * async: {opt}");
11021129
}
11031130
self.types.analyze(resolve);
1131+
if self.opts.merge_structurally_equal_types() {
1132+
self.types.collect_equal_types(resolve);
1133+
}
11041134
self.world = Some(world);
11051135

11061136
let world = &resolve.worlds[world];
@@ -1206,9 +1236,9 @@ impl WorldGenerator for RustWasm {
12061236
_files: &mut Files,
12071237
) -> Result<()> {
12081238
let mut to_define = Vec::new();
1209-
for (name, ty_id) in resolve.interfaces[id].types.iter() {
1239+
for (ty_name, ty_id) in resolve.interfaces[id].types.iter() {
12101240
let full_name = full_wit_type_name(resolve, *ty_id);
1211-
to_define.push((name, ty_id));
1241+
to_define.push((ty_name, ty_id));
12121242
self.generated_types.insert(full_name);
12131243
}
12141244

@@ -1224,8 +1254,8 @@ impl WorldGenerator for RustWasm {
12241254
return Ok(());
12251255
}
12261256

1227-
for (name, ty_id) in to_define {
1228-
r#gen.define_type(&name, *ty_id);
1257+
for (ty_name, ty_id) in to_define {
1258+
r#gen.define_type(&ty_name, *ty_id);
12291259
}
12301260

12311261
let macro_name =
@@ -1427,7 +1457,11 @@ impl WorldGenerator for RustWasm {
14271457
}
14281458
}
14291459

1430-
fn compute_module_path(name: &WorldKey, resolve: &Resolve, is_export: bool) -> Vec<String> {
1460+
pub(crate) fn compute_module_path(
1461+
name: &WorldKey,
1462+
resolve: &Resolve,
1463+
is_export: bool,
1464+
) -> Vec<String> {
14311465
let mut path = Vec::new();
14321466
if is_export {
14331467
path.push("exports".to_string());
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package example:composition;
2+
3+
let host = new test:host { ... };
4+
let proxy = new test:proxy { ...host, ... };
5+
let runner = new test:runner { ...proxy, ... };
6+
export runner...;

0 commit comments

Comments
 (0)