Skip to content

Commit d7a3f66

Browse files
fix
1 parent 947a3ee commit d7a3f66

3 files changed

Lines changed: 82 additions & 2 deletions

File tree

pyrefly/lib/binding/function.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use crate::binding::binding::ReturnTypeKind;
6363
use crate::binding::bindings::BindingsBuilder;
6464
use crate::binding::bindings::LegacyTParamCollector;
6565
use crate::binding::expr::Usage;
66+
use crate::binding::pattern::pattern_is_syntactically_exhaustive_for_subject;
6667
use crate::binding::scope::FlowStyle;
6768
use crate::binding::scope::InstanceAttribute;
6869
use crate::binding::scope::Scope;
@@ -974,7 +975,7 @@ fn function_last_expressions<'a>(
974975
let mut syntactically_exhaustive = false;
975976
for case in x.cases.iter() {
976977
f(sys_info, &case.body, res)?;
977-
if case.pattern.is_wildcard() || case.pattern.is_irrefutable() {
978+
if pattern_is_syntactically_exhaustive_for_subject(&x.subject, &case.pattern) {
978979
syntactically_exhaustive = true;
979980
break;
980981
}

pyrefly/lib/binding/pattern.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use ruff_python_ast::AtomicNodeIndex;
1212
use ruff_python_ast::Expr;
1313
use ruff_python_ast::ExprNumberLiteral;
1414
use ruff_python_ast::ExprStringLiteral;
15+
use ruff_python_ast::ExprTuple;
1516
use ruff_python_ast::Int;
1617
use ruff_python_ast::MatchCase;
1718
use ruff_python_ast::Number;
@@ -43,6 +44,47 @@ use crate::error::context::ErrorInfo;
4344
use crate::export::special::SpecialExport;
4445
use crate::types::facet::UnresolvedFacetKind;
4546

47+
fn sequence_pattern_is_syntactically_exhaustive(subjects: &[Expr], patterns: &[Pattern]) -> bool {
48+
let Some(star_idx) = patterns
49+
.iter()
50+
.position(|pattern| matches!(pattern, Pattern::MatchStar(_)))
51+
else {
52+
return subjects.len() == patterns.len()
53+
&& subjects.iter().zip(patterns).all(|(subject, pattern)| {
54+
pattern_is_syntactically_exhaustive_for_subject(subject, pattern)
55+
});
56+
};
57+
let suffix_len = patterns.len() - star_idx - 1;
58+
if subjects.len() + 1 < patterns.len() {
59+
return false;
60+
}
61+
patterns[..star_idx]
62+
.iter()
63+
.zip(&subjects[..star_idx])
64+
.all(|(pattern, subject)| pattern_is_syntactically_exhaustive_for_subject(subject, pattern))
65+
&& patterns[star_idx + 1..]
66+
.iter()
67+
.zip(&subjects[subjects.len() - suffix_len..])
68+
.all(|(pattern, subject)| {
69+
pattern_is_syntactically_exhaustive_for_subject(subject, pattern)
70+
})
71+
}
72+
73+
pub(crate) fn pattern_is_syntactically_exhaustive_for_subject(
74+
subject: &Expr,
75+
pattern: &Pattern,
76+
) -> bool {
77+
if pattern.is_wildcard() || pattern.is_irrefutable() {
78+
return true;
79+
}
80+
match (subject, pattern) {
81+
(Expr::Tuple(ExprTuple { elts, .. }), Pattern::MatchSequence(x)) => {
82+
sequence_pattern_is_syntactically_exhaustive(elts, &x.patterns)
83+
}
84+
_ => false,
85+
}
86+
}
87+
4688
#[derive(Clone, Debug)]
4789
enum MatchSubject {
4890
/// No narrowing subject available.
@@ -538,7 +580,8 @@ impl<'a> BindingsBuilder<'a> {
538580
..
539581
} = case;
540582
self.start_branch();
541-
let case_is_irrefutable = pattern.is_wildcard() || pattern.is_irrefutable();
583+
let case_is_irrefutable =
584+
pattern_is_syntactically_exhaustive_for_subject(&subject_expr, &pattern);
542585
if case_is_irrefutable {
543586
exhaustive = true;
544587
}

pyrefly/lib/test/pattern_match.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,42 @@ def test_multi_match2(o1: object, o2: object) -> None:
812812
"#,
813813
);
814814

815+
// Regression test for https://github.com/facebook/pyrefly/issues/2932
816+
testcase!(
817+
test_match_multi_subject_tuple_catch_all_is_exhaustive,
818+
r#"
819+
from typing import assert_type
820+
821+
def test(x: int | None, y: int | None) -> None:
822+
match x, y:
823+
case None, None:
824+
raise ValueError
825+
case int(m), None:
826+
u = m * 3
827+
v = m
828+
case None, int(n):
829+
u = n
830+
v = n // 3
831+
case _, _:
832+
raise ValueError
833+
834+
assert_type(u, int)
835+
assert_type(v, int)
836+
"#,
837+
);
838+
839+
testcase!(
840+
test_match_multi_subject_tuple_catch_all_counts_for_return_analysis,
841+
r#"
842+
def test(x: int | None, y: int | None) -> int:
843+
match x, y:
844+
case None, None:
845+
return 0
846+
case _, _:
847+
return 1
848+
"#,
849+
);
850+
815851
testcase!(
816852
test_exhaustive_enum_or_pattern_no_missing_return,
817853
r#"

0 commit comments

Comments
 (0)