|
| 1 | +use std::sync::Arc; |
| 2 | + |
| 3 | +use emmylua_parser::{LuaAstToken, LuaCallExpr}; |
| 4 | +use rowan::TextRange; |
| 5 | + |
| 6 | +use crate::{ |
| 7 | + DbIndex, LuaFunctionType, LuaType, LuaTypeNode, SemanticModel, TypeSubstitutor, |
| 8 | + build_call_generic_substitutor, filter_callable_overloads, |
| 9 | +}; |
| 10 | + |
| 11 | +// 泛型约束上下文 |
| 12 | +pub(super) struct CallConstraintContext { |
| 13 | + pub params: Vec<(String, Option<LuaType>)>, |
| 14 | + pub args: Vec<CallConstraintArg>, |
| 15 | + pub substitutor: TypeSubstitutor, |
| 16 | +} |
| 17 | + |
| 18 | +pub(super) struct CallConstraintArg { |
| 19 | + pub raw_type: LuaType, |
| 20 | + pub range: TextRange, |
| 21 | +} |
| 22 | + |
| 23 | +pub(super) fn build_call_constraint_context( |
| 24 | + semantic_model: &SemanticModel, |
| 25 | + call_expr: &LuaCallExpr, |
| 26 | +) -> Option<CallConstraintContext> { |
| 27 | + let mut args = get_arg_infos(semantic_model, call_expr)?; |
| 28 | + let call_arg_types = args |
| 29 | + .iter() |
| 30 | + .map(|arg| arg.raw_type.clone()) |
| 31 | + .collect::<Vec<_>>(); |
| 32 | + let doc_func = get_call_doc_func(semantic_model, call_expr, &call_arg_types)?; |
| 33 | + |
| 34 | + let mut params = doc_func.get_params().to_vec(); |
| 35 | + let substitutor = { |
| 36 | + build_call_generic_substitutor( |
| 37 | + semantic_model.get_db(), |
| 38 | + &mut semantic_model.get_cache().borrow_mut(), |
| 39 | + &doc_func, |
| 40 | + call_expr, |
| 41 | + ) |
| 42 | + .ok()? |
| 43 | + }; |
| 44 | + |
| 45 | + // 处理冒号调用与函数定义在 self 参数上的差异 |
| 46 | + match (call_expr.is_colon_call(), doc_func.is_colon_define()) { |
| 47 | + (true, true) | (false, false) => {} |
| 48 | + (false, true) => { |
| 49 | + params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); |
| 50 | + } |
| 51 | + (true, false) => { |
| 52 | + let source_type = semantic_model.infer_call_receiver_type(call_expr)?; |
| 53 | + args.insert( |
| 54 | + 0, |
| 55 | + CallConstraintArg { |
| 56 | + raw_type: source_type, |
| 57 | + range: call_expr.get_colon_token()?.get_range(), |
| 58 | + }, |
| 59 | + ); |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + Some(CallConstraintContext { |
| 64 | + params, |
| 65 | + args, |
| 66 | + substitutor, |
| 67 | + }) |
| 68 | +} |
| 69 | + |
| 70 | +fn get_call_doc_func( |
| 71 | + semantic_model: &SemanticModel, |
| 72 | + call_expr: &LuaCallExpr, |
| 73 | + call_arg_types: &[LuaType], |
| 74 | +) -> Option<Arc<LuaFunctionType>> { |
| 75 | + let prefix_expr = call_expr.get_prefix_expr()?.clone(); |
| 76 | + let callable_type = semantic_model.infer_expr(prefix_expr).ok()?; |
| 77 | + let overloads = filter_callable_overloads( |
| 78 | + semantic_model.get_db(), |
| 79 | + &mut semantic_model.get_cache().borrow_mut(), |
| 80 | + &callable_type, |
| 81 | + &call_arg_types, |
| 82 | + call_expr, |
| 83 | + None, |
| 84 | + false, |
| 85 | + ) |
| 86 | + .ok()?; |
| 87 | + let generic_arg_count = |func: &LuaFunctionType| -> usize { |
| 88 | + call_arg_types |
| 89 | + .iter() |
| 90 | + .enumerate() |
| 91 | + .filter(|(arg_index, _)| { |
| 92 | + let mut param_index = *arg_index; |
| 93 | + match (func.is_colon_define(), call_expr.is_colon_call()) { |
| 94 | + (true, false) => { |
| 95 | + if param_index == 0 { |
| 96 | + return false; |
| 97 | + } |
| 98 | + param_index -= 1; |
| 99 | + } |
| 100 | + (false, true) => param_index += 1, |
| 101 | + _ => {} |
| 102 | + } |
| 103 | + |
| 104 | + let param_type = func |
| 105 | + .get_params() |
| 106 | + .get(param_index) |
| 107 | + .or_else(|| { |
| 108 | + func.get_params() |
| 109 | + .last() |
| 110 | + .filter(|last_param| last_param.0 == "...") |
| 111 | + }) |
| 112 | + .and_then(|(_, param_type)| param_type.as_ref()); |
| 113 | + param_type.is_some_and(|param_type| { |
| 114 | + param_type.any_type(|ty| match ty { |
| 115 | + LuaType::TplRef(tpl) => tpl.get_tpl_id().is_func(), |
| 116 | + LuaType::StrTplRef(tpl) => tpl.get_tpl_id().is_func(), |
| 117 | + _ => false, |
| 118 | + }) |
| 119 | + }) |
| 120 | + }) |
| 121 | + .count() |
| 122 | + }; |
| 123 | + // 诊断阶段保留所有可匹配候选, 但优先选择当前实参直接命中具体参数类型的 overload. |
| 124 | + let doc_func = overloads |
| 125 | + .into_iter() |
| 126 | + .min_by_key(|func| generic_arg_count(func.as_ref()))?; |
| 127 | + Some(doc_func) |
| 128 | +} |
| 129 | + |
| 130 | +// 将推导结果转换为更易比较的形式 |
| 131 | +pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { |
| 132 | + match ty { |
| 133 | + LuaType::Tuple(tuple) if tuple.is_infer_resolve() => tuple.collapse_to_union(db), |
| 134 | + _ => ty, |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +// 推导每个实参类型 |
| 139 | +fn get_arg_infos( |
| 140 | + semantic_model: &SemanticModel, |
| 141 | + call_expr: &LuaCallExpr, |
| 142 | +) -> Option<Vec<CallConstraintArg>> { |
| 143 | + let arg_exprs = call_expr.get_args_list()?.get_args().collect::<Vec<_>>(); |
| 144 | + let arg_infos = semantic_model |
| 145 | + .infer_expr_list_types(&arg_exprs, None) |
| 146 | + .into_iter() |
| 147 | + .map(|(raw_type, range)| CallConstraintArg { raw_type, range }) |
| 148 | + .collect(); |
| 149 | + |
| 150 | + Some(arg_infos) |
| 151 | +} |
0 commit comments