Skip to content

Commit 5eb6f50

Browse files
Merge pull request #20864 from A4-Tacks/extract-method-in-trait
Fix extract function invalid self param
2 parents b9185af + 3e82fd0 commit 5eb6f50

1 file changed

Lines changed: 96 additions & 15 deletions

File tree

crates/ide-assists/src/handlers/extract_function.rs

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
9292

9393
let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
9494
let insert_after = node_to_insert_after(&body, anchor)?;
95+
let trait_name = ast::Trait::cast(insert_after.clone()).and_then(|trait_| trait_.name());
9596
let semantics_scope = ctx.sema.scope(&insert_after)?;
9697
let module = semantics_scope.module();
9798
let edition = semantics_scope.krate().edition(ctx.db());
9899

99-
let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema, edition)?;
100+
let (container_info, contains_tail_expr) =
101+
body.analyze_container(&ctx.sema, edition, trait_name)?;
100102

101103
let ret_ty = body.return_ty(ctx)?;
102104
let control_flow = body.external_control_flow(ctx, &container_info)?;
@@ -181,6 +183,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
181183
builder.add_tabstop_before(cap, name);
182184
}
183185

186+
// FIXME: wrap non-adt types
184187
let fn_def = match fun.self_param_adt(ctx) {
185188
Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
186189
fn_def.indent(1.into());
@@ -377,6 +380,7 @@ struct ControlFlow<'db> {
377380
struct ContainerInfo<'db> {
378381
is_const: bool,
379382
parent_loop: Option<SyntaxNode>,
383+
trait_name: Option<ast::Type>,
380384
/// The function's return type, const's type etc.
381385
ret_type: Option<hir::Type<'db>>,
382386
generic_param_lists: Vec<ast::GenericParamList>,
@@ -838,6 +842,7 @@ impl FunctionBody {
838842
&self,
839843
sema: &Semantics<'db, RootDatabase>,
840844
edition: Edition,
845+
trait_name: Option<ast::Name>,
841846
) -> Option<(ContainerInfo<'db>, bool)> {
842847
let mut ancestors = self.parent()?.ancestors();
843848
let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
@@ -924,6 +929,9 @@ impl FunctionBody {
924929
false
925930
};
926931

932+
// FIXME: make trait arguments
933+
let trait_name = trait_name.map(|name| make::ty_path(make::ext::ident_path(&name.text())));
934+
927935
let parent = self.parent()?;
928936
let parents = generic_parents(&parent);
929937
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
@@ -934,6 +942,7 @@ impl FunctionBody {
934942
ContainerInfo {
935943
is_const,
936944
parent_loop,
945+
trait_name,
937946
ret_type: ty,
938947
generic_param_lists,
939948
where_clauses,
@@ -1419,14 +1428,18 @@ fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) {
14191428
fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) -> SyntaxNode {
14201429
let ret_ty = fun.return_type(ctx);
14211430

1422-
let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition)));
14231431
let name = fun.name.clone();
1424-
let mut call_expr = if fun.self_param.is_some() {
1432+
let args = fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition));
1433+
let mut call_expr = if fun.make_this_param().is_some() {
1434+
let self_arg = make::expr_path(make::ext::ident_path("self"));
1435+
let func = make::expr_path(make::path_unqualified(make::path_segment(name)));
1436+
make::expr_call(func, make::arg_list(Some(self_arg).into_iter().chain(args))).into()
1437+
} else if fun.self_param.is_some() {
14251438
let self_arg = make::expr_path(make::ext::ident_path("self"));
1426-
make::expr_method_call(self_arg, name, args).into()
1439+
make::expr_method_call(self_arg, name, make::arg_list(args)).into()
14271440
} else {
14281441
let func = make::expr_path(make::path_unqualified(make::path_segment(name)));
1429-
make::expr_call(func, args).into()
1442+
make::expr_call(func, make::arg_list(args)).into()
14301443
};
14311444

14321445
let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
@@ -1729,9 +1742,28 @@ impl<'db> Function<'db> {
17291742
module: hir::Module,
17301743
edition: Edition,
17311744
) -> ast::ParamList {
1732-
let self_param = self.self_param.clone();
1745+
let this_param = self.make_this_param().map(|f| f());
1746+
let self_param = self.self_param.clone().filter(|_| this_param.is_none());
17331747
let params = self.params.iter().map(|param| param.to_param(ctx, module, edition));
1734-
make::param_list(self_param, params)
1748+
make::param_list(self_param, this_param.into_iter().chain(params))
1749+
}
1750+
1751+
fn make_this_param(&self) -> Option<impl FnOnce() -> ast::Param> {
1752+
if let Some(name) = self.mods.trait_name.clone()
1753+
&& let Some(self_param) = &self.self_param
1754+
{
1755+
Some(|| {
1756+
let bounds = make::type_bound_list([make::type_bound(name)]);
1757+
let pat = make::path_pat(make::ext::ident_path("this"));
1758+
let mut ty = make::impl_trait_type(bounds.unwrap()).into();
1759+
if self_param.amp_token().is_some() {
1760+
ty = make::ty_ref(ty, self_param.mut_token().is_some());
1761+
}
1762+
make::param(pat, ty)
1763+
})
1764+
} else {
1765+
None
1766+
}
17351767
}
17361768

