Skip to content

Commit f8e940c

Browse files
committed
resolve instances in ctfe based on current const context
1 parent 54f67d2 commit f8e940c

10 files changed

Lines changed: 192 additions & 20 deletions

File tree

compiler/rustc_const_eval/src/const_eval/machine.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,8 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> {
935935

936936
fn get_default_alloc_params(&self) -> <Self::Bytes as mir::interpret::AllocBytes>::AllocParams {
937937
}
938+
939+
const SHOULD_RESPECT_CONST_BOUNDS_WHEN_RESOLVING_INSTANCES: bool = true;
938940
}
939941

940942
// Please do not add any code below the above `Machine` trait impl. I (oli-obk) plan more cleanups

compiler/rustc_const_eval/src/interpret/eval_context.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
370370
trace!("resolve: {:?}, {:#?}", def, args);
371371
trace!("typing_env: {:#?}", self.typing_env);
372372
trace!("args: {:#?}", args);
373-
match ty::Instance::try_resolve(*self.tcx, self.typing_env, def, args) {
373+
let resolve = if M::SHOULD_RESPECT_CONST_BOUNDS_WHEN_RESOLVING_INSTANCES {
374+
ty::Instance::try_resolve_for_ctfe
375+
} else {
376+
ty::Instance::try_resolve
377+
};
378+
match resolve(*self.tcx, self.typing_env, def, args) {
374379
Ok(Some(instance)) => interp_ok(instance),
375380
Ok(None) => throw_inval!(TooGeneric),
376381

compiler/rustc_const_eval/src/interpret/machine.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,8 @@ pub trait Machine<'tcx>: Sized {
638638
fn enter_trace_span(_span: impl FnOnce() -> tracing::Span) -> impl EnteredTraceSpan {
639639
()
640640
}
641+
642+
const SHOULD_RESPECT_CONST_BOUNDS_WHEN_RESOLVING_INSTANCES: bool = false;
641643
}
642644

643645
/// A lot of the flexibility above is just needed for `Miri`, but all "compile-time" machines

compiler/rustc_hir/src/hir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4521,7 +4521,7 @@ impl fmt::Display for Safety {
45214521
}
45224522
}
45234523

4524-
#[derive(Copy, Clone, PartialEq, Eq, Debug, Encodable, Decodable, StableHash)]
4524+
#[derive(Copy, Clone, PartialEq, Eq, Debug, Encodable, Decodable, StableHash, Hash)]
45254525
#[derive(Default)]
45264526
pub enum Constness {
45274527
#[default]

compiler/rustc_middle/src/queries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2622,7 +2622,7 @@ rustc_queries! {
26222622
/// from `Ok(None)` to avoid misleading diagnostics when an error
26232623
/// has already been/will be emitted, for the original cause.
26242624
query resolve_instance_raw(
2625-
key: ty::PseudoCanonicalInput<'tcx, (DefId, GenericArgsRef<'tcx>)>
2625+
key: ty::PseudoCanonicalInput<'tcx, (DefId, GenericArgsRef<'tcx>, hir::Constness)>
26262626
) -> Result<Option<ty::Instance<'tcx>>, ErrorGuaranteed> {
26272627
desc { "resolving instance `{}`", ty::Instance::new_raw(key.value.0, key.value.1) }
26282628
}

compiler/rustc_middle/src/query/keys.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::hash::Hash;
66

77
use rustc_ast::tokenstream::TokenStream;
88
use rustc_data_structures::stable_hasher::StableHash;
9+
use rustc_hir as hir;
910
use rustc_hir::def_id::{CrateNum, DefId, LOCAL_CRATE, LocalDefId, LocalModDefId};
1011
use rustc_hir::hir_id::OwnerId;
1112
use rustc_span::{DUMMY_SP, Ident, LocalExpnId, Span, Symbol};
@@ -243,6 +244,12 @@ impl<'tcx> QueryKey for (DefId, GenericArgsRef<'tcx>) {
243244
}
244245
}
245246

