diff --git a/crates/ide-assists/src/handlers/add_lifetime_to_type.rs b/crates/ide-assists/src/handlers/add_lifetime_to_type.rs index dc847dcdbe7d..6abda9bdfeaa 100644 --- a/crates/ide-assists/src/handlers/add_lifetime_to_type.rs +++ b/crates/ide-assists/src/handlers/add_lifetime_to_type.rs @@ -1,6 +1,8 @@ +use either::Either; use syntax::{ - SyntaxKind, SyntaxNode, SyntaxToken, + SyntaxKind, SyntaxNode, ast::{self, AstNode, HasGenericParams, HasName}, + match_ast, }; use crate::{AssistContext, AssistId, Assists}; @@ -23,12 +25,12 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> { - let ref_type_focused = ctx.find_node_at_offset::()?; - if ref_type_focused.lifetime().is_some_and(|lifetime| lifetime.text() != "'_") { - return None; - } - let node = ctx.find_node_at_offset::()?; + // XXX: Maybe delete this and allow it to be triggered conveniently on ADT + let _trigger = syntax::algo::ancestors_at_offset(node.syntax(), ctx.offset()) + .take_while(|it| node.syntax() != it) + .filter_map(Either::>::cast) + .find_map(|it| Missing::from_node(it.syntax().clone(), ctx))?; let has_lifetime = node .generic_param_list() .is_some_and(|gen_list| gen_list.lifetime_params().next().is_some()); @@ -37,7 +39,7 @@ pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_, '_ return None; } - let changes = fetch_borrowed_types(&node)?; + let changes = fetch_borrowed_types(&node, ctx)?; let target = node.syntax().text_range(); acc.add(AssistId::quick_fix("add_lifetime_to_type"), "Add lifetime", target, |builder| { @@ -56,18 +58,23 @@ pub(crate) fn add_lifetime_to_type(acc: &mut Assists, ctx: &AssistContext<'_, '_ for change in changes { match change { - Change::Replace(it) => { - builder.replace(it.text_range(), "'a"); + Missing::Lifetime(lt) => { + builder.replace(lt.syntax().text_range(), "'a"); } - Change::Insert(it) => { - builder.insert(it.text_range().end(), "'a "); + Missing::RefType(ref_type) => { + if let Some(amp_token) = ref_type.amp_token() { + builder.insert(amp_token.text_range().end(), "'a "); + } + } + Missing::PathType(path_type) => { + builder.insert(path_type.syntax().text_range().end(), "<'a>"); } } } }) } -fn fetch_borrowed_types(node: &ast::Adt) -> Option> { +fn fetch_borrowed_types(node: &ast::Adt, ctx: &AssistContext<'_, '_>) -> Option> { let ref_types: Vec<_> = match node { ast::Adt::Enum(enum_) => { let variant_list = enum_.variant_list()?; @@ -76,58 +83,71 @@ fn fetch_borrowed_types(node: &ast::Adt) -> Option> { .filter_map(|variant| { let field_list = variant.field_list()?; - find_ref_types_from_field_list(&field_list) + find_ref_types_from_field_list(&field_list, ctx) }) .flatten() .collect() } ast::Adt::Struct(strukt) => { let field_list = strukt.field_list()?; - find_ref_types_from_field_list(&field_list)? + find_ref_types_from_field_list(&field_list, ctx)? } ast::Adt::Union(un) => { let record_field_list = un.record_field_list()?; - find_ref_types_from_field_list(&record_field_list.into())? + find_ref_types_from_field_list(&record_field_list.into(), ctx)? } }; if ref_types.is_empty() { None } else { Some(ref_types) } } -fn find_ref_types_from_field_list(field_list: &ast::FieldList) -> Option> { +fn find_ref_types_from_field_list( + field_list: &ast::FieldList, + ctx: &AssistContext<'_, '_>, +) -> Option> { let ref_types: Vec<_> = match field_list { ast::FieldList::RecordFieldList(record_list) => { - record_list.fields().flat_map(|f| infer_lifetimes(f.syntax())).collect() + record_list.fields().flat_map(|f| infer_lifetimes(f.syntax(), ctx)).collect() } ast::FieldList::TupleFieldList(tuple_field_list) => { - tuple_field_list.fields().flat_map(|f| infer_lifetimes(f.syntax())).collect() + tuple_field_list.fields().flat_map(|f| infer_lifetimes(f.syntax(), ctx)).collect() } }; if ref_types.is_empty() { None } else { Some(ref_types) } } -enum Change { - Replace(SyntaxToken), - Insert(SyntaxToken), +enum Missing { + RefType(ast::RefType), + PathType(ast::PathType), + Lifetime(ast::Lifetime), +} + +impl Missing { + fn from_node(node: SyntaxNode, ctx: &AssistContext<'_, '_>) -> Option { + match_ast! { + match node { + ast::Lifetime(it) => (it.syntax().text() == "'_").then_some(Missing::Lifetime(it)), + ast::RefType(it) => (it.lifetime().is_none() && it.amp_token().is_some()).then_some(Missing::RefType(it)), + ast::PathType(it) => { + let needs_lifetime = match ctx.sema.resolve_path(&it.path()?)? { + hir::PathResolution::Def(hir::ModuleDef::Adt(adt)) => adt.lifetime(ctx.db()).is_some(), + // FIXME: check TypeAlias and Trait lifetime params + _ => false, + }; + (needs_lifetime && ast::Type::from(it.clone()).generic_arg_list().is_none()) + .then_some(Missing::PathType(it)) + }, + _ => None, + } + } + } } -fn infer_lifetimes(node: &SyntaxNode) -> Vec { +fn infer_lifetimes(node: &SyntaxNode, ctx: &AssistContext<'_, '_>) -> Vec { node.children() .filter(|it| !matches!(it.kind(), SyntaxKind::FN_PTR_TYPE | SyntaxKind::TYPE_BOUND_LIST)) - .flat_map(|it| { - infer_lifetimes(&it) - .into_iter() - .chain(ast::Lifetime::cast(it.clone()).and_then(|lt| { - lt.lifetime_ident_token().filter(|lt| lt.text() == "'_").map(Change::Replace) - })) - .chain( - ast::RefType::cast(it) - .filter(|ty| ty.lifetime().is_none()) - .and_then(|ty| ty.amp_token()) - .map(Change::Insert), - ) - }) + .flat_map(|it| infer_lifetimes(&it, ctx).into_iter().chain(Missing::from_node(it, ctx))) .collect() } @@ -180,9 +200,34 @@ mod tests { fn add_lifetime_to_explicit_infer_lifetime() { check_assist( add_lifetime_to_type, - r#"struct Foo { a: &'_ $0i32, b: &'_ (&'_ i32, fn(&str) -> &str) }"#, + r#"struct Foo { a: &'_$0 i32, b: &'_ (&'_ i32, fn(&str) -> &str) }"#, r#"struct Foo<'a> { a: &'a i32, b: &'a (&'a i32, fn(&str) -> &str) }"#, ); + + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &'_$0 i32, b: Foo<'_> }"#, + r#"struct Foo<'a> { a: &'a i32, b: Foo<'a> }"#, + ); + + check_assist( + add_lifetime_to_type, + r#"struct Foo { a: &'_ i32, b: Foo<'_$0> }"#, + r#"struct Foo<'a> { a: &'a i32, b: Foo<'a> }"#, + ); + } + + #[test] + fn add_lifetime_to_implicit_infer_lifetime() { + check_assist( + add_lifetime_to_type, + r#" +struct Ref<'a>(&'a ()); +struct Foo { a: &'_ i32, b: Ref$0 }"#, + r#" +struct Ref<'a>(&'a ()); +struct Foo<'a> { a: &'a i32, b: Ref<'a> }"#, + ); } #[test]