17371769
fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
@@ -1806,10 +1838,12 @@ fn make_body(
18061838
) -> ast::BlockExpr {
18071839
let ret_ty = fun.return_type(ctx);
18081840
let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
1841+
let to_this_param = fun.self_param.clone().filter(|_| fun.make_this_param().is_some());
18091842

18101843
let block = match &fun.body {
18111844
FunctionBody::Expr(expr) => {
1812-
let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax());
1845+
let expr =
1846+
rewrite_body_segment(ctx, to_this_param, &fun.params, &handler, expr.syntax());
18131847
let expr = ast::Expr::cast(expr).expect("Body segment should be an expr");
18141848
match expr {
18151849
ast::Expr::BlockExpr(block) => {
@@ -1847,7 +1881,7 @@ fn make_body(
18471881
.filter(|it| text_range.contains_range(it.text_range()))
18481882
.map(|it| match &it {
18491883
syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node(
1850-
rewrite_body_segment(ctx, &fun.params, &handler, n),
1884+
rewrite_body_segment(ctx, to_this_param.clone(), &fun.params, &handler, n),
18511885
),
18521886
_ => it,
18531887
})
@@ -1997,42 +2031,60 @@ fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) ->
19972031

19982032
fn rewrite_body_segment(
19992033
ctx: &AssistContext<'_>,
2034+
to_this_param: Option<ast::SelfParam>,
20002035
params: &[Param<'_>],
20012036
handler: &FlowHandler<'_>,
20022037
syntax: &SyntaxNode,
20032038
) -> SyntaxNode {
2004-
let syntax = fix_param_usages(ctx, params, syntax);
2039+
let to_this_param = to_this_param.and_then(|it| ctx.sema.to_def(&it));
2040+
let syntax = fix_param_usages(ctx, to_this_param, params, syntax);
20052041
update_external_control_flow(handler, &syntax);
20062042
syntax
20072043
}
20082044

20092045
/// change all usages to account for added `&`/`&mut` for some params
20102046
fn fix_param_usages(
20112047
ctx: &AssistContext<'_>,
2048+
to_this_param: Option<Local>,
20122049
params: &[Param<'_>],
20132050
syntax: &SyntaxNode,
20142051
) -> SyntaxNode {
20152052
let mut usages_for_param: Vec<(&Param<'_>, Vec<ast::Expr>)> = Vec::new();
2053+
let mut usages_for_self_param: Vec<ast::Expr> = Vec::new();
20162054

20172055
let tm = TreeMutator::new(syntax);
2056+
let reference_filter = |reference: &FileReference| {
2057+
syntax
2058+
.text_range()
2059+
.contains_range(reference.range)
2060+
.then_some(())
2061+
.and_then(|_| path_element_of_reference(syntax, reference))
2062+
.map(|expr| tm.make_mut(&expr))
2063+
};
20182064

2065+
if let Some(self_param) = to_this_param {
2066+
usages_for_self_param = LocalUsages::find_local_usages(ctx, self_param)
2067+
.iter()
2068+
.filter_map(reference_filter)
2069+
.collect();
2070+
}
20192071
for param in params {
20202072
if !param.kind().is_ref() {
20212073
continue;
20222074
}
20232075

20242076
let usages = LocalUsages::find_local_usages(ctx, param.var);
2025-
let usages = usages
2026-
.iter()
2027-
.filter(|reference| syntax.text_range().contains_range(reference.range))
2028-
.filter_map(|reference| path_element_of_reference(syntax, reference))
2029-
.map(|expr| tm.make_mut(&expr));
2077+
let usages = usages.iter().filter_map(reference_filter);
20302078

20312079
usages_for_param.push((param, usages.unique().collect()));
20322080
}
20332081

20342082
let res = tm.make_syntax_mut(syntax);
20352083

2084+
for self_usage in usages_for_self_param {
2085+
let this_expr = make::expr_path(make::ext::ident_path("this")).clone_for_update();
2086+
ted::replace(self_usage.syntax(), this_expr.syntax());
2087+
}
20362088
for (param, usages) in usages_for_param {
20372089
for usage in usages {
20382090
match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
@@ -2939,6 +2991,35 @@ impl S {
29392991
);
29402992
}
29412993

2994+
#[test]
2995+
fn method_in_trait() {
2996+
check_assist(
2997+
extract_function,
2998+
r#"
2999+
trait Foo {
3000+
fn f(&self) -> i32;
3001+
3002+
fn foo(&self) -> i32 {
3003+
$0self.f()+self.f()$0
3004+
}
3005+
}
3006+
"#,
3007+
r#"
3008+
trait Foo {
3009+
fn f(&self) -> i32;
3010+
3011+
fn foo(&self) -> i32 {
3012+
fun_name(self)
3013+
}
3014+
}
3015+
3016+
fn $0fun_name(this: &impl Foo) -> i32 {
3017+
this.f()+this.f()
3018+
}
3019+
"#,
3020+
);
3021+
}
3022+
29423023
#[test]
29433024
fn variable_defined_inside_and_used_after_no_ret() {
29443025
check_assist(

0 commit comments

Comments
 (0)