247+
impl<'tcx> QueryKey for (DefId, GenericArgsRef<'tcx>, hir::Constness) {
248+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
249+
self.0.default_span(tcx)
250+
}
251+
}
252+
246253
impl<'tcx> QueryKey for ty::TraitRef<'tcx> {
247254
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
248255
tcx.def_span(self.def_id)

compiler/rustc_middle/src/ty/instance.rs

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,16 @@ impl<'tcx> Instance<'tcx> {
494494
/// couldn't complete due to errors elsewhere - this is distinct
495495
/// from `Ok(None)` to avoid misleading diagnostics when an error
496496
/// has already been/will be emitted, for the original cause
497-
#[instrument(level = "debug", skip(tcx), ret)]
498-
pub fn try_resolve(
497+
fn try_resolve_inner(
499498
tcx: TyCtxt<'tcx>,
500499
typing_env: ty::TypingEnv<'tcx>,
501500
def_id: DefId,
502501
args: GenericArgsRef<'tcx>,
502+
mut constness: hir::Constness,
503503
) -> Result<Option<Instance<'tcx>>, ErrorGuaranteed> {
504+
let def_kind = tcx.def_kind(def_id);
504505
assert_matches!(
505-
tcx.def_kind(def_id),
506+
def_kind,
506507
DefKind::Fn
507508
| DefKind::AssocFn
508509
| DefKind::Const { .. }
@@ -518,6 +519,10 @@ impl<'tcx> Instance<'tcx> {
518519
`try_normalize_erasing_regions`."
519520
);
520521

522+
if !def_kind.is_assoc() {
523+
constness = hir::Constness::NotConst;
524+
}
525+
521526
// Rust code can easily create exponentially-long types using only a
522527
// polynomial recursion depth. Even with the default recursion
523528
// depth, you can easily get cases that take >2^60 steps to run,
@@ -530,11 +535,36 @@ impl<'tcx> Instance<'tcx> {
530535
return Ok(None);
531536
}
532537

538+
let input = tcx.erase_and_anonymize_regions(typing_env.as_query_input((def_id, args)));
539+
533540
// All regions in the result of this query are erased, so it's
534541
// fine to erase all of the input regions.
535-
tcx.resolve_instance_raw(
536-
tcx.erase_and_anonymize_regions(typing_env.as_query_input((def_id, args))),
537-
)
542+
tcx.resolve_instance_raw(ty::PseudoCanonicalInput {
543+
typing_env: input.typing_env,
544+
value: (input.value.0, input.value.1, constness),
545+
})
546+
}
547+
548+
/// See `try_resolve_inner`.
549+
#[instrument(level = "debug", skip(tcx), ret)]
550+
pub fn try_resolve(
551+
tcx: TyCtxt<'tcx>,
552+
typing_env: ty::TypingEnv<'tcx>,
553+
def_id: DefId,
554+
args: GenericArgsRef<'tcx>,
555+
) -> Result<Option<Instance<'tcx>>, ErrorGuaranteed> {
556+
Self::try_resolve_inner(tcx, typing_env, def_id, args, hir::Constness::NotConst)
557+
}
558+
559+
/// See `try_resolve_inner`.
560+
#[instrument(level = "debug", skip(tcx), ret)]
561+
pub fn try_resolve_for_ctfe(
562+
tcx: TyCtxt<'tcx>,
563+
typing_env: ty::TypingEnv<'tcx>,
564+
def_id: DefId,
565+
args: GenericArgsRef<'tcx>,
566+
) -> Result<Option<Instance<'tcx>>, ErrorGuaranteed> {
567+
Self::try_resolve_inner(tcx, typing_env, def_id, args, hir::Constness::Const)
538568
}
539569

540570
pub fn expect_resolve(

compiler/rustc_trait_selection/src/solve/select.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_infer::traits::{
88
Selection, SelectionError, SelectionResult, TraitObligation,
99
};
1010
use rustc_macros::extension;
11-
use rustc_middle::{bug, span_bug};
11+
use rustc_middle::{bug, span_bug, ty};
1212
use rustc_span::Span;
1313
use thin_vec::thin_vec;
1414

@@ -30,6 +30,19 @@ impl<'tcx> InferCtxt<'tcx> {
3030
.break_value()
3131
.unwrap()
3232
}
33+
fn select_host_effect_predicate_in_new_trait_solver(
34+
&self,
35+
obligation: &Obligation<'tcx, ty::HostEffectPredicate<'tcx>>,
36+
) -> SelectionResult<'tcx, Selection<'tcx>> {
37+
assert!(self.next_trait_solver());
38+
39+
self.visit_proof_tree(
40+
Goal::new(self.tcx, obligation.param_env, ty::Binder::dummy(obligation.predicate)),
41+
&mut Select { span: obligation.cause.span },
42+
)
43+
.break_value()
44+
.unwrap()
45+
}
3346
}
3447

3548
struct Select {

compiler/rustc_ty_utils/src/instance.rs

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use rustc_errors::ErrorGuaranteed;
2-
use rustc_hir::LangItem;
32
use rustc_hir::def_id::DefId;
3+
use rustc_hir::{Constness, LangItem};
44
use rustc_infer::infer::TyCtxtInferExt;
5+
use rustc_infer::traits::{
6+
ImplSource, Obligation, ObligationCause, ScrubbedTraitError, SelectionError,
7+
};
58
use rustc_middle::bug;
69
use rustc_middle::query::Providers;
710
use rustc_middle::traits::{BuiltinImplSource, CodegenObligationError};
@@ -10,17 +13,19 @@ use rustc_middle::ty::{
1013
Unnormalized,
1114
};
1215
use rustc_span::sym;
13-
use rustc_trait_selection::traits;
16+
use rustc_trait_selection::error_reporting::InferCtxtErrorExt as _;
17+
use rustc_trait_selection::solve::InferCtxtSelectExt as _;
18+
use rustc_trait_selection::traits::{self, ObligationCtxt};
1419
use tracing::debug;
1520
use traits::translate_args;
1621

1722
use crate::errors::UnexpectedFnPtrAssociatedItem;
1823

1924
fn resolve_instance_raw<'tcx>(
2025
tcx: TyCtxt<'tcx>,
21-
key: ty::PseudoCanonicalInput<'tcx, (DefId, GenericArgsRef<'tcx>)>,
26+
key: ty::PseudoCanonicalInput<'tcx, (DefId, GenericArgsRef<'tcx>, Constness)>,
2227
) -> Result<Option<Instance<'tcx>>, ErrorGuaranteed> {
23-
let PseudoCanonicalInput { typing_env, value: (def_id, args) } = key;
28+
let PseudoCanonicalInput { typing_env, value: (def_id, args, constness) } = key;
2429

