Skip to content

Commit fc1b85c

Browse files
committed
Handle parameter shifting in TypeBuilder
1 parent 3f431a8 commit fc1b85c

4 files changed

Lines changed: 82 additions & 81 deletions

File tree

src/analyze/basic_block.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
5656
self.ctx.basic_block_ty(self.local_def_id, bb)
5757
}
5858

59+
fn type_builder(&self) -> TypeBuilder<'tcx> {
60+
TypeBuilder::new(self.tcx, self.local_def_id.to_def_id())
61+
}
62+
5963
fn bind_local(&mut self, local: Local, rty: rty::RefinedType<Var>) {
6064
let rty = if self.is_mut_local(local) {
6165
// elaboration:
@@ -222,7 +226,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
222226
let rty_args: IndexVec<_, _> = args
223227
.types()
224228
.map(|ty| {
225-
TypeBuilder::new(self.tcx)
229+
self.type_builder()
226230
.for_template(&mut self.ctx)
227231
.with_scope(&self.env)
228232
.build_refined(ty)
@@ -435,7 +439,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
435439
let func_ty = match func.const_fn_def() {
436440
// TODO: move this to well-known defs?
437441
Some((def_id, args)) if self.is_box_new(def_id) => {
438-
let inner_ty = TypeBuilder::new(self.tcx)
442+
let inner_ty = self
443+
.type_builder()
439444
.for_template(&mut self.ctx)
440445
.build(args.type_at(0))
441446
.vacuous();
@@ -449,7 +454,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
449454
rty::FunctionType::new([param].into_iter().collect(), ret).into()
450455
}
451456
Some((def_id, args)) if self.is_mem_swap(def_id) => {
452-
let inner_ty = TypeBuilder::new(self.tcx).build(args.type_at(0)).vacuous();
457+
let inner_ty = self.type_builder().build(args.type_at(0)).vacuous();
453458
let param1 =
454459
rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into());
455460
let param2 =
@@ -536,7 +541,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
536541
}
537542

538543
fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) {
539-
let ty = TypeBuilder::new(self.tcx).build(ty);
544+
let ty = self.type_builder().build(ty);
540545
let temp_var = self.env.push_temp_var(ty.vacuous());
541546
self.prophecy_vars.insert(statement_index, temp_var);
542547
tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var");
@@ -557,7 +562,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
557562
referent: mir::Place<'tcx>,
558563
prophecy_ty: mir_ty::Ty<'tcx>,
559564
) -> rty::RefinedType<Var> {
560-
let prophecy_ty = TypeBuilder::new(self.tcx).build(prophecy_ty);
565+
let prophecy_ty = self.type_builder().build(prophecy_ty);
561566
let prophecy = self.env.push_temp_var(prophecy_ty.vacuous());
562567
let place = self.elaborate_place_for_borrow(&referent);
563568
self.env.borrow_place(place, prophecy).into()
@@ -669,7 +674,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
669674
}
670675

671676
let decl = self.local_decls[destination].clone();
672-
let rty = TypeBuilder::new(self.tcx)
677+
let rty = self
678+
.type_builder()
673679
.for_template(&mut self.ctx)
674680
.with_scope(&self.env)
675681
.build_refined(decl.ty);
@@ -743,7 +749,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
743749
#[tracing::instrument(skip(self))]
744750
fn ret_template(&mut self) -> rty::RefinedType<Var> {
745751
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
746-
TypeBuilder::new(self.tcx)
752+
self.type_builder()
747753
.for_template(&mut self.ctx)
748754
.with_scope(&self.env)
749755
.build_refined(ret_ty)

src/analyze/crate_.rs

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
132132

133133
let mut param_resolver = analyze::annot::ParamResolver::default();
134134
for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) {
135-
let input_ty = TypeBuilder::new(self.tcx).build(*input_ty);
135+
let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty);
136136
param_resolver.push_param(input_ident.name, input_ty.to_sort());
137137
}
138138

139139
let mut require_annot = self.extract_require_annot(&param_resolver, def_id);
140140
let mut ensure_annot = {
141-
let output_ty = TypeBuilder::new(self.tcx).build(sig.output());
141+
let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output());
142142
let resolver = annot::StackedResolver::default()
143143
.resolver(analyze::annot::ResultResolver::new(output_ty.to_sort()))
144144
.resolver((&param_resolver).map(rty::RefinedTypeVar::Free));
@@ -175,7 +175,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
175175
self.trusted.insert(def_id);
176176
}
177177

178-
let mut builder = TypeBuilder::new(self.tcx).for_function_template(&mut self.ctx, sig);
178+
let mut builder =
179+
TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig);
179180
if let Some(AnnotFormula::Formula(require)) = require_annot {
180181
let formula = require.map_var(|idx| {
181182
if idx.index() == sig.inputs().len() - 1 {
@@ -252,28 +253,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
252253
};
253254
let adt = self.tcx.adt_def(local_def_id);
254255

255-
// The index of TyKind::ParamTy is based on the every generic parameters in
256-
// the definition, including lifetimes. Given the following definition:
257-
//
258-
// struct X<'a, T> { f: &'a T }
259-
//
260-
// The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime
261-
// parameters and the index of rty::ParamType is based on type parameters only.
262-
// We're building a mapping from the original index to the new index here.
263-
let generics = self.tcx.generics_of(local_def_id);
264-
let mut type_param_mapping: std::collections::HashMap<usize, usize> =
265-
Default::default();
266-
for i in 0..generics.count() {
267-
let generic_param = generics.param_at(i, self.tcx);
268-
match generic_param.kind {
269-
mir_ty::GenericParamDefKind::Lifetime => {}
270-
mir_ty::GenericParamDefKind::Type { .. } => {
271-
type_param_mapping.insert(i, type_param_mapping.len());
272-
}
273-
mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(),
274-
}
275-
}
276-
277256
let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id());
278257
let variants: IndexVec<_, _> = adt
279258
.variants()
@@ -287,27 +266,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
287266
.iter()
288267
.map(|field| {
289268
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
290-
291-
// see the comment above about this mapping
292-
let subst = rty::TypeParamSubst::new(
293-
type_param_mapping
294-
.iter()
295-
.map(|(old, new)| {
296-
let old = rty::TypeParamIdx::from(*old);
297-
let new =
298-
rty::ParamType::new(rty::TypeParamIdx::from(*new));
299-
(old, rty::RefinedType::unrefined(new.into()))
300-
})
301-
.collect(),
302-
);
303-
304-
// the subst doesn't contain refinements, so it's OK to take ty only
305-
// after substitution
306-
let mut field_rty = rty::RefinedType::unrefined(
307-
TypeBuilder::new(self.tcx).build(field_ty),
308-
);
309-
field_rty.subst_ty_params(&subst);
310-
field_rty.ty
269+
TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty)
311270
})
312271
.collect();
313272
rty::EnumVariantDef {
@@ -318,7 +277,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
318277
})
319278
.collect();
320279

