Skip to content

Commit 313da4f

Browse files
committed
fix(hover): generic function call
1 parent 4621da4 commit 313da4f

5 files changed

Lines changed: 167 additions & 74 deletions

File tree

crates/emmylua_ls/src/handlers/hover/function/call_hover.rs

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
use std::{collections::HashSet, sync::Arc};
22

3-
use emmylua_code_analysis::{
4-
DbIndex, LuaDocReturnOverloadInfo, LuaFunctionType, LuaSignature, LuaType, LuaTypeDeclId,
5-
VariadicType, infer_call_generic,
6-
};
3+
use emmylua_code_analysis::{DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId, infer_call_generic};
74
use emmylua_parser::LuaCallExpr;
85

96
use crate::handlers::hover::{HoverBuilder, HoverDeclContext, HoverDeclInfo};
@@ -13,7 +10,10 @@ use super::{
1310
extract_function_member,
1411
generic::generic_type_substitutor,
1512
get_function_description,
16-
render::{FunctionRenderContext, build_function_return_overload_rows, process_function_type},
13+
render::{
14+
FunctionRenderContext, build_function_return_overload_rows,
15+
instantiate_call_return_overloads, process_function_type, render_function,
16+
},
1717
};
1818

1919
pub(super) fn build_function_call_hover(
@@ -30,7 +30,9 @@ pub(super) fn build_function_call_hover(
3030
find_decls_for_call(builder, db, &ordered_decls, &call_arg_types, call_expr);
3131
if matched_decls.is_empty() {
3232
for matched_decl in ordered_decls {
33-
if let Some(info) = build_decl_hover_function_info(builder, db, matched_decl) {
33+
if let Some(info) =
34+
build_unmatched_call_hover_function_info(builder, db, matched_decl, call_expr)
35+
{
3436
function_infos.push(info);
3537
}
3638
}
@@ -61,10 +63,11 @@ fn infer_call_arg_types(builder: &HoverBuilder, call_expr: &LuaCallExpr) -> Vec<
6163
.collect()
6264
}
6365

64-
fn build_decl_hover_function_info(
66+
fn build_unmatched_call_hover_function_info(
6567
builder: &mut HoverBuilder,
6668
db: &DbIndex,
6769
matched_decl: &HoverDeclInfo,
70+
call_expr: &LuaCallExpr,
6871
) -> Option<HoverFunctionInfo> {
6972
let match_semantic_decl = matched_decl.id();
7073
let function_member = extract_function_member(db, match_semantic_decl);
@@ -74,6 +77,7 @@ fn build_decl_hover_function_info(
7477
matched_decl.typ(),
7578
match_semantic_decl,
7679
function_member,
80+
Some(call_expr),
7781
)?;
7882
if contents.is_empty() {
7983
return None;
@@ -121,14 +125,15 @@ fn build_call_hover_function_info(
121125
return_docs: Vec::new(),
122126
ret_detail: Some(ret_detail),
123127
};
124-
vec![super::render::render_function(builder, db, ctx)?]
128+
vec![render_function(builder, db, ctx)?]
125129
} else {
126130
process_function_type(
127131
builder,
128132
db,
129133
&LuaType::DocFunction(call_func.clone()),
130134
match_semantic_decl,
131135
function_member,
136+
None,
132137
)?
133138
}
134139
} else {
@@ -138,57 +143,14 @@ fn build_call_hover_function_info(
138143
&LuaType::DocFunction(call_func.clone()),
139144
match_semantic_decl,
140145
function_member,
146+
None,
141147
)?
142148
};
143149

144150
let description = get_function_description(builder, db, match_semantic_decl);
145151
HoverFunctionInfo::from_contents(contents, description)
146152
}
147153

148-
fn instantiate_call_return_overloads(
149-
builder: &HoverBuilder,
150-
db: &DbIndex,
151-
call_expr: &LuaCallExpr,
152-
signature: &LuaSignature,
153-
) -> Vec<LuaDocReturnOverloadInfo> {
154-
let mut cache = builder.semantic_model.get_cache().borrow_mut();
155-
156-
signature
157-
.return_overloads
158-
.iter()
159-
.map(|row| {
160-
let row_return_type = match row.type_refs.len() {
161-
0 => LuaType::Nil,
162-
1 => row.type_refs[0].clone(),
163-
_ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()),
164-
};
165-
let row_function = LuaFunctionType::new(
166-
signature.async_state,
167-
signature.is_colon_define,
168-
signature.is_vararg,
169-
signature.get_type_params(),
170-
row_return_type,
171-
Some(signature.get_function_generic_params()),
172-
);
173-
let type_refs = infer_call_generic(db, &mut cache, &row_function, call_expr.clone())
174-
.ok()
175-
.map(|func| match func.get_ret() {
176-
LuaType::Variadic(variadic) => match variadic.as_ref() {
177-
VariadicType::Multi(types) => types.clone(),
178-
VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())],
179-
},
180-
typ => vec![typ.clone()],
181-
})
182-
.unwrap_or_else(|| row.type_refs.clone());
183-
184-
LuaDocReturnOverloadInfo {
185-
type_refs,
186-
description: row.description.clone(),
187-
}
188-
})
189-
.collect()
190-
}
191-
192154
struct MatchedCallDecl<'a> {
193155
decl: &'a HoverDeclInfo,
194156
func: Arc<LuaFunctionType>,