2530
let result = if let Some(trait_def_id) = tcx.trait_of_assoc(def_id) {
2631
debug!(" => associated item, attempting to find impl in typing_env {:#?}", typing_env);
@@ -30,6 +35,7 @@ fn resolve_instance_raw<'tcx>(
3035
typing_env,
3136
trait_def_id,
3237
tcx.normalize_erasing_regions(typing_env, Unnormalized::new_wip(args)),
38+
constness,
3339
)
3440
} else {
3541
let def = if tcx.intrinsic(def_id).is_some() {
@@ -101,24 +107,88 @@ fn resolve_instance_raw<'tcx>(
101107
result
102108
}
103109

110+
#[tracing::instrument(level = "debug", skip(tcx))]
104111
fn resolve_associated_item<'tcx>(
105112
tcx: TyCtxt<'tcx>,
106113
trait_item_id: DefId,
107114
typing_env: ty::TypingEnv<'tcx>,
108115
trait_id: DefId,
109116
rcvr_args: GenericArgsRef<'tcx>,
117+
constness: Constness,
110118
) -> Result<Option<Instance<'tcx>>, ErrorGuaranteed> {
111-
debug!(?trait_item_id, ?typing_env, ?trait_id, ?rcvr_args, "resolve_associated_item");
112-
113119
let trait_ref = ty::TraitRef::from_assoc(tcx, trait_id, rcvr_args);
114120

