diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 43d1771b89..92d7a4d6af 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -63,6 +63,7 @@ use crate::binding::binding::ReturnTypeKind; use crate::binding::bindings::BindingsBuilder; use crate::binding::bindings::LegacyTParamCollector; use crate::binding::expr::Usage; +use crate::binding::pattern::pattern_is_syntactically_exhaustive_for_subject; use crate::binding::scope::FlowStyle; use crate::binding::scope::InstanceAttribute; use crate::binding::scope::Scope; @@ -974,7 +975,12 @@ fn function_last_expressions<'a>( let mut syntactically_exhaustive = false; for case in x.cases.iter() { f(sys_info, &case.body, res)?; - if case.pattern.is_wildcard() || case.pattern.is_irrefutable() { + if case.guard.is_none() + && pattern_is_syntactically_exhaustive_for_subject( + &x.subject, + &case.pattern, + ) + { syntactically_exhaustive = true; break; } diff --git a/pyrefly/lib/binding/pattern.rs b/pyrefly/lib/binding/pattern.rs index eb569835f9..2615450a61 100644 --- a/pyrefly/lib/binding/pattern.rs +++ b/pyrefly/lib/binding/pattern.rs @@ -12,6 +12,7 @@ use ruff_python_ast::AtomicNodeIndex; use ruff_python_ast::Expr; use ruff_python_ast::ExprNumberLiteral; use ruff_python_ast::ExprStringLiteral; +use ruff_python_ast::ExprTuple; use ruff_python_ast::Int; use ruff_python_ast::MatchCase; use ruff_python_ast::Number; @@ -43,6 +44,47 @@ use crate::error::context::ErrorInfo; use crate::export::special::SpecialExport; use crate::types::facet::UnresolvedFacetKind; +fn sequence_pattern_is_syntactically_exhaustive(subjects: &[Expr], patterns: &[Pattern]) -> bool { + let Some(star_idx) = patterns + .iter() + .position(|pattern| matches!(pattern, Pattern::MatchStar(_))) + else { + return subjects.len() == patterns.len() + && subjects.iter().zip(patterns).all(|(subject, pattern)| { + pattern_is_syntactically_exhaustive_for_subject(subject, pattern) + }); + }; + let suffix_len = patterns.len() - star_idx - 1; + if subjects.len() + 1 < patterns.len() { + return false; + } + patterns[..star_idx] + .iter() + .zip(&subjects[..star_idx]) + .all(|(pattern, subject)| pattern_is_syntactically_exhaustive_for_subject(subject, pattern)) + && patterns[star_idx + 1..] + .iter() + .zip(&subjects[subjects.len() - suffix_len..]) + .all(|(pattern, subject)| { + pattern_is_syntactically_exhaustive_for_subject(subject, pattern) + }) +} + +pub(crate) fn pattern_is_syntactically_exhaustive_for_subject( + subject: &Expr, + pattern: &Pattern, +) -> bool { + if pattern.is_wildcard() || pattern.is_irrefutable() { + return true; + } + match (subject, pattern) { + (Expr::Tuple(ExprTuple { elts, .. }), Pattern::MatchSequence(x)) => { + sequence_pattern_is_syntactically_exhaustive(elts, &x.patterns) + } + _ => false, + } +} + #[derive(Clone, Debug)] enum MatchSubject { /// No narrowing subject available. @@ -538,7 +580,8 @@ impl<'a> BindingsBuilder<'a> { .. } = case; self.start_branch(); - let case_is_irrefutable = pattern.is_wildcard() || pattern.is_irrefutable(); + let case_is_irrefutable = guard.is_none() + && pattern_is_syntactically_exhaustive_for_subject(&subject_expr, &pattern); if case_is_irrefutable { exhaustive = true; } diff --git a/pyrefly/lib/test/pattern_match.rs b/pyrefly/lib/test/pattern_match.rs index ab914733d3..29c6ef8a48 100644 --- a/pyrefly/lib/test/pattern_match.rs +++ b/pyrefly/lib/test/pattern_match.rs @@ -812,6 +812,77 @@ def test_multi_match2(o1: object, o2: object) -> None: "#, ); +// Regression test for https://github.com/facebook/pyrefly/issues/2932 +testcase!( + test_match_multi_subject_tuple_catch_all_is_exhaustive, + r#" +from typing import assert_type + +def test(x: int | None, y: int | None) -> None: + match x, y: + case None, None: + raise ValueError + case int(m), None: + u = m * 3 + v = m + case None, int(n): + u = n + v = n // 3 + case _, _: + raise ValueError + + assert_type(u, int) + assert_type(v, int) +"#, +); + +testcase!( + test_match_multi_subject_tuple_catch_all_counts_for_return_analysis, + r#" +def test(x: int | None, y: int | None) -> int: + match x, y: + case None, None: + return 0 + case _, _: + return 1 +"#, +); + +testcase!( + test_match_multi_subject_guarded_tuple_catch_all_is_not_exhaustive, + r#" +from typing import assert_type + +def test(x: int | None, y: int | None, cond: bool) -> None: + match x, y: + case None, None: + raise ValueError + case int(m), None: + u = m * 3 + v = m + case None, int(n): + u = n + v = n // 3 + case _, _ if cond: + raise ValueError + + assert_type(u, int) # E: `u` may be uninitialized + assert_type(v, int) # E: `v` may be uninitialized +"#, +); + +testcase!( + test_match_multi_subject_guarded_tuple_catch_all_counts_for_return_analysis, + r#" +def test(x: int | None, y: int | None, cond: bool) -> int: # E: Function declared to return `int`, but one or more paths are missing an explicit `return` + match x, y: + case None, None: + return 0 + case _, _ if cond: + return 1 +"#, +); + testcase!( test_exhaustive_enum_or_pattern_no_missing_return, r#"