1- use emmylua_parser:: { LuaCallExpr , LuaChunk } ;
1+ use std:: ops:: Deref ;
2+
3+ use emmylua_parser:: { LuaCallExpr , LuaChunk , LuaExpr } ;
24
35use crate :: {
4- infer_call_expr_func ,
5- semantic:: {
6- infer :: {
7- infer_call :: infer_call_expr ,
8- narrow :: { condition_flow :: InferConditionFlow , ResultTypeOrContinue } ,
9- VarRefId ,
6+ infer_expr ,
7+ semantic:: infer :: {
8+ narrow :: {
9+ condition_flow :: InferConditionFlow , get_single_antecedent ,
10+ get_type_at_cast_flow :: cast_type , get_type_at_flow :: get_type_at_flow ,
11+ var_ref_id :: get_var_expr_var_ref_id , ResultTypeOrContinue ,
1012 } ,
11- semantic_info :: { infer_expr_semantic_decl , SemanticDeclGuard } ,
13+ VarRefId ,
1214 } ,
13- DbIndex , FlowNode , FlowTree , InferFailReason , LuaInferCache , LuaSemanticDeclId ,
14- SemanticDeclLevel ,
15+ DbIndex , FlowNode , FlowTree , InferFailReason , LuaInferCache , LuaSignatureCast , LuaSignatureId ,
16+ LuaType , TypeOps ,
1517} ;
1618
1719#[ allow( unused) ]
@@ -29,35 +31,230 @@ pub fn get_type_at_call_expr(
2931 return Ok ( ResultTypeOrContinue :: Continue ) ;
3032 } ;
3133
32- let Some ( semantic_decl_id) = infer_expr_semantic_decl (
34+ let maybe_func = infer_expr ( db, cache, prefix_expr. clone ( ) ) ?;
35+ match maybe_func {
36+ LuaType :: DocFunction ( f) => {
37+ let return_type = f. get_ret ( ) ;
38+ match return_type {
39+ LuaType :: TypeGuard ( guard_type) => {
40+ return get_type_at_call_expr_by_type_guard (
41+ db,
42+ tree,
43+ cache,
44+ root,
45+ var_ref_id,
46+ flow_node,
47+ call_expr,
48+ guard_type. deref ( ) . clone ( ) ,
49+ condition_flow,
50+ )
51+ }
52+ _ => {
53+ // If the return type is not a type guard, we cannot infer the type cast.
54+ return Ok ( ResultTypeOrContinue :: Continue ) ;
55+ }
56+ }
57+ }
58+ LuaType :: Signature ( signature_id) => {
59+ let Some ( signature_cast) = db
60+ . get_flow_index ( )
61+ . get_signature_cast ( & cache. get_file_id ( ) , & signature_id)
62+ else {
63+ return Ok ( ResultTypeOrContinue :: Continue ) ;
64+ } ;
65+
66+ match signature_cast. name . as_str ( ) {
67+ "self" => {
68+ return get_type_at_call_expr_by_signature_self (
69+ db,
70+ tree,
71+ cache,
72+ root,
73+ var_ref_id,
74+ flow_node,
75+ prefix_expr,
76+ signature_cast,
77+ condition_flow,
78+ ) ;
79+ }
80+ name => {
81+ return get_type_at_call_expr_by_signature_param_name (
82+ db,
83+ tree,
84+ cache,
85+ root,
86+ var_ref_id,
87+ flow_node,
88+ call_expr,
89+ signature_cast,
90+ signature_id,
91+ name,
92+ condition_flow,
93+ )
94+ }
95+ }
96+ }
97+ _ => {
98+ // If the prefix expression is not a function, we cannot infer the type cast.
99+ return Ok ( ResultTypeOrContinue :: Continue ) ;
100+ }
101+ }
102+
103+ Ok ( ResultTypeOrContinue :: Continue )
104+ }
105+
106+ fn get_type_at_call_expr_by_type_guard (
107+ db : & DbIndex ,
108+ tree : & FlowTree ,
109+ cache : & mut LuaInferCache ,
110+ root : & LuaChunk ,
111+ var_ref_id : & VarRefId ,
112+ flow_node : & FlowNode ,
113+ call_expr : LuaCallExpr ,
114+ guard_type : LuaType ,
115+ condition_flow : InferConditionFlow ,
116+ ) -> Result < ResultTypeOrContinue , InferFailReason > {
117+ let Some ( arg_list) = call_expr. get_args_list ( ) else {
118+ return Ok ( ResultTypeOrContinue :: Continue ) ;
119+ } ;
120+
121+ let Some ( first_arg) = arg_list. get_args ( ) . next ( ) else {
122+ return Ok ( ResultTypeOrContinue :: Continue ) ;
123+ } ;
124+
125+ let Some ( maybe_ref_id) = get_var_expr_var_ref_id ( db, cache, first_arg) else {
126+ return Ok ( ResultTypeOrContinue :: Continue ) ;
127+ } ;
128+
129+ if maybe_ref_id != * var_ref_id {
130+ return Ok ( ResultTypeOrContinue :: Continue ) ;
131+ }
132+
133+ match condition_flow {
134+ InferConditionFlow :: TrueCondition => Ok ( ResultTypeOrContinue :: Result ( guard_type) ) ,
135+ InferConditionFlow :: FalseCondition => {
136+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
137+ let antecedent_type =
138+ get_type_at_flow ( db, tree, cache, root, var_ref_id, antecedent_flow_id) ?;
139+ Ok ( ResultTypeOrContinue :: Result ( TypeOps :: Remove . apply (
140+ db,
141+ & antecedent_type,
142+ & guard_type,
143+ ) ) )
144+ }
145+ }
146+ }
147+
148+ fn get_type_at_call_expr_by_signature_self (
149+ db : & DbIndex ,
150+ tree : & FlowTree ,
151+ cache : & mut LuaInferCache ,
152+ root : & LuaChunk ,
153+ var_ref_id : & VarRefId ,
154+ flow_node : & FlowNode ,
155+ call_prefix : LuaExpr ,
156+ signature_cast : & LuaSignatureCast ,
157+ condition_flow : InferConditionFlow ,
158+ ) -> Result < ResultTypeOrContinue , InferFailReason > {
159+ let LuaExpr :: IndexExpr ( call_prefix_index) = call_prefix else {
160+ return Ok ( ResultTypeOrContinue :: Continue ) ;
161+ } ;
162+
163+ let Some ( self_expr) = call_prefix_index. get_prefix_expr ( ) else {
164+ return Ok ( ResultTypeOrContinue :: Continue ) ;
165+ } ;
166+
167+ let Some ( name_var_ref_id) = get_var_expr_var_ref_id ( db, cache, self_expr) else {
168+ return Ok ( ResultTypeOrContinue :: Continue ) ;
169+ } ;
170+
171+ if name_var_ref_id != * var_ref_id {
172+ return Ok ( ResultTypeOrContinue :: Continue ) ;
173+ }
174+
175+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
176+ let antecedent_type = get_type_at_flow ( db, tree, cache, root, var_ref_id, antecedent_flow_id) ?;
177+
178+ let Some ( cast_op_type) = signature_cast. cast . to_node ( root) else {
179+ return Ok ( ResultTypeOrContinue :: Continue ) ;
180+ } ;
181+
182+ let result_type = cast_type (
33183 db,
34- cache,
35- prefix_expr,
36- SemanticDeclGuard :: default ( ) ,
37- SemanticDeclLevel :: NoTrace ,
38- ) else {
184+ cache. get_file_id ( ) ,
185+ cast_op_type,
186+ antecedent_type,
187+ condition_flow,
188+ ) ?;
189+ Ok ( ResultTypeOrContinue :: Result ( result_type) )
190+ }
191+
192+ fn get_type_at_call_expr_by_signature_param_name (
193+ db : & DbIndex ,
194+ tree : & FlowTree ,
195+ cache : & mut LuaInferCache ,
196+ root : & LuaChunk ,
197+ var_ref_id : & VarRefId ,
198+ flow_node : & FlowNode ,
199+ call_expr : LuaCallExpr ,
200+ signature_cast : & LuaSignatureCast ,
201+ signature_id : LuaSignatureId ,
202+ name : & str ,
203+ condition_flow : InferConditionFlow ,
204+ ) -> Result < ResultTypeOrContinue , InferFailReason > {
205+ let colon_call = call_expr. is_colon_call ( ) ;
206+ let Some ( arg_list) = call_expr. get_args_list ( ) else {
39207 return Ok ( ResultTypeOrContinue :: Continue ) ;
40208 } ;
41209
42- let LuaSemanticDeclId :: Signature ( signature_id ) = semantic_decl_id else {
210+ let Some ( signature ) = db . get_signature_index ( ) . get ( & signature_id ) else {
43211 return Ok ( ResultTypeOrContinue :: Continue ) ;
44212 } ;
45213
46- let Some ( signature_cast) = db
47- . get_flow_index ( )
48- . get_signature_cast ( & cache. get_file_id ( ) , & signature_id)
49- else {
214+ let Some ( mut param_idx) = signature. find_param_idx ( name) else {
50215 return Ok ( ResultTypeOrContinue :: Continue ) ;
51216 } ;
52217
53- match signature_cast. name . as_str ( ) {
54- "self" => { }
218+ let colon_define = signature. is_colon_define ;
219+ match ( colon_call, colon_define) {
220+ ( true , false ) => {
221+ if param_idx == 0 {
222+ return Ok ( ResultTypeOrContinue :: Continue ) ;
223+ }
224+
225+ param_idx -= 1 ;
226+ }
227+ ( false , true ) => {
228+ param_idx += 1 ;
229+ }
55230 _ => { }
56231 }
57232
58- Ok ( ResultTypeOrContinue :: Continue )
59- }
233+ let Some ( expr) = arg_list. get_args ( ) . nth ( param_idx) else {
234+ return Ok ( ResultTypeOrContinue :: Continue ) ;
235+ } ;
236+
237+ let Some ( name_var_ref_id) = get_var_expr_var_ref_id ( db, cache, expr) else {
238+ return Ok ( ResultTypeOrContinue :: Continue ) ;
239+ } ;
240+
241+ if name_var_ref_id != * var_ref_id {
242+ return Ok ( ResultTypeOrContinue :: Continue ) ;
243+ }
244+
245+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
246+ let antecedent_type = get_type_at_flow ( db, tree, cache, root, var_ref_id, antecedent_flow_id) ?;
60247
61- // pub fn get_type_cast_call(call_prefix: LuaExpr) {
62- // let
63- // }
248+ let Some ( cast_op_type) = signature_cast. cast . to_node ( root) else {
249+ return Ok ( ResultTypeOrContinue :: Continue ) ;
250+ } ;
251+
252+ let result_type = cast_type (
253+ db,
254+ cache. get_file_id ( ) ,
255+ cast_op_type,
256+ antecedent_type,
257+ condition_flow,
258+ ) ?;
259+ Ok ( ResultTypeOrContinue :: Result ( result_type) )
260+ }
0 commit comments