Skip to content

Commit 62b3a6a

Browse files
committed
Enable to handle unannotated polymorphic function
1 parent efa69bf commit 62b3a6a

6 files changed

Lines changed: 152 additions & 30 deletions

File tree

src/analyze.rs

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ use std::collections::HashMap;
1111
use std::rc::Rc;
1212

1313
use rustc_hir::lang_items::LangItem;
14+
use rustc_index::IndexVec;
1415
use rustc_middle::mir::{self, BasicBlock, Local};
1516
use rustc_middle::ty::{self as mir_ty, TyCtxt};
1617
use rustc_span::def_id::{DefId, LocalDefId};
1718

1819
use crate::chc;
1920
use crate::pretty::PrettyDisplayExt as _;
20-
use crate::refine::{self, BasicBlockType};
21+
use crate::refine::{self, BasicBlockType, TypeBuilder};
2122
use crate::rty;
2223

2324
mod annot;
@@ -103,6 +104,17 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> {
103104
}
104105
}
105106

107+
#[derive(Debug, Clone)]
108+
struct DeferredDefTy<'tcx> {
109+
cache: Rc<RefCell<HashMap<Vec<mir_ty::Ty<'tcx>>, rty::RefinedType>>>,
110+
}
111+
112+
#[derive(Debug, Clone)]
113+
enum DefTy<'tcx> {
114+
Concrete(rty::RefinedType),
115+
Deferred(DeferredDefTy<'tcx>),
116+
}
117+
106118
#[derive(Clone)]
107119
pub struct Analyzer<'tcx> {
108120
tcx: TyCtxt<'tcx>,
@@ -112,7 +124,7 @@ pub struct Analyzer<'tcx> {
112124
/// currently contains only local-def templates,
113125
/// but will be extended to contain externally known def's refinement types
114126
/// (at least for every defs referenced by local def bodies)
115-
defs: HashMap<DefId, rty::RefinedType>,
127+
defs: HashMap<DefId, DefTy<'tcx>>,
116128

117129
/// Resulting CHC system.
118130
system: Rc<RefCell<chc::System>>,
@@ -207,11 +219,72 @@ impl<'tcx> Analyzer<'tcx> {
207219

208220
pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) {
209221
tracing::info!(def_id = ?def_id, rty = %rty.display(), "register_def");
210-
self.defs.insert(def_id, rty);
222+
self.defs.insert(def_id, DefTy::Concrete(rty));
223+
}
224+
225+
pub fn register_deferred_def(&mut self, def_id: DefId) {
226+
tracing::info!(def_id = ?def_id, "register_deferred_def");
227+
self.defs.insert(
228+
def_id,
229+
DefTy::Deferred(DeferredDefTy {
230+
cache: Rc::new(RefCell::new(HashMap::new())),
231+
}),
232+
);
233+
}
234+
235+
pub fn concrete_def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> {
236+
self.defs.get(&def_id).and_then(|def_ty| match def_ty {
237+
DefTy::Concrete(rty) => Some(rty),
238+
DefTy::Deferred(_) => None,
239+
})
211240
}
212241

213-
pub fn def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> {
214-
self.defs.get(&def_id)
242+
pub fn def_ty_with_args(
243+
&mut self,
244+
def_id: DefId,
245+
args: mir_ty::GenericArgsRef<'tcx>,
246+
) -> Option<rty::RefinedType> {
247+
let type_builder = TypeBuilder::new(self.tcx, def_id);
248+
let rty_args: IndexVec<_, _> = args.types().map(|ty| type_builder.build(ty)).collect();
249+
250+
let deferred_ty = match self.defs.get(&def_id)? {
251+
DefTy::Concrete(rty) => {
252+
let mut def_ty = rty.clone();
253+
def_ty.instantiate_ty_params(
254+
rty_args
255+
.clone()
256+
.into_iter()
257+
.map(rty::RefinedType::unrefined)
258+
.collect(),
259+
);
260+
return Some(def_ty);
261+
}
262+
DefTy::Deferred(deferred) => deferred,
263+
};
264+
265+
let ty_args: Vec<_> = args.types().collect();
266+
let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self
267+
if let Some(rty) = deferred_ty_cache.borrow().get(&ty_args) {
268+
return Some(rty.clone());
269+
}
270+
let local_def_id = def_id.as_local()?;
271+
272+
let sig = self
273+
.tcx
274+
.fn_sig(def_id)
275+
.instantiate(self.tcx, args)
276+
.skip_binder();
277+
let expected = self
278+
.crate_analyzer()
279+
.fn_def_ty_with_sig(local_def_id.to_def_id(), sig)
280+
.unwrap();
281+
self.local_def_analyzer(local_def_id)
282+
.type_builder(type_builder.with_param_mapper(move |ty| rty_args[ty.idx].clone()))
283+
.run(&expected);
284+
deferred_ty_cache
285+
.borrow_mut()
286+
.insert(ty_args, expected.clone());
287+
Some(expected)
215288
}
216289

217290
pub fn register_basic_block_ty(

src/analyze/basic_block.rs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
263263
_ty,
264264
) => {
265265
let func_ty = match operand.const_fn_def() {
266-
Some((def_id, args)) => {
267-
if !args.is_empty() {
268-
tracing::warn!(?args, ?def_id, "generic args ignored");
269-
}
270-
self.ctx.def_ty(def_id).expect("unknown def").ty.clone()
271-
}
266+
Some((def_id, args)) => self
267+
.ctx
268+
.def_ty_with_args(def_id, args)
269+
.expect("unknown def")
270+
.ty
271+
.clone(),
272272
_ => unimplemented!(),
273273
};
274274
PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null())
@@ -468,18 +468,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
468468
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
469469
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
470470
}
471-
Some((def_id, args)) => {
472-
if args.consts().next().is_some() {
473-
tracing::warn!(?args, ?def_id, "const generic args ignored");
474-
}
475-
let ty_args = args
476-
.types()
477-
.map(|ty| rty::RefinedType::unrefined(self.type_builder.build(ty)))
478-
.collect();
479-
let mut def_ty = self.ctx.def_ty(def_id).expect("unknown def").clone();
480-
def_ty.instantiate_ty_params(ty_args);
481-
def_ty.ty.vacuous()
482-
}
471+
Some((def_id, args)) => self
472+
.ctx
473+
.def_ty_with_args(def_id, args)
474+
.expect("unknown def")
475+
.ty
476+
.vacuous(),
483477
_ => self.operand_type(func.clone()).ty,
484478
};
485479
let expected_args: IndexVec<_, _> = args

