Skip to content

Commit 098fb12

Browse files
committed
perf: filter callable overloads
1 parent b0a0f00 commit 098fb12

8 files changed

Lines changed: 140 additions & 141 deletions

File tree

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
use std::sync::Arc;
22

33
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaSyntaxKind};
4-
use hashbrown::HashSet;
54
use rowan::TextRange;
65

76
use super::{
87
super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature},
98
InferFailReason, InferResult,
109
};
11-
use crate::semantic::overload_resolve::callable_accepts_args;
1210
use crate::{
1311
AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType,
1412
LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignature,
@@ -18,9 +16,11 @@ use crate::{
1816
use crate::{
1917
InferGuardRef,
2018
semantic::{
21-
generic::TypeSubstitutor, infer::narrow::get_type_at_call_expr_inline_cast,
22-
infer_node_semantic_decl, member::find_member_origin_owner,
23-
overload_resolve::collect_callable_overload_groups,
19+
generic::TypeSubstitutor,
20+
infer::narrow::get_type_at_call_expr_inline_cast,
21+
infer_node_semantic_decl,
22+
member::find_member_origin_owner,
23+
overload_resolve::{collect_callable_overload_groups, match_callable_by_arg_types},
2424
},
2525
};
2626
use crate::{build_self_type, infer_call_generic, infer_self_type, semantic::infer_expr};
@@ -236,13 +236,12 @@ fn infer_doc_function(
236236
Ok(func.clone().into())
237237
}
238238

239-
fn filter_callable_overloads_by_call_args(
239+
fn filter_callable_overloads_by_args(
240240
db: &DbIndex,
241241
cache: &mut LuaInferCache,
242242
overloads: Vec<Arc<LuaFunctionType>>,
243243
call_expr: &LuaCallExpr,
244244
args_count: Option<usize>,
245-
strict_arg_filter: bool,
246245
) -> Result<Vec<Arc<LuaFunctionType>>, InferFailReason> {
247246
let args = call_expr.get_args_list().ok_or(InferFailReason::None)?;
248247
let expr_types = super::infer_expr_list_types(
@@ -255,35 +254,11 @@ fn filter_callable_overloads_by_call_args(
255254
.into_iter()
256255
.map(|(ty, _)| ty)
257256
.collect::<Vec<_>>();
258-
let is_colon_call = call_expr.is_colon_call();
259257

260258
Ok(overloads
261259
.into_iter()
262-
.filter(|func| {
263-
let callable_tpls = func
264-
.get_generic_params()
265-
.iter()
266-
.map(|generic_tpl| generic_tpl.get_tpl_id())
267-
.collect::<HashSet<_>>();
268-
269-
if callable_tpls.is_empty() && !strict_arg_filter {
270-
return true;
271-
}
272-
273-
let has_tpls = !callable_tpls.is_empty();
274-
let mut substitutor = TypeSubstitutor::new();
275-
substitutor.add_need_infer_tpls(callable_tpls);
276-
let match_func = if has_tpls {
277-
let func_type = LuaType::DocFunction(func.clone());
278-
match instantiate_type_generic(db, &func_type, &substitutor) {
279-
LuaType::DocFunction(doc_func) => doc_func,
280-
_ => func.clone(),
281-
}
282-
} else {
283-
func.clone()
284-
};
285-
286-
callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count)
260+
.filter_map(|func| {
261+
match_callable_by_arg_types(db, cache, func, &expr_types, call_expr, args_count, true)
287262
})
288263
.collect())
289264
}
@@ -542,13 +517,12 @@ fn infer_union(
542517
let mut overload_groups = Vec::new();
543518
collect_callable_overload_groups(db, &ty, &mut overload_groups)?;
544519
for overloads in overload_groups {
545-
let compatible_overloads = filter_callable_overloads_by_call_args(
520+
let compatible_overloads = filter_callable_overloads_by_args(
546521
db,
547522
cache,
548523
overloads.clone(),
549524
&call_expr,
550525
args_count,
551-
true,
552526
)?;
553527
if compatible_overloads.is_empty() {
554528
fallback_overloads.extend(overloads);
@@ -586,14 +560,6 @@ fn infer_union(
586560
let Some(first_func) = first_func else {
587561
if !fallback_overloads.is_empty() {
588562
let contains_tpl = fallback_overloads.iter().any(|func| func.contain_tpl());
589-
let fallback_overloads = filter_callable_overloads_by_call_args(
590-
db,
591-
cache,
592-
fallback_overloads,
593-
&call_expr,
594-
args_count,
595-
false,
596-
)?;
597563
return resolve_signature(
598564
db,
599565
cache,

crates/emmylua_code_analysis/src/semantic/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ pub(crate) use infer::try_infer_expr_for_index;
6060
pub(crate) use infer::{infer_expr, try_infer_expr_no_flow};
6161
use overload_resolve::resolve_signature;
6262
pub(crate) use overload_resolve::{
63-
callable_accepts_args, collect_callable_overload_groups, get_func_param_type,
64-
is_func_last_param_variadic,
63+
callable_accepts_args, get_func_param_type, is_func_last_param_variadic,
64+
};
65+
pub use overload_resolve::{
66+
collect_callable_overload_groups, filter_callable_overloads, find_callable_overload,
6567
};
6668
pub use semantic_info::SemanticDeclLevel;
6769
pub use type_check::{TypeCheckFailReason, TypeCheckResult};

crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010
},
1111
};
1212

13-
pub(crate) fn collect_callable_overload_groups(
13+
pub fn collect_callable_overload_groups(
1414
db: &DbIndex,
1515
callable_type: &LuaType,
1616
groups: &mut Vec<Vec<Arc<LuaFunctionType>>>,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use std::sync::Arc;
2+
3+
use emmylua_parser::LuaCallExpr;
4+
5+
use crate::{
6+
DbIndex, LuaFunctionType, LuaType,
7+
semantic::{LuaInferCache, generic::infer_call_generic, infer::InferFailReason},
8+
};
9+
10+
use super::{
11+
collect_overloads::collect_callable_overload_groups,
12+
resolve_signature_by_args::callable_accepts_args,
13+
};
14+
15+
pub fn filter_callable_overloads(
16+
db: &DbIndex,
17+
cache: &mut LuaInferCache,
18+
callable_type: &LuaType,
19+
call_arg_types: &[LuaType],
20+
call_expr: &LuaCallExpr,
21+
args_count: Option<usize>,
22+
return_instantiated_generic: bool,
23+
) -> Result<Vec<Arc<LuaFunctionType>>, InferFailReason> {
24+
let mut overload_groups = Vec::new();
25+
collect_callable_overload_groups(db, callable_type, &mut overload_groups)?;
26+
27+
Ok(overload_groups
28+
.into_iter()
29+
.flatten()
30+
.filter_map(|func| {
31+
match_callable_by_arg_types(
32+
db,
33+
cache,
34+
func,
35+
call_arg_types,
36+
call_expr,
37+
args_count,
38+
return_instantiated_generic,
39+
)
40+
})
41+
.collect())
42+
}
43+
44+
pub fn find_callable_overload(
45+
db: &DbIndex,
46+
cache: &mut LuaInferCache,
47+
callable_type: &LuaType,
48+
call_arg_types: &[LuaType],
49+
call_expr: &LuaCallExpr,
50+
args_count: Option<usize>,
51+
return_instantiated_generic: bool,
52+
) -> Result<Option<Arc<LuaFunctionType>>, InferFailReason> {
53+
let mut overload_groups = Vec::new();
54+
collect_callable_overload_groups(db, callable_type, &mut overload_groups)?;
55+
56+
Ok(overload_groups.into_iter().flatten().find_map(|func| {
57+
match_callable_by_arg_types(
58+
db,
59+
cache,
60+
func,
61+
call_arg_types,
62+
call_expr,
63+
args_count,
64+
return_instantiated_generic,
65+
)
66+
}))
67+
}
68+
69+
pub(crate) fn match_callable_by_arg_types(
70+
db: &DbIndex,
71+
cache: &mut LuaInferCache,
72+
func: Arc<LuaFunctionType>,
73+
call_arg_types: &[LuaType],
74+
call_expr: &LuaCallExpr,
75+
args_count: Option<usize>,
76+
return_instantiated_generic: bool,
77+
) -> Option<Arc<LuaFunctionType>> {
78+
let has_tpls = func.contain_tpl();
79+
let match_func = if has_tpls {
80+
infer_call_generic(db, cache, func.as_ref(), call_expr.clone())
81+
.map(Arc::new)
82+
.unwrap_or_else(|_| func.clone())
83+
} else {
84+
func.clone()
85+
};
86+
87+
if !callable_accepts_args(
88+
db,
89+
&match_func,
90+
call_arg_types,
91+
call_expr.is_colon_call(),
92+
args_count,
93+
) {
94+
return None;
95+
}
96+
97+
if has_tpls && return_instantiated_generic {
98+
Some(match_func)
99+
} else {
100+
Some(func)
101+
}
102+
}

crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod collect_overloads;
2+
mod filter_overloads;
23
mod resolve_signature_by_args;
34

45
use std::sync::Arc;
@@ -13,7 +14,9 @@ use super::{
1314
infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow},
1415
};
1516

16-
pub(crate) use collect_overloads::collect_callable_overload_groups;
17+
pub use collect_overloads::collect_callable_overload_groups;
18+
pub(crate) use filter_overloads::match_callable_by_arg_types;
19+
pub use filter_overloads::{filter_callable_overloads, find_callable_overload};
1720
pub(crate) use resolve_signature_by_args::{
1821
callable_accepts_args, get_func_param_type, is_func_last_param_variadic,
1922
resolve_signature_by_args,

crates/emmylua_ls/src/handlers/hover/function/call_hover.rs

Lines changed: 16 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
use std::{collections::HashSet, sync::Arc};
1+
use std::sync::Arc;
22

3-
use emmylua_code_analysis::{
4-
DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId, TypeSubstitutor, infer_call_generic,
5-
};
3+
use emmylua_code_analysis::{DbIndex, LuaFunctionType, LuaType, find_callable_overload};
64
use emmylua_parser::LuaCallExpr;
75

86
use crate::handlers::hover::{HoverBuilder, HoverDeclContext, HoverDeclInfo};
97

108
use super::{
11-
define_hover::{HoverFunctionInfo, set_builder_contents},
9+
define_hover::{HoverFunctionInfo, set_function_info_to_builder},
1210
extract_function_member, get_function_description,
1311
render::process_function_type,
1412
};
@@ -34,7 +32,7 @@ pub(super) fn build_function_call_hover(
3432
}
3533
}
3634

37-
return set_builder_contents(builder, &mut function_infos);
35+
return set_function_info_to_builder(builder, &mut function_infos);
3836
}
3937

4038
for matched_decl in matched_decls {
@@ -44,7 +42,7 @@ pub(super) fn build_function_call_hover(
4442
}
4543
}
4644

47-
set_builder_contents(builder, &mut function_infos)
45+
set_function_info_to_builder(builder, &mut function_infos)
4846
}
4947

5048
fn infer_call_arg_types(builder: &HoverBuilder, call_expr: &LuaCallExpr) -> Vec<LuaType> {
@@ -138,87 +136,15 @@ fn find_callable_for_call(
138136
call_arg_types: &[LuaType],
139137
call_expr: &LuaCallExpr,
140138
) -> Option<Arc<LuaFunctionType>> {
141-
let mut overloads = Vec::new();
142-
let mut visiting_aliases = HashSet::new();
143-
collect_callable_functions(db, decl_type, &mut overloads, &mut visiting_aliases);
144-
145-
overloads.into_iter().find_map(|func| {
146-
let func = if func.contain_tpl() {
147-
infer_call_generic(
148-
db,
149-
&mut builder.semantic_model.get_cache().borrow_mut(),
150-
func.as_ref(),
151-
call_expr.clone(),
152-
)
153-
.map(Arc::new)
154-
.unwrap_or(func)
155-
} else {
156-
func
157-
};
158-
159-
builder
160-
.semantic_model
161-
.callable_accepts_args(
162-
func.as_ref(),
163-
call_arg_types,
164-
call_expr.is_colon_call(),
165-
None,
166-
)
167-
.then_some(func)
168-
})
169-
}
170-
171-
fn collect_callable_functions(
172-
db: &DbIndex,
173-
typ: &LuaType,
174-
overloads: &mut Vec<Arc<LuaFunctionType>>,
175-
visiting_aliases: &mut HashSet<LuaTypeDeclId>,
176-
) {
177-
match typ {
178-
LuaType::Ref(type_id) | LuaType::Def(type_id) => {
179-
let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else {
180-
return;
181-
};
182-
if !visiting_aliases.insert(type_id.clone()) {
183-
return;
184-
}
185-
186-
if let Some(origin_type) = type_decl.get_alias_origin(db, None) {
187-
collect_callable_functions(db, &origin_type, overloads, visiting_aliases);
188-
}
189-
visiting_aliases.remove(type_id);
190-
}
191-
LuaType::Generic(generic) => {
192-
let type_id = generic.get_base_type_id();
193-
let substitutor = TypeSubstitutor::from_type_array(generic.get_params().clone());
194-
if !visiting_aliases.insert(type_id.clone()) {
195-
return;
196-
}
197-
198-
if let Some(type_decl) = db.get_type_index().get_type_decl(&type_id)
199-
&& let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor))
200-
{
201-
collect_callable_functions(db, &origin_type, overloads, visiting_aliases);
202-
}
203-
visiting_aliases.remove(&type_id);
204-
}
205-
LuaType::Union(union) => {
206-
for member in union.into_vec() {
207-
collect_callable_functions(db, &member, overloads, visiting_aliases);
208-
}
209-
}
210-
LuaType::Intersection(intersection) => {
211-
for member in intersection.get_types() {
212-
collect_callable_functions(db, member, overloads, visiting_aliases);
213-
}
214-
}
215-
LuaType::DocFunction(func) => overloads.push(func.clone()),
216-
LuaType::Signature(signature_id) => {
217-
if let Some(signature) = db.get_signature_index().get(signature_id) {
218-
overloads.extend(signature.overloads.iter().cloned());
219-
overloads.push(signature.to_doc_func_type());
220-
}
221-
}
222-
_ => {}
223-
}
139+
find_callable_overload(
140+
db,
141+
&mut builder.semantic_model.get_cache().borrow_mut(),
142+
decl_type,
143+
call_arg_types,
144+
call_expr,
145+
None,
146+
true,
147+
)
148+
.ok()
149+
.flatten()
224150
}

0 commit comments

Comments
 (0)