Skip to content

Commit 91bc734

Browse files
committed
refactor(diagnostic): param type check for all overloads
1 parent d3a1563 commit 91bc734

9 files changed

Lines changed: 547 additions & 267 deletions

File tree

crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs

Lines changed: 333 additions & 179 deletions
Large diffs are not rendered by default.

crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,27 @@
22
mod test {
33
use std::{ops::Deref, sync::Arc};
44

5+
use lsp_types::{Diagnostic, NumberOrString};
6+
use tokio_util::sync::CancellationToken;
7+
58
use crate::{DiagnosticCode, VirtualWorkspace};
69

10+
fn param_type_diagnostics(ws: &mut VirtualWorkspace, block_str: &str) -> Vec<Diagnostic> {
11+
ws.analysis
12+
.diagnostic
13+
.enable_only(DiagnosticCode::ParamTypeMismatch);
14+
let file_id = ws.def(block_str);
15+
let code = Some(NumberOrString::String(
16+
DiagnosticCode::ParamTypeMismatch.get_name().to_string(),
17+
));
18+
ws.analysis
19+
.diagnose_file(file_id, CancellationToken::new())
20+
.unwrap_or_default()
21+
.into_iter()
22+
.filter(|diagnostic| diagnostic.code == code)
23+
.collect()
24+
}
25+
726
#[test]
827
fn test_issue_216() {
928
let mut ws = VirtualWorkspace::new();
@@ -41,6 +60,27 @@ mod test {
4160
));
4261
}
4362

63+
#[test]
64+
fn test_overload_param_type_mismatch_unions_failed_position() {
65+
let mut ws = VirtualWorkspace::new();
66+
let diagnostics = param_type_diagnostics(
67+
&mut ws,
68+
r#"
69+
---@type fun(name: "游戏-初始化") | fun(name: "游戏-开始")
70+
local event
71+
local bad ---@type boolean
72+
73+
event(bad)
74+
"#,
75+
);
76+
77+
assert_eq!(diagnostics.len(), 1);
78+
let message = &diagnostics[0].message;
79+
assert!(message.contains("boolean"), "{message}");
80+
assert!(message.contains("游戏-初始化"), "{message}");
81+
assert!(message.contains("游戏-开始"), "{message}");
82+
}
83+
4484
#[test]
4585
fn test_issue_75() {
4686
let mut ws = VirtualWorkspace::new_with_init_std_lib();
@@ -825,8 +865,9 @@ mod test {
825865
826866
---@class (partial) D21.A
827867
---@field event fun(self: self, event: "游戏-初始化")
868+
---@field event fun(self: self, event: "游戏-开始")
828869
829-
---@param p string
870+
---@param p boolean
830871
local function test(p)
831872
M:event(p)
832873
end

crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs

Lines changed: 4 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use std::{ops::Deref, sync::Arc};
22

3-
use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr};
3+
use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr};
44
use hashbrown::HashSet;
55
use rowan::TextRange;
66

77
use crate::{
8-
DbIndex, DocTypeInferContext, GenericTplId, LuaFunctionType, LuaSemanticDeclId, LuaType,
9-
SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, infer_doc_type,
8+
DbIndex, DocTypeInferContext, GenericTplId, LuaFunctionType, LuaType, SemanticModel, TypeOps,
9+
TypeSubstitutor, VariadicType, infer_doc_type,
1010
};
1111