115121
let input = typing_env.as_query_input(trait_ref);
116-
let vtbl = match tcx.codegen_select_candidate(input) {
117-
Ok(vtbl) => vtbl,
118-
Err(CodegenObligationError::Ambiguity | CodegenObligationError::Unimplemented) => {
122+
let vtbl = if constness == Constness::Const && tcx.next_trait_solver_globally() {
123+
let (infcx, param_env) =
124+
tcx.infer_ctxt().ignoring_regions().build_with_typing_env(typing_env);
125+
126+
let obligation_cause = ObligationCause::dummy();
127+
let host_predicate =
128+
ty::HostEffectPredicate { trait_ref, constness: ty::BoundConstness::Const };
129+
let obligation = Obligation::new(tcx, obligation_cause, param_env, host_predicate);
130+
131+
let selection = match infcx.select_host_effect_predicate_in_new_trait_solver(&obligation) {
132+
Ok(Some(selection)) => selection,
133+
Ok(None) => return Ok(None),
134+
Err(SelectionError::Unimplemented) => return Ok(None),
135+
Err(e) => {
136+
bug!("Encountered error `{:?}` selecting `{:?}` during codegen", e, host_predicate)
137+
}
138+
};
139+
140+
debug!(?selection);
141+
142+
// Currently, we use a fulfillment context to completely resolve
143+
// all nested obligations. This is because they can inform the
144+
// inference of the impl's type parameters.
145+
let ocx = ObligationCtxt::new(&infcx);
146+
let impl_source = selection.map(|obligation| {
147+
ocx.register_obligation(obligation);
148+
});
149+
150+
// In principle, we only need to do this so long as `impl_source`
151+
// contains unbound type parameters. It could be a slight
152+
// optimization to stop iterating early.
153+
let errors = ocx.evaluate_obligations_error_on_ambiguity();
154+
if !errors.is_empty() {
155+
// `rustc_monomorphize::collector` assumes there are no type errors.
156+
// Cycle errors are the only post-monomorphization errors possible; emit them now so
157+
// `rustc_ty_utils::resolve_associated_item` doesn't return `None` post-monomorphization.
158+
for err in errors {
159+
if let ScrubbedTraitError::Cycle(cycle) = err {
160+
infcx.err_ctxt().report_overflow_obligation_cycle(&cycle);
161+
}
162+
}
119163
return Ok(None);
120164
}
121-
Err(CodegenObligationError::UnconstrainedParam(guar)) => return Err(guar),
165+
166+
let impl_source = infcx.resolve_vars_if_possible(impl_source);
167+
let impl_source = tcx.erase_and_anonymize_regions(impl_source);
168+
if impl_source.has_non_region_infer() {
169+
// Unused generic types or consts on an impl get replaced with inference vars,
170+
// but never resolved, causing the return value of a query to contain inference
171+
// vars. We do not have a concept for this and will in fact ICE in stable hashing
172+
// of the return value. So bail out instead.
173+
let guar = match impl_source {
174+
ImplSource::UserDefined(impl_) => tcx.dcx().span_delayed_bug(
175+
tcx.def_span(impl_.impl_def_id),
176+
"this impl has unconstrained generic parameters",
177+
),
178+
_ => unreachable!(),
179+
};
180+
return Err(guar);
181+
}
182+
183+
&*tcx.arena.alloc(impl_source)
184+
} else {
185+
match tcx.codegen_select_candidate(input) {
186+
Ok(vtbl) => vtbl,
187+
Err(CodegenObligationError::Ambiguity | CodegenObligationError::Unimplemented) => {
188+
return Ok(None);
189+
}
190+
Err(CodegenObligationError::UnconstrainedParam(guar)) => return Err(guar),
191+
}
122192
};
123193

124194
// Now that we know which impl is being used, we can dispatch to
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//@ compile-flags: -Znext-solver
2+
//@ check-pass
3+
4+
#![feature(min_specialization, const_trait_impl)]
5+
6+
struct Dummy;
7+
8+
const trait DummyTrait {
9+
fn dummy_fn() -> u32;
10+
}
11+
impl DummyTrait for Wrap<i32> {
12+
fn dummy_fn() -> u32 {
13+
println!("wut");
14+
0
15+
}
16+
}
17+
18+
const trait Trait {
19+
fn trait_fn() -> u32;
20+
}
21+
impl<T> const Trait for T where T: DummyTrait {
22+
default fn trait_fn() -> u32 {
23+
42
24+
}
25+
}
26+
27+
struct Wrap<T>(T);
28+
29+
impl<T> const Trait for Wrap<T> where Self: [const] DummyTrait {
30+
fn trait_fn() -> u32 {
31+
<Wrap<T>>::dummy_fn()
32+
}
33+
}
34+
35+
const fn indirect<T: DummyTrait>() -> u32 {
36+
T::trait_fn()
37+
}
38+
39+
const A: u32 = indirect::<Wrap<i32>>();
40+
41+
const B: () = { assert!(A == 42); };
42+
43+
fn main() {}

0 commit comments

Comments
 (0)