Skip to content

Commit 1241177

Browse files
committed
Support enums with lifetime params
1 parent 855ff0a commit 1241177

4 files changed

Lines changed: 76 additions & 14 deletions

File tree

src/analyze/crate_.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,29 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
251251
continue;
252252
};
253253
let adt = self.tcx.adt_def(local_def_id);
254+
255+
// The index of TyKind::ParamTy is based on the every generic parameters in
256+
// the definition, including lifetimes. Given the following definition:
257+
//
258+
// struct X<'a, T> { f: &'a T }
259+
//
260+
// The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime
261+
// parameters and the index of rty::ParamType is based on type parameters only.
262+
// We're building a mapping from the original index to the new index here.
263+
let generics = self.tcx.generics_of(local_def_id);
264+
let mut type_param_mapping: std::collections::HashMap<usize, usize> =
265+
Default::default();
266+
for i in 0..generics.count() {
267+
let generic_param = generics.param_at(i, self.tcx);
268+
match generic_param.kind {
269+
mir_ty::GenericParamDefKind::Lifetime => {}
270+
mir_ty::GenericParamDefKind::Type { .. } => {
271+
type_param_mapping.insert(i, type_param_mapping.len());
272+
}
273+
mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(),
274+
}
275+
}
276+
254277
let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id());
255278
let variants: IndexVec<_, _> = adt
256279
.variants()
@@ -264,7 +287,26 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
264287
.iter()
265288
.map(|field| {
266289
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
267-
self.ctx.unrefined_ty(field_ty)
290+
291+
// see the comment above about this mapping
292+
let subst = rty::TypeParamSubst::new(
293+
type_param_mapping
294+
.iter()
295+
.map(|(old, new)| {
296+
let old = rty::TypeParamIdx::from(*old);
297+
let new =
298+
rty::ParamType::new(rty::TypeParamIdx::from(*new));
299+
(old, rty::RefinedType::unrefined(new.into()))
300+
})
301+
.collect(),
302+
);
303+
304+
// the subst doesn't contain refinements, so it's OK to take ty only
305+
// after substitution
306+
let mut field_rty =
307+
rty::RefinedType::unrefined(self.ctx.unrefined_ty(field_ty));
308+
field_rty.subst_ty_params(&subst);
309+
field_rty.ty
268310
})
269311
.collect();
270312
rty::EnumVariantDef {
@@ -275,19 +317,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
275317
})
276318
.collect();
277319

278-
let ty_params = adt
279-
.all_fields()
280-
.map(|f| self.tcx.type_of(f.did).instantiate_identity())
281-
.flat_map(|ty| {
282-
if let mir_ty::TyKind::Param(p) = ty.kind() {
283-
Some(p.index as usize)
284-
} else {
285-
None
286-
}
287-
})
288-
.max()
289-
.map(|max| max + 1)
290-
.unwrap_or(0);
320+
let ty_params = type_param_mapping.len();
291321
tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count");
292322

293323
let def = rty::EnumDatatypeDef {

src/rty/params.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ impl<T> std::ops::Index<TypeParamIdx> for TypeParamSubst<T> {
7171
}
7272

7373
impl<T> TypeParamSubst<T> {
74+
pub fn new(subst: BTreeMap<TypeParamIdx, RefinedType<T>>) -> Self {
75+
Self { subst }
76+
}
77+
7478
pub fn singleton(idx: TypeParamIdx, ty: RefinedType<T>) -> Self {
7579
let mut subst = BTreeMap::default();
7680
subst.insert(idx, ty);

tests/ui/fail/adt_poly_ref.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
enum X<'a, T> {
5+
A(&'a T),
6+
}
7+
8+
fn main() {
9+
let i = 42;
10+
let x = X::A(&i);
11+
match x {
12+
X::A(i) => assert!(*i == 41),
13+
}
14+
}

tests/ui/pass/adt_poly_ref.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
4+
enum X<'a, T> {
5+
A(&'a T),
6+
}
7+
8+
fn main() {
9+
let i = 42;
10+
let x = X::A(&i);
11+
match x {
12+
X::A(i) => assert!(*i == 42),
13+
}
14+
}

0 commit comments

Comments
 (0)