Skip to content

Commit c021d2d

Browse files
beetreesfolkertdev
authored andcommitted
Fallback {float} to f32 when f32: From<{float}>
1 parent 7db0ab4 commit c021d2d

9 files changed

Lines changed: 207 additions & 5 deletions

File tree

compiler/rustc_hir/src/lang_items.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,9 @@ language_item_table! {
443443
FieldBase, sym::field_base, field_base, Target::AssocTy, GenericRequirement::Exact(0);
444444
FieldType, sym::field_type, field_type, Target::AssocTy, GenericRequirement::Exact(0);
445445
FieldOffset, sym::field_offset, field_offset, Target::AssocConst, GenericRequirement::Exact(0);
446+
447+
// Used to fallback `{float}` to `f32` when `f32: From<{float}>`
448+
From, sym::From, from_trait, Target::Trait, GenericRequirement::Exact(1);
446449
}
447450

448451
/// The requirement imposed on the generics of a lang item

compiler/rustc_hir_typeck/src/fallback.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_hir::attrs::DivergingFallbackBehavior;
1111
use rustc_hir::def::{DefKind, Res};
1212
use rustc_hir::def_id::DefId;
1313
use rustc_hir::intravisit::{InferKind, Visitor};
14-
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
14+
use rustc_middle::ty::{self, FloatVid, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
1515
use rustc_session::lint;
1616
use rustc_span::def_id::LocalDefId;
1717
use rustc_span::{DUMMY_SP, Span};
@@ -55,15 +55,20 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
5555

5656
let (diverging_fallback, diverging_fallback_ty) =
5757
self.calculate_diverging_fallback(&unresolved_variables);
58+
let fallback_to_f32 = self.calculate_fallback_to_f32(&unresolved_variables);
5859

5960
// We do fallback in two passes, to try to generate
6061
// better error messages.
6162
// The first time, we do *not* replace opaque types.
6263
let mut fallback_occurred = false;
6364
for ty in unresolved_variables {
6465
debug!("unsolved_variable = {:?}", ty);
65-
fallback_occurred |=
66-
self.fallback_if_possible(ty, &diverging_fallback, diverging_fallback_ty);
66+
fallback_occurred |= self.fallback_if_possible(
67+
ty,
68+
&diverging_fallback,
69+
diverging_fallback_ty,
70+
&fallback_to_f32,
71+
);
6772
}
6873

6974
fallback_occurred
@@ -73,7 +78,8 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
7378
///
7479
/// - Unconstrained ints are replaced with `i32`.
7580
///
76-
/// - Unconstrained floats are replaced with `f64`.
81+
/// - Unconstrained floats are replaced with `f64`, except when there is a trait predicate
82+
/// `f32: From<{float}>`, in which case `f32` is used as the fallback instead.
7783
///
7884
/// - Non-numerics may get replaced with `()` or `!`, depending on how they
7985
/// were categorized by [`Self::calculate_diverging_fallback`], crate's
@@ -89,6 +95,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
8995
ty: Ty<'tcx>,
9096
diverging_fallback: &UnordSet<Ty<'tcx>>,
9197
diverging_fallback_ty: Ty<'tcx>,
98+
fallback_to_f32: &UnordSet<FloatVid>,
9299
) -> bool {
93100
// Careful: we do NOT shallow-resolve `ty`. We know that `ty`
94101
// is an unsolved variable, and we determine its fallback
@@ -111,6 +118,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
111118
let fallback = match ty.kind() {
112119
_ if let Some(e) = self.tainted_by_errors() => Ty::new_error(self.tcx, e),
113120
ty::Infer(ty::IntVar(_)) => self.tcx.types.i32,
121+
ty::Infer(ty::FloatVar(vid)) if fallback_to_f32.contains(vid) => self.tcx.types.f32,
114122
ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64,
115123
_ if diverging_fallback.contains(&ty) => {
116124
self.diverging_fallback_has_occurred.set(true);
@@ -125,6 +133,38 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
125133
true
126134
}
127135

136+
/// Existing code relies on `f32: From<T>` (usually written as `T: Into<f32>`) resolving `T` to
137+
/// `f32` when the type of `T` is inferred from an unsuffixed float literal. Using the default
138+
/// fallback of `f64`, this would break when adding `impl From<f16> for f32`, as there are now
139+
/// two float type which could be `T`, meaning that the fallback of `f64` would be used and
140+
/// compilation error would occur as `f32` does not implement `From<f64>`. To avoid breaking
141+
/// existing code, we instead fallback `T` to `f32` when there is a trait predicate
142+
/// `f32: From<T>`. This means code like the following will continue to compile:
143+
///
144+
/// ```rust
145+
/// fn foo<T: Into<f32>>(_: T) {}
146+
///
147+
/// foo(1.0);
148+
/// ```
149+
fn calculate_fallback_to_f32(&self, unresolved_variables: &[Ty<'tcx>]) -> UnordSet<FloatVid> {
150+
let roots: UnordSet<ty::FloatVid> = self.from_float_for_f32_root_vids();
151+
if roots.is_empty() {
152+
// Most functions have no `f32: From<{float}>` predicates, so short-circuit and return
153+
// an empty set when this is the case.
154+
return UnordSet::new();
155+
}
156+
// Calculate all the unresolved variables that need to fallback to `f32` here. This ensures
157+
// we don't need to find root variables in `fallback_if_possible`: see the comment at the
158+
// top of that function for details.
159+
let fallback_to_f32 = unresolved_variables
160+
.iter()
161+
.flat_map(|ty| ty.float_vid())
162+
.filter(|vid| roots.contains(&self.root_float_var(*vid)))
163+
.collect();
164+
debug!("calculate_fallback_to_f32: fallback_to_f32={:?}", fallback_to_f32);
165+
fallback_to_f32
166+
}
167+
128168
fn calculate_diverging_fallback(
129169
&self,
130170
unresolved_variables: &[Ty<'tcx>],
@@ -362,6 +402,11 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
362402
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
363403
}
364404

405+
/// If `ty` is an unresolved float type variable, returns its root vid.
406+
pub(crate) fn root_float_vid(&self, ty: Ty<'tcx>) -> Option<ty::FloatVid> {
407+
Some(self.root_float_var(self.shallow_resolve(ty).float_vid()?))
408+
}
409+
365410
/// Given a set of diverging vids and coercions, walk the HIR to gather a
366411
/// set of suggestions which can be applied to preserve fallback to unit.
367412
fn try_to_suggest_annotations(

compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! A utility module to inspect currently ambiguous obligations in the current context.
22
3+
use rustc_data_structures::unord::UnordSet;
4+
use rustc_hir::def_id::DefId;
35
use rustc_infer::traits::{self, ObligationCause, PredicateObligations};
46
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
57
use rustc_span::Span;
@@ -96,6 +98,69 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
9698
});
9799
obligations_for_self_ty
98100
}
101+
102+
/// Only needed for the `From<{float}>` for `f32` type fallback.
103+
#[instrument(skip(self), level = "debug")]
104+
pub(crate) fn from_float_for_f32_root_vids(&self) -> UnordSet<ty::FloatVid> {
105+
if self.next_trait_solver() {
106+
self.from_float_for_f32_root_vids_next()
107+
} else {
108+
let Some(from_trait) = self.tcx.lang_items().from_trait() else {
109+
return UnordSet::new();
110+
};
111+
self.fulfillment_cx
112+
.borrow_mut()
113+
.pending_obligations()
114+
.into_iter()
115+
.filter_map(|obligation| {
116+
self.predicate_from_float_for_f32_root_vid(from_trait, obligation.predicate)
117+
})
118+
.collect()
119+
}
120+
}
121+
122+
fn predicate_from_float_for_f32_root_vid(
123+
&self,
124+
from_trait: DefId,
125+
predicate: ty::Predicate<'tcx>,
126+
) -> Option<ty::FloatVid> {
127+
// The predicates we are looking for look like
128+
// `TraitPredicate(<f32 as std::convert::From<{float}>>, polarity:Positive)`.
129+
// They will have no bound variables.
130+
match predicate.kind().no_bound_vars() {
131+
Some(ty::PredicateKind::Clause(ty::ClauseKind::Trait(ty::TraitPredicate {
132+
polarity: ty::PredicatePolarity::Positive,
133+
trait_ref,
134+
}))) if trait_ref.def_id == from_trait
135+
&& self.shallow_resolve(trait_ref.self_ty()).kind()
136+
== &ty::Float(ty::FloatTy::F32) =>
137+
{
138+
self.root_float_vid(trait_ref.args.type_at(1))
139+
}
140+
_ => None,
141+
}
142+
}
143+
144+
fn from_float_for_f32_root_vids_next(&self) -> UnordSet<ty::FloatVid> {
145+
let Some(from_trait) = self.tcx.lang_items().from_trait() else {
146+
return UnordSet::new();
147+
};
148+
let obligations = self.fulfillment_cx.borrow().pending_obligations();
149+
debug!(?obligations);
150+
let mut vids = UnordSet::new();
151+
for obligation in obligations {
152+
let mut visitor = FindFromFloatForF32RootVids {
153+
fcx: self,
154+
from_trait,
155+
vids: &mut vids,
156+
span: obligation.cause.span,
157+
};
158+
159+
let goal = obligation.as_goal();
160+
self.visit_proof_tree(goal, &mut visitor);
161+
}
162+
vids
163+
}
99164
}
100165

101166
struct NestedObligationsForSelfTy<'a, 'tcx> {
@@ -105,7 +170,7 @@ struct NestedObligationsForSelfTy<'a, 'tcx> {
105170
obligations_for_self_ty: &'a mut PredicateObligations<'tcx>,
106171
}
107172

108-
impl<'a, 'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'a, 'tcx> {
173+
impl<'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'_, 'tcx> {
109174
fn span(&self) -> Span {
110175
self.root_cause.span
111176
}
@@ -144,3 +209,37 @@ impl<'a, 'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'a, 'tcx> {
144209
}
145210
}
146211
}
212+
213+
struct FindFromFloatForF32RootVids<'a, 'tcx> {
214+
fcx: &'a FnCtxt<'a, 'tcx>,
215+
from_trait: DefId,
216+
vids: &'a mut UnordSet<ty::FloatVid>,
217+
span: Span,
218+
}
219+
220+
impl<'tcx> ProofTreeVisitor<'tcx> for FindFromFloatForF32RootVids<'_, 'tcx> {
221+
fn span(&self) -> Span {
222+
self.span
223+
}
224+
225+
fn config(&self) -> InspectConfig {
226+
// Avoid hang from exponentially growing proof trees (see `cycle-modulo-ambig-aliases.rs`).
227+
// 3 is more than enough for all occurences in practice (a.k.a. `Into`).
228+
InspectConfig { max_depth: 3 }
229+
}
230+
231+
fn visit_goal(&mut self, inspect_goal: &InspectGoal<'_, 'tcx>) {
232+
if let Some(vid) = self
233+
.fcx
234+
.predicate_from_float_for_f32_root_vid(self.from_trait, inspect_goal.goal().predicate)
235+
{
236+
self.vids.insert(vid);
237+
} else if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
238+
let start_len = self.vids.len();
239+
let _ = candidate.goal().infcx().commit_if_ok(|_| {
240+
candidate.visit_nested_no_probe(self);
241+
if self.vids.len() > start_len { Ok(()) } else { Err(()) }
242+
});
243+
}
244+
}
245+
}

