Skip to content

Commit 8af02e2

Browse files
mGCA: Validate const literal against expected type
Co-authored-by: Boxy <rust@boxyuwu.dev>
1 parent 381e9ef commit 8af02e2

25 files changed

Lines changed: 402 additions & 131 deletions

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
3535
use rustc_infer::traits::DynCompatibilityViolation;
3636
use rustc_macros::{TypeFoldable, TypeVisitable};
3737
use rustc_middle::middle::stability::AllowUnstable;
38-
use rustc_middle::mir::interpret::LitToConstInput;
38+
use rustc_middle::mir::interpret::{LitToConstInput, const_lit_matches_ty};
3939
use rustc_middle::ty::print::PrintPolyTraitRefExt as _;
4040
use rustc_middle::ty::{
4141
self, Const, GenericArgKind, GenericArgsRef, GenericParamDefKind, Ty, TyCtxt,
@@ -2803,8 +2803,17 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
28032803
span: Span,
28042804
) -> Const<'tcx> {
28052805
let tcx = self.tcx();
2806+
if let LitKind::Err(guar) = *kind {
2807+
return ty::Const::new_error(tcx, guar);
2808+
}
28062809
let input = LitToConstInput { lit: *kind, ty, neg };
2807-
tcx.at(span).lit_to_const(input)
2810+
match tcx.at(span).lit_to_const(input) {
2811+
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
2812+
None => {
2813+
let e = tcx.dcx().span_err(span, "type annotations needed for the literal");
2814+
ty::Const::new_error(tcx, e)
2815+
}
2816+
}
28082817
}
28092818

28102819
#[instrument(skip(self), level = "debug")]
@@ -2833,11 +2842,15 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
28332842
_ => None,
28342843
};
28352844

2836-
lit_input
2837-
// Allow the `ty` to be an alias type, though we cannot handle it here, we just go through
2838-
// the more expensive anon const code path.
2839-
.filter(|l| !l.ty.has_aliases())
2840-
.map(|l| tcx.at(expr.span).lit_to_const(l))
2845+
lit_input.and_then(|l| {
2846+
if const_lit_matches_ty(tcx, &l.lit, l.ty, l.neg) {
2847+
tcx.at(expr.span)
2848+
.lit_to_const(l)
2849+
.map(|value| ty::Const::new_value(tcx, value.valtree, value.ty))
2850+
} else {
2851+
None
2852+
}
2853+
})
28412854
}
28422855

28432856
fn require_type_const_attribute(

compiler/rustc_middle/src/mir/interpret/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,44 @@ pub struct LitToConstInput<'tcx> {
8484
pub neg: bool,
8585
}
8686

87+
pub fn const_lit_matches_ty<'tcx>(
88+
tcx: TyCtxt<'tcx>,
89+
kind: &LitKind,
90+
ty: Ty<'tcx>,
91+
neg: bool,
92+
) -> bool {
93+
match (*kind, ty.kind()) {
94+
(LitKind::Str(..), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => true,
95+
(LitKind::Str(..), ty::Str) if tcx.features().deref_patterns() => true,
96+
(LitKind::ByteStr(..), ty::Ref(_, inner_ty, _))
97+
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
98+
&& matches!(ty.kind(), ty::Uint(ty::UintTy::U8)) =>
99+
{
100+
true
101+
}
102+
(LitKind::ByteStr(..), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
103+
if tcx.features().deref_patterns()
104+
&& matches!(inner_ty.kind(), ty::Uint(ty::UintTy::U8)) =>
105+
{
106+
true
107+
}
108+
(LitKind::Byte(..), ty::Uint(ty::UintTy::U8)) => true,
109+
(LitKind::CStr(..), ty::Ref(_, inner_ty, _))
110+
if matches!(inner_ty.kind(), ty::Adt(def, _)
111+
if tcx.is_lang_item(def.did(), rustc_hir::LangItem::CStr)) =>
112+
{
113+
true
114+
}
115+
(LitKind::Int(..), ty::Uint(_)) if !neg => true,
116+
(LitKind::Int(..), ty::Int(_)) => true,
117+
(LitKind::Bool(..), ty::Bool) => true,
118+
(LitKind::Float(..), ty::Float(_)) => true,
119+
(LitKind::Char(..), ty::Char) => true,
120+
(LitKind::Err(..), _) => true,
121+
_ => false,
122+
}
123+
}
124+
87125
#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
88126
pub struct AllocId(pub NonZero<u64>);
89127

compiler/rustc_middle/src/queries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ rustc_queries! {
14121412
// FIXME get rid of this with valtrees
14131413
query lit_to_const(
14141414
key: LitToConstInput<'tcx>
1415-
) -> ty::Const<'tcx> {
1415+
) -> Option<ty::Value<'tcx>> {
14161416
desc { "converting literal to const" }
14171417
}
14181418

compiler/rustc_middle/src/query/erase.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ impl Erasable for Option<ty::EarlyBinder<'_, Ty<'_>>> {
256256
type Storage = [u8; size_of::<Option<ty::EarlyBinder<'static, Ty<'static>>>>()];
257257
}
258258

259+
impl Erasable for Option<ty::Value<'_>> {
260+
type Storage = [u8; size_of::<Option<ty::Value<'static>>>()];
261+
}
262+
259263
impl Erasable for rustc_hir::MaybeOwner<'_> {
260264
type Storage = [u8; size_of::<rustc_hir::MaybeOwner<'static>>()];
261265
}

compiler/rustc_mir_build/src/thir/constant.rs

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ use rustc_ast::{self as ast, UintTy};
33
use rustc_hir::LangItem;
44
use rustc_middle::bug;
55
use rustc_middle::mir::interpret::LitToConstInput;
6-
use rustc_middle::ty::{self, ScalarInt, TyCtxt, TypeVisitableExt as _};
6+
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt, TypeVisitableExt as _};
77
use tracing::trace;
88

