@@ -4,6 +4,7 @@ use either::Either;
44use hir:: { ExprKind , Param } ;
55use rustc_abi:: FieldIdx ;
66use rustc_errors:: { Applicability , Diag } ;
7+ use rustc_hir:: def_id:: DefId ;
78use rustc_hir:: intravisit:: Visitor ;
89use rustc_hir:: { self as hir, BindingMode , ByRef , Node } ;
910use rustc_middle:: bug;
@@ -1092,42 +1093,65 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
10921093 let mut look_at_return = true ;
10931094
10941095 err. span_label ( closure_span, "in this closure" ) ;
1095- // If the HIR node is a function or method call, get the DefId
1096- // of the callee function or method, the span, and args of the call expr
1097- let get_call_details = || {
1098- let hir:: Node :: Expr ( hir:: Expr { hir_id, kind, .. } ) = node else {
1099- return None ;
1096+ let closure_arg_has_fn_trait_bound =
1097+ |callee_def_id, input_index, generic_args : ty:: GenericArgsRef < ' tcx > | {
1098+ let sig = tcx. fn_sig ( callee_def_id) . instantiate ( tcx, generic_args) . skip_binder ( ) ;
1099+ let Some ( input_ty) : Option < Ty < ' tcx > > = sig. inputs ( ) . get ( input_index) . copied ( )
1100+ else {
1101+ return false ;
1102+ } ;
1103+
1104+ tcx. predicates_of ( callee_def_id)
1105+ . instantiate ( tcx, generic_args)
1106+ . predicates
1107+ . iter ( )
1108+ . any ( |predicate| {
1109+ predicate. as_trait_clause ( ) . is_some_and ( |trait_pred| {
1110+ trait_pred. polarity ( ) == ty:: PredicatePolarity :: Positive
1111+ && tcx. fn_trait_kind_from_def_id ( trait_pred. def_id ( ) )
1112+ == Some ( ty:: ClosureKind :: Fn )
1113+ && trait_pred. self_ty ( ) . skip_binder ( ) . peel_refs ( )
1114+ == input_ty. peel_refs ( )
1115+ } )
1116+ } )
11001117 } ;
11011118
1102- let typeck_results = tcx. typeck ( def_id) ;
1119+ // If the HIR node is a function or method call, get the DefId
1120+ // of the callee function or method, the span, and argument info for the call expr.
1121+ let get_call_details =
1122+ || -> Option < ( DefId , Span , usize , usize , ty:: GenericArgsRef < ' tcx > ) > {
1123+ let hir:: Node :: Expr ( hir:: Expr { hir_id, kind, .. } ) = node else {
1124+ return None ;
1125+ } ;
11031126
1104- match kind {
1105- hir:: ExprKind :: Call ( expr, args) => {
1106- if let Some ( ty:: FnDef ( def_id, _) ) =
1107- typeck_results. node_type_opt ( expr. hir_id ) . as_ref ( ) . map ( |ty| ty. kind ( ) )
1108- {
1109- Some ( ( * def_id, expr. span , * args) )
1110- } else {
1111- None
1127+ let typeck_results = tcx. typeck ( def_id) ;
1128+
1129+ match kind {
1130+ hir:: ExprKind :: Call ( expr, args) => {
1131+ if let Some ( ty:: FnDef ( def_id, generic_args) ) =
1132+ typeck_results. node_type_opt ( expr. hir_id ) . as_ref ( ) . map ( |ty| ty. kind ( ) )
1133+ {
1134+ let arg_pos = args. iter ( ) . position ( |arg| arg. hir_id == closure_id) ?;
1135+ Some ( ( * def_id, expr. span , arg_pos, arg_pos, generic_args) )
1136+ } else {
1137+ None
1138+ }
11121139 }
1140+ hir:: ExprKind :: MethodCall ( _, _, args, span) => {
1141+ let arg_pos = args. iter ( ) . position ( |arg| arg. hir_id == closure_id) ?;
1142+ let def_id = typeck_results. type_dependent_def_id ( * hir_id) ?;
1143+ let generic_args = typeck_results. node_args_opt ( * hir_id) ?;
1144+ Some ( ( def_id, * span, arg_pos, arg_pos + 1 , generic_args) )
1145+ }
1146+ _ => None ,
11131147 }
1114- hir:: ExprKind :: MethodCall ( _, _, args, span) => typeck_results
1115- . type_dependent_def_id ( * hir_id)
1116- . map ( |def_id| ( def_id, * span, * args) ) ,
1117- _ => None ,
1118- }
1119- } ;
1148+ } ;
11201149
11211150 // If we can detect the expression to be a function or method call where the closure was
11221151 // an argument, we point at the function or method definition argument...
1123- if let Some ( ( callee_def_id, call_span, call_args) ) = get_call_details ( ) {
1124- let arg_pos = call_args
1125- . iter ( )
1126- . enumerate ( )
1127- . filter ( |( _, arg) | arg. hir_id == closure_id)
1128- . map ( |( pos, _) | pos)
1129- . next ( ) ;
1130-
1152+ if let Some ( ( callee_def_id, call_span, arg_pos, input_index, generic_args) ) =
1153+ get_call_details ( )
1154+ {
11311155 let arg = match tcx. hir_get_if_local ( callee_def_id) {
11321156 Some (
11331157 hir:: Node :: Item ( hir:: Item {
@@ -1144,16 +1168,12 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
11441168 ..
11451169 } ) ,
11461170 ) => Some (
1147- arg_pos
1148- . and_then ( |pos| {
1149- sig. decl . inputs . get (
1150- pos + if sig. decl . implicit_self ( ) . has_implicit_self ( ) {
1151- 1
1152- } else {
1153- 0
1154- } ,
1155- )
1156- } )
1171+ sig. decl
1172+ . inputs
1173+ . get (
1174+ arg_pos
1175+ + if sig. decl . implicit_self ( ) . has_implicit_self ( ) { 1 } else { 0 } ,
1176+ )
11571177 . map ( |arg| arg. span )
11581178 . unwrap_or ( ident. span ) ,
11591179 ) ,
@@ -1163,6 +1183,13 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
11631183 err. span_label ( span, "change this to accept `FnMut` instead of `Fn`" ) ;
11641184 err. span_label ( call_span, "expects `Fn` instead of `FnMut`" ) ;
11651185 look_at_return = false ;
1186+ } else if closure_arg_has_fn_trait_bound ( callee_def_id, input_index, generic_args) {
1187+ // The callee is not local, so we cannot point at its argument declaration, but we
1188+ // can still explain that this call site expects an `Fn` closure. Avoid falling
1189+ // through to the enclosing function's return type, which is misleading in cases
1190+ // like `flat_map(|_| external::map(|_| ...))`.
1191+ err. span_label ( call_span, "expects `Fn` instead of `FnMut`" ) ;
1192+ look_at_return = false ;
11661193 }
11671194 }
11681195
0 commit comments