Skip to content

Commit 124598a

Browse files
committed
Enable to handle annotated polymorphic function
1 parent fc1b85c commit 124598a

6 files changed

Lines changed: 101 additions & 26 deletions

File tree

src/analyze/basic_block.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub struct Analyzer<'tcx, 'ctx> {
3333
basic_block: BasicBlock,
3434
body: Cow<'tcx, Body<'tcx>>,
3535

36+
type_builder: TypeBuilder<'tcx>,
3637
env: Env,
3738
local_decls: IndexVec<Local, mir::LocalDecl<'tcx>>,
3839
// TODO: remove this
@@ -56,10 +57,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
5657
self.ctx.basic_block_ty(self.local_def_id, bb)
5758
}
5859

59-
fn type_builder(&self) -> TypeBuilder<'tcx> {
60-
TypeBuilder::new(self.tcx, self.local_def_id.to_def_id())
61-
}
62-
6360
fn bind_local(&mut self, local: Local, rty: rty::RefinedType<Var>) {
6461
let rty = if self.is_mut_local(local) {
6562
// elaboration:
@@ -226,7 +223,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
226223
let rty_args: IndexVec<_, _> = args
227224
.types()
228225
.map(|ty| {
229-
self.type_builder()
226+
self.type_builder
230227
.for_template(&mut self.ctx)
231228
.with_scope(&self.env)
232229
.build_refined(ty)
@@ -440,7 +437,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
440437
// TODO: move this to well-known defs?
441438
Some((def_id, args)) if self.is_box_new(def_id) => {
442439
let inner_ty = self
443-
.type_builder()
440+
.type_builder
444441
.for_template(&mut self.ctx)
445442
.build(args.type_at(0))
446443
.vacuous();
@@ -454,7 +451,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
454451
rty::FunctionType::new([param].into_iter().collect(), ret).into()
455452
}
456453
Some((def_id, args)) if self.is_mem_swap(def_id) => {
457-
let inner_ty = self.type_builder().build(args.type_at(0)).vacuous();
454+
let inner_ty = self.type_builder.build(args.type_at(0)).vacuous();
458455
let param1 =
459456
rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into());
460457
let param2 =
@@ -472,15 +469,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
472469
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
473470
}
474471
Some((def_id, args)) => {
475-
if !args.is_empty() {
476-
tracing::warn!(?args, ?def_id, "generic args ignored");
472+
if args.consts().next().is_some() {
473+
tracing::warn!(?args, ?def_id, "const generic args ignored");
477474
}
478-
self.ctx
479-
.def_ty(def_id)
480-
.expect("unknown def")
481-
.ty
482-
.clone()
483-
.vacuous()
475+
let ty_args = args
476+
.types()
477+
.map(|ty| rty::RefinedType::unrefined(self.type_builder.build(ty)))
478+
.collect();
479+
let mut def_ty = self.ctx.def_ty(def_id).expect("unknown def").clone();
480+
def_ty.instantiate_ty_params(ty_args);
481+
def_ty.ty.vacuous()
484482
}
485483
_ => self.operand_type(func.clone()).ty,
486484
};
@@ -541,7 +539,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
541539
}
542540

543541
fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) {
544-
let ty = self.type_builder().build(ty);
542+
let ty = self.type_builder.build(ty);
545543
let temp_var = self.env.push_temp_var(ty.vacuous());
546544
self.prophecy_vars.insert(statement_index, temp_var);
547545
tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var");
@@ -562,7 +560,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
562560
referent: mir::Place<'tcx>,
563561
prophecy_ty: mir_ty::Ty<'tcx>,
564562
) -> rty::RefinedType<Var> {
565-
let prophecy_ty = self.type_builder().build(prophecy_ty);
563+
let prophecy_ty = self.type_builder.build(prophecy_ty);
566564
let prophecy = self.env.push_temp_var(prophecy_ty.vacuous());
567565
let place = self.elaborate_place_for_borrow(&referent);
568566
self.env.borrow_place(place, prophecy).into()
@@ -675,7 +673,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
675673

676674
let decl = self.local_decls[destination].clone();
677675
let rty = self
678-
.type_builder()
676+
.type_builder
679677
.for_template(&mut self.ctx)
680678
.with_scope(&self.env)
681679
.build_refined(decl.ty);
@@ -749,7 +747,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
749747
#[tracing::instrument(skip(self))]
750748
fn ret_template(&mut self) -> rty::RefinedType<Var> {
751749
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
752-
self.type_builder()
750+
self.type_builder
753751
.for_template(&mut self.ctx)
754752
.with_scope(&self.env)
755753
.build_refined(ret_ty)
@@ -955,13 +953,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
955953
let env = ctx.new_env();
956954
let local_decls = body.local_decls.clone();
957955
let prophecy_vars = Default::default();
956+
let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id());
958957
Self {
959958
ctx,
960959
tcx,
961960
local_def_id,
962961
drop_points,
963962
basic_block,
964963
body,
964+
type_builder,
965965
env,
966966
local_decls,
967967
prophecy_vars,
@@ -989,6 +989,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
989989
self
990990
}
991991

992+
pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
993+
self.type_builder = type_builder;
994+
self
995+
}
996+
992997
pub fn run(&mut self, expected: &BasicBlockType) {
993998
let span = tracing::info_span!("bb", bb = ?self.basic_block);
994999
let _guard = span.enter();

src/analyze/crate_.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
211211
tracing::info!(?local_def_id, "trusted");
212212
continue;
213213
}
214-
let expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone();
215-
self.ctx.local_def_analyzer(*local_def_id).run(&expected);
214+
// check polymorphic function def by replacing type params with some opaque type
215+
let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id())
216+
.with_param_mapper(|_| rty::Type::int());
217+
let mut expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone();
218+
let subst = rty::TypeParamSubst::new(
219+
expected
220+
.free_ty_params()
221+
.into_iter()
222+
.map(|ty_param| (ty_param, rty::RefinedType::unrefined(rty::Type::int())))
223+
.collect(),
224+
);
225+
expected.subst_ty_params(&subst);
226+
self.ctx
227+
.local_def_analyzer(*local_def_id)
228+
.type_builder(type_builder)
229+
.run(&expected);
216230
}
217231
}
218232