99
use crate::builder::parse_float_into_scalar;
1010

1111
pub(crate) fn lit_to_const<'tcx>(
1212
tcx: TyCtxt<'tcx>,
1313
lit_input: LitToConstInput<'tcx>,
14-
) -> ty::Const<'tcx> {
15-
let LitToConstInput { lit, ty, neg } = lit_input;
14+
) -> Option<ty::Value<'tcx>> {
15+
let LitToConstInput { lit, ty: expected_ty, neg } = lit_input;
1616

17-
if let Err(guar) = ty.error_reported() {
18-
return ty::Const::new_error(tcx, guar);
17+
if expected_ty.error_reported().is_err() {
18+
return None;
1919
}
2020

2121
let trunc = |n, width: ty::UintTy| {
@@ -32,63 +32,84 @@ pub(crate) fn lit_to_const<'tcx>(
3232
.unwrap_or_else(|| bug!("expected to create ScalarInt from uint {:?}", result))
3333
};
3434

35-
let valtree = match (lit, ty.kind()) {
36-
(ast::LitKind::Str(s, _), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => {
35+
let (valtree, valtree_ty) = match (lit, expected_ty.kind()) {
36+
(ast::LitKind::Str(s, _), _) => {
3737
let str_bytes = s.as_str().as_bytes();
38-
ty::ValTree::from_raw_bytes(tcx, str_bytes)
39-
}
40-
(ast::LitKind::Str(s, _), ty::Str) if tcx.features().deref_patterns() => {
41-
// String literal patterns may have type `str` if `deref_patterns` is enabled, in order
42-
// to allow `deref!("..."): String`.
43-
let str_bytes = s.as_str().as_bytes();
44-
ty::ValTree::from_raw_bytes(tcx, str_bytes)
38+
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, tcx.types.str_);
39+
(ty::ValTree::from_raw_bytes(tcx, str_bytes), valtree_ty)
4540
}
4641
(ast::LitKind::ByteStr(byte_sym, _), ty::Ref(_, inner_ty, _))
4742
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
4843
&& let ty::Uint(UintTy::U8) = ty.kind() =>
4944
{
50-
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
45+
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
5146
}
5247
(ast::LitKind::ByteStr(byte_sym, _), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
5348
if tcx.features().deref_patterns()
5449
&& let ty::Uint(UintTy::U8) = inner_ty.kind() =>
5550
{
5651
// Byte string literal patterns may have type `[u8]` or `[u8; N]` if `deref_patterns` is
5752
// enabled, in order to allow, e.g., `deref!(b"..."): Vec<u8>`.
58-
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
53+
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
5954
}
60-
(ast::LitKind::Byte(n), ty::Uint(ty::UintTy::U8)) => {
61-
ty::ValTree::from_scalar_int(tcx, n.into())
55+
(ast::LitKind::ByteStr(byte_sym, _), _) => {
56+
let valtree = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
57+
let valtree_ty = Ty::new_array(tcx, tcx.types.u8, byte_sym.as_byte_str().len() as u64);
58+
(valtree, valtree_ty)
6259
}
63-
(ast::LitKind::CStr(byte_sym, _), ty::Ref(_, inner_ty, _)) if matches!(inner_ty.kind(), ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::CStr)) =>
60+
(ast::LitKind::Byte(n), _) => (ty::ValTree::from_scalar_int(tcx, n.into()), tcx.types.u8),
61+
(ast::LitKind::CStr(byte_sym, _), _)
62+
if let Some(cstr_def_id) = tcx.lang_items().get(LangItem::CStr) =>
6463
{
6564
// A CStr is a newtype around a byte slice, so we create the inner slice here.
6665
// We need a branch for each "level" of the data structure.
66+
let cstr_ty = tcx.type_of(cstr_def_id).skip_binder();
6767
let bytes = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
68-
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, *inner_ty)])
68+
let valtree =
69+
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, cstr_ty)]);
70+
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, cstr_ty);
71+
(valtree, valtree_ty)
6972
}
70-
(ast::LitKind::Int(n, _), ty::Uint(ui)) if !neg => {
73+
(ast::LitKind::Int(n, ast::LitIntType::Unsigned(ui)), _) if !neg => {
74+
let scalar_int = trunc(n.get(), ui);
75+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, ui))
76+
}
77+
(ast::LitKind::Int(_, ast::LitIntType::Unsigned(_)), _) if neg => return None,
78+
(ast::LitKind::Int(n, ast::LitIntType::Signed(i)), _) => {
79+
let scalar_int =
80+
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
81+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, i))
82+
}
83+
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Uint(ui)) if !neg => {
7184
let scalar_int = trunc(n.get(), *ui);
72-
ty::ValTree::from_scalar_int(tcx, scalar_int)
85+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, *ui))
7386
}
74-
(ast::LitKind::Int(n, _), ty::Int(i)) => {
87+
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Int(i)) => {
7588
// Unsigned "negation" has the same bitwise effect as signed negation,
7689
// which gets the result we want without additional casts.
7790
let scalar_int =
7891
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
79-
ty::ValTree::from_scalar_int(tcx, scalar_int)
92+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, *i))
93+
}
94+
(ast::LitKind::Bool(b), _) => (ty::ValTree::from_scalar_int(tcx, b.into()), tcx.types.bool),
95+
(ast::LitKind::Float(n, ast::LitFloatType::Suffixed(fty)), _) => {
96+
let fty = match fty {
97+
ast::FloatTy::F16 => ty::FloatTy::F16,
98+
ast::FloatTy::F32 => ty::FloatTy::F32,
99+
ast::FloatTy::F64 => ty::FloatTy::F64,
100+
ast::FloatTy::F128 => ty::FloatTy::F128,
101+
};
102+
let bits = parse_float_into_scalar(n, fty, neg)?;
103+
(ty::ValTree::from_scalar_int(tcx, bits), Ty::new_float(tcx, fty))
80104
}
81-
(ast::LitKind::Bool(b), ty::Bool) => ty::ValTree::from_scalar_int(tcx, b.into()),
82-
(ast::LitKind::Float(n, _), ty::Float(fty)) => {
83-
let bits = parse_float_into_scalar(n, *fty, neg).unwrap_or_else(|| {
84-
tcx.dcx().bug(format!("couldn't parse float literal: {:?}", lit_input.lit))
85-
});
86-
ty::ValTree::from_scalar_int(tcx, bits)
105+
(ast::LitKind::Float(n, ast::LitFloatType::Unsuffixed), ty::Float(fty)) => {
106+
let bits = parse_float_into_scalar(n, *fty, neg)?;
107+
(ty::ValTree::from_scalar_int(tcx, bits), Ty::new_float(tcx, *fty))
87108
}
88-
(ast::LitKind::Char(c), ty::Char) => ty::ValTree::from_scalar_int(tcx, c.into()),
89-
(ast::LitKind::Err(guar), _) => return ty::Const::new_error(tcx, guar),
90-
_ => return ty::Const::new_misc_error(tcx),
109+
(ast::LitKind::Char(c), _) => (ty::ValTree::from_scalar_int(tcx, c.into()), tcx.types.char),
110+
(ast::LitKind::Err(_), _) => return None,
111+
_ => return None,
91112
};
92113

