Skip to content

Commit ca67d4b

Browse files
committed
fix #1048
1 parent ccdb922 commit ca67d4b

6 files changed

Lines changed: 515 additions & 67 deletions

File tree

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2897,4 +2897,223 @@ _2 = a[1]
28972897

28982898
assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar"));
28992899
}
2900+
2901+
#[test]
2902+
fn test_issue_1048() {
2903+
let mut ws = VirtualWorkspace::new();
2904+
2905+
ws.def(
2906+
r#"
2907+
--- @alias RunMode 'run'|'skip'
2908+
2909+
--- @class Suite
2910+
--- @field result string?
2911+
--- @field mode RunMode
2912+
2913+
--- @param a string
2914+
function TestSuite(a) end
2915+
"#,
2916+
);
2917+
assert!(ws.has_no_diagnostic(
2918+
DiagnosticCode::ParamTypeMismatch,
2919+
r#"
2920+
--- @type Suite
2921+
local suite
2922+
2923+
suite.result = 'a'
2924+
if suite.mode == "run" then
2925+
TestSuite(suite.result)
2926+
end
2927+
"#,
2928+
));
2929+
}
2930+
2931+
#[test]
2932+
fn test_discriminant_narrowed_sibling_field_keeps_prior_assignment_flow() {
2933+
let mut ws = VirtualWorkspace::new();
2934+
2935+
ws.def(
2936+
r#"
2937+
---@class A
2938+
---@field type "point"
2939+
---@field handle string?
2940+
2941+
---@class B
2942+
---@field type "unit"
2943+
---@field handle integer?
2944+
2945+
---@param a string
2946+
function testA(a) end
2947+
"#,
2948+
);
2949+
assert!(ws.has_no_diagnostic(
2950+
DiagnosticCode::ParamTypeMismatch,
2951+
r#"
2952+
---@param target A | B
2953+
function test(target)
2954+
target.handle = "ready"
2955+
if target.type == "point" then
2956+
testA(target.handle)
2957+
end
2958+
end
2959+
"#,
2960+
));
2961+
}
2962+
2963+
#[test]
2964+
fn test_discriminant_narrowed_sibling_field_keeps_prior_truthiness_flow() {
2965+
let mut ws = VirtualWorkspace::new();
2966+
2967+
ws.def(
2968+
r#"
2969+
---@class A
2970+
---@field type "point"
2971+
---@field handle string?
2972+
2973+
---@class B
2974+
---@field type "unit"
2975+
---@field handle integer?
2976+
2977+
---@param a string
2978+
function testA(a) end
2979+
"#,
2980+
);
2981+
assert!(ws.has_no_diagnostic(
2982+
DiagnosticCode::ParamTypeMismatch,
2983+
r#"
2984+
---@param target A | B
2985+
function test(target)
2986+
if target.handle then
2987+
if target.type == "point" then
2988+
testA(target.handle)
2989+
end
2990+
end
2991+
end
2992+
"#,
2993+
));
2994+
}
2995+
2996+
#[test]
2997+
fn test_discriminant_narrowed_sibling_field_keeps_prior_nil_guard_flow() {
2998+
let mut ws = VirtualWorkspace::new();
2999+
3000+
ws.def(
3001+
r#"
3002+
---@class A
3003+
---@field type "point"
3004+
---@field handle string?
3005+
3006+
---@class B
3007+
---@field type "unit"
3008+
---@field handle integer?
3009+
3010+
---@param a string
3011+
function testA(a) end
3012+
"#,
3013+
);
3014+
assert!(ws.has_no_diagnostic(
3015+
DiagnosticCode::ParamTypeMismatch,
3016+
r#"
3017+
---@param target A | B
3018+
function test(target)
3019+
if target.handle ~= nil then
3020+
if target.type == "point" then
3021+
testA(target.handle)
3022+
end
3023+
end
3024+
end
3025+
"#,
3026+
));
3027+
}
3028+
3029+
#[test]
3030+
fn test_discriminant_narrowed_sibling_field_keeps_prior_literal_guard_flow() {
3031+
let mut ws = VirtualWorkspace::new();
3032+
3033+
ws.def(
3034+
r#"
3035+
---@class A
3036+
---@field type "point"
3037+
---@field handle string?
3038+
3039+
---@class B
3040+
---@field type "unit"
3041+
---@field handle integer?
3042+
3043+
---@param a string
3044+
function testA(a) end
3045+
"#,
3046+
);
3047+
assert!(ws.has_no_diagnostic(
3048+
DiagnosticCode::ParamTypeMismatch,
3049+
r#"
3050+
---@param target A | B
3051+
function test(target)
3052+
if target.handle == "ready" then
3053+
if target.type == "point" then
3054+
testA(target.handle)
3055+
end
3056+
end
3057+
end
3058+
"#,
3059+
));
3060+
}
3061+
3062+
#[test]
3063+
fn test_discriminant_false_branch_all_members_match_is_never() {
3064+
let mut ws = VirtualWorkspace::new();
3065+
3066+
ws.def(
3067+
r#"
3068+
---@class A
3069+
---@field kind "foo"
3070+
3071+
---@class B
3072+
---@field kind "foo"
3073+
"#,
3074+
);
3075+
3076+
ws.def(
3077+
r#"
3078+
---@param target A | B
3079+
function test(target)
3080+
if target.kind ~= "foo" then
3081+
impossible = target
3082+
end
3083+
end
3084+
"#,
3085+
);
3086+
3087+
assert_eq!(ws.expr_ty("impossible"), ws.ty("never"));
3088+
}
3089+
3090+
#[test]
3091+
fn test_discriminant_sibling_projection_preserves_missing_member_nil() {
3092+
let mut ws = VirtualWorkspace::new();
3093+
3094+
ws.def(
3095+
r#"
3096+
---@class A
3097+
---@field type "point"
3098+
---@field handle string
3099+
3100+
---@class C
3101+
---@field type "point"
3102+
3103+
---@param a string
3104+
function testA(a) end
3105+
"#,
3106+
);
3107+
assert!(!ws.has_no_diagnostic(
3108+
DiagnosticCode::ParamTypeMismatch,
3109+
r#"
3110+
---@param target A | C
3111+
function test(target)
3112+
if target.type == "point" then
3113+
testA(target.handle)
3114+
end
3115+
end
3116+
"#,
3117+
));
3118+
}
29003119
}

