Skip to content

Commit 538cec3

Browse files
committed
fix(generic): stabilize callable overload inference
Infer higher-order callable returns from compatible overload members instead of a single order-dependent branch. Keep unknown returns when the argument shape is unknown, treat function as an erased callable fallback, and union compatible callable-union returns. Assisted-by: Codex
1 parent 0a343a5 commit 538cec3

6 files changed

Lines changed: 815 additions & 174 deletions

File tree

crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,333 @@ mod test {
8585

8686
assert_eq!(ws.expr_ty("result"), ws.ty("string"));
8787
}
88+
89+
#[test]
90+
fn test_apply_return_infer_prefers_structural_callback_over_function_fallback() {
91+
let mut ws = VirtualWorkspace::new();
92+
ws.def(
93+
r#"
94+
---@generic A, R
95+
---@param f fun(x: A): R
96+
---@param x A
97+
---@return R
98+
local function apply(f, x)
99+
return f(x)
100+
end
101+
102+
---@overload fun<T>(cb: fun(): T): T
103+
---@param cb function
104+
---@return boolean
105+
local function run(cb) end
106+
107+
---@overload fun<T>(cb: fun(x: integer): T): integer
108+
---@overload fun<T>(cb: fun(x: string): T): string
109+
---@overload fun(cb: function): boolean
110+
local function classify(cb) end
111+
112+
---@overload fun<T>(cb: fun(): T): { value: T }
113+
---@overload fun(cb: function): boolean
114+
local function wrap(cb) end
115+
116+
local source ---@type table
117+
118+
---@return integer
119+
local function cb_concrete()
120+
return 1
121+
end
122+
123+
---@type fun(): unknown
124+
local cb_unknown
125+
126+
---@type fun(x: integer): unknown
127+
local cb_param_unknown
128+
129+
---@type fun(x: string): unknown
130+
local cb_param_unknown_string
131+
132+
---@param x integer
133+
local function cb_param_unresolved(x)
134+
return source.missing
135+
end
136+
137+
---@param x string
138+
local function cb_param_unresolved_string(x)
139+
return source.missing
140+
end
141+
142+
local function cb_named_unresolved()
143+
return source.missing
144+
end
145+
146+
run_concrete = apply(run, cb_concrete)
147+
148+
run_unknown = apply(run, cb_unknown)
149+
150+
run_unresolved = apply(run, function()
151+
return source.missing
152+
end)
153+
154+
run_named_unresolved = apply(run, cb_named_unresolved)
155+
156+
wrap_named_unresolved = apply(wrap, cb_named_unresolved)
157+
158+
classify_unknown = apply(classify, cb_param_unknown)
159+
160+
classify_unresolved = apply(classify, cb_param_unresolved)
161+
162+
classify_string_unknown = apply(classify, cb_param_unknown_string)
163+
164+
classify_string_unresolved = apply(classify, cb_param_unresolved_string)
165+
"#,
166+
);
167+
168+
// The callback return is concrete, so `T` is inferred as `integer` and the generic
169+
// overload is more informative than the `function -> boolean` fallback.
170+
assert_eq!(ws.expr_ty("run_concrete"), ws.ty("integer"));
171+
172+
// `function` is an erased fallback. A structural `fun(): unknown` callback should keep
173+
// the generic overload and preserve the unknown return.
174+
assert_eq!(ws.expr_ty("run_unknown"), ws.ty("unknown"));
175+
176+
// An unresolved closure return is treated the same as an explicit `unknown` return.
177+
assert_eq!(ws.expr_ty("run_unresolved"), ws.ty("unknown"));
178+
179+
// The named-callback path should stay aligned with the inline unresolved callback case.
180+
assert_eq!(ws.expr_ty("run_named_unresolved"), ws.ty("unknown"));
181+
182+
// The structural overload still wins when the unknown is nested in the return shape.
183+
assert_eq!(
184+
ws.expr_ty("wrap_named_unresolved"),
185+
ws.ty("{ value: unknown }")
186+
);
187+
188+
// The callback's parameter type is known, so the generic `fun(x: integer): T` overload
189+
// should still win even though the callback return is only `unknown`.
190+
assert_eq!(ws.expr_ty("classify_unknown"), ws.ty("integer"));
191+
192+
// The callback return is unresolved, but its parameter is still `integer`, so overload
193+
// ranking should keep using that known shape and pick the generic integer branch.
194+
assert_eq!(ws.expr_ty("classify_unresolved"), ws.ty("integer"));
195+
196+
// The callback's parameter type is `string`, so overload selection should not fall back
197+
// to the first generic branch when the callback return is only `unknown`.
198+
assert_eq!(ws.expr_ty("classify_string_unknown"), ws.ty("string"));
199+
200+
// The same `string`-parameter branch should still win when the callback return is
201+
// unresolved and carried through a named callback value.
202+
assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string"));
203+
}
204+
205+
#[test]
206+
fn test_apply_return_infer_leaves_result_unknown_when_no_callable_member_matches_arg_shape() {
207+
let mut ws = VirtualWorkspace::new();
208+
ws.def(
209+
r#"
210+
---@generic A, R
211+
---@param f fun(x: A): R
212+
---@param x A
213+
---@return R
214+
local function apply(f, x)
215+
return f(x)
216+
end
217+
218+
---@alias FnInt fun(x: integer): integer
219+
---@alias FnString fun(x: string): string
220+
221+
---@type FnInt | FnString
222+
local run
223+
224+
---@type boolean
225+
local b
226+
227+
result = apply(run, b)
228+
"#,
229+
);
230+
231+
let result_ty = ws.expr_ty("result");
232+
assert_eq!(result_ty, ws.ty("unknown"));
233+
}
234+
235+
#[test]
236+
fn test_apply_return_infer_uses_function_fallback_when_no_structural_overload_matches() {
237+
let mut ws = VirtualWorkspace::new();
238+
ws.def(
239+
r#"
240+
---@generic A, R
241+
---@param f fun(x: A): R
242+
---@param x A
243+
---@return R
244+
local function apply(f, x)
245+
return f(x)
246+
end
247+
248+
---@param cb function
249+
---@return boolean
250+
local function run(cb) end
251+
252+
---@type fun(): unknown
253+
local cb
254+
255+
result = apply(run, cb)
256+
"#,
257+
);
258+
259+
assert_eq!(ws.expr_ty("result"), ws.ty("boolean"));
260+
}
261+
262+
#[test]
263+
fn test_apply_return_infer_keeps_only_arity_compatible_fallbacks() {
264+
let mut ws = VirtualWorkspace::new();
265+
ws.def(
266+
r#"
267+
---@generic A, B, R
268+
---@param f fun(x: A, y: B): R
269+
---@param x A
270+
---@param y B
271+
---@return R
272+
local function apply2(f, x, y)
273+
return f(x, y)
274+
end
275+
276+
---@overload fun(x: integer): integer
277+
---@param x integer
278+
---@param y string
279+
---@return string
280+
local function run(x, y) end
281+
282+
local source ---@type table
283+
284+
result = apply2(run, 1, source.missing)
285+
"#,
286+
);
287+
288+
let result_ty = ws.expr_ty("result");
289+
assert_eq!(ws.humanize_type(result_ty), "string");
290+
}
291+
292+
#[test]
293+
fn test_apply_return_infer_keeps_same_arity_overload_returns_when_tail_is_unknown() {
294+
let mut ws = VirtualWorkspace::new();
295+
ws.def(
296+
r#"
297+
---@generic A, B, R
298+
---@param f fun(x: A, y: B): R
299+
---@param x A
300+
---@param y B
301+
---@return R
302+
local function apply2(f, x, y)
303+
return f(x, y)
304+
end
305+
306+
---@overload fun(x: integer, y: number): number
307+
---@param x integer
308+
---@param y string
309+
---@return string
310+
local function run(x, y) end
311+
312+
local source ---@type table
313+
314+
result = apply2(run, 1, source.missing)
315+
"#,
316+
);
317+
318+
let result_ty = ws.expr_ty("result");
319+
assert_eq!(result_ty, ws.ty("number|string"));
320+
}
321+
322+
#[test]
323+
fn test_apply_return_infer_keeps_unknown_return_when_arg_shape_is_unknown() {
324+
let mut ws = VirtualWorkspace::new();
325+
ws.def(
326+
r#"
327+
---@generic A, R
328+
---@param f fun(x: A): R
329+
---@param x A
330+
---@return R
331+
local function apply(f, x)
332+
return f(x)
333+
end
334+
335+
---@overload fun(x: integer): unknown
336+
---@param x string
337+
---@return string
338+
local function run(x) end
339+
340+
local source ---@type table
341+
342+
result = apply(run, source.missing)
343+
"#,
344+
);
345+
346+
let result_ty = ws.expr_ty("result");
347+
assert_eq!(result_ty, ws.ty("unknown|string"));
348+
}
349+
350+
#[test]
351+
fn test_union_call_ignores_non_matching_generic_callable_member() {
352+
let mut ws = VirtualWorkspace::new();
353+
ws.def(
354+
r#"
355+
---@type (fun<T: string>(x: T): T) | fun(x: integer): integer
356+
local run
357+
358+
result = run(1)
359+
"#,
360+
);
361+
362+
let result_ty = ws.expr_ty("result");
363+
assert_eq!(ws.humanize_type(result_ty), "integer");
364+
}
365+
366+
#[test]
367+
fn test_union_call_ignores_non_matching_generic_alias_member() {
368+
let mut ws = VirtualWorkspace::new();
369+
ws.def(
370+
r#"
371+
---@alias GenericStr<T: string> fun(x: T): T
372+
373+
---@type GenericStr | fun(x: integer): integer
374+
local run
375+
376+
result = run(1)
377+
"#,
378+
);
379+
380+
let result_ty = ws.expr_ty("result");
381+
assert_eq!(ws.humanize_type(result_ty), "integer");
382+
}
383+
384+
#[test]
385+
fn test_direct_callable_union_unions_same_domain_returns() {
386+
let mut ws = VirtualWorkspace::new();
387+
ws.def(
388+
r#"
389+
---@alias FnA fun(x: integer): integer
390+
---@alias FnB fun(x: integer): boolean
391+
392+
---@type FnA | FnB
393+
local run
394+
395+
result = run(1)
396+
"#,
397+
);
398+
399+
assert_eq!(ws.expr_ty("result"), ws.ty("integer|boolean"));
400+
}
401+
402+
#[test]
403+
fn test_plain_function_call_returns_unknown_values() {
404+
let mut ws = VirtualWorkspace::new();
405+
ws.def(
406+
r#"
407+
---@type function
408+
local f
409+
410+
a, b = f(1)
411+
"#,
412+
);
413+
414+
assert_eq!(ws.expr_ty("a"), ws.ty("unknown"));
415+
assert_eq!(ws.expr_ty("b"), ws.ty("unknown"));
416+
}
88417
}

0 commit comments

Comments
 (0)