93-
ty::Const::new_value(tcx, valtree, ty)
114+
Some(ty::Value { ty: valtree_ty, valtree })
94115
}

compiler/rustc_mir_build/src/thir/pattern/mod.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ use std::cmp::Ordering;
88
use std::sync::Arc;
99

1010
use rustc_abi::{FieldIdx, Integer};
11+
use rustc_ast::LitKind;
1112
use rustc_data_structures::assert_matches;
1213
use rustc_errors::codes::*;
1314
use rustc_hir::def::{CtorOf, DefKind, Res};
1415
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
1516
use rustc_hir::{self as hir, RangeEnd};
1617
use rustc_index::Idx;
17-
use rustc_middle::mir::interpret::LitToConstInput;
18+
use rustc_middle::mir::interpret::{LitToConstInput, const_lit_matches_ty};
1819
use rustc_middle::thir::{
1920
Ascription, DerefPatBorrowMode, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary,
2021
};
@@ -197,8 +198,6 @@ impl<'tcx> PatCtxt<'tcx> {
197198
expr: Option<&'tcx hir::PatExpr<'tcx>>,
198199
ty: Ty<'tcx>,
199200
) -> Result<(), ErrorGuaranteed> {
200-
use rustc_ast::ast::LitKind;
201-
202201
let Some(expr) = expr else {
203202
return Ok(());
204203
};
@@ -696,7 +695,17 @@ impl<'tcx> PatCtxt<'tcx> {
696695

697696
let pat_ty = self.typeck_results.node_type(pat.hir_id);
698697
let lit_input = LitToConstInput { lit: lit.node, ty: pat_ty, neg: *negated };
699-
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
698+
let constant = const_lit_matches_ty(self.tcx, &lit.node, pat_ty, *negated)
699+
.then(|| self.tcx.at(expr.span).lit_to_const(lit_input))
700+
.flatten()
701+
.map(|v| ty::Const::new_value(self.tcx, v.valtree, pat_ty))
702+
.unwrap_or_else(|| {
703+
ty::Const::new_error_with_message(
704+
self.tcx,
705+
expr.span,
706+
"literal does not match expected type",
707+
)
708+
});
700709
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span)
701710
}
702711
}

