Skip to content

Commit 76c76cb

Browse files
mGCA: Validate const literal against expected type
Co-authored-by: Boxy <rust@boxyuwu.dev>
1 parent 56aaf58 commit 76c76cb

23 files changed

Lines changed: 365 additions & 135 deletions

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,15 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
745745
// Avoid ICE #86756 when type error recovery goes awry.
746746
return Ty::new_error(tcx, prev).into();
747747
}
748+
if preceding_args.is_empty()
749+
&& tcx.type_of(param.def_id).skip_binder().has_param()
750+
{
751+
let guar = self.lowerer.dcx().span_delayed_bug(
752+
self.span,
753+
"default type has params but no args",
754+
);
755+
return Ty::new_error(tcx, guar).into();
756+
}
748757
tcx.at(self.span)
749758
.type_of(param.def_id)
750759
.instantiate(tcx, preceding_args)
@@ -2798,8 +2807,27 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27982807
span: Span,
27992808
) -> Const<'tcx> {
28002809
let tcx = self.tcx();
2810+
if let LitKind::Err(guar) = *kind {
2811+
return ty::Const::new_error(tcx, guar);
2812+
}
28012813
let input = LitToConstInput { lit: *kind, ty, neg };
2802-
tcx.at(span).lit_to_const(input)
2814+
match tcx.at(span).lit_to_const(input) {
2815+
Some(value) => {
2816+
if value.ty == ty {
2817+
ty::Const::new_value(tcx, value.valtree, value.ty)
2818+
} else {
2819+
let e = tcx.dcx().span_err(
2820+
span,
2821+
format!("mismatched types: expected `{}`, found `{}`", ty, value.ty),
2822+
);
2823+
ty::Const::new_error(tcx, e)
2824+
}
2825+
}
2826+
None => {
2827+
let e = tcx.dcx().span_err(span, "type annotations needed for the literal");
2828+
ty::Const::new_error(tcx, e)
2829+
}
2830+
}
28032831
}
28042832

28052833
#[instrument(skip(self), level = "debug")]
@@ -2828,11 +2856,15 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
28282856
_ => None,
28292857
};
28302858

2831-
lit_input
2832-
// Allow the `ty` to be an alias type, though we cannot handle it here, we just go through
2833-
// the more expensive anon const code path.
2834-
.filter(|l| !l.ty.has_aliases())
2835-
.map(|l| tcx.at(expr.span).lit_to_const(l))
2859+
lit_input.and_then(|l| {
2860+
tcx.at(expr.span).lit_to_const(l).and_then(|value| {
2861+
if value.ty == ty {
2862+
Some(ty::Const::new_value(tcx, value.valtree, value.ty))
2863+
} else {
2864+
None
2865+
}
2866+
})
2867+
})
28362868
}
28372869

28382870
fn require_type_const_attribute(

compiler/rustc_middle/src/queries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ rustc_queries! {
14121412
// FIXME get rid of this with valtrees
14131413
query lit_to_const(
14141414
key: LitToConstInput<'tcx>
1415-
) -> ty::Const<'tcx> {
1415+
) -> Option<ty::Value<'tcx>> {
14161416
desc { "converting literal to const" }
14171417
}
14181418

compiler/rustc_middle/src/query/erase.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ impl Erasable for Option<ty::EarlyBinder<'_, Ty<'_>>> {
256256
type Storage = [u8; size_of::<Option<ty::EarlyBinder<'static, Ty<'static>>>>()];
257257
}
258258

259+
impl Erasable for Option<ty::Value<'_>> {
260+
type Storage = [u8; size_of::<Option<ty::Value<'static>>>()];
261+
}
262+
259263
impl Erasable for rustc_hir::MaybeOwner<'_> {
260264
type Storage = [u8; size_of::<rustc_hir::MaybeOwner<'static>>()];
261265
}
Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
use rustc_abi::Size;
22
use rustc_ast::{self as ast, UintTy};
33
use rustc_hir::LangItem;
4-
use rustc_middle::bug;
54
use rustc_middle::mir::interpret::LitToConstInput;
6-
use rustc_middle::ty::{self, ScalarInt, TyCtxt, TypeVisitableExt as _};
5+
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt, TypeVisitableExt as _};
76
use tracing::trace;
87

98
use crate::builder::parse_float_into_scalar;
109

