@@ -8,7 +8,7 @@ use ide_db::{
88use syntax:: {
99 AstNode , AstPtr , TextSize ,
1010 ast:: {
11- self , BlockExpr , Expr , ExprStmt , HasArgList ,
11+ self , BlockExpr , Expr , ExprStmt , HasArgList , RefExpr ,
1212 edit:: { AstNodeEdit , IndentLevel } ,
1313 } ,
1414} ;
@@ -69,7 +69,7 @@ fn fixes(ctx: &DiagnosticsContext<'_, '_>, d: &hir::TypeMismatch<'_>) -> Option<
6969
7070 if let Some ( expr_ptr) = d. expr_or_pat . value . cast :: < ast:: Expr > ( ) {
7171 let expr_ptr = & InFile { file_id : d. expr_or_pat . file_id , value : expr_ptr } ;
72- add_reference ( ctx, d, expr_ptr, & mut fixes) ;
72+ add_or_fix_reference ( ctx, d, expr_ptr, & mut fixes) ;
7373 add_missing_ok_or_some ( ctx, d, expr_ptr, & mut fixes) ;
7474 remove_unnecessary_wrapper ( ctx, d, expr_ptr, & mut fixes) ;
7575 remove_semicolon ( ctx, d, expr_ptr, & mut fixes) ;
@@ -79,21 +79,52 @@ fn fixes(ctx: &DiagnosticsContext<'_, '_>, d: &hir::TypeMismatch<'_>) -> Option<
7979 if fixes. is_empty ( ) { None } else { Some ( fixes) }
8080}
8181
82- fn add_reference (
82+ fn add_or_fix_reference (
8383 ctx : & DiagnosticsContext < ' _ , ' _ > ,
8484 d : & hir:: TypeMismatch < ' _ > ,
8585 expr_ptr : & InFile < AstPtr < ast:: Expr > > ,
8686 acc : & mut Vec < Assist > ,
8787) -> Option < ( ) > {
8888 let range = ctx. sema . diagnostics_display_range ( ( * expr_ptr) . map ( |it| it. into ( ) ) ) ;
8989
90- let ( _, mutability) = d. expected . as_reference ( ) ?;
91- let actual_with_ref = d. actual . add_reference ( ctx. db ( ) , mutability) ;
90+ let ( expected_with_ref_removed, expected_mutability) = d. expected . as_reference ( ) ?;
91+
92+ if let Some ( ( actual_with_ref_removed, hir:: Mutability :: Shared ) ) = d. actual . as_reference ( )
93+ && expected_mutability == hir:: Mutability :: Mut
94+ && actual_with_ref_removed. could_coerce_to ( ctx. db ( ) , & expected_with_ref_removed)
95+ {
96+ // The actual type is `&T`, and the exprected type is `&mut T`, (or `U` that `T` can be coerced to).
97+ // It's likely that, instead of adding a reference, we should just change the mutability of
98+ // the existing one.
99+
100+ let expr = expr_ptr. to_node ( ctx. db ( ) ) ;
101+ // If the node comes from a macro expansion, then we shouldn't assist,
102+ // as the suggestion would overwrite the macro _definition_ position
103+ let expr = ctx. sema . original_ast_node ( expr) ?;
104+ let expr_without_ref = RefExpr :: cast ( expr. syntax ( ) . clone ( ) ) ?. expr ( ) ?;
105+
106+ let file_id = expr_ptr. file_id . original_file ( ctx. db ( ) ) ;
107+ let mut builder = SourceChangeBuilder :: new ( file_id. file_id ( ctx. db ( ) ) ) ;
108+ let editor = builder. make_editor ( expr. syntax ( ) ) ;
109+ let make = editor. make ( ) ;
110+ let new_expr = make. expr_ref ( expr_without_ref, true ) ;
111+ builder. replace_ast ( expr, new_expr) ;
112+ let source_change = builder. finish ( ) ;
113+ acc. push ( fix (
114+ "make_reference_mutable" ,
115+ "Make reference mutable" ,
116+ source_change,
117+ range. range ,
118+ ) ) ;
119+ return Some ( ( ) ) ;
120+ }
121+
122+ let actual_with_ref = d. actual . add_reference ( ctx. db ( ) , expected_mutability) ;
92123 if !actual_with_ref. could_coerce_to ( ctx. db ( ) , & d. expected ) {
93124 return None ;
94125 }
95126
96- let ampersands = format ! ( "&{}" , mutability . as_keyword_for_ref( ) ) ;
127+ let ampersands = format ! ( "&{}" , expected_mutability . as_keyword_for_ref( ) ) ;
97128
98129 let edit = TextEdit :: insert ( range. range . start ( ) , ampersands) ;
99130 let source_change = SourceChange :: from_text_edit ( range. file_id , edit) ;
@@ -390,6 +421,96 @@ fn test(_arg: &mut i32) {}
390421 ) ;
391422 }
392423
424+ #[ test]
425+ fn fix_reference_to_int ( ) {
426+ check_fix (
427+ r#"
428+ fn main() {
429+ test($0&123);
430+ }
431+ fn test(_arg: &mut i32) {}
432+ "# ,
433+ r#"
434+ fn main() {
435+ test(&mut 123);
436+ }
437+ fn test(_arg: &mut i32) {}
438+ "# ,
439+ ) ;
440+ }
441+
442+ #[ test]
443+ fn add_reference_to_parenthesized_int ( ) {
444+ check_fix (
445+ r#"
446+ fn main() {
447+ test(($0123));
448+ }
449+ fn test(_arg: &i32) {}
450+ "# ,
451+ r#"
452+ fn main() {
453+ test((&123));
454+ }
455+ fn test(_arg: &i32) {}
456+ "# ,
457+ ) ;
458+ }
459+
460+ #[ test]
461+ fn add_mutable_reference_to_parenthesized_int ( ) {
462+ check_fix (
463+ r#"
464+ fn main() {
465+ test(($0123));
466+ }
467+ fn test(_arg: &mut i32) {}
468+ "# ,
469+ r#"
470+ fn main() {
471+ test((&mut 123));
472+ }
473+ fn test(_arg: &mut i32) {}
474+ "# ,
475+ ) ;
476+ }
477+
478+ #[ test]
479+ fn fix_reference_to_parenthesized_int_paren_inside_ref ( ) {
480+ check_fix (
481+ r#"
482+ fn main() {
483+ test(&$0(123));
484+ }
485+ fn test(_arg: &mut i32) {}
486+ "# ,
487+ r#"
488+ fn main() {
489+ test(&mut (123));
490+ }
491+ fn test(_arg: &mut i32) {}
492+ "# ,
493+ ) ;
494+ }
495+
496+ #[ test]
497+ fn fix_reference_to_parenthesized_int_ref_inside_paren ( ) {
498+ check_fix (
499+ r#"
500+ fn main() {
501+ test(($0&123));
502+ }
503+ fn test(_arg: &mut i32) {}
504+ "# ,
505+ r#"
506+ fn main() {
507+ test((&mut 123));
508+ }
509+ fn test(_arg: &mut i32) {}
510+ "# ,
511+ ) ;
512+ }
513+
393514 #[ test]
394515 fn add_reference_to_array ( ) {
395516 check_fix (
@@ -409,6 +530,19 @@ fn test(_arg: &[i32]) {}
409530 ) ;
410531 }
411532
533+ #[ test]
534+ fn fix_reference_to_array ( ) {
535+ check_no_fix (
536+ r#"
537+ //- minicore: coerce_unsized
538+ fn main() {
539+ test($0&[1, 2, 3]);
540+ }
541+ fn test(_arg: &mut [i32]) {}
542+ "# ,
543+ ) ;
544+ }
545+
412546 #[ test]
413547 fn add_reference_with_autoderef ( ) {
414548 check_fix (
@@ -442,6 +576,49 @@ fn test(_arg: &Bar) {}
442576 ) ;
443577 }
444578
579+ #[ test]
580+ // FIXME: this should suggest making the reference mutable instead: `&Foo -> &mut Foo`.
581+ // Currently it doesn't, as the logic for that assist strips away references, and thus checks
582+ // whether `Foo` can be coerced to `Bar` (which it can't), instead of checking `&mut Foo` to
583+ // `&mut Bar` (which it can)
584+ fn fix_reference_with_autoderef ( ) {
585+ check_fix (
586+ r#"
587+ //- minicore: coerce_unsized, deref_mut
588+ struct Foo;
589+ struct Bar;
590+ impl core::ops::Deref for Foo {
591+ type Target = Bar;
592+ fn deref(&self) -> &Self::Target { loop {} }
593+ }
594+ impl core::ops::DerefMut for Foo {
595+ fn deref_mut(&mut self) -> &mut Self::Target { loop {} }
596+ }
597+
598+ fn main() {
599+ test($0&Foo);
600+ }
601+ fn test(_arg: &mut Bar) {}
602+ "# ,
603+ r#"
604+ struct Foo;
605+ struct Bar;
606+ impl core::ops::Deref for Foo {
607+ type Target = Bar;
608+ fn deref(&self) -> &Self::Target { loop {} }
609+ }
610+ impl core::ops::DerefMut for Foo {
611+ fn deref_mut(&mut self) -> &mut Self::Target { loop {} }
612+ }
613+
614+ fn main() {
615+ test(&mut &Foo);
616+ }
617+ fn test(_arg: &mut Bar) {}
618+ "# ,
619+ ) ;
620+ }
621+
445622 #[ test]
446623 fn add_reference_to_method_call ( ) {
447624 check_fix (
@@ -498,6 +675,22 @@ fn main() {
498675 ) ;
499676 }
500677
678+ #[ test]
679+ fn fix_reference_to_let_stmt ( ) {
680+ check_fix (
681+ r#"
682+ fn main() {
683+ let _test: &mut i32 = $0&123;
684+ }
685+ "# ,
686+ r#"
687+ fn main() {
688+ let _test: &mut i32 = &mut 123;
689+ }
690+ "# ,
691+ ) ;
692+ }
693+
501694 #[ test]
502695 fn add_reference_to_macro_call ( ) {
503696 check_fix (
@@ -526,6 +719,56 @@ fn main() {
526719 ) ;
527720 }
528721
722+ #[ test]
723+ fn fix_reference_to_macro_call ( ) {
724+ check_fix (
725+ r#"
726+ macro_rules! thousand {
727+ () => {
728+ 1000_u64
729+ };
730+ }
731+
732+ fn test(_foo: &mut u64) {}
733+ fn main() {
734+ test($0&thousand!());
735+ }
736+ "# ,
737+ r#"
738+ macro_rules! thousand {
739+ () => {
740+ 1000_u64
741+ };
742+ }
743+
744+ fn test(_foo: &mut u64) {}
745+ fn main() {
746+ test(&mut thousand!());
747+ }
748+ "# ,
749+ ) ;
750+ }
751+
752+ #[ test]
753+ // If the immutable reference comes from a macro expansion,
754+ // we can't do anything to change it to a mutable one.
755+ fn dont_fix_reference_inside_macro_call ( ) {
756+ check_no_fix (
757+ r#"
758+ macro_rules! thousand {
759+ () => {
760+ &1000_u64
761+ };
762+ }
763+
764+ fn test(_foo: &mut u64) {}
765+ fn main() {
766+ test($0thousand!());
767+ }
768+ "# ,
769+ ) ;
770+ }
771+
529772 #[ test]
530773 fn const_generic_type_mismatch ( ) {
531774 check_diagnostics (
0 commit comments