321-
let ty_params = type_param_mapping.len();
280+
let generics = self.tcx.generics_of(local_def_id);
281+
let ty_params = (0..generics.count())
282+
.filter(|idx| {
283+
matches!(
284+
generics.param_at(*idx, self.tcx).kind,
285+
mir_ty::GenericParamDefKind::Type { .. }
286+
)
287+
})
288+
.count();
322289
tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count");
323290

324291
let def = rty::EnumDatatypeDef {

src/analyze/local_def.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
306306
}
307307
// function return type is basic block return type
308308
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
309-
let rty = TypeBuilder::new(self.tcx)
309+
let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id())
310310
.for_template(&mut self.ctx)
311311
.build_basic_block(live_locals, ret_ty);
312312
self.ctx.register_basic_block_ty(self.local_def_id, bb, rty);

src/refine/template.rs

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::HashMap;
33
use rustc_index::IndexVec;
44
use rustc_middle::mir::{Local, Mutability};
55
use rustc_middle::ty as mir_ty;
6+
use rustc_span::def_id::DefId;
67

78
use super::basic_block::BasicBlockType;
89
use crate::chc;
@@ -60,11 +61,43 @@ where
6061
#[derive(Clone)]
6162
pub struct TypeBuilder<'tcx> {
6263
tcx: mir_ty::TyCtxt<'tcx>,
64+
type_param_mapping: HashMap<u32, rty::TypeParamIdx>,
6365
}
6466