src/analyze/local_def.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub struct Analyzer<'tcx, 'ctx> {
2626

2727
body: Body<'tcx>,
2828
drop_points: HashMap<BasicBlock, analyze::basic_block::DropPoints>,
29+
type_builder: TypeBuilder<'tcx>,
2930
}
3031

3132
impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
@@ -306,7 +307,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
306307
}
307308
// function return type is basic block return type
308309
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
309-
let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id())
310+
let rty = self
311+
.type_builder
310312
.for_template(&mut self.ctx)
311313
.build_basic_block(live_locals, ret_ty);
312314
self.ctx.register_basic_block_ty(self.local_def_id, bb, rty);
@@ -321,6 +323,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
321323
.basic_block_analyzer(self.local_def_id, bb)
322324
.body(self.body.clone())
323325
.drop_points(drop_points)
326+
.type_builder(self.type_builder.clone())
324327
.run(&rty);
325328
}
326329
}
@@ -426,15 +429,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
426429
let tcx = ctx.tcx;
427430
let body = tcx.optimized_mir(local_def_id.to_def_id()).clone();
428431
let drop_points = Default::default();
432+
let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id());
429433
Self {
430434
ctx,
431435
tcx,
432436
local_def_id,
433437
body,
434438
drop_points,
439+
type_builder,
435440
}
436441
}
437442

443+
pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
444+
self.type_builder = type_builder;
445+
self
446+
}
447+
438448
pub fn run(&mut self, expected: &rty::RefinedType) {
439449
let span = tracing::info_span!("def", def = %self.tcx.def_path_str(self.local_def_id));
440450
let _guard = span.enter();

src/refine/template.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,24 @@ where
5858
}
5959
}
6060

61+
trait ParamTypeMapper {
62+
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed>;
63+
}
64+
65+
impl<F> ParamTypeMapper for F
66+
where
67+
F: Fn(rty::ParamType) -> rty::Type<rty::Closed>,
68+
{
69+
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed> {
70+
self(ty)
71+
}
72+
}
73+
6174
#[derive(Clone)]
6275
pub struct TypeBuilder<'tcx> {
6376
tcx: mir_ty::TyCtxt<'tcx>,
6477
type_param_mapping: HashMap<u32, rty::TypeParamIdx>,
78+
param_type_mapper: std::rc::Rc<dyn ParamTypeMapper>,
6579
}
6680

6781
impl<'tcx> TypeBuilder<'tcx> {
@@ -89,15 +103,25 @@ impl<'tcx> TypeBuilder<'tcx> {
89103
Self {
90104
tcx,
91105
type_param_mapping,
106+
param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()),
92107
}
93108
}
94109

95-
fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType {
110+
pub fn with_param_mapper<F>(mut self, mapper: F) -> Self
111+
where
112+
F: Fn(rty::ParamType) -> rty::Type<rty::Closed> + 'static,
113+
{
114+
self.param_type_mapper = std::rc::Rc::new(mapper);
115+
self
116+
}
117+
118+
fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type<rty::Closed> {
96119
let index = *self
97120
.type_param_mapping
98121
.get(&ty.index)
99122
.expect("unknown type param idx");
100-
rty::ParamType::new(index)
123+
let param_ty = rty::ParamType::new(index);
124+
self.param_type_mapper.map_param_ty(param_ty)
101125
}
102126

103127
// TODO: consolidate two impls
@@ -122,7 +146,7 @@ impl<'tcx> TypeBuilder<'tcx> {
122146
rty::TupleType::new(elems).into()
123147
}
124148
mir_ty::TyKind::Never => rty::Type::never(),
125-
mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(),
149+
mir_ty::TyKind::Param(ty) => self.translate_param_type(ty),
126150
mir_ty::TyKind::FnPtr(sig) => {
127151
// TODO: justification for skip_binder
128152
let sig = sig.skip_binder();
@@ -240,7 +264,7 @@ where
240264
rty::TupleType::new(elems).into()
241265
}
242266
mir_ty::TyKind::Never => rty::Type::never(),
243-
mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(),
267+
mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).vacuous(),
244268
mir_ty::TyKind::FnPtr(sig) => {
245269
// TODO: justification for skip_binder
246270
let sig = sig.skip_binder();

tests/ui/fail/fn_poly_annot.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[thrust::requires(true)]
4+
#[thrust::ensures(result != x.0)]
5+
fn left<T, U>(x: (T, U)) -> T {
6+
x.0
7+
}
8+
9+
fn main() {
10+
assert!(left((42, 0)) == 42);
11+
}

tests/ui/pass/fn_poly_annot.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//@check-pass
2+
3+
#[thrust::requires(true)]
4+
#[thrust::ensures(result == x.0)]
5+
fn left<T, U>(x: (T, U)) -> T {
6+
x.0
7+
}
8+
9+
fn main() {
10+
assert!(left((42, 0)) == 42);
11+
}

0 commit comments

Comments
 (0)