Skip to content

Commit ed36445

Browse files
committed
Implement #[thrust::extern_spec_fn]
1 parent b41df50 commit ed36445

3 files changed

Lines changed: 66 additions & 2 deletions

File tree

src/analyze/annot.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] {
3333
[Symbol::intern("thrust"), Symbol::intern("callable")]
3434
}
3535

36+
pub fn extern_spec_fn_path() -> [Symbol; 2] {
37+
[Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")]
38+
}
39+
3640
/// A [`annot::Resolver`] implementation for resolving function parameters.
3741
///
3842
/// The parameter names and their sorts needs to be configured via

src/analyze/crate_.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
4646
self.trusted.insert(local_def_id.to_def_id());
4747
}
4848

49+
if analyzer.is_annotated_as_extern_spec_fn() {
50+
assert!(analyzer.is_fully_annotated());
51+
self.trusted.insert(local_def_id.to_def_id());
52+
}
53+
4954
let sig = self
5055
.tcx
5156
.fn_sig(local_def_id)
@@ -56,7 +61,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
5661
self.ctx.register_deferred_def(local_def_id.to_def_id());
5762
} else {
5863
let expected = analyzer.expected_ty();
59-
self.ctx.register_def(local_def_id.to_def_id(), expected);
64+
let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() {
65+
analyzer.extern_spec_fn_target_def_id()
66+
} else {
67+
local_def_id.to_def_id()
68+
};
69+
self.ctx.register_def(target_def_id, expected);
6070
}
6171
}
6272

src/analyze/local_def.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet;
66
use rustc_index::IndexVec;
77
use rustc_middle::mir::{self, BasicBlock, Body, Local};
88
use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut};
9-
use rustc_span::def_id::LocalDefId;
9+
use rustc_span::def_id::{DefId, LocalDefId};
1010
use rustc_span::symbol::Ident;
1111

1212
use crate::analyze;
@@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
126126
.is_some()
127127
}
128128

129+
pub fn is_annotated_as_extern_spec_fn(&self) -> bool {
130+
self.tcx
131+
.get_attrs_by_path(
132+
self.local_def_id.to_def_id(),
133+
&analyze::annot::extern_spec_fn_path(),
134+
)
135+
.next()
136+
.is_some()
137+
}
138+
129139
// TODO: unify this logic with extraction functions above
130140
pub fn is_fully_annotated(&self) -> bool {
131141
let has_require = self
@@ -240,6 +250,46 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
240250
rty::RefinedType::unrefined(builder.build().into())
241251
}
242252

253+
pub fn extern_spec_fn_target_def_id(&self) -> DefId {
254+
struct ExtractDefId<'tcx> {
255+
tcx: TyCtxt<'tcx>,
256+
outer_def_id: LocalDefId,
257+
inner_def_id: Option<DefId>,
258+
}
259+
260+
impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> {
261+
type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
262+
263+
fn nested_visit_map(&mut self) -> Self::Map {
264+
self.tcx.hir()
265+
}
266+
267+
fn visit_qpath(
268+
&mut self,
269+
qpath: &rustc_hir::QPath<'tcx>,
270+
hir_id: rustc_hir::HirId,
271+
_span: rustc_span::Span,
272+
) {
273+
let typeck_result = self.tcx.typeck(self.outer_def_id);
274+
if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id)
275+
{
276+
self.inner_def_id = Some(def_id);
277+
}
278+
}
279+
}
280+
281+
use rustc_hir::intravisit::Visitor as _;
282+
let mut visitor = ExtractDefId {
283+
tcx: self.tcx,
284+
outer_def_id: self.local_def_id,
285+
inner_def_id: None,
286+
};
287+
if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) {
288+
visitor.visit_item(item);
289+
}
290+
visitor.inner_def_id.expect("invalid extern_spec_fn")
291+
}
292+
243293
fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool {
244294
let param_local = analyze::local_of_function_param(param_idx);
245295
self.body.local_decls[param_local].mutability.is_mut()

0 commit comments

Comments
 (0)