6567
impl<'tcx> TypeBuilder<'tcx> {
66-
pub fn new(tcx: mir_ty::TyCtxt<'tcx>) -> Self {
67-
Self { tcx }
68+
pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self {
69+
// The index of TyKind::ParamTy is based on the every generic parameters in
70+
// the definition, including lifetimes. Given the following definition:
71+
//
72+
// struct X<'a, T> { f: &'a T }
73+
//
74+
// The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime
75+
// parameters and the index of rty::ParamType is based on type parameters only.
76+
// We're building a mapping from the original index to the new index here.
77+
let generics = tcx.generics_of(def_id);
78+
let mut type_param_mapping: HashMap<u32, rty::TypeParamIdx> = Default::default();
79+
for i in 0..generics.count() {
80+
let generic_param = generics.param_at(i, tcx);
81+
match generic_param.kind {
82+
mir_ty::GenericParamDefKind::Lifetime => {}
83+
mir_ty::GenericParamDefKind::Type { .. } => {
84+
type_param_mapping.insert(i as u32, type_param_mapping.len().into());
85+
}
86+
mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(),
87+
}
88+
}
89+
Self {
90+
tcx,
91+
type_param_mapping,
92+
}
93+
}
94+
95+
fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType {
96+
let index = *self
97+
.type_param_mapping
98+
.get(&ty.index)
99+
.expect("unknown type param idx");
100+
rty::ParamType::new(index)
68101
}
69102

70103
// TODO: consolidate two impls
@@ -89,7 +122,7 @@ impl<'tcx> TypeBuilder<'tcx> {
89122
rty::TupleType::new(elems).into()
90123
}
91124
mir_ty::TyKind::Never => rty::Type::never(),
92-
mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(),
125+
mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(),
93126
mir_ty::TyKind::FnPtr(sig) => {
94127
// TODO: justification for skip_binder
95128
let sig = sig.skip_binder();
@@ -135,7 +168,7 @@ impl<'tcx> TypeBuilder<'tcx> {
135168
registry: &'a mut R,
136169
) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> {
137170
TemplateTypeBuilder {
138-
tcx: self.tcx,
171+
inner: self.clone(),
139172
registry,
140173
scope: Default::default(),
141174
}
@@ -147,7 +180,7 @@ impl<'tcx> TypeBuilder<'tcx> {
147180
sig: mir_ty::FnSig<'tcx>,
148181
) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
149182
FunctionTemplateTypeBuilder {
150-
tcx: self.tcx,
183+
inner: self.clone(),
151184
registry,
152185
param_tys: sig
153186
.inputs()
@@ -166,15 +199,15 @@ impl<'tcx> TypeBuilder<'tcx> {
166199
}
167200

168201
pub struct TemplateTypeBuilder<'tcx, 'a, R, S> {
169-
tcx: mir_ty::TyCtxt<'tcx>,
202+
inner: TypeBuilder<'tcx>,
170203
registry: &'a mut R,
171204
scope: S,
172205
}
173206

174207
impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> {
175208
pub fn with_scope<T>(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> {
176209
TemplateTypeBuilder {
177-
tcx: self.tcx,
210+
inner: self.inner,
178211
registry: self.registry,
179212
scope,
180213
}
@@ -207,29 +240,27 @@ where
207240
rty::TupleType::new(elems).into()
208241
}
209242
mir_ty::TyKind::Never => rty::Type::never(),
210-
mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(),
243+
mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(),
211244
mir_ty::TyKind::FnPtr(sig) => {
212245
// TODO: justification for skip_binder
213246
let sig = sig.skip_binder();
214-
let ty = TypeBuilder::new(self.tcx)
215-
.for_function_template(self.registry, sig)
216-
.build();
247+
let ty = self.inner.for_function_template(self.registry, sig).build();
217248
rty::Type::function(ty)
218249
}
219250
mir_ty::TyKind::Adt(def, params) if def.is_box() => {
220251
rty::PointerType::own(self.build(params.type_at(0))).into()
221252
}
222253
mir_ty::TyKind::Adt(def, params) => {
223254
if def.is_enum() {
224-
let sym = refine::datatype_symbol(self.tcx, def.did());
255+
let sym = refine::datatype_symbol(self.inner.tcx, def.did());
225256
let args: IndexVec<_, _> =
226257
params.types().map(|ty| self.build_refined(ty)).collect();
227258
rty::EnumType::new(sym, args).into()
228259
} else if def.is_struct() {
229260
let elem_tys = def
230261
.all_fields()
231262
.map(|field| {
232-
let ty = field.ty(self.tcx, params);
263+
let ty = field.ty(self.inner.tcx, params);
233264
// elaboration: all fields are boxed
234265
rty::PointerType::own(self.build(ty)).into()
235266
})
@@ -245,10 +276,7 @@ where
245276

246277
pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType<S::Var> {
247278
// TODO: consider building ty with scope
248-
let ty = TypeBuilder::new(self.tcx)
249-
.for_template(self.registry)
250-
.build(ty)
251-
.vacuous();
279+
let ty = self.inner.for_template(self.registry).build(ty).vacuous();
252280
let tmpl = self.scope.build_template().build(ty);
253281
self.registry.register_template(tmpl)
254282
}
@@ -269,7 +297,7 @@ where
269297
tys.push(ty);
270298
}
271299
let ty = FunctionTemplateTypeBuilder {
272-
tcx: self.tcx,
300+
inner: self.inner.clone(),
273301
registry: self.registry,
274302
param_tys: tys,
275303
ret_ty,
@@ -283,7 +311,7 @@ where
283311
}
284312

285313
pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> {
286-
tcx: mir_ty::TyCtxt<'tcx>,
314+
inner: TypeBuilder<'tcx>,
287315
registry: &'a mut R,
288316
param_tys: Vec<mir_ty::TypeAndMut<'tcx>>,
289317
ret_ty: mir_ty::Ty<'tcx>,
@@ -324,7 +352,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
324352
&mut self,
325353
refinement: rty::Refinement<rty::FunctionParamIdx>,
326354
) -> &mut Self {
327-
let ty = TypeBuilder::new(self.tcx).build(self.ret_ty);
355+
let ty = self.inner.build(self.ret_ty);
328356
self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement));
329357
self
330358
}
@@ -350,17 +378,17 @@ where
350378
.unwrap_or_else(|| {
351379
if idx == self.param_tys.len() - 1 {
352380
if let Some(param_refinement) = &self.param_refinement {
353-
let ty = TypeBuilder::new(self.tcx).build(param_ty.ty);
381+
let ty = self.inner.build(param_ty.ty);
354382
rty::RefinedType::new(ty.vacuous(), param_refinement.clone())
355383
} else {
356-
TypeBuilder::new(self.tcx)
384+
self.inner
357385
.for_template(self.registry)
358386
.with_scope(&builder)
359387
.build_refined(param_ty.ty)
360388
}
361389
} else {
362390
rty::RefinedType::unrefined(
363-
TypeBuilder::new(self.tcx)
391+
self.inner
364392
.for_template(self.registry)
365393
.with_scope(&builder)
366394
.build(param_ty.ty),
@@ -383,8 +411,8 @@ where
383411
let param_rty = if let Some(param_refinement) = &self.param_refinement {
384412
rty::RefinedType::new(rty::Type::unit(), param_refinement.clone())
385413
} else {
386-
let unit_ty = mir_ty::Ty::new_unit(self.tcx);
387-
TypeBuilder::new(self.tcx)
414+
let unit_ty = mir_ty::Ty::new_unit(self.inner.tcx);
415+
self.inner
388416
.for_template(self.registry)
389417
.with_scope(&builder)
390418
.build_refined(unit_ty)
@@ -393,7 +421,7 @@ where
393421
}
394422

395423
let ret_rty = self.ret_rty.clone().unwrap_or_else(|| {
396-
TypeBuilder::new(self.tcx)
424+
self.inner
397425
.for_template(self.registry)
398426
.with_scope(&builder)
399427
.build_refined(self.ret_ty)

0 commit comments

Comments
 (0)