Skip to content

Commit 03b7227

Browse files
authored
Merge pull request #294 from xuhuanzy/diagnostic
diagnostic
2 parents 5ec9f09 + ff37b81 commit 03b7227

8 files changed

Lines changed: 329 additions & 12 deletions

File tree

crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::HashSet;
22

3-
use emmylua_parser::{LuaAst, LuaAstNode, LuaExpr, LuaIndexExpr, LuaIndexKey, LuaVarExpr};
3+
use emmylua_parser::{LuaAst, LuaAstNode, LuaIndexExpr, LuaIndexKey, LuaVarExpr};
44

55
use crate::{DiagnosticCode, InferFailReason, LuaType, SemanticModel};
66

@@ -54,8 +54,6 @@ fn check_index_expr(
5454
code: DiagnosticCode,
5555
) -> Option<()> {
5656
let db = context.db;
57-
58-
let index_key = index_expr.get_index_key()?;
5957
let prefix_typ = semantic_model
6058
.infer_expr(index_expr.get_prefix_expr()?)
6159
.unwrap_or(LuaType::Unknown);
@@ -64,14 +62,10 @@ fn check_index_expr(
6462
return Some(());
6563
}
6664

67-
if !is_valid_index_key(&index_key) {
68-
return Some(());
69-
}
65+
let index_key = index_expr.get_index_key()?;
7066

71-
let result = semantic_model.infer_expr(LuaExpr::IndexExpr(index_expr.clone()));
72-
match result {
73-
Err(InferFailReason::FieldDotFound) => {}
74-
_ => return Some(()),
67+
if is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key).is_some() {
68+
return Some(());
7569
}
7670

7771
let index_name = index_key.get_path_part();
@@ -121,9 +115,95 @@ fn is_valid_prefix_type(typ: &LuaType) -> bool {
121115
}
122116
}
123117

118+
#[allow(dead_code)]
124119
fn is_valid_index_key(index_key: &LuaIndexKey) -> bool {
125120
match index_key {
126121
LuaIndexKey::String(_) | LuaIndexKey::Name(_) | LuaIndexKey::Integer(_) => true,
127122
_ => false,
128123
}
129124
}
125+
126+
fn is_valid_member(
127+
semantic_model: &SemanticModel,
128+
prefix_typ: &LuaType,
129+
index_expr: &LuaIndexExpr,
130+
index_key: &LuaIndexKey,
131+
) -> Option<()> {
132+
// 检查 member_info
133+
let need_add_diagnostic =
134+
match semantic_model.get_semantic_info(index_expr.syntax().clone().into()) {
135+
Some(info) => info.semantic_decl.is_none() && info.typ.is_unknown(),
136+
None => true,
137+
};
138+
139+
if !need_add_diagnostic {
140+
return Some(());
141+
}
142+
143+
// 获取并验证 key_type
144+
let key_type = match index_key {
145+
LuaIndexKey::Expr(expr) => match semantic_model.infer_expr(expr.clone()) {
146+
Ok(
147+
LuaType::Any
148+
| LuaType::Unknown
149+
| LuaType::Table
150+
| LuaType::TplRef(_)
151+
| LuaType::StrTplRef(_),
152+
) => {
153+
return Some(());
154+
}
155+
Ok(typ) => typ,
156+
// 解析失败时认为其是合法的, 因为他可能没有标注类型
157+
Err(InferFailReason::UnResolveDeclType(_)) => {
158+
return Some(());
159+
}
160+
Err(_) => {
161+
return None;
162+
}
163+
},
164+
_ => return None,
165+
};
166+
167+
// 允许特定类型组合通过
168+
match (prefix_typ, &key_type) {
169+
(LuaType::Tuple(_), LuaType::Integer | LuaType::IntegerConst(_)) => return Some(()),
170+
_ => {}
171+
}
172+
173+
// 解决`key`类型为联合名称时的报错(通常是`pairs`返回的`key`)
174+
let mut key_path_set = HashSet::new();
175+
get_index_key_names(&mut key_path_set, &key_type);
176+
if key_path_set.is_empty() {
177+
return None;
178+
}
179+
let member_path_set: HashSet<_> = semantic_model
180+
.infer_member_infos(prefix_typ)?
181+
.iter()
182+
.map(|info| info.key.to_path())
183+
.collect();
184+
185+
if member_path_set.is_empty() {
186+
return None;
187+
}
188+
if key_path_set.is_subset(&member_path_set) {
189+
return Some(());
190+
}
191+
192+
None
193+
}
194+
195+
fn get_index_key_names(name_set: &mut HashSet<String>, typ: &LuaType) {
196+
match typ {
197+
LuaType::StringConst(name) => {
198+
name_set.insert(name.as_ref().to_string());
199+
}
200+
LuaType::IntegerConst(i) => {
201+
name_set.insert(format!("[{}]", i));
202+
}
203+
LuaType::Union(union_typ) => union_typ
204+
.get_types()
205+
.iter()
206+
.for_each(|typ| get_index_key_names(name_set, typ)),
207+
_ => {}
208+
}
209+
}

crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,52 @@ mod tests {
6464
));
6565
}
6666

67+
#[test]
68+
fn test_enum() {
69+
let mut ws = VirtualWorkspace::new();
70+
assert!(ws.check_code_for_namespace(
71+
DiagnosticCode::AssignTypeMismatch,
72+
r#"
73+
---@enum SubscriberFlags
74+
local SubscriberFlags = {
75+
None = 0,
76+
Tracking = 1 << 0,
77+
Recursed = 1 << 1,
78+
ToCheckDirty = 1 << 3,
79+
Dirty = 1 << 4,
80+
}
81+
---@class Subscriber
82+
---@field flags SubscriberFlags
83+
84+
---@type Subscriber
85+
local subscriber
86+
87+
subscriber.flags = subscriber.flags & ~SubscriberFlags.Tracking -- 被推断为`integer`而不是实际整数值, 允许匹配
88+
"#
89+
));
90+
91+
assert!(!ws.check_code_for_namespace(
92+
DiagnosticCode::AssignTypeMismatch,
93+
r#"
94+
---@enum SubscriberFlags
95+
local SubscriberFlags = {
96+
None = 0,
97+
Tracking = 1 << 0,
98+
Recursed = 1 << 1,
99+
ToCheckDirty = 1 << 3,
100+
Dirty = 1 << 4,
101+
}
102+
---@class Subscriber
103+
---@field flags SubscriberFlags
104+
105+
---@type Subscriber
106+
local subscriber
107+
108+
subscriber.flags = 9 -- 不允许匹配不上的实际值
109+
"#
110+
));
111+
}
112+
67113
#[test]
68114
fn test_issue_193() {
69115
let mut ws = VirtualWorkspace::new();

crates/emmylua_code_analysis/src/diagnostic/test/inject_field_test.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,20 @@ mod test {
112112
assert!(ws.check_code_for(
113113
DiagnosticCode::InjectField,
114114
r#"
115-
local a = { 'a' }
116-
a[#a + 1] = 'b'
115+
local a = { 'a' }
116+
a[#a + 1] = 'b'
117+
118+
---@type string[]
119+
local b = { 'a' }
120+
b[#b + 1] = 'b'
121+
122+
---@type table<integer, string>
123+
local c = { 'a' }
124+
c[#c + 1] = 'b'
125+
126+
---@type { [integer]: string }
127+
local d = { 'a' }
128+
d[#d + 1] = 'b'
117129
"#
118130
));
119131
}

crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,4 +259,21 @@ mod tests {
259259
"#
260260
));
261261
}
262+
263+
#[test]
264+
fn test_3() {
265+
let mut ws = VirtualWorkspace::new();
266+
267+
assert!(ws.check_code_for(
268+
DiagnosticCode::ReturnTypeMismatch,
269+
r#"
270+
---@return table<string, {old: any, new: any}>
271+
local function test()
272+
---@type table<string, {old: any, new: any}>|table
273+
local a
274+
return a
275+
end
276+
"#
277+
));
278+
}
262279
}

crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,83 @@ mod test {
7979
"#
8080
));
8181
}
82+
83+
#[test]
84+
fn test_any_key() {
85+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
86+
assert!(ws.check_code_for(
87+
DiagnosticCode::UndefinedField,
88+
r#"
89+
---@class LogicalOperators
90+
local logicalOperators <const> = {}
91+
92+
---@param key any
93+
local function test(key)
94+
print(logicalOperators[key])
95+
end
96+
"#
97+
));
98+
}
99+
100+
#[test]
101+
fn test_class_key_to_class_key() {
102+
let mut ws = VirtualWorkspace::new();
103+
104+
assert!(!ws.check_code_for(
105+
DiagnosticCode::UndefinedField,
106+
r#"
107+
--- @type table<string, integer>
108+
local FUNS = {}
109+
110+
---@class D10.AAA
111+
112+
---@type D10.AAA
113+
local Test1
114+
115+
local a = FUNS[Test1]
116+
"#
117+
));
118+
119+
assert!(ws.check_code_for(
120+
DiagnosticCode::UndefinedField,
121+
r#"
122+
---@generic K, V
123+
---@param t table<K, V> | V[] | {[K]: V}
124+
---@return fun(tbl: any):K, std.NotNull<V>
125+
local function pairs(t) end
126+
127+
---@class D11.AAA
128+
---@field name string
129+
---@field key string
130+
local AAA = {}
131+
132+
---@type D11.AAA
133+
local a
134+
135+
for k, v in pairs(AAA) do
136+
if not a[k] then
137+
-- a[k] = v
138+
end
139+
end
140+
"#
141+
));
142+
}
143+
144+
#[test]
145+
fn test_2() {
146+
let mut ws = VirtualWorkspace::new();
147+
148+
assert!(ws.check_code_for(
149+
DiagnosticCode::UndefinedField,
150+
r#"
151+
local function sortCallbackOfIndex()
152+
---@type table<string, integer>
153+
local indexMap = {}
154+
return function(v)
155+
return -indexMap[v]
156+
end
157+
end
158+
"#
159+
));
160+
}
82161
}

crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,67 @@ fn object_tpl_pattern_match(
202202
}
203203
}
204204
}
205+
LuaType::TableConst(inst) => {
206+
let owner = LuaMemberOwner::Element(inst.clone());
207+
object_tpl_pattern_match_member_owner_match(
208+
db,
209+
cache,
210+
root,
211+
origin_obj,
212+
owner,
213+
substitutor,
214+
);
215+
}
205216
_ => {}
206217
}
207218
Some(())
208219
}
209220

221+
fn object_tpl_pattern_match_member_owner_match(
222+
db: &DbIndex,
223+
cache: &mut LuaInferCache,
224+
root: &LuaSyntaxNode,
225+
object: &LuaObjectType,
226+
owner: LuaMemberOwner,
227+
substitutor: &mut TypeSubstitutor,
228+
) -> Option<()> {
229+
let owner_type = match &owner {
230+
LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()),
231+
LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()),
232+
_ => {
233+
return None;
234+
}
235+
};
236+
237+
let members = infer_member_map(db, &owner_type)?;
238+
for (k, v) in members {
239+
let resolve_key = match &k {
240+
LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(i.clone())),
241+
LuaMemberKey::Name(s) => Some(LuaType::StringConst(s.clone().into())),
242+
_ => None,
243+
};
244+
let resolve_type = match v.len() {
245+
0 => LuaType::Any,
246+
1 => v[0].typ.clone(),
247+
_ => {
248+
let mut types = Vec::new();
249+
for m in v {
250+
types.push(m.typ.clone());
251+
}
252+
LuaType::Union(LuaUnionType::new(types).into())
253+
}
254+
};
255+
256+
if let Some(_) = resolve_key {
257+
if let Some(field_value) = object.get_field(&k) {
258+
tpl_pattern_match(db, cache, root, field_value, &resolve_type, substitutor);
259+
}
260+
}
261+
}
262+
263+
Some(())
264+
}
265+
210266
fn array_tpl_pattern_match(
211267
db: &DbIndex,
212268
cache: &mut LuaInferCache,

0 commit comments

Comments
 (0)