11use emmylua_parser:: {
22 BinaryOperator , LuaAstNode , LuaBinaryExpr , LuaCallExpr , LuaChunk , LuaExpr , LuaLiteralToken ,
3- LuaNameExpr ,
43} ;
54
65use crate :: {
6+ infer_expr,
77 semantic:: infer:: infer_name:: narrow:: {
8- condition_flow:: InferConditionFlow , get_single_antecedent,
9- get_type_at_flow:: get_type_at_flow, ResultTypeOrContinue ,
8+ condition_flow:: { call_flow:: get_type_at_call_expr, InferConditionFlow } ,
9+ get_single_antecedent,
10+ get_type_at_flow:: get_type_at_flow,
11+ ResultTypeOrContinue ,
1012 } ,
1113 DbIndex , FlowNode , FlowTree , InferFailReason , LuaDeclId , LuaInferCache , LuaType , TypeOps ,
1214} ;
@@ -44,13 +46,25 @@ pub fn get_type_at_binary_expr(
4446 root,
4547 decl_id,
4648 flow_node,
47- left_expr,
48- right_expr,
49+ left_expr. clone ( ) ,
50+ right_expr. clone ( ) ,
4951 condition_flow,
5052 ) ?;
5153 if let ResultTypeOrContinue :: Result ( result_type) = result_type {
5254 return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
5355 }
56+
57+ return maybe_var_eq_narrow (
58+ db,
59+ tree,
60+ cache,
61+ root,
62+ decl_id,
63+ flow_node,
64+ left_expr,
65+ right_expr,
66+ condition_flow,
67+ ) ;
5468 }
5569 BinaryOperator :: OpNe => {
5670 let result_type = maybe_type_guard_bianry (
@@ -60,13 +74,25 @@ pub fn get_type_at_binary_expr(
6074 root,
6175 decl_id,
6276 flow_node,
63- left_expr,
64- right_expr,
77+ left_expr. clone ( ) ,
78+ right_expr. clone ( ) ,
6579 condition_flow. get_negated ( ) ,
6680 ) ?;
6781 if let ResultTypeOrContinue :: Result ( result_type) = result_type {
6882 return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
6983 }
84+
85+ return maybe_var_eq_narrow (
86+ db,
87+ tree,
88+ cache,
89+ root,
90+ decl_id,
91+ flow_node,
92+ left_expr,
93+ right_expr,
94+ condition_flow. get_negated ( ) ,
95+ ) ;
7096 }
7197 _ => { }
7298 }
@@ -169,7 +195,7 @@ fn maybe_type_guard_bianry(
169195 Ok ( ResultTypeOrContinue :: Result ( result_type) )
170196}
171197
172- fn maybe_var_eq (
198+ fn maybe_var_eq_narrow (
173199 db : & DbIndex ,
174200 tree : & FlowTree ,
175201 cache : & mut LuaInferCache ,
@@ -178,6 +204,74 @@ fn maybe_var_eq(
178204 flow_node : & FlowNode ,
179205 left_expr : LuaExpr ,
180206 right_expr : LuaExpr ,
207+ condition_flow : InferConditionFlow ,
181208) -> Result < ResultTypeOrContinue , InferFailReason > {
209+ // only check left as need narrow
210+ match left_expr {
211+ LuaExpr :: NameExpr ( left_name_expr) => {
212+ let Some ( maybe_ref_id) = db
213+ . get_reference_index ( )
214+ . get_var_reference_decl ( & cache. get_file_id ( ) , left_name_expr. get_range ( ) )
215+ else {
216+ return Ok ( ResultTypeOrContinue :: Continue ) ;
217+ } ;
218+
219+ if maybe_ref_id != decl_id {
220+ // If the reference declaration ID does not match, we cannot narrow it
221+ return Ok ( ResultTypeOrContinue :: Continue ) ;
222+ }
223+
224+ let right_expr_type = infer_expr ( db, cache, right_expr) ?;
225+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
226+ let antecedent_type =
227+ get_type_at_flow ( db, tree, cache, root, maybe_ref_id, antecedent_flow_id) ?;
228+
229+ let result_type = match condition_flow {
230+ InferConditionFlow :: TrueCondition => {
231+ TypeOps :: Narrow . apply ( db, & antecedent_type, & right_expr_type)
232+ }
233+ InferConditionFlow :: FalseCondition => {
234+ TypeOps :: Remove . apply ( db, & antecedent_type, & right_expr_type)
235+ }
236+ } ;
237+ return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
238+ }
239+ LuaExpr :: CallExpr ( left_call_expr) => {
240+ match right_expr {
241+ LuaExpr :: LiteralExpr ( literal_expr) => match literal_expr. get_literal ( ) {
242+ Some ( LuaLiteralToken :: Bool ( b) ) => {
243+ let flow = if b. is_true ( ) {
244+ condition_flow
245+ } else {
246+ condition_flow. get_negated ( )
247+ } ;
248+
249+ return get_type_at_call_expr (
250+ db,
251+ tree,
252+ cache,
253+ root,
254+ decl_id,
255+ flow_node,
256+ left_call_expr,
257+ flow,
258+ ) ;
259+ }
260+ _ => return Ok ( ResultTypeOrContinue :: Continue ) ,
261+ } ,
262+ _ => { }
263+ } ;
264+
265+ return Ok ( ResultTypeOrContinue :: Continue ) ;
266+ }
267+ LuaExpr :: IndexExpr ( _) => {
268+ // todo
269+ }
270+ _ => {
271+ // If the left expression is not a name or call expression, we cannot narrow it
272+ return Ok ( ResultTypeOrContinue :: Continue ) ;
273+ }
274+ }
275+
182276 Ok ( ResultTypeOrContinue :: Continue )
183277}
0 commit comments