crates/emmylua_code_analysis/src/semantic/cache/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod cache_options;
22

33
pub use cache_options::{CacheOptions, LuaAnalysisPhase};
44
use emmylua_parser::{LuaExpr, LuaSyntaxId, LuaVarExpr};
5-
use hashbrown::{HashMap, HashSet};
5+
use hashbrown::HashMap;
66
use std::{rc::Rc, sync::Arc};
77

88
use crate::{
@@ -39,7 +39,7 @@ pub(in crate::semantic) enum FlowMode {
3939

4040
impl FlowMode {
4141
pub fn uses_conditions(self) -> bool {
42-
matches!(self, Self::WithConditions)
42+
!matches!(self, Self::WithoutConditions)
4343
}
4444
}
4545

@@ -64,7 +64,6 @@ pub struct LuaInferCache {
6464
pub(in crate::semantic) flow_assignment_info_cache: Vec<Option<Rc<FlowAssignmentInfo>>>,
6565
pub index_ref_origin_type_cache: HashMap<VarRefId, CacheEntry<LuaType>>,
6666
pub expr_var_ref_id_cache: HashMap<LuaSyntaxId, VarRefId>,
67-
pub narrow_by_literal_stop_position_cache: HashSet<LuaSyntaxId>,
6867
}
6968

7069
impl LuaInferCache {
@@ -82,7 +81,6 @@ impl LuaInferCache {
8281
flow_assignment_info_cache: Vec::new(),
8382
index_ref_origin_type_cache: HashMap::new(),
8483
expr_var_ref_id_cache: HashMap::new(),
85-
narrow_by_literal_stop_position_cache: HashSet::new(),
8684
}
8785
}
8886

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
InferConditionFlow, PendingConditionNarrow, always_literal_equal,
1414
call_flow::get_type_at_call_expr,
1515
},
16-
get_single_antecedent, get_var_ref_type,
16+
get_single_antecedent,
1717
var_ref_id::get_var_expr_var_ref_id,
1818
},
1919
},
@@ -524,7 +524,6 @@ fn maybe_field_literal_eq_action(
524524
condition_flow: InferConditionFlow,
525525
) -> Result<Option<ConditionFlowAction>, InferFailReason> {
526526
// only check left as need narrow
527-
let syntax_id = left_expr.get_syntax_id();
528527
let (index_expr, literal_expr) = match (left_expr, right_expr) {
529528
(LuaExpr::IndexExpr(index_expr), LuaExpr::LiteralExpr(literal_expr)) => {
530529
(index_expr, literal_expr)
@@ -542,29 +541,36 @@ fn maybe_field_literal_eq_action(
542541
return Ok(None);
543542
};
544543

544+
let index_var_ref_id =
545+
get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone()));
546+
if index_var_ref_id.as_ref() == Some(var_ref_id) {
547+
return Ok(None);
548+
}
549+
545550
let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else {
546551
// If we cannot find a reference declaration ID, we cannot narrow it
547552
return Ok(None);
548553
};
549554

550555
if maybe_var_ref_id != *var_ref_id {
551-
if cache
552-
.narrow_by_literal_stop_position_cache
553-
.contains(&syntax_id)
554-
&& var_ref_id.start_with(&maybe_var_ref_id)
555-
{
556-
return Ok(Some(ConditionFlowAction::Result(get_var_ref_type(
557-
db, cache, var_ref_id,
558-
)?)));
556+
if var_ref_id.start_with(&maybe_var_ref_id) {
557+
let right_type = infer_expr(db, cache, LuaExpr::LiteralExpr(literal_expr))?;
558+
return Ok(Some(ConditionFlowAction::NeedSubquery(
559+
ConditionSubquery::FieldLiteralSibling {
560+
var_ref_id: var_ref_id.clone(),
561+
discriminant_prefix_var_ref_id: maybe_var_ref_id,
562+
antecedent_flow_id: get_single_antecedent(flow_node)?,
563+
subquery_condition_flow: condition_flow,
564+
idx: LuaIndexMemberExpr::IndexExpr(index_expr),
565+
right_expr_type: right_type,
566+
},
567+
)));
559568
}
560569

561570
return Ok(None);
562571
}
563572

564573
let antecedent_flow_id = get_single_antecedent(flow_node)?;
565-
cache
566-
.narrow_by_literal_stop_position_cache
567-
.insert(syntax_id);
568574
let right_type = infer_expr(db, cache, LuaExpr::LiteralExpr(literal_expr))?;
569575
Ok(Some(ConditionFlowAction::NeedSubquery(
570576
ConditionSubquery::FieldLiteralEq {

0 commit comments

Comments
 (0)