Skip to content

Commit eacc8f6

Browse files
committed
Extend capabilities of TypeFoldable_Generic
1 parent d56483a commit eacc8f6

1 file changed

Lines changed: 92 additions & 17 deletions

File tree

  • compiler/rustc_type_ir_macros/src

compiler/rustc_type_ir_macros/src/lib.rs

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,24 @@ decl_derive!(
1717
[GenericTypeVisitable] => customizable_type_visitable_derive
1818
);
1919

20-
struct LiftedTy {
20+
struct TransformedTy {
2121
ty: syn::Type,
2222
generic_parameter_bounds: Vec<syn::Ident>,
2323
}
2424

25+
enum TypeParameterPath {
26+
Interner,
27+
GenericParameter(syn::Ident),
28+
}
29+
30+
enum TypeParameterTransform {
31+
Continue,
32+
Stop,
33+
}
34+
35+
type TypeParameterVisitor =
36+
fn(TypeParameterPath, &mut syn::TypePath, &mut Vec<syn::Ident>) -> TypeParameterTransform;
37+
2538
fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
2639
let mut ignored = false;
2740
attrs.iter().for_each(|attr| {
@@ -91,6 +104,9 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke
91104

92105
s.add_where_predicate(parse_quote! { I: Interner });
93106
s.add_bounds(synstructure::AddBounds::Fields);
107+
let generic_parameters =
108+
s.ast().generics.type_params().map(|ty| ty.ident.clone()).collect::<Vec<_>>();
109+
let mut generic_parameter_bounds = vec![];
94110
s.bind_with(|_| synstructure::BindStyle::Move);
95111
let body_try_fold = s.each_variant(|vi| {
96112
let bindings = vi.bindings();
@@ -101,6 +117,12 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke
101117
if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
102118
bind.to_token_stream()
103119
} else {
120+
for param in
121+
type_foldable_generic_parameters(bind.ast().ty.clone(), &generic_parameters)
122+
{
123+
push_unique(&mut generic_parameter_bounds, param);
124+
}
125+
104126
quote! {
105127
::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
106128
}
@@ -129,6 +151,9 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke
129151
// to generate code for them.
130152
s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
131153
s.add_bounds(synstructure::AddBounds::Fields);
154+
for param in generic_parameter_bounds {
155+
s.add_where_predicate(parse_quote! { #param: ::rustc_type_ir::TypeFoldable<I> });
156+
}
132157
s.bound_impl(
133158
quote!(::rustc_type_ir::TypeFoldable<I>),
134159
quote! {
@@ -149,6 +174,21 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke
149174
)
150175
}
151176

177+
fn type_foldable_generic_parameters(
178+
ty: syn::Type,
179+
generic_parameters: &[syn::Ident],
180+
) -> Vec<syn::Ident> {
181+
transform_type_parameters(ty, generic_parameters, |path, _, generic_parameter_bounds| {
182+
if let TypeParameterPath::GenericParameter(param) = path {
183+
push_unique(generic_parameter_bounds, param);
184+
TypeParameterTransform::Stop
185+
} else {
186+
TypeParameterTransform::Continue
187+
}
188+
})
189+
.generic_parameter_bounds
190+
}
191+
152192
/// `Lift_Generic` is specialised for structs/enums parameterised by an interner
153193
/// `I: Interner`. It derives `Lift<J>` by rewriting interner associated types
154194
/// from `I::Assoc` to `J::Assoc`. The required associated type lift bounds are
@@ -251,40 +291,75 @@ fn is_type_phantom(ty: &syn::Type) -> bool {
251291
get_first_path_segment(ty).is_some_and(|segment| segment.ident == "PhantomData")
252292
}
253293

254-
fn lift(mut ty: syn::Type, generic_parameters: &[syn::Ident]) -> LiftedTy {
255-
struct ItoJ<'a> {
294+
fn lift(ty: syn::Type, generic_parameters: &[syn::Ident]) -> TransformedTy {
295+
transform_type_parameters(ty, generic_parameters, |path, ty, generic_parameter_bounds| {
296+
match path {
297+
TypeParameterPath::Interner => {
298+
*ty.path.segments.first_mut().unwrap() = parse_quote! { J };
299+
TypeParameterTransform::Continue
300+
}
301+
TypeParameterPath::GenericParameter(param) => {
302+
push_unique(generic_parameter_bounds, param.clone());
303+
*ty = parse_quote! { <#param as ::rustc_type_ir::lift::Lift<J>>::Lifted };
304+
TypeParameterTransform::Stop
305+
}
306+
}
307+
})
308+
}
309+
310+
fn transform_type_parameters(
311+
mut ty: syn::Type,
312+
generic_parameters: &[syn::Ident],
313+
visit: TypeParameterVisitor,
314+
) -> TransformedTy {
315+
struct TypeParameterTransformer<'a> {
256316
generic_parameters: &'a [syn::Ident],
257317
generic_parameter_bounds: Vec<syn::Ident>,
318+
visit: TypeParameterVisitor,
258319
}
259320

260-
impl VisitMut for ItoJ<'_> {
321+
impl VisitMut for TypeParameterTransformer<'_> {
261322
fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
262-
if i.qself.is_none() {
323+
let path = if i.qself.is_none() {
263324
let segments_len = i.path.segments.len();
264-
if let Some(first) = i.path.segments.first_mut() {
265-
// Turn paths from `I` into `J`
325+
i.path.segments.first().and_then(|first| {
266326
if first.ident == "I" {
267-
*first = parse_quote! { J };
327+
Some(TypeParameterPath::Interner)
268328
} else if segments_len == 1
269329
&& matches!(first.arguments, syn::PathArguments::None)
270330
&& self.generic_parameters.iter().any(|param| first.ident == *param)
271331
{
272-
let ident = first.ident.clone();
273-
if !self.generic_parameter_bounds.iter().any(|param| *param == ident) {
274-
self.generic_parameter_bounds.push(ident.clone());
275-
}
276-
277-
*i = parse_quote! { <#ident as ::rustc_type_ir::lift::Lift<J>>::Lifted };
278-
return;
332+
Some(TypeParameterPath::GenericParameter(first.ident.clone()))
333+
} else {
334+
None
279335
}
336+
})
337+
} else {
338+
None
339+
};
340+
341+
if let Some(path) = path {
342+
if let TypeParameterTransform::Stop =
343+
(self.visit)(path, i, &mut self.generic_parameter_bounds)
344+
{
345+
return;
280346
}
281347
}
348+
282349
syn::visit_mut::visit_type_path_mut(self, i);
283350
}
284351
}
285-
let mut visitor = ItoJ { generic_parameters, generic_parameter_bounds: Vec::new() };
352+
353+
let mut visitor =
354+
TypeParameterTransformer { generic_parameters, generic_parameter_bounds: vec![], visit };
286355
visitor.visit_type_mut(&mut ty);
287-
LiftedTy { ty, generic_parameter_bounds: visitor.generic_parameter_bounds }
356+
TransformedTy { ty, generic_parameter_bounds: visitor.generic_parameter_bounds }
357+
}
358+
359+
fn push_unique(params: &mut Vec<syn::Ident>, param: syn::Ident) {
360+
if !params.iter().any(|prev| *prev == param) {
361+
params.push(param);
362+
}
288363
}
289364

290365
#[cfg(not(feature = "nightly"))]

0 commit comments

Comments
 (0)