Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 79 additions & 24 deletions crates/ide-assists/src/handlers/add_missing_match_arms.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::iter::{self, Peekable};
use std::iter;

use either::Either;
use hir::{Adt, AsAssocItem, Crate, FindPathConfig, HasAttrs, ModuleDef, Semantics};
Expand Down Expand Up @@ -93,8 +93,8 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
} else {
None
};
let (mut missing_pats, is_non_exhaustive, has_hidden_variants): (
Peekable<Box<dyn Iterator<Item = (ast::Pat, bool)>>>,
let (missing_pats, is_non_exhaustive, has_hidden_variants): (
Vec<(ast::Pat, bool)>,
bool,
bool,
) = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr, self_ty.as_ref()) {
Expand All @@ -117,15 +117,15 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
.filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));

let option_enum = FamousDefs(&ctx.sema, module.krate(ctx.db())).core_option_Option();
let missing_pats: Box<dyn Iterator<Item = _>> = if matches!(enum_def, ExtendedEnum::Enum { enum_: e, .. } if Some(e) == option_enum)
let missing_pats: Vec<_> = if matches!(enum_def, ExtendedEnum::Enum { enum_: e, .. } if Some(e) == option_enum)
{
// Match `Some` variant first.
cov_mark::hit!(option_order);
Box::new(missing_pats.rev())
missing_pats.rev().collect()
} else {
Box::new(missing_pats)
missing_pats.collect()
};
(missing_pats.peekable(), is_non_exhaustive, has_hidden_variants)
(missing_pats, is_non_exhaustive, has_hidden_variants)
} else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr, self_ty.as_ref()) {
let is_non_exhaustive = enum_defs
.iter()
Expand Down Expand Up @@ -169,12 +169,9 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)

(ast::Pat::from(make.tuple_pat(patterns)), is_hidden)
})
.filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
(
(Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(),
is_non_exhaustive,
has_hidden_variants,
)
.filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat))
.collect();
(missing_pats, is_non_exhaustive, has_hidden_variants)
} else if let Some((enum_def, len)) =
resolve_array_of_enum_def(&ctx.sema, &expr, self_ty.as_ref())
{
Expand Down Expand Up @@ -205,33 +202,41 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)

(ast::Pat::from(make.slice_pat(patterns)), is_hidden)
})
.filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
(
(Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(),
is_non_exhaustive,
has_hidden_variants,
)
.filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat))
.collect();
(missing_pats, is_non_exhaustive, has_hidden_variants)
} else {
return None;
};

let mut needs_catch_all_arm = is_non_exhaustive && !has_catch_all_arm;

if !needs_catch_all_arm
&& ((has_hidden_variants && has_catch_all_arm) || missing_pats.peek().is_none())
&& ((has_hidden_variants && has_catch_all_arm) || missing_pats.is_empty())
{
return None;
}

let visible_count = missing_pats.iter().filter(|(_, hidden)| !hidden).count();
let label = if visible_count == 0 {
"Add missing catch-all match arm `_`".to_owned()
} else if visible_count == 1 {
let pat = &missing_pats.iter().find(|(_, hidden)| !hidden).unwrap().0;
format!("Add missing match arm `{pat}`")
} else {
format!("Add {visible_count} missing match arms")
};

acc.add(
AssistId::quick_fix("add_missing_match_arms"),
"Fill match arms",
label,
ctx.sema.original_range(match_expr.syntax()).range,
|builder| {
// having any hidden variants means that we need a catch-all arm
needs_catch_all_arm |= has_hidden_variants;

let mut missing_arms = missing_pats
.into_iter()
.filter(|(_, hidden)| {
// filter out hidden patterns because they're handled by the catch-all arm
!hidden
Expand Down Expand Up @@ -635,7 +640,7 @@ mod tests {
use crate::AssistConfig;
use crate::tests::{
TEST_CONFIG, check_assist, check_assist_not_applicable, check_assist_target,
check_assist_unresolved, check_assist_with_config,
check_assist_unresolved, check_assist_with_config, check_assist_with_label,
};

use super::add_missing_match_arms;
Expand Down Expand Up @@ -1828,8 +1833,10 @@ fn foo(t: Test) {

#[test]
fn lazy_computation() {
// Computing a single missing arm is enough to determine applicability of the assist.
cov_mark::check_count!(add_missing_match_arms_lazy_computation, 1);
// We now collect all missing arms eagerly, so we can show the count
// of missing arms.
cov_mark::check_count!(add_missing_match_arms_lazy_computation, 4);

check_assist_unresolved(
add_missing_match_arms,
r#"
Expand All @@ -1841,6 +1848,54 @@ fn foo(tuple: (A, A)) {
);
}

#[test]
fn label_single_missing_arm() {
check_assist_with_label(
add_missing_match_arms,
r#"
enum A { One, Two }
fn foo(a: A) {
match $0a {
A::One => {}
}
}
"#,
"Add missing match arm `A::Two`",
);
}

#[test]
fn label_multiple_missing_arms() {
check_assist_with_label(
add_missing_match_arms,
r#"
enum A { One, Two, Three }
fn foo(a: A) {
match $0a {}
}
"#,
"Add 3 missing match arms",
);
}

#[test]
fn label_catch_all_only() {
check_assist_with_label(
add_missing_match_arms,
r#"
//- /main.rs crate:main deps:e
fn foo(t: ::e::E) {
match $0t {
e::E::A => {}
}
}
//- /e.rs crate:e
pub enum E { A, #[doc(hidden)] B, }
"#,
"Add missing catch-all match arm `_`",
);
}

#[test]
fn adds_comma_before_new_arms() {
check_assist(
Expand Down
20 changes: 18 additions & 2 deletions crates/ide-assists/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ pub(crate) fn check_assist_target(
check(assist, ra_fixture, ExpectedResult::Target(target), None);
}

#[track_caller]
pub(crate) fn check_assist_with_label(
assist: Handler,
#[rust_analyzer::rust_fixture] ra_fixture: &str,
label: &str,
) {
check(assist, ra_fixture, ExpectedResult::Label(label), None);
}

#[track_caller]
pub(crate) fn check_assist_not_applicable(
assist: Handler,
Expand Down Expand Up @@ -307,6 +316,7 @@ enum ExpectedResult<'a> {
Unresolved,
After(&'a str),
Target(&'a str),
Label(&'a str),
}

#[track_caller]
Expand Down Expand Up @@ -335,7 +345,7 @@ fn check_with_config(

let ctx = AssistContext::new(sema, &config, frange);
let resolve = match expected {
ExpectedResult::Unresolved => AssistResolveStrategy::None,
ExpectedResult::Unresolved | ExpectedResult::Label(_) => AssistResolveStrategy::None,
_ => AssistResolveStrategy::All,
};
let mut acc = Assists::new(&ctx, resolve);
Expand Down Expand Up @@ -404,14 +414,20 @@ fn check_with_config(
let range = assist.target;
assert_eq_text!(&text_without_caret[range], target);
}
(Some(assist), ExpectedResult::Label(label)) => {
assert_eq!(assist.label.to_string(), label);
}
(Some(assist), ExpectedResult::Unresolved) => assert!(
assist.source_change.is_none(),
"unresolved assist should not contain source changes"
),
(Some(_), ExpectedResult::NotApplicable) => panic!("assist should not be applicable!"),
(
None,
ExpectedResult::After(_) | ExpectedResult::Target(_) | ExpectedResult::Unresolved,
ExpectedResult::After(_)
| ExpectedResult::Target(_)
| ExpectedResult::Label(_)
| ExpectedResult::Unresolved,
) => {
panic!("code action is not applicable")
}
Expand Down