compiler/rustc_infer/src/infer/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,10 @@ impl<'tcx> InferCtxt<'tcx> {
11661166
self.inner.borrow_mut().type_variables().sub_unification_table_root_var(var)
11671167
}
11681168

1169+
pub fn root_float_var(&self, var: ty::FloatVid) -> ty::FloatVid {
1170+
self.inner.borrow_mut().float_unification_table().find(var)
1171+
}
1172+
11691173
pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid {
11701174
self.inner.borrow_mut().const_unification_table().find(var).vid
11711175
}

compiler/rustc_middle/src/ty/sty.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,14 @@ impl<'tcx> Ty<'tcx> {
11801180
}
11811181
}
11821182

1183+
#[inline]
1184+
pub fn float_vid(self) -> Option<ty::FloatVid> {
1185+
match self.kind() {
1186+
&Infer(FloatVar(vid)) => Some(vid),
1187+
_ => None,
1188+
}
1189+
}
1190+
11831191
#[inline]
11841192
pub fn is_ty_or_numeric_infer(self) -> bool {
11851193
matches!(self.kind(), Infer(_))

library/core/src/convert/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ pub const trait Into<T>: Sized {
577577
/// [`from`]: From::from
578578
/// [book]: ../../book/ch09-00-error-handling.html
579579
#[rustc_diagnostic_item = "From"]
580+
#[lang = "From"]
580581
#[stable(feature = "rust1", since = "1.0.0")]
581582
#[rustc_on_unimplemented(on(
582583
all(Self = "&str", T = "alloc::string::String"),

tests/ui/float/f32-into-f32.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ revisions: old-solver next-solver
2+
//@[next-solver] compile-flags: -Znext-solver
3+
//@ run-pass
4+
5+
fn foo(_: impl Into<f32>) {}
6+
7+
fn main() {
8+
foo(1.0);
9+
}

tests/ui/float/trait-f16-or-f32.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//@ check-fail
2+
3+
#![feature(f16)]
4+
5+
trait Trait {}
6+
impl Trait for f16 {}
7+
impl Trait for f32 {}
8+
9+
fn foo(_: impl Trait) {}
10+
11+
fn main() {
12+
foo(1.0); //~ ERROR the trait bound `f64: Trait` is not satisfied
13+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
error[E0277]: the trait bound `f64: Trait` is not satisfied
2+
--> $DIR/trait-f16-or-f32.rs:12:9
3+
|
4+
LL | foo(1.0);
5+
| --- ^^^ the trait `Trait` is not implemented for `f64`
6+
| |
7+
| required by a bound introduced by this call
8+
|
9+
= help: the following other types implement trait `Trait`:
10+
f16
11+
f32
12+
note: required by a bound in `foo`
13+
--> $DIR/trait-f16-or-f32.rs:9:16
14+
|
15+
LL | fn foo(_: impl Trait) {}
16+
| ^^^^^ required by this bound in `foo`
17+
18+
error: aborting due to 1 previous error
19+
20+
For more information about this error, try `rustc --explain E0277`.

0 commit comments

Comments
 (0)