@@ -8,11 +8,12 @@ use std::cmp::Ordering;
88use std:: sync:: Arc ;
99
1010use rustc_abi:: { FieldIdx , Integer } ;
11+ use rustc_ast:: LitKind ;
1112use rustc_data_structures:: assert_matches;
1213use rustc_errors:: codes:: * ;
1314use rustc_hir:: def:: { CtorOf , DefKind , Res } ;
1415use rustc_hir:: pat_util:: EnumerateAndAdjustIterator ;
15- use rustc_hir:: { self as hir, RangeEnd } ;
16+ use rustc_hir:: { self as hir, LangItem , RangeEnd } ;
1617use rustc_index:: Idx ;
1718use rustc_middle:: mir:: interpret:: LitToConstInput ;
1819use rustc_middle:: thir:: {
@@ -197,8 +198,6 @@ impl<'tcx> PatCtxt<'tcx> {
197198 expr : Option < & ' tcx hir:: PatExpr < ' tcx > > ,
198199 ty : Ty < ' tcx > ,
199200 ) -> Result < ( ) , ErrorGuaranteed > {
200- use rustc_ast:: ast:: LitKind ;
201-
202201 let Some ( expr) = expr else {
203202 return Ok ( ( ) ) ;
204203 } ;
@@ -696,9 +695,58 @@ impl<'tcx> PatCtxt<'tcx> {
696695
697696 let pat_ty = self . typeck_results . node_type ( pat. hir_id ) ;
698697 let lit_input = LitToConstInput { lit : lit. node , ty : pat_ty, neg : * negated } ;
699- let constant = self . tcx . at ( expr. span ) . lit_to_const ( lit_input) ;
698+ let error_const = || {
699+ if let Some ( guar) = self . typeck_results . tainted_by_errors {
700+ ty:: Const :: new_error ( self . tcx , guar)
701+ } else {
702+ ty:: Const :: new_error_with_message (
703+ self . tcx ,
704+ expr. span ,
705+ "literal does not match expected type" ,
706+ )
707+ }
708+ } ;
709+ let constant = if self . const_lit_matches_ty ( & lit. node , pat_ty, * negated) {
710+ match self . tcx . at ( expr. span ) . lit_to_const ( lit_input) {
711+ Some ( value) => ty:: Const :: new_value ( self . tcx , value. valtree , value. ty ) ,
712+ None => error_const ( ) ,
713+ }
714+ } else {
715+ error_const ( )
716+ } ;
700717 self . const_to_pat ( constant, pat_ty, expr. hir_id , lit. span )
701718 }
702719 }
703720 }
721+
722+ fn const_lit_matches_ty ( & self , kind : & LitKind , ty : Ty < ' tcx > , neg : bool ) -> bool {
723+ let tcx = self . tcx ;
724+ match ( * kind, ty. kind ( ) ) {
725+ ( LitKind :: Str ( ..) , ty:: Ref ( _, inner_ty, _) ) if inner_ty. is_str ( ) => true ,
726+ ( LitKind :: Str ( ..) , ty:: Str ) if tcx. features ( ) . deref_patterns ( ) => true ,
727+ ( LitKind :: ByteStr ( ..) , ty:: Ref ( _, inner_ty, _) )
728+ if let ty:: Slice ( ty) | ty:: Array ( ty, _) = inner_ty. kind ( )
729+ && matches ! ( ty. kind( ) , ty:: Uint ( ty:: UintTy :: U8 ) ) =>
730+ {
731+ true
732+ }
733+ ( LitKind :: ByteStr ( ..) , ty:: Slice ( inner_ty) | ty:: Array ( inner_ty, _) )
734+ if tcx. features ( ) . deref_patterns ( )
735+ && matches ! ( inner_ty. kind( ) , ty:: Uint ( ty:: UintTy :: U8 ) ) =>
736+ {
737+ true
738+ }
739+ ( LitKind :: Byte ( ..) , ty:: Uint ( ty:: UintTy :: U8 ) ) => true ,
740+ ( LitKind :: CStr ( ..) , ty:: Ref ( _, inner_ty, _) ) if matches ! ( inner_ty. kind( ) , ty:: Adt ( def, _) if tcx. is_lang_item( def. did( ) , LangItem :: CStr ) ) => {
741+ true
742+ }
743+ ( LitKind :: Int ( ..) , ty:: Uint ( _) ) if !neg => true ,
744+ ( LitKind :: Int ( ..) , ty:: Int ( _) ) => true ,
745+ ( LitKind :: Bool ( ..) , ty:: Bool ) => true ,
746+ ( LitKind :: Float ( ..) , ty:: Float ( _) ) => true ,
747+ ( LitKind :: Char ( ..) , ty:: Char ) => true ,
748+ ( LitKind :: Err ( ..) , _) => true ,
749+ _ => false ,
750+ }
751+ }
704752}
0 commit comments