src/analyze/crate_.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
128128
#[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(def_id)))]
129129
fn refine_fn_def(&mut self, def_id: DefId) {
130130
let sig = self.tcx.fn_sig(def_id);
131-
let sig = sig.instantiate_identity().skip_binder(); // TODO: is it OK?
131+
let sig = sig.instantiate_identity().skip_binder();
132+
if let Some(rty) = self.fn_def_ty_with_sig(def_id, sig) {
133+
self.ctx.register_def(def_id, rty);
134+
} else {
135+
self.ctx.register_deferred_def(def_id);
136+
}
137+
}
132138

139+
pub fn fn_def_ty_with_sig(
140+
&mut self,
141+
def_id: DefId,
142+
sig: mir_ty::FnSig<'tcx>,
143+
) -> Option<rty::RefinedType> {
133144
let mut param_resolver = analyze::annot::ParamResolver::default();
134145
for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) {
135146
let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty);
@@ -198,8 +209,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
198209
if let Some(ret_rty) = ret_annot {
199210
builder.ret_rty(ret_rty);
200211
}
201-
let rty = rty::RefinedType::unrefined(builder.build().into());
202-
self.ctx.register_def(def_id, rty);
212+
213+
// can't generate template with type parameter...
214+
use mir_ty::TypeVisitableExt as _;
215+
if builder.would_contain_template() && sig.has_param() {
216+
None
217+
} else {
218+
Some(rty::RefinedType::unrefined(builder.build().into()))
219+
}
203220
}
204221

205222
fn analyze_local_defs(&mut self) {
@@ -211,10 +228,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
211228
tracing::info!(?local_def_id, "trusted");
212229
continue;
213230
}
231+
let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else {
232+
// when the local_def_id is deferred it would be skipped
233+
continue;
234+
};
235+
214236
// check polymorphic function def by replacing type params with some opaque type
237+
// (and this is no-op if the function is mono)
215238
let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id())
216239
.with_param_mapper(|_| rty::Type::int());
217-
let mut expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone();
240+
let mut expected = expected.clone();
218241
let subst = rty::TypeParamSubst::new(
219242
expected
220243
.free_ty_params()
@@ -236,7 +259,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
236259
// TODO: replace code here with relate_* in Env + Refine context (created with empty env)
237260
let entry_ty = self
238261
.ctx
239-
.def_ty(def_id)
262+
.concrete_def_ty(def_id)
240263
.unwrap()
241264
.ty
242265
.as_function()

src/refine/template.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ where
8383
pub struct TypeBuilder<'tcx> {
8484
tcx: mir_ty::TyCtxt<'tcx>,
8585
/// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`].
86-
/// These indices may differ because we skip lifetime parameters.
86+
/// These indices may differ because we skip lifetime parameters and they always need to be
87+
/// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`].
8788
/// See [`rty::TypeParamIdx`] for more details.
8889
param_idx_mapping: HashMap<u32, rty::TypeParamIdx>,
90+
/// Optionally also want to further map rty::ParamType to other rty::Type before generating
91+
/// templates. This is no-op by default.
8992
param_type_mapper: std::rc::Rc<dyn ParamTypeMapper>,
9093
}
9194

@@ -100,7 +103,7 @@ impl<'tcx> TypeBuilder<'tcx> {
100103
mir_ty::GenericParamDefKind::Type { .. } => {
101104
param_idx_mapping.insert(i as u32, param_idx_mapping.len().into());
102105
}
103-
mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(),
106+
mir_ty::GenericParamDefKind::Const { .. } => {}
104107
}
105108
}
106109
Self {
@@ -397,6 +400,17 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
397400
self.ret_rty = Some(rty);
398401
self
399402
}
403+
404+
pub fn would_contain_template(&self) -> bool {
405+
if self.param_tys.is_empty() {
406+
return self.ret_rty.is_none();
407+
}
408+
409+
let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1);
410+
let param_annotated =
411+
self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx);
412+
self.ret_rty.is_none() || !param_annotated
413+
}
400414
}
401415

402416
impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>

tests/ui/fail/fn_poly.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@error-in-other-file: Unsat
2+
3+
fn left<T, U>(x: (T, U)) -> T {
4+
x.0
5+
}
6+
7+
fn main() {
8+
assert!(left((42, 0)) == 0);
9+
}

tests/ui/pass/fn_poly.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@check-pass
2+
3+
fn left<T, U>(x: (T, U)) -> T {
4+
x.0
5+
}
6+
7+
fn main() {
8+
assert!(left((42, 0)) == 42);
9+
}

0 commit comments

Comments
 (0)