Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;

use emmylua_parser::{LuaAst, LuaAstNode, LuaExpr, LuaIndexExpr, LuaIndexKey, LuaVarExpr};
use emmylua_parser::{LuaAst, LuaAstNode, LuaIndexExpr, LuaIndexKey, LuaVarExpr};

use crate::{DiagnosticCode, InferFailReason, LuaType, SemanticModel};

Expand Down Expand Up @@ -54,8 +54,6 @@ fn check_index_expr(
code: DiagnosticCode,
) -> Option<()> {
let db = context.db;

let index_key = index_expr.get_index_key()?;
let prefix_typ = semantic_model
.infer_expr(index_expr.get_prefix_expr()?)
.unwrap_or(LuaType::Unknown);
Expand All @@ -64,14 +62,10 @@ fn check_index_expr(
return Some(());
}

if !is_valid_index_key(&index_key) {
return Some(());
}
let index_key = index_expr.get_index_key()?;

let result = semantic_model.infer_expr(LuaExpr::IndexExpr(index_expr.clone()));
match result {
Err(InferFailReason::FieldDotFound) => {}
_ => return Some(()),
if is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key).is_some() {
return Some(());
}

let index_name = index_key.get_path_part();
Expand Down Expand Up @@ -121,9 +115,95 @@ fn is_valid_prefix_type(typ: &LuaType) -> bool {
}
}

#[allow(dead_code)]
fn is_valid_index_key(index_key: &LuaIndexKey) -> bool {
match index_key {
LuaIndexKey::String(_) | LuaIndexKey::Name(_) | LuaIndexKey::Integer(_) => true,
_ => false,
}
}

fn is_valid_member(
semantic_model: &SemanticModel,
prefix_typ: &LuaType,
index_expr: &LuaIndexExpr,
index_key: &LuaIndexKey,
) -> Option<()> {
// 检查 member_info
let need_add_diagnostic =
match semantic_model.get_semantic_info(index_expr.syntax().clone().into()) {
Some(info) => info.semantic_decl.is_none() && info.typ.is_unknown(),
None => true,
};

if !need_add_diagnostic {
return Some(());
}

// 获取并验证 key_type
let key_type = match index_key {
LuaIndexKey::Expr(expr) => match semantic_model.infer_expr(expr.clone()) {
Ok(
LuaType::Any
| LuaType::Unknown
| LuaType::Table
| LuaType::TplRef(_)
| LuaType::StrTplRef(_),
) => {
return Some(());
}
Ok(typ) => typ,
// 解析失败时认为其是合法的, 因为他可能没有标注类型
Err(InferFailReason::UnResolveDeclType(_)) => {
return Some(());
}
Err(_) => {
return None;
}
},
_ => return None,
};

// 允许特定类型组合通过
match (prefix_typ, &key_type) {
(LuaType::Tuple(_), LuaType::Integer | LuaType::IntegerConst(_)) => return Some(()),
_ => {}
}

// 解决`key`类型为联合名称时的报错(通常是`pairs`返回的`key`)
let mut key_path_set = HashSet::new();
get_index_key_names(&mut key_path_set, &key_type);
if key_path_set.is_empty() {
return None;
}
let member_path_set: HashSet<_> = semantic_model
.infer_member_infos(prefix_typ)?
.iter()
.map(|info| info.key.to_path())
.collect();

if member_path_set.is_empty() {
return None;
}
if key_path_set.is_subset(&member_path_set) {
return Some(());
}

None
}

