Skip to content

Commit c8c6cf0

Browse files
authored
Merge pull request #15 from coord-e/extern-spec
Implement #[thrust::extern_spec_fn]
2 parents 5f5f102 + a48be41 commit c8c6cf0

5 files changed

Lines changed: 97 additions & 2 deletions

File tree

src/analyze/annot.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] {
3333
[Symbol::intern("thrust"), Symbol::intern("callable")]
3434
}
3535

36+
pub fn extern_spec_fn_path() -> [Symbol; 2] {
37+
[Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")]
38+
}
39+
3640
/// A [`annot::Resolver`] implementation for resolving function parameters.
3741
///
3842
/// The parameter names and their sorts needs to be configured via

src/analyze/crate_.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
4545
self.trusted.insert(local_def_id.to_def_id());
4646
}
4747

48+
if analyzer.is_annotated_as_extern_spec_fn() {
49+
assert!(analyzer.is_fully_annotated());
50+
self.trusted.insert(local_def_id.to_def_id());
51+
}
52+
4853
use mir_ty::TypeVisitableExt as _;
4954
if sig.has_param() && !analyzer.is_fully_annotated() {
5055
self.ctx.register_deferred_def(local_def_id.to_def_id());
5156
} else {
5257
let expected = analyzer.expected_ty();
53-
self.ctx.register_def(local_def_id.to_def_id(), expected);
58+
let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() {
59+
analyzer.extern_spec_fn_target_def_id()
60+
} else {
61+
local_def_id.to_def_id()
62+
};
63+
self.ctx.register_def(target_def_id, expected);
5464
}
5565
}
5666

src/analyze/local_def.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet;
66
use rustc_index::IndexVec;
77
use rustc_middle::mir::{self, BasicBlock, Body, Local};
88
use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut};
9-
use rustc_span::def_id::LocalDefId;
9+
use rustc_span::def_id::{DefId, LocalDefId};
1010
use rustc_span::symbol::Ident;
1111

1212
use crate::analyze;
@@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
126126
.is_some()
127127
}
128128

129+
pub fn is_annotated_as_extern_spec_fn(&self) -> bool {
130+
self.tcx
131+
.get_attrs_by_path(
132+
self.local_def_id.to_def_id(),
133+
&analyze::annot::extern_spec_fn_path(),
134+
)
135+
.next()
136+
.is_some()
137+
}
138+
129139
// TODO: unify this logic with extraction functions above
130140
pub fn is_fully_annotated(&self) -> bool {
131141
let has_require = self
@@ -240,6 +250,48 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
240250
rty::RefinedType::unrefined(builder.build().into())
241251
}
242252

253+
/// Extract the target DefId from `#[thrust::extern_spec_fn]` function.
254+
pub fn extern_spec_fn_target_def_id(&self) -> DefId {
255+
struct ExtractDefId<'tcx> {
256+
tcx: TyCtxt<'tcx>,
257+
outer_def_id: LocalDefId,
258+
inner_def_id: Option<DefId>,
259+
}
260+
261+
impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> {
262+
type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
263+
264+
fn nested_visit_map(&mut self) -> Self::Map {
265+
self.tcx.hir()
266+
}
267+
268+
fn visit_qpath(
269+
&mut self,
270+
qpath: &rustc_hir::QPath<'tcx>,
271+
hir_id: rustc_hir::HirId,
272+
_span: rustc_span::Span,
273+
) {
274+
let typeck_result = self.tcx.typeck(self.outer_def_id);
275+
if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id)
276+
{
277+
assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn");
278+
self.inner_def_id = Some(def_id);
279+
}
280+
}
281+
}
282+
283+
use rustc_hir::intravisit::Visitor as _;
284+
let mut visitor = ExtractDefId {
285+
tcx: self.tcx,
286+
outer_def_id: self.local_def_id,
287+
inner_def_id: None,
288+
};
289+
if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) {
290+
visitor.visit_item(item);
291+
}
292+
visitor.inner_def_id.expect("invalid extern_spec_fn")
293+
}
294+
243295
fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool {
244296
let param_local = analyze::local_of_function_param(param_idx);
245297
self.body.local_decls[param_local].mutability.is_mut()

tests/ui/fail/extern_spec_take.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[thrust::extern_spec_fn]
4+
#[thrust::requires(true)]
5+
#[thrust::ensures(result == *dest && ^dest == 0)]
6+
fn _extern_spec_take(dest: &mut i32) -> i32 {
7+
std::mem::take(dest)
8+
}
9+
10+
fn main() {
11+
let mut x = 42;
12+
let old = std::mem::take(&mut x);
13+
assert!(x == 42);
14+
}

tests/ui/pass/extern_spec_take.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//@check-pass
2+
3+
#[thrust::extern_spec_fn]
4+
#[thrust::requires(true)]
5+
#[thrust::ensures(result == *dest && ^dest == 0)]
6+
fn _extern_spec_take(dest: &mut i32) -> i32 {
7+
std::mem::take(dest)
8+
}
9+
10+
fn main() {
11+
let mut x = 42;
12+
let old = std::mem::take(&mut x);
13+
assert!(old == 42);
14+
assert!(x == 0);
15+
}

0 commit comments

Comments
 (0)