crates/emmylua_ls/src/handlers/hover/function/define_hover.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub(super) fn build_function_define_hover(
6060
.unwrap_or_else(|| decl_info.typ());
6161

6262
let Some(contents) =
63-
process_function_type(builder, db, typ, semantic_decl_id, function_member)
63+
process_function_type(builder, db, typ, semantic_decl_id, function_member, None)
6464
else {
6565
continue;
6666
};

crates/emmylua_ls/src/handlers/hover/function/render.rs

Lines changed: 127 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
use std::{collections::HashSet, fmt::Write};
1+
use std::{collections::HashSet, fmt::Write, sync::Arc};
22

33
use emmylua_code_analysis::{
44
AsyncState, DbIndex, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, LuaMember,
5-
LuaMemberOwner, LuaSemanticDeclId, LuaType, RenderLevel, VariadicType, humanize_type,
5+
LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, VariadicType,
6+
humanize_type, infer_call_generic,
67
};
8+
use emmylua_parser::LuaCallExpr;
79

810
use crate::handlers::hover::{
911
HoverBuilder,
@@ -29,14 +31,16 @@ pub(super) fn process_function_type(
2931
typ: &LuaType,
3032
semantic_decl: &LuaSemanticDeclId,
3133
function_member: Option<&LuaMember>,
34+
call_expr: Option<&LuaCallExpr>,
3235
) -> Option<Vec<String>> {
3336
match typ {
3437
LuaType::DocFunction(lua_func) => {
38+
let lua_func = instantiate_function_for_call(builder, db, lua_func, call_expr);
3539
let ctx = FunctionRenderContext {
36-
func: lua_func,
40+
func: lua_func.as_ref(),
3741
semantic_decl,
3842
owner_member: function_member,
39-
return_docs: convert_function_return_to_docs(lua_func),
43+
return_docs: convert_function_return_to_docs(lua_func.as_ref()),
4044
ret_detail: None,
4145
};
4246
let content = render_function(builder, db, ctx)?;
@@ -46,31 +50,24 @@ pub(super) fn process_function_type(
4650
let signature = db.get_signature_index().get(&signature_id)?;
4751
let fake_doc_function = signature.to_doc_func_type();
4852
let mut contents = Vec::with_capacity(signature.overloads.len() + 1);
49-
for (i, overload) in std::iter::once(fake_doc_function.as_ref())
50-
.chain(signature.overloads.iter().map(|overload| overload.as_ref()))
53+
for (i, overload) in std::iter::once(&fake_doc_function)
54+
.chain(signature.overloads.iter())
5155
.enumerate()
5256
{
57+
let overload = instantiate_function_for_call(builder, db, overload, call_expr);
5358
// 提前计算 return_docs 和 ret_detail 的差异, 免重复的 hover_doc_function_type 调用
5459
let (return_docs, ret_detail) = if i == 0 && !signature.return_overloads.is_empty()
5560
{
5661
let detail =
57-
build_function_return_overload_rows(builder, &signature.return_overloads);
62+
build_signature_return_overload_rows(builder, db, signature, call_expr);
5863
(Vec::new(), Some(detail))
5964
} else {
60-
let docs = if i == 0 {
61-
if signature.return_docs.is_empty() {
62-
convert_function_return_to_docs(overload)
63-
} else {
64-
signature.return_docs.clone()
65-
}
66-
} else {
67-
convert_function_return_to_docs(overload)
68-
};
65+
let docs = signature_return_docs(signature, i, overload.as_ref(), call_expr);
6966
(docs, None)
7067
};
7168

7269
let ctx = FunctionRenderContext {
73-
func: overload,
70+
func: overload.as_ref(),
7471
semantic_decl,
7572
owner_member: function_member,
7673
return_docs,
@@ -83,9 +80,14 @@ pub(super) fn process_function_type(
8380
LuaType::Union(union) => {
8481
let mut contents = Vec::new();
8582
for typ in union.into_vec() {
86-
if let Some(content) =
87-
process_function_type(builder, db, &typ, semantic_decl, function_member)
88-
{
83+
if let Some(content) = process_function_type(
84+
builder,
85+
db,
86+
&typ,
87+
semantic_decl,
88+
function_member,
89+
call_expr,
90+
) {
8991
contents.extend(content);
9092
}
9193
}
@@ -95,6 +97,111 @@ pub(super) fn process_function_type(
9597
}
9698
}
9799

100+
fn instantiate_function_for_call(
101+
builder: &HoverBuilder,
102+
db: &DbIndex,
103+
func: &Arc<LuaFunctionType>,
104+
call_expr: Option<&LuaCallExpr>,
105+
) -> Arc<LuaFunctionType> {
106+
let Some(call_expr) = call_expr else {
107+
return func.clone();
108+
};
109+
if !func.contain_tpl() {
110+
return func.clone();
111+
}
112+
113+
infer_call_generic(
114+
db,
115+
&mut builder.semantic_model.get_cache().borrow_mut(),
116+
func.as_ref(),
117+
call_expr.clone(),
118+
)
119+
.map(Arc::new)
120+
.unwrap_or_else(|_| func.clone())
121+
}
122+
123+
fn build_signature_return_overload_rows(
124+
builder: &mut HoverBuilder,
125+
db: &DbIndex,
126+
signature: &LuaSignature,
127+
call_expr: Option<&LuaCallExpr>,
128+
) -> String {
129+
if let Some(call_expr) = call_expr {
130+
let return_overloads = instantiate_call_return_overloads(builder, db, call_expr, signature);
131+
build_function_return_overload_rows(builder, &return_overloads)
132+
} else {
133+
build_function_return_overload_rows(builder, &signature.return_overloads)
134+
}
135+
}
136+
137+
fn signature_return_docs(
138+
signature: &LuaSignature,
139+
index: usize,
140+
func: &LuaFunctionType,
141+
call_expr: Option<&LuaCallExpr>,
142+
) -> Vec<LuaDocReturnInfo> {
143+
if index == 0 && !signature.return_docs.is_empty() {
144+
if call_expr.is_none() {
145+
return signature.return_docs.clone();
146+
}
147+
148+
let mut return_docs = signature.return_docs.clone();
149+
for (return_doc, inferred_doc) in return_docs
150+
.iter_mut()
151+
.zip(convert_function_return_to_docs(func))
152+
{
153+
return_doc.type_ref = inferred_doc.type_ref;
154+
}
155+
return return_docs;
156+
}
157+
158+
convert_function_return_to_docs(func)
159+
}
160+
161+
pub(super) fn instantiate_call_return_overloads(
162+
builder: &HoverBuilder,
163+
db: &DbIndex,
164+
call_expr: &LuaCallExpr,
165+
signature: &LuaSignature,
166+
) -> Vec<LuaDocReturnOverloadInfo> {
167+
let mut cache = builder.semantic_model.get_cache().borrow_mut();
168+
169+
signature
170+
.return_overloads
171+
.iter()
172+
.map(|row| {
173+
let row_return_type = match row.type_refs.len() {
174+
0 => LuaType::Nil,
175+
1 => row.type_refs[0].clone(),
176+
_ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()),
177+
};
178+
let row_function = LuaFunctionType::new(
179+
signature.async_state,
180+
signature.is_colon_define,
181+
signature.is_vararg,
182+
signature.get_type_params(),
183+
row_return_type,
184+
Some(signature.get_function_generic_params()),
185+
);
186+
let type_refs = infer_call_generic(db, &mut cache, &row_function, call_expr.clone())
187+
.ok()
188+
.map(|func| match func.get_ret() {
189+
LuaType::Variadic(variadic) => match variadic.as_ref() {
190+
VariadicType::Multi(types) => types.clone(),
191+
VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())],
192+
},
193+
typ => vec![typ.clone()],
194+
})
195+
.unwrap_or_else(|| row.type_refs.clone());
196+
197+
LuaDocReturnOverloadInfo {
198+
type_refs,
199+
description: row.description.clone(),
200+
}
201+
})
202+
.collect()
203+
}
204+
98205
/// 渲染单个函数签名的完整 hover 文本
99206
pub(super) fn render_function(
100207
builder: &mut HoverBuilder,

crates/emmylua_ls/src/handlers/hover/function/table_field.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub(super) fn build_table_field_hover(
3434
let function_member = extract_function_member(db, semantic_decl_id);
3535

3636
let Some(contents) =
37-
process_function_type(builder, db, &typ, semantic_decl_id, function_member)
37+
process_function_type(builder, db, &typ, semantic_decl_id, function_member, None)
3838
else {
3939
continue;
4040
};

crates/emmylua_ls/src/handlers/test/hover_function_test.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,30 @@ mod tests {
667667
Ok(())
668668
}
669669

670+
#[gtest]
671+
fn test_call_hover_shows_all_generic_overloads_when_no_match() -> Result<()> {
672+
let mut ws = ProviderVirtualWorkspace::new();
673+
check!(ws.check_hover(
674+
r#"
675+
---@generic T, U
676+
---@overload fun(value: string, fallback: T): T, U
677+
---@overload fun(value: number, fallback: T): T, U
678+
---@param value table
679+
---@param fallback T
680+
---@return T
681+
---@return U
682+
function generic_test(value, fallback)
683+
end
684+
685+
generic_te<??>st(true, false)
686+
"#,
687+
VirtualHoverResult {
688+
value: "```lua\nfunction generic_test(value: table, fallback: boolean) -> boolean, unknown (+2 overloads)\n```\n\n---\n\n---\n\n```lua\nfunction generic_test(value: string, fallback: boolean) -> boolean, unknown\n```\n\n```lua\nfunction generic_test(value: number, fallback: boolean) -> boolean, unknown\n```".to_string(),
689+
},
690+
));
691+
Ok(())
692+
}
693+
670694
#[gtest]
671695
fn test_fix_method_1() -> Result<()> {
672696
let mut ws = ProviderVirtualWorkspace::new();

0 commit comments

Comments
 (0)