@@ -3,6 +3,7 @@ use std::collections::HashMap;
33use rustc_index:: IndexVec ;
44use rustc_middle:: mir:: { Local , Mutability } ;
55use rustc_middle:: ty as mir_ty;
6+ use rustc_span:: def_id:: DefId ;
67
78use super :: basic_block:: BasicBlockType ;
89use crate :: chc;
@@ -60,11 +61,43 @@ where
6061#[ derive( Clone ) ]
6162pub struct TypeBuilder < ' tcx > {
6263 tcx : mir_ty:: TyCtxt < ' tcx > ,
64+ type_param_mapping : HashMap < u32 , rty:: TypeParamIdx > ,
6365}
6466
6567impl < ' tcx > TypeBuilder < ' tcx > {
66- pub fn new ( tcx : mir_ty:: TyCtxt < ' tcx > ) -> Self {
67- Self { tcx }
68+ pub fn new ( tcx : mir_ty:: TyCtxt < ' tcx > , def_id : DefId ) -> Self {
69+ // The index of TyKind::ParamTy is based on the every generic parameters in
70+ // the definition, including lifetimes. Given the following definition:
71+ //
72+ // struct X<'a, T> { f: &'a T }
73+ //
74+ // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime
75+ // parameters and the index of rty::ParamType is based on type parameters only.
76+ // We're building a mapping from the original index to the new index here.
77+ let generics = tcx. generics_of ( def_id) ;
78+ let mut type_param_mapping: HashMap < u32 , rty:: TypeParamIdx > = Default :: default ( ) ;
79+ for i in 0 ..generics. count ( ) {
80+ let generic_param = generics. param_at ( i, tcx) ;
81+ match generic_param. kind {
82+ mir_ty:: GenericParamDefKind :: Lifetime => { }
83+ mir_ty:: GenericParamDefKind :: Type { .. } => {
84+ type_param_mapping. insert ( i as u32 , type_param_mapping. len ( ) . into ( ) ) ;
85+ }
86+ mir_ty:: GenericParamDefKind :: Const { .. } => unimplemented ! ( ) ,
87+ }
88+ }
89+ Self {
90+ tcx,
91+ type_param_mapping,
92+ }
93+ }
94+
95+ fn translate_param_type ( & self , ty : & mir_ty:: ParamTy ) -> rty:: ParamType {
96+ let index = * self
97+ . type_param_mapping
98+ . get ( & ty. index )
99+ . expect ( "unknown type param idx" ) ;
100+ rty:: ParamType :: new ( index)
68101 }
69102
70103 // TODO: consolidate two impls
@@ -89,7 +122,7 @@ impl<'tcx> TypeBuilder<'tcx> {
89122 rty:: TupleType :: new ( elems) . into ( )
90123 }
91124 mir_ty:: TyKind :: Never => rty:: Type :: never ( ) ,
92- mir_ty:: TyKind :: Param ( ty) => rty :: ParamType :: new ( ty . index . into ( ) ) . into ( ) ,
125+ mir_ty:: TyKind :: Param ( ty) => self . translate_param_type ( ty ) . into ( ) ,
93126 mir_ty:: TyKind :: FnPtr ( sig) => {
94127 // TODO: justification for skip_binder
95128 let sig = sig. skip_binder ( ) ;
@@ -135,7 +168,7 @@ impl<'tcx> TypeBuilder<'tcx> {
135168 registry : & ' a mut R ,
136169 ) -> TemplateTypeBuilder < ' tcx , ' a , R , EmptyTemplateScope > {
137170 TemplateTypeBuilder {
138- tcx : self . tcx ,
171+ inner : self . clone ( ) ,
139172 registry,
140173 scope : Default :: default ( ) ,
141174 }
@@ -147,7 +180,7 @@ impl<'tcx> TypeBuilder<'tcx> {
147180 sig : mir_ty:: FnSig < ' tcx > ,
148181 ) -> FunctionTemplateTypeBuilder < ' tcx , ' a , R > {
149182 FunctionTemplateTypeBuilder {
150- tcx : self . tcx ,
183+ inner : self . clone ( ) ,
151184 registry,
152185 param_tys : sig
153186 . inputs ( )
@@ -166,15 +199,15 @@ impl<'tcx> TypeBuilder<'tcx> {
166199}
167200
168201pub struct TemplateTypeBuilder < ' tcx , ' a , R , S > {
169- tcx : mir_ty :: TyCtxt < ' tcx > ,
202+ inner : TypeBuilder < ' tcx > ,
170203 registry : & ' a mut R ,
171204 scope : S ,
172205}
173206
174207impl < ' tcx , ' a , R , S > TemplateTypeBuilder < ' tcx , ' a , R , S > {
175208 pub fn with_scope < T > ( self , scope : T ) -> TemplateTypeBuilder < ' tcx , ' a , R , T > {
176209 TemplateTypeBuilder {
177- tcx : self . tcx ,
210+ inner : self . inner ,
178211 registry : self . registry ,
179212 scope,
180213 }
@@ -207,29 +240,27 @@ where
207240 rty:: TupleType :: new ( elems) . into ( )
208241 }
209242 mir_ty:: TyKind :: Never => rty:: Type :: never ( ) ,
210- mir_ty:: TyKind :: Param ( ty) => rty :: ParamType :: new ( ty . index . into ( ) ) . into ( ) ,
243+ mir_ty:: TyKind :: Param ( ty) => self . inner . translate_param_type ( ty ) . into ( ) ,
211244 mir_ty:: TyKind :: FnPtr ( sig) => {
212245 // TODO: justification for skip_binder
213246 let sig = sig. skip_binder ( ) ;
214- let ty = TypeBuilder :: new ( self . tcx )
215- . for_function_template ( self . registry , sig)
216- . build ( ) ;
247+ let ty = self . inner . for_function_template ( self . registry , sig) . build ( ) ;
217248 rty:: Type :: function ( ty)
218249 }
219250 mir_ty:: TyKind :: Adt ( def, params) if def. is_box ( ) => {
220251 rty:: PointerType :: own ( self . build ( params. type_at ( 0 ) ) ) . into ( )
221252 }
222253 mir_ty:: TyKind :: Adt ( def, params) => {
223254 if def. is_enum ( ) {
224- let sym = refine:: datatype_symbol ( self . tcx , def. did ( ) ) ;
255+ let sym = refine:: datatype_symbol ( self . inner . tcx , def. did ( ) ) ;
225256 let args: IndexVec < _ , _ > =
226257 params. types ( ) . map ( |ty| self . build_refined ( ty) ) . collect ( ) ;
227258 rty:: EnumType :: new ( sym, args) . into ( )
228259 } else if def. is_struct ( ) {
229260 let elem_tys = def
230261 . all_fields ( )
231262 . map ( |field| {
232- let ty = field. ty ( self . tcx , params) ;
263+ let ty = field. ty ( self . inner . tcx , params) ;
233264 // elaboration: all fields are boxed
234265 rty:: PointerType :: own ( self . build ( ty) ) . into ( )
235266 } )
@@ -245,10 +276,7 @@ where
245276
246277 pub fn build_refined ( & mut self , ty : mir_ty:: Ty < ' tcx > ) -> rty:: RefinedType < S :: Var > {
247278 // TODO: consider building ty with scope
248- let ty = TypeBuilder :: new ( self . tcx )
249- . for_template ( self . registry )
250- . build ( ty)
251- . vacuous ( ) ;
279+ let ty = self . inner . for_template ( self . registry ) . build ( ty) . vacuous ( ) ;
252280 let tmpl = self . scope . build_template ( ) . build ( ty) ;
253281 self . registry . register_template ( tmpl)
254282 }
@@ -269,7 +297,7 @@ where
269297 tys. push ( ty) ;
270298 }
271299 let ty = FunctionTemplateTypeBuilder {
272- tcx : self . tcx ,
300+ inner : self . inner . clone ( ) ,
273301 registry : self . registry ,
274302 param_tys : tys,
275303 ret_ty,
@@ -283,7 +311,7 @@ where
283311}
284312
285313pub struct FunctionTemplateTypeBuilder < ' tcx , ' a , R > {
286- tcx : mir_ty :: TyCtxt < ' tcx > ,
314+ inner : TypeBuilder < ' tcx > ,
287315 registry : & ' a mut R ,
288316 param_tys : Vec < mir_ty:: TypeAndMut < ' tcx > > ,
289317 ret_ty : mir_ty:: Ty < ' tcx > ,
@@ -324,7 +352,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
324352 & mut self ,
325353 refinement : rty:: Refinement < rty:: FunctionParamIdx > ,
326354 ) -> & mut Self {
327- let ty = TypeBuilder :: new ( self . tcx ) . build ( self . ret_ty ) ;
355+ let ty = self . inner . build ( self . ret_ty ) ;
328356 self . ret_rty = Some ( rty:: RefinedType :: new ( ty. vacuous ( ) , refinement) ) ;
329357 self
330358 }
@@ -350,17 +378,17 @@ where
350378 . unwrap_or_else ( || {
351379 if idx == self . param_tys . len ( ) - 1 {
352380 if let Some ( param_refinement) = & self . param_refinement {
353- let ty = TypeBuilder :: new ( self . tcx ) . build ( param_ty. ty ) ;
381+ let ty = self . inner . build ( param_ty. ty ) ;
354382 rty:: RefinedType :: new ( ty. vacuous ( ) , param_refinement. clone ( ) )
355383 } else {
356- TypeBuilder :: new ( self . tcx )
384+ self . inner
357385 . for_template ( self . registry )
358386 . with_scope ( & builder)
359387 . build_refined ( param_ty. ty )
360388 }
361389 } else {
362390 rty:: RefinedType :: unrefined (
363- TypeBuilder :: new ( self . tcx )
391+ self . inner
364392 . for_template ( self . registry )
365393 . with_scope ( & builder)
366394 . build ( param_ty. ty ) ,
@@ -383,8 +411,8 @@ where
383411 let param_rty = if let Some ( param_refinement) = & self . param_refinement {
384412 rty:: RefinedType :: new ( rty:: Type :: unit ( ) , param_refinement. clone ( ) )
385413 } else {
386- let unit_ty = mir_ty:: Ty :: new_unit ( self . tcx ) ;
387- TypeBuilder :: new ( self . tcx )
414+ let unit_ty = mir_ty:: Ty :: new_unit ( self . inner . tcx ) ;
415+ self . inner
388416 . for_template ( self . registry )
389417 . with_scope ( & builder)
390418 . build_refined ( unit_ty)
@@ -393,7 +421,7 @@ where
393421 }
394422
395423 let ret_rty = self . ret_rty . clone ( ) . unwrap_or_else ( || {
396- TypeBuilder :: new ( self . tcx )
424+ self . inner
397425 . for_template ( self . registry )
398426 . with_scope ( & builder)
399427 . build_refined ( self . ret_ty )
0 commit comments