Skip to content

Commit c058e3d

Browse files
committed
fix(generic): keep callable overload inference shape-aware
1 parent f81dcfb commit c058e3d

5 files changed

Lines changed: 199 additions & 19 deletions

File tree

crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,98 @@ mod test {
200200
// unresolved and carried through a named callback value.
201201
assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string"));
202202
}
203+
204+
#[test]
205+
fn test_apply_return_infer_leaves_result_unknown_when_no_callable_member_matches_arg_shape() {
206+
let mut ws = VirtualWorkspace::new();
207+
ws.def(
208+
r#"
209+
---@generic A, R
210+
---@param f fun(x: A): R
211+
---@param x A
212+
---@return R
213+
local function apply(f, x)
214+
return f(x)
215+
end
216+
217+
---@alias FnInt fun(x: integer): integer
218+
---@alias FnString fun(x: string): string
219+
220+
---@type FnInt | FnString
221+
local run
222+
223+
---@type boolean
224+
local b
225+
226+
result = apply(run, b)
227+
"#,
228+
);
229+
230+
let result_ty = ws.expr_ty("result");
231+
assert_eq!(result_ty, ws.ty("unknown"));
232+
}
233+
234+
#[test]
235+
fn test_apply_return_infer_keeps_only_arity_compatible_fallbacks() {
236+
let mut ws = VirtualWorkspace::new();
237+
ws.def(
238+
r#"
239+
---@generic A, B, R
240+
---@param f fun(x: A, y: B): R
241+
---@param x A
242+
---@param y B
243+
---@return R
244+
local function apply2(f, x, y)
245+
return f(x, y)
246+
end
247+
248+
---@overload fun(x: integer): integer
249+
---@param x integer
250+
---@param y string
251+
---@return string
252+
local function run(x, y) end
253+
254+
local source ---@type table
255+
256+
result = apply2(run, 1, source.missing)
257+
"#,
258+
);
259+
260+
let result_ty = ws.expr_ty("result");
261+
assert_eq!(ws.humanize_type(result_ty), "string");
262+
}
263+
264+
#[test]
265+
fn test_union_call_ignores_non_matching_generic_callable_member() {
266+
let mut ws = VirtualWorkspace::new();
267+
ws.def(
268+
r#"
269+
---@type (fun<T: string>(x: T): T) | fun(x: integer): integer
270+
local run
271+
272+
result = run(1)
273+
"#,
274+
);
275+
276+
let result_ty = ws.expr_ty("result");
277+
assert_eq!(ws.humanize_type(result_ty), "integer");
278+
}
279+
280+
#[test]
281+
fn test_union_call_ignores_non_matching_generic_alias_member() {
282+
let mut ws = VirtualWorkspace::new();
283+
ws.def(
284+
r#"
285+
---@alias GenericStr<T: string> fun(x: T): T
286+
287+
---@type GenericStr | fun(x: integer): integer
288+
local run
289+
290+
result = run(1)
291+
"#,
292+
);
293+
294+
let result_ty = ws.expr_ty("result");
295+
assert_eq!(ws.humanize_type(result_ty), "integer");
296+
}
203297
}

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::{
2727
};
2828
use crate::{
2929
LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl,
30-
tpl_pattern_match_args,
30+
tpl_pattern_match_args_skip_unknown,
3131
};
3232

3333
use super::{TypeSubstitutor, collect_callable_overload_groups};
@@ -259,15 +259,22 @@ pub fn infer_callable_return_from_remaining_args(
259259
let call_arg_types =
260260
match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) {
261261
Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::<Vec<_>>(),
262-
Err(_) => {
263-
return infer_callable_return_from_arg_types(context, callable_type, None, true);
264-
}
262+
Err(_) => arg_exprs
263+
.iter()
264+
.map(|arg_expr| {
265+
infer_expr(context.db, context.cache, arg_expr.clone())
266+
.unwrap_or(LuaType::Unknown)
267+
})
268+
.collect::<Vec<_>>(),
265269
};
266270
if call_arg_types.is_empty() {
267271
return Ok(None);
268272
}
269273

270-
infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), true)
274+
// Preserve any known remaining-arg shape, including arity, even when some later arguments
275+
// collapse to `unknown`. This avoids unioning returns from overloads that are impossible
276+
// for the current call.
277+
infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), false)
271278
}
272279

273280
fn instantiate_callable_from_arg_types(
@@ -302,7 +309,12 @@ fn instantiate_callable_from_arg_types(
302309
substitutor: &mut callable_substitutor,
303310
call_expr: context.call_expr.clone(),
304311
};
305-
if tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types).is_err()
312+
if tpl_pattern_match_args_skip_unknown(
313+
&mut callable_context,
314+
&callable_param_types,
315+
call_arg_types,
316+
)
317+
.is_err()
306318
{
307319
return None;
308320
}

crates/emmylua_code_analysis/src/semantic/generic/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub use instantiate_type::*;
1515
use rowan::NodeOrToken;
1616
pub use tpl_context::TplContext;
1717
pub use tpl_pattern::tpl_pattern_match_args;
18+
pub use tpl_pattern::tpl_pattern_match_args_skip_unknown;
1819
pub use type_substitutor::TypeSubstitutor;
1920

2021
use crate::DbIndex;

crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ pub fn tpl_pattern_match_args(
3232
context: &mut TplContext,
3333
func_param_types: &[LuaType],
3434
call_arg_types: &[LuaType],
35+
) -> TplPatternMatchResult {
36+
tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, false)
37+
}
38+
39+
pub fn tpl_pattern_match_args_skip_unknown(
40+
context: &mut TplContext,
41+
func_param_types: &[LuaType],
42+
call_arg_types: &[LuaType],
43+
) -> TplPatternMatchResult {
44+
tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, true)
45+
}
46+
47+
fn tpl_pattern_match_args_inner(
48+
context: &mut TplContext,
49+
func_param_types: &[LuaType],
50+
call_arg_types: &[LuaType],
51+
skip_unknown_tpl: bool,
3552
) -> TplPatternMatchResult {
3653
for i in 0..func_param_types.len() {
3754
if i >= call_arg_types.len() {
@@ -54,6 +71,9 @@ pub fn tpl_pattern_match_args(
5471
)?;
5572
break;
5673
}
74+
_ if skip_unknown_tpl
75+
&& func_param_type.contain_tpl()
76+
&& (call_arg_type.is_any() || call_arg_type.is_unknown()) => {}
5777
_ => {
5878
tpl_pattern_match(context, func_param_type, call_arg_type)?;
5979
}

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use std::sync::Arc;
22

33
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind};
4+
use hashbrown::HashSet;
45
use rowan::TextRange;
56

67
use super::{
78
super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature},
89
InferFailReason, InferResult,
910
};
11+
use crate::semantic::overload_resolve::callable_accepts_args;
1012
use crate::{
1113
CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType,
1214
LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId,
@@ -102,12 +104,16 @@ pub fn infer_call_expr_func(
102104
),
103105
LuaType::Union(union) => {
104106
// 此时我们将其视为泛型实例化联合体
105-
if union
106-
.into_vec()
107+
let union_types = union.into_vec();
108+
if union_types
107109
.iter()
108110
.all(|t| matches!(t, LuaType::DocFunction(_)))
109111
{
110-
infer_generic_doc_function_union(db, cache, union, call_expr.clone(), args_count)
112+
let mut overloads = Vec::new();
113+
for typ in union_types {
114+
overloads.extend(collect_callable_overloads(db, &typ)?);
115+
}
116+
resolve_filtered_overloads(db, cache, overloads, call_expr.clone(), args_count)
111117
} else {
112118
infer_union(db, cache, union, call_expr.clone(), args_count)
113119
}
@@ -194,19 +200,67 @@ fn infer_doc_function(
194200
Ok(func.clone().into())
195201
}
196202

197-
fn infer_generic_doc_function_union(
203+
fn filter_callable_overloads_by_call_args(
198204
db: &DbIndex,
199205
cache: &mut LuaInferCache,
200-
union: &LuaUnionType,
206+
overloads: Vec<Arc<LuaFunctionType>>,
207+
call_expr: &LuaCallExpr,
208+
args_count: Option<usize>,
209+
) -> Result<Vec<Arc<LuaFunctionType>>, InferFailReason> {
210+
let args = call_expr.get_args_list().ok_or(InferFailReason::None)?;
211+
let expr_types = super::infer_expr_list_types(
212+
db,
213+
cache,
214+
&args.get_args().collect::<Vec<_>>(),
215+
args_count,
216+
|db, cache, expr| Ok(infer_expr(db, cache, expr).unwrap_or(LuaType::Unknown)),
217+
)?
218+
.into_iter()
219+
.map(|(ty, _)| ty)
220+
.collect::<Vec<_>>();
221+
let is_colon_call = call_expr.is_colon_call();
222+
223+
Ok(overloads
224+
.into_iter()
225+
.filter(|func| {
226+
let mut callable_tpls = HashSet::new();
227+
func.visit_type(&mut |ty| match ty {
228+
LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => {
229+
callable_tpls.insert(generic_tpl.get_tpl_id());
230+
}
231+
LuaType::StrTplRef(str_tpl) => {
232+
callable_tpls.insert(str_tpl.get_tpl_id());
233+
}
234+
_ => {}
235+
});
236+
237+
if callable_tpls.is_empty() {
238+
return true;
239+
}
240+
241+
let mut substitutor = TypeSubstitutor::new();
242+
substitutor.add_need_infer_tpls(callable_tpls);
243+
let match_func = match instantiate_doc_function(db, func, &substitutor) {
244+
LuaType::DocFunction(doc_func) => doc_func,
245+
_ => func.clone(),
246+
};
247+
248+
callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count)
249+
})
250+
.collect())
251+
}
252+
253+
fn resolve_filtered_overloads(
254+
db: &DbIndex,
255+
cache: &mut LuaInferCache,
256+
overloads: Vec<Arc<LuaFunctionType>>,
201257
call_expr: LuaCallExpr,
202258
args_count: Option<usize>,
203259
) -> InferCallFuncResult {
204-
let mut overloads = Vec::new();
205-
for typ in union.into_vec() {
206-
overloads.extend(collect_callable_overloads(db, &typ)?);
207-
}
208-
209-
resolve_signature(db, cache, overloads, call_expr.clone(), false, args_count)
260+
let contains_tpl = overloads.iter().any(|func| func.contain_tpl());
261+
let overloads =
262+
filter_callable_overloads_by_call_args(db, cache, overloads, &call_expr, args_count)?;
263+
resolve_signature(db, cache, overloads, call_expr, contains_tpl, args_count)
210264
}
211265

212266
fn infer_signature_doc_function(
@@ -482,8 +536,7 @@ fn infer_union(
482536
if overloads.is_empty() {
483537
return Err(InferFailReason::None);
484538
}
485-
let is_generic = overloads.iter().any(|func| func.contain_tpl());
486-
resolve_signature(db, cache, overloads, call_expr, is_generic, args_count)
539+
resolve_filtered_overloads(db, cache, overloads, call_expr, args_count)
487540
}
488541

489542
fn collect_callable_overloads(

0 commit comments

Comments
 (0)