1212
use super::{TplContext, tpl_pattern_match_args};
@@ -59,7 +59,7 @@ pub fn build_call_constraint_context(
5959
params.insert(0, ("self".into(), Some(LuaType::SelfInfer)));
6060
}
6161
(true, false) => {
62-
let source_type = infer_call_source_type(semantic_model, call_expr)?;
62+
let source_type = semantic_model.infer_call_receiver_type(call_expr)?;
6363
args.insert(
6464
0,
6565
CallConstraintArg {
@@ -103,75 +103,6 @@ pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType {
103103
}
104104
}
105105

106-
// 解析冒号调用时调用者的具体类型
107-
fn infer_call_source_type(
108-
semantic_model: &SemanticModel,
109-
call_expr: &LuaCallExpr,
110-
) -> Option<LuaType> {
111-
match call_expr.get_prefix_expr()? {
112-
LuaExpr::IndexExpr(index_expr) => {
113-
let decl = semantic_model.find_decl(
114-
index_expr.syntax().clone().into(),
115-
SemanticDeclLevel::default(),
116-
)?;
117-
118-
if let LuaSemanticDeclId::Member(member_id) = decl
119-
&& let Some(LuaSemanticDeclId::Member(member_id)) =
120-
semantic_model.get_member_origin_owner(member_id)
121-
{
122-
let root = semantic_model
123-
.get_db()
124-
.get_vfs()
125-
.get_syntax_tree(&member_id.file_id)?
126-
.get_red_root();
127-
let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?;
128-
let index_expr = LuaIndexExpr::cast(cur_node)?;
129-
130-
return index_expr.get_prefix_expr().map(|prefix_expr| {
131-
semantic_model
132-
.infer_expr(prefix_expr.clone())
133-
.unwrap_or(LuaType::SelfInfer)
134-
});
135-
}
136-
137-
return if let Some(prefix_expr) = index_expr.get_prefix_expr() {
138-
let expr_type = semantic_model
139-
.infer_expr(prefix_expr.clone())
140-
.unwrap_or(LuaType::SelfInfer);
141-
Some(expr_type)
142-
} else {
143-
None
144-
};
145-
}
146-
LuaExpr::NameExpr(name_expr) => {
147-
let decl = semantic_model.find_decl(
148-
name_expr.syntax().clone().into(),
149-
SemanticDeclLevel::default(),
150-
)?;
151-
if let LuaSemanticDeclId::Member(member_id) = decl {
152-
let root = semantic_model
153-
.get_db()
154-
.get_vfs()
155-
.get_syntax_tree(&member_id.file_id)?
156-
.get_red_root();
157-
let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?;
158-
let index_expr = LuaIndexExpr::cast(cur_node)?;
159-
160-
return index_expr.get_prefix_expr().map(|prefix_expr| {
161-
semantic_model
162-
.infer_expr(prefix_expr.clone())
163-
.unwrap_or(LuaType::SelfInfer)
164-
});
165-
}
166-
167-
return None;
168-
}
169-
_ => {}
170-
}
171-
172-
None
173-
}
174-
175106
// 推推导每个实参类型
176107
fn get_arg_infos(
177108
semantic_model: &SemanticModel,

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22

3-
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind};
3+
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaSyntaxKind};
44
use hashbrown::HashSet;
55
use rowan::TextRange;
66

@@ -11,13 +11,15 @@ use super::{
1111
use crate::semantic::overload_resolve::callable_accepts_args;
1212
use crate::{
1313
AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType,
14-
LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId,
15-
LuaType, LuaTypeDeclId, LuaUnionType, TypeVisitTrait, VariadicType,
14+
LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignature,
15+
LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType, SemanticDeclLevel, TypeVisitTrait,
16+
VariadicType,
1617
};
1718
use crate::{
1819
InferGuardRef,
1920
semantic::{
2021
generic::TypeSubstitutor, infer::narrow::get_type_at_call_expr_inline_cast,
22+
infer_node_semantic_decl, member::find_member_origin_owner,
2123
overload_resolve::collect_callable_overload_groups,
2224
},
2325
};
@@ -838,6 +840,67 @@ fn signature_is_generic(
838840
}
839841
}
840842

843+
/// 推断调用者的具体类型.
844+
pub fn infer_call_receiver_type(
845+
db: &DbIndex,
846+
cache: &mut LuaInferCache,
847+
call_expr: &LuaCallExpr,
848+
) -> Option<LuaType> {
849+
match call_expr.get_prefix_expr()? {
850+
LuaExpr::IndexExpr(index_expr) => {
851+
let decl = infer_node_semantic_decl(
852+
db,
853+
cache,
854+
index_expr.syntax().clone(),
855+
SemanticDeclLevel::default(),
856+
)?;
857+
858+
if let LuaSemanticDeclId::Member(member_id) = decl
859+
&& let Some(LuaSemanticDeclId::Member(member_id)) =
860+
find_member_origin_owner(db, cache, member_id)
861+
{
862+
let root = db
863+
.get_vfs()
864+
.get_syntax_tree(&member_id.file_id)?
865+
.get_red_root();
866+
let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?;
867+
let index_expr = LuaIndexExpr::cast(cur_node)?;
868+
869+
return index_expr.get_prefix_expr().map(|prefix_expr| {
870+
infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer)
871+
});
872+
}
873+
874+
index_expr
875+
.get_prefix_expr()
876+
.map(|prefix_expr| infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer))
877+
}
878+
LuaExpr::NameExpr(name_expr) => {
879+
let decl = infer_node_semantic_decl(
880+
db,
881+
cache,
882+
name_expr.syntax().clone(),
883+
SemanticDeclLevel::default(),
884+
)?;
885+
if let LuaSemanticDeclId::Member(member_id) = decl {
886+
let root = db
887+
.get_vfs()
888+
.get_syntax_tree(&member_id.file_id)?
889+
.get_red_root();
890+
let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?;
891+
let index_expr = LuaIndexExpr::cast(cur_node)?;
892+
893+
return index_expr.get_prefix_expr().map(|prefix_expr| {
894+
infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer)
895+
});
896+
}
897+
898+
None
899+
}
900+
_ => None,
901+
}
902+
}
903+
841904
#[cfg(test)]
842905
mod tests {
843906
use crate::{

crates/emmylua_code_analysis/src/semantic/infer/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use emmylua_parser::{
1717
};
1818
use infer_binary::infer_binary_expr;
1919
use infer_call::infer_call_expr;
20-
pub use infer_call::infer_call_expr_func;
20+
pub use infer_call::{infer_call_expr_func, infer_call_receiver_type};
2121
pub use infer_doc_type::{DocTypeInferContext, infer_doc_type};
2222
pub use infer_fail_reason::InferFailReason;
2323
pub use infer_index::infer_index_expr;

crates/emmylua_code_analysis/src/semantic/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use emmylua_parser::{
2121
LuaSyntaxToken, LuaTableExpr,
2222
};
2323
pub use infer::infer_index_expr;
24-
use infer::{infer_bind_value_type, infer_expr_list_types};
24+
use infer::{infer_bind_value_type, infer_call_receiver_type, infer_expr_list_types};
2525
pub use infer::{infer_table_field_value_should_be, infer_table_should_be};
2626
use lsp_types::Uri;
2727
pub use member::LuaMemberInfo;
@@ -59,7 +59,10 @@ pub use infer::infer_param;
5959
pub(crate) use infer::try_infer_expr_for_index;
6060
pub(crate) use infer::{infer_expr, try_infer_expr_no_flow};
6161
use overload_resolve::resolve_signature;
62-
pub(crate) use overload_resolve::{callable_accepts_args, collect_callable_overload_groups};
62+
pub(crate) use overload_resolve::{
63+
callable_accepts_args, collect_callable_overload_groups, get_func_param_type,
64+
is_func_last_param_variadic,
65+
};
6366
pub use semantic_info::SemanticDeclLevel;
6467
pub use type_check::{TypeCheckFailReason, TypeCheckResult};
6568

@@ -329,6 +332,10 @@ impl<'a> SemanticModel<'a> {
329332
find_member_origin_owner(self.db, &mut self.infer_cache.borrow_mut(), member_id)
330333
}
331334

335+
pub fn infer_call_receiver_type(&self, call_expr: &LuaCallExpr) -> Option<LuaType> {
336+
infer_call_receiver_type(self.db, &mut self.infer_cache.borrow_mut(), call_expr)
337+
}
338+
332339
pub fn get_index_decl_type(&self, index_expr: LuaIndexExpr) -> Option<LuaType> {
333340
let cache = &mut self.infer_cache.borrow_mut();
334341
infer_index_expr(self.db, cache, index_expr, false).ok()

0 commit comments

Comments
 (0)