compiler/rustc_ty_utils/src/consts.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ fn recurse_build<'tcx>(
5959
}
6060
&ExprKind::Literal { lit, neg } => {
6161
let sp = node.span;
62-
tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg })
62+
match tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg }) {
63+
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
64+
None => ty::Const::new_misc_error(tcx),
65+
}
6366
}
6467
&ExprKind::NonHirLiteral { lit, user_ty: _ } => {
6568
let val = ty::ValTree::from_scalar_int(tcx, lit);

tests/ui/const-generics/adt_const_params/byte-string-u8-validation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
struct ConstBytes<const T: &'static [*mut u8; 3]>
99
//~^ ERROR rustc_dump_predicates
1010
//~| NOTE Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
11-
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
11+
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
1212
where
1313
ConstBytes<b"AAA">: Sized;
1414
//~^ ERROR mismatched types
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
error: rustc_dump_predicates
2-
--> $DIR/byte-string-u8-validation.rs:8:1
3-
|
4-
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
5-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6-
|
7-
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
8-
= note: Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
9-
101
error[E0308]: mismatched types
112
--> $DIR/byte-string-u8-validation.rs:13:16
123
|
@@ -16,6 +7,15 @@ LL | ConstBytes<b"AAA">: Sized;
167
= note: expected reference `&'static [*mut u8; 3]`
178
found reference `&'static [u8; 3]`
189

10+
error: rustc_dump_predicates
11+
--> $DIR/byte-string-u8-validation.rs:8:1
12+
|
13+
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
14+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15+
|
16+
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
17+
= note: Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
18+
1919
error: aborting due to 2 previous errors
2020

2121
For more information about this error, try `rustc --explain E0308`.

tests/ui/const-generics/adt_const_params/mismatch-raw-ptr-in-adt.stderr

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ LL | struct ConstBytes<const T: &'static [*mut u8; 3]>;
88
= note: `[*mut u8; 3]` must implement `ConstParamTy_`, but it does not
99

1010
error[E0308]: mismatched types
11-
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
11+
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
1212
|
1313
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
14-
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
14+
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
1515
|
1616
= note: expected reference `&'static [*mut u8; 3]`
1717
found reference `&'static [u8; 3]`
1818

1919
error[E0308]: mismatched types
20-
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
20+
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
2121
|
2222
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
23-
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
23+
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
2424
|
2525
= note: expected reference `&'static [*mut u8; 3]`
2626
found reference `&'static [u8; 3]`

0 commit comments

Comments
 (0)