fn get_index_key_names(name_set: &mut HashSet<String>, typ: &LuaType) {
match typ {
LuaType::StringConst(name) => {
name_set.insert(name.as_ref().to_string());
}
LuaType::IntegerConst(i) => {
name_set.insert(format!("[{}]", i));
}
LuaType::Union(union_typ) => union_typ
.get_types()
.iter()
.for_each(|typ| get_index_key_names(name_set, typ)),
_ => {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,52 @@ mod tests {
));
}

#[test]
fn test_enum() {
let mut ws = VirtualWorkspace::new();
assert!(ws.check_code_for_namespace(
DiagnosticCode::AssignTypeMismatch,
r#"
---@enum SubscriberFlags
local SubscriberFlags = {
None = 0,
Tracking = 1 << 0,
Recursed = 1 << 1,
ToCheckDirty = 1 << 3,
Dirty = 1 << 4,
}
---@class Subscriber
---@field flags SubscriberFlags

---@type Subscriber
local subscriber

subscriber.flags = subscriber.flags & ~SubscriberFlags.Tracking -- 被推断为`integer`而不是实际整数值, 允许匹配
"#
));

assert!(!ws.check_code_for_namespace(
DiagnosticCode::AssignTypeMismatch,
r#"
---@enum SubscriberFlags
local SubscriberFlags = {
None = 0,
Tracking = 1 << 0,
Recursed = 1 << 1,
ToCheckDirty = 1 << 3,
Dirty = 1 << 4,
}
---@class Subscriber
---@field flags SubscriberFlags

---@type Subscriber
local subscriber

subscriber.flags = 9 -- 不允许匹配不上的实际值
"#
));
}

#[test]
fn test_issue_193() {
let mut ws = VirtualWorkspace::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,20 @@ mod test {
assert!(ws.check_code_for(
DiagnosticCode::InjectField,
r#"
local a = { 'a' }
a[#a + 1] = 'b'
local a = { 'a' }
a[#a + 1] = 'b'

---@type string[]
local b = { 'a' }
b[#b + 1] = 'b'

---@type table<integer, string>
local c = { 'a' }
c[#c + 1] = 'b'

---@type { [integer]: string }
local d = { 'a' }
d[#d + 1] = 'b'
"#
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,21 @@ mod tests {
"#
));
}

#[test]
fn test_3() {
let mut ws = VirtualWorkspace::new();

assert!(ws.check_code_for(
DiagnosticCode::ReturnTypeMismatch,
r#"
---@return table<string, {old: any, new: any}>
local function test()
---@type table<string, {old: any, new: any}>|table
local a
return a
end
"#
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,83 @@ mod test {
"#
));
}

#[test]
fn test_any_key() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
assert!(ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
---@class LogicalOperators
local logicalOperators <const> = {}

---@param key any
local function test(key)
print(logicalOperators[key])
end
"#
));
}

#[test]
fn test_class_key_to_class_key() {
let mut ws = VirtualWorkspace::new();

assert!(!ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
--- @type table<string, integer>
local FUNS = {}

---@class D10.AAA

---@type D10.AAA
local Test1

local a = FUNS[Test1]
"#
));

assert!(ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
---@generic K, V
---@param t table<K, V> | V[] | {[K]: V}
---@return fun(tbl: any):K, std.NotNull<V>
local function pairs(t) end

---@class D11.AAA
---@field name string
---@field key string
local AAA = {}

---@type D11.AAA
local a

for k, v in pairs(AAA) do
if not a[k] then
-- a[k] = v
end
end
"#
));
}

#[test]
fn test_2() {
let mut ws = VirtualWorkspace::new();

assert!(ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
local function sortCallbackOfIndex()
---@type table<string, integer>
local indexMap = {}
return function(v)
return -indexMap[v]
end
end
"#
));
}
}
56 changes: 56 additions & 0 deletions crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,67 @@ fn object_tpl_pattern_match(
}
}
}
LuaType::TableConst(inst) => {
let owner = LuaMemberOwner::Element(inst.clone());
object_tpl_pattern_match_member_owner_match(
db,
cache,
root,
origin_obj,
owner,
substitutor,
);
}
_ => {}
}
Some(())
}

fn object_tpl_pattern_match_member_owner_match(
db: &DbIndex,
cache: &mut LuaInferCache,
root: &LuaSyntaxNode,
object: &LuaObjectType,
owner: LuaMemberOwner,
substitutor: &mut TypeSubstitutor,
) -> Option<()> {
let owner_type = match &owner {
LuaMemberOwner::Element(inst) => LuaType::TableConst(inst.clone()),
LuaMemberOwner::Type(type_id) => LuaType::Ref(type_id.clone()),
_ => {
return None;
}
};

let members = infer_member_map(db, &owner_type)?;
for (k, v) in members {
let resolve_key = match &k {
LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(i.clone())),
LuaMemberKey::Name(s) => Some(LuaType::StringConst(s.clone().into())),
_ => None,
};
let resolve_type = match v.len() {
0 => LuaType::Any,
1 => v[0].typ.clone(),
_ => {
let mut types = Vec::new();
for m in v {
types.push(m.typ.clone());
}
LuaType::Union(LuaUnionType::new(types).into())
}
};

if let Some(_) = resolve_key {
if let Some(field_value) = object.get_field(&k) {
tpl_pattern_match(db, cache, root, field_value, &resolve_type, substitutor);
}
}
}

Some(())
}

fn array_tpl_pattern_match(
db: &DbIndex,
cache: &mut LuaInferCache,
Expand Down
Loading