|
1 | 1 | use std::sync::Arc; |
2 | 2 |
|
3 | 3 | use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind}; |
| 4 | +use hashbrown::HashSet; |
4 | 5 | use rowan::TextRange; |
5 | 6 |
|
6 | 7 | use super::{ |
7 | 8 | super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, |
8 | 9 | InferFailReason, InferResult, |
9 | 10 | }; |
| 11 | +use crate::semantic::overload_resolve::callable_accepts_args; |
10 | 12 | use crate::{ |
11 | 13 | CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, |
12 | 14 | LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, |
@@ -102,12 +104,16 @@ pub fn infer_call_expr_func( |
102 | 104 | ), |
103 | 105 | LuaType::Union(union) => { |
104 | 106 | // 此时我们将其视为泛型实例化联合体 |
105 | | - if union |
106 | | - .into_vec() |
| 107 | + let union_types = union.into_vec(); |
| 108 | + if union_types |
107 | 109 | .iter() |
108 | 110 | .all(|t| matches!(t, LuaType::DocFunction(_))) |
109 | 111 | { |
110 | | - infer_generic_doc_function_union(db, cache, union, call_expr.clone(), args_count) |
| 112 | + let mut overloads = Vec::new(); |
| 113 | + for typ in union_types { |
| 114 | + overloads.extend(collect_callable_overloads(db, &typ)?); |
| 115 | + } |
| 116 | + resolve_filtered_overloads(db, cache, overloads, call_expr.clone(), args_count) |
111 | 117 | } else { |
112 | 118 | infer_union(db, cache, union, call_expr.clone(), args_count) |
113 | 119 | } |
@@ -194,19 +200,67 @@ fn infer_doc_function( |
194 | 200 | Ok(func.clone().into()) |
195 | 201 | } |
196 | 202 |
|
197 | | -fn infer_generic_doc_function_union( |
| 203 | +fn filter_callable_overloads_by_call_args( |
198 | 204 | db: &DbIndex, |
199 | 205 | cache: &mut LuaInferCache, |
200 | | - union: &LuaUnionType, |
| 206 | + overloads: Vec<Arc<LuaFunctionType>>, |
| 207 | + call_expr: &LuaCallExpr, |
| 208 | + args_count: Option<usize>, |
| 209 | +) -> Result<Vec<Arc<LuaFunctionType>>, InferFailReason> { |
| 210 | + let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; |
| 211 | + let expr_types = super::infer_expr_list_types( |
| 212 | + db, |
| 213 | + cache, |
| 214 | + &args.get_args().collect::<Vec<_>>(), |
| 215 | + args_count, |
| 216 | + |db, cache, expr| Ok(infer_expr(db, cache, expr).unwrap_or(LuaType::Unknown)), |
| 217 | + )? |
| 218 | + .into_iter() |
| 219 | + .map(|(ty, _)| ty) |
| 220 | + .collect::<Vec<_>>(); |
| 221 | + let is_colon_call = call_expr.is_colon_call(); |
| 222 | + |
| 223 | + Ok(overloads |
| 224 | + .into_iter() |
| 225 | + .filter(|func| { |
| 226 | + let mut callable_tpls = HashSet::new(); |
| 227 | + func.visit_type(&mut |ty| match ty { |
| 228 | + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { |
| 229 | + callable_tpls.insert(generic_tpl.get_tpl_id()); |
| 230 | + } |
| 231 | + LuaType::StrTplRef(str_tpl) => { |
| 232 | + callable_tpls.insert(str_tpl.get_tpl_id()); |
| 233 | + } |
| 234 | + _ => {} |
| 235 | + }); |
| 236 | + |
| 237 | + if callable_tpls.is_empty() { |
| 238 | + return true; |
| 239 | + } |
| 240 | + |
| 241 | + let mut substitutor = TypeSubstitutor::new(); |
| 242 | + substitutor.add_need_infer_tpls(callable_tpls); |
| 243 | + let match_func = match instantiate_doc_function(db, func, &substitutor) { |
| 244 | + LuaType::DocFunction(doc_func) => doc_func, |
| 245 | + _ => func.clone(), |
| 246 | + }; |
| 247 | + |
| 248 | + callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count) |
| 249 | + }) |
| 250 | + .collect()) |
| 251 | +} |
| 252 | + |
| 253 | +fn resolve_filtered_overloads( |
| 254 | + db: &DbIndex, |
| 255 | + cache: &mut LuaInferCache, |
| 256 | + overloads: Vec<Arc<LuaFunctionType>>, |
201 | 257 | call_expr: LuaCallExpr, |
202 | 258 | args_count: Option<usize>, |
203 | 259 | ) -> InferCallFuncResult { |
204 | | - let mut overloads = Vec::new(); |
205 | | - for typ in union.into_vec() { |
206 | | - overloads.extend(collect_callable_overloads(db, &typ)?); |
207 | | - } |
208 | | - |
209 | | - resolve_signature(db, cache, overloads, call_expr.clone(), false, args_count) |
| 260 | + let contains_tpl = overloads.iter().any(|func| func.contain_tpl()); |
| 261 | + let overloads = |
| 262 | + filter_callable_overloads_by_call_args(db, cache, overloads, &call_expr, args_count)?; |
| 263 | + resolve_signature(db, cache, overloads, call_expr, contains_tpl, args_count) |
210 | 264 | } |
211 | 265 |
|
212 | 266 | fn infer_signature_doc_function( |
@@ -482,8 +536,7 @@ fn infer_union( |
482 | 536 | if overloads.is_empty() { |
483 | 537 | return Err(InferFailReason::None); |
484 | 538 | } |
485 | | - let is_generic = overloads.iter().any(|func| func.contain_tpl()); |
486 | | - resolve_signature(db, cache, overloads, call_expr, is_generic, args_count) |
| 539 | + resolve_filtered_overloads(db, cache, overloads, call_expr, args_count) |
487 | 540 | } |
488 | 541 |
|
489 | 542 | fn collect_callable_overloads( |
|
0 commit comments