1110
pub(crate) fn lit_to_const<'tcx>(
1211
tcx: TyCtxt<'tcx>,
1312
lit_input: LitToConstInput<'tcx>,
14-
) -> ty::Const<'tcx> {
15-
let LitToConstInput { lit, ty, neg } = lit_input;
13+
) -> Option<ty::Value<'tcx>> {
14+
let LitToConstInput { lit, ty: expected_ty, neg } = lit_input;
1615

17-
if let Err(guar) = ty.error_reported() {
18-
return ty::Const::new_error(tcx, guar);
16+
if expected_ty.error_reported().is_err() {
17+
return None;
1918
}
2019

2120
let trunc = |n, width: ty::UintTy| {
@@ -29,66 +28,76 @@ pub(crate) fn lit_to_const<'tcx>(
2928
trace!("trunc result: {}", result);
3029

3130
ScalarInt::try_from_uint(result, width)
32-
.unwrap_or_else(|| bug!("expected to create ScalarInt from uint {:?}", result))
3331
};
3432

35-
let valtree = match (lit, ty.kind()) {
36-
(ast::LitKind::Str(s, _), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => {
33+
let (valtree, valtree_ty) = match (lit, expected_ty.kind()) {
34+
(ast::LitKind::Str(s, _), _) => {
3735
let str_bytes = s.as_str().as_bytes();
38-
ty::ValTree::from_raw_bytes(tcx, str_bytes)
39-
}
40-
(ast::LitKind::Str(s, _), ty::Str) if tcx.features().deref_patterns() => {
41-
// String literal patterns may have type `str` if `deref_patterns` is enabled, in order
42-
// to allow `deref!("..."): String`.
43-
let str_bytes = s.as_str().as_bytes();
44-
ty::ValTree::from_raw_bytes(tcx, str_bytes)
36+
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, tcx.types.str_);
37+
(ty::ValTree::from_raw_bytes(tcx, str_bytes), valtree_ty)
4538
}
4639
(ast::LitKind::ByteStr(byte_sym, _), ty::Ref(_, inner_ty, _))
4740
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
4841
&& let ty::Uint(UintTy::U8) = ty.kind() =>
4942
{
50-
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
43+
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
5144
}
5245
(ast::LitKind::ByteStr(byte_sym, _), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
5346
if tcx.features().deref_patterns()
5447
&& let ty::Uint(UintTy::U8) = inner_ty.kind() =>
5548
{
5649
// Byte string literal patterns may have type `[u8]` or `[u8; N]` if `deref_patterns` is
5750
// enabled, in order to allow, e.g., `deref!(b"..."): Vec<u8>`.
58-
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
51+
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
5952
}
60-
(ast::LitKind::Byte(n), ty::Uint(ty::UintTy::U8)) => {
61-
ty::ValTree::from_scalar_int(tcx, n.into())
53+
(ast::LitKind::ByteStr(byte_sym, _), _) => {
54+
let valtree = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
55+
let valtree_ty = Ty::new_array(tcx, tcx.types.u8, byte_sym.as_byte_str().len() as u64);
56+
(valtree, valtree_ty)
6257
}
63-
(ast::LitKind::CStr(byte_sym, _), ty::Ref(_, inner_ty, _)) if matches!(inner_ty.kind(), ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::CStr)) =>
58+
(ast::LitKind::Byte(n), _) => (ty::ValTree::from_scalar_int(tcx, n.into()), tcx.types.u8),
59+
(ast::LitKind::CStr(byte_sym, _), _)
60+
if let Some(cstr_def_id) = tcx.lang_items().get(LangItem::CStr) =>
6461
{
6562
// A CStr is a newtype around a byte slice, so we create the inner slice here.
6663
// We need a branch for each "level" of the data structure.
64+
let cstr_ty = tcx.type_of(cstr_def_id).skip_binder();
6765
let bytes = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
68-
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, *inner_ty)])
66+
let valtree =
67+
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, cstr_ty)]);
68+
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, cstr_ty);
69+
(valtree, valtree_ty)
70+
}
71+
(ast::LitKind::Int(n, ast::LitIntType::Unsigned(ui)), _) if !neg => {
72+
let scalar_int = trunc(n.get(), ui)?;
73+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, ui))
74+
}
75+
(ast::LitKind::Int(_, ast::LitIntType::Unsigned(_)), _) if neg => return None,
76+
(ast::LitKind::Int(n, ast::LitIntType::Signed(i)), _) => {
77+
let scalar_int =
78+
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned())?;
79+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, i))
6980
}
70-
(ast::LitKind::Int(n, _), ty::Uint(ui)) if !neg => {
71-
let scalar_int = trunc(n.get(), *ui);
72-
ty::ValTree::from_scalar_int(tcx, scalar_int)
81+
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Uint(ui)) if !neg => {
82+
let scalar_int = trunc(n.get(), *ui)?;
83+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, *ui))
7384
}
74-
(ast::LitKind::Int(n, _), ty::Int(i)) => {
85+
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Int(i)) => {
7586
// Unsigned "negation" has the same bitwise effect as signed negation,
7687
// which gets the result we want without additional casts.
7788
let scalar_int =
78-
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
79-
ty::ValTree::from_scalar_int(tcx, scalar_int)
89+
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned())?;
90+
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, *i))
8091
}
81-
(ast::LitKind::Bool(b), ty::Bool) => ty::ValTree::from_scalar_int(tcx, b.into()),
92+
(ast::LitKind::Bool(b), _) => (ty::ValTree::from_scalar_int(tcx, b.into()), tcx.types.bool),
8293
(ast::LitKind::Float(n, _), ty::Float(fty)) => {
83-
let bits = parse_float_into_scalar(n, *fty, neg).unwrap_or_else(|| {
84-
tcx.dcx().bug(format!("couldn't parse float literal: {:?}", lit_input.lit))
85-
});
86-
ty::ValTree::from_scalar_int(tcx, bits)
94+
let bits = parse_float_into_scalar(n, *fty, neg)?;
95+
(ty::ValTree::from_scalar_int(tcx, bits), Ty::new_float(tcx, *fty))
8796
}
88-
(ast::LitKind::Char(c), ty::Char) => ty::ValTree::from_scalar_int(tcx, c.into()),
89-
(ast::LitKind::Err(guar), _) => return ty::Const::new_error(tcx, guar),
90-
_ => return ty::Const::new_misc_error(tcx),
97+
(ast::LitKind::Char(c), _) => (ty::ValTree::from_scalar_int(tcx, c.into()), tcx.types.char),
98+
(ast::LitKind::Err(_), _) => return None,
99+
_ => return None,
91100
};
92101

93-
ty::Const::new_value(tcx, valtree, ty)
102+
Some(ty::Value { ty: valtree_ty, valtree })
94103
}

compiler/rustc_mir_build/src/thir/pattern/mod.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ use std::cmp::Ordering;
88
use std::sync::Arc;
99

1010
use rustc_abi::{FieldIdx, Integer};
11+
use rustc_ast::LitKind;
1112
use rustc_data_structures::assert_matches;
1213
use rustc_errors::codes::*;
1314
use rustc_hir::def::{CtorOf, DefKind, Res};
1415
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
15-
use rustc_hir::{self as hir, RangeEnd};
16+
use rustc_hir::{self as hir, LangItem, RangeEnd};
1617
use rustc_index::Idx;
1718
use rustc_middle::mir::interpret::LitToConstInput;
1819
use rustc_middle::thir::{
@@ -197,8 +198,6 @@ impl<'tcx> PatCtxt<'tcx> {
197198
expr: Option<&'tcx hir::PatExpr<'tcx>>,
198199
ty: Ty<'tcx>,
199200
) -> Result<(), ErrorGuaranteed> {
200-
use rustc_ast::ast::LitKind;
201-
202201
let Some(expr) = expr else {
203202
return Ok(());
204203
};
@@ -696,9 +695,61 @@ impl<'tcx> PatCtxt<'tcx> {
696695

697696
let pat_ty = self.typeck_results.node_type(pat.hir_id);
698697
let lit_input = LitToConstInput { lit: lit.node, ty: pat_ty, neg: *negated };
699-
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
698+
let error_const = || {
699+
if let Some(guar) = self.typeck_results.tainted_by_errors {
700+
ty::Const::new_error(self.tcx, guar)
701+
} else {
702+
ty::Const::new_error_with_message(
703+
self.tcx,
704+
expr.span,
705+
"literal does not match expected type",
706+
)
707+
}
708+
};
709+
let constant = if self.const_lit_matches_ty(&lit.node, pat_ty, *negated) {
710+
match self.tcx.at(expr.span).lit_to_const(lit_input) {
711+
Some(value) => {
712+
let const_ty = if value.ty == pat_ty { value.ty } else { pat_ty };
713+
ty::Const::new_value(self.tcx, value.valtree, const_ty)
714+
}
715+
None => error_const(),
716+
}
717+
} else {
718+
error_const()
719+
};
700720
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span)
701721
}
702722
}
703723
}
724+
725+
fn const_lit_matches_ty(&self, kind: &LitKind, ty: Ty<'tcx>, neg: bool) -> bool {
726+
let tcx = self.tcx;
727+
match (*kind, ty.kind()) {
728+
(LitKind::Str(..), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => true,
729+
(LitKind::Str(..), ty::Str) if tcx.features().deref_patterns() => true,
730+
(LitKind::ByteStr(..), ty::Ref(_, inner_ty, _))
731+
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
732+
&& matches!(ty.kind(), ty::Uint(ty::UintTy::U8)) =>
733+
{
734+
true
735+
}
736+
(LitKind::ByteStr(..), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
737+
if tcx.features().deref_patterns()
738+
&& matches!(inner_ty.kind(), ty::Uint(ty::UintTy::U8)) =>
739+
{
740+
true
741+
}
742+
(LitKind::Byte(..), ty::Uint(ty::UintTy::U8)) => true,
743+
(LitKind::CStr(..), ty::Ref(_, inner_ty, _)) if matches!(inner_ty.kind(), ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::CStr)) => {
744+
true
745+
}
746+
(LitKind::Int(..), ty::Uint(_)) if !neg => true,
747+
(LitKind::Int(..), ty::Int(_)) => true,
748+
(LitKind::Bool(..), ty::Bool) => true,
749+
(LitKind::Float(..), ty::Float(_)) => true,
750+
(LitKind::Char(..), ty::Char) => true,
751+
(LitKind::Err(..), _) => true,
752+
_ => false,
753+
}
754+
}
704755
}

compiler/rustc_trait_selection/src/traits/wf.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,8 +1006,20 @@ impl<'a, 'tcx> TypeVisitor<TyCtxt<'tcx>> for WfPredicates<'a, 'tcx> {
10061006
self.add_wf_preds_for_inherent_projection(uv.into());
10071007
return; // Subtree is handled by above function
10081008
} else {
1009-
let obligations = self.nominal_obligations(uv.def, uv.args);
1010-
self.out.extend(obligations);
1009+
if uv.args.is_empty()
1010+
&& tcx
1011+
.predicates_of(uv.def)
1012+
.predicates
1013+
.iter()
1014+
.any(|(pred, _)| pred.has_param())
1015+
{
1016+
tcx.dcx().delayed_bug(
1017+
"unevaluated const has predicates with params but no args",
1018+
);
1019+
} else {
1020+
let obligations = self.nominal_obligations(uv.def, uv.args);
1021+
self.out.extend(obligations);
1022+
}
10111023
}
10121024
}
10131025
}

compiler/rustc_ty_utils/src/consts.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ fn recurse_build<'tcx>(
5959
}
6060
&ExprKind::Literal { lit, neg } => {
6161
let sp = node.span;
62-
tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg })
62+
match tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg }) {
63+
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
64+
None => ty::Const::new_misc_error(tcx),
65+
}
6366
}
6467
&ExprKind::NonHirLiteral { lit, user_ty: _ } => {
6568
let val = ty::ValTree::from_scalar_int(tcx, lit);

tests/ui/const-generics/adt_const_params/byte-string-u8-validation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
struct ConstBytes<const T: &'static [*mut u8; 3]>
99
//~^ ERROR rustc_dump_predicates
1010
//~| NOTE Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
11-
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
11+
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
1212
where
1313
ConstBytes<b"AAA">: Sized;
1414
//~^ ERROR mismatched types
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
error: rustc_dump_predicates
2-
--> $DIR/byte-string-u8-validation.rs:8:1
3-
|
4-
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
5-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6-
|
7-
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
8-
= note: Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
9-
101
error[E0308]: mismatched types
112
--> $DIR/byte-string-u8-validation.rs:13:16
123
|
@@ -16,6 +7,15 @@ LL | ConstBytes<b"AAA">: Sized;
167
= note: expected reference `&'static [*mut u8; 3]`
178
found reference `&'static [u8; 3]`
189

10+
error: rustc_dump_predicates
11+
--> $DIR/byte-string-u8-validation.rs:8:1
12+
|
13+
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
14+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15+
|
16+
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
17+
= note: Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
18+
1919
error: aborting due to 2 previous errors
2020

2121
For more information about this error, try `rustc --explain E0308`.

0 commit comments

Comments
 (0)