Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/emmylua_code_analysis/locales/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ Should not reassign to iter variable:
zh_CN: '不应重新赋值给迭代变量'
zh_HK: '不應重新指定迭代變數'

expected `%{source}` but found `%{found}`:
en: expected `%{source}` but found `%{found}`
zh_CN: '预期 `%{source}`,但得到 `%{found}`'
zh_HK: '期望 `%{source}`,但得到 `%{found}`'
expected `%{source}` but found `%{found}`. %{reason}:
en: expected `%{source}` but found `%{found}`. %{reason}
zh_CN: '预期 `%{source}`,但得到 `%{found}`。 %{reason}'
zh_HK: '期望 `%{source}`,但得到 `%{found}`。 %{reason}'
function %{name} may be nil:
en: function %{name} may be nil
zh_CN: '函数 %{name} 可能为 nil'
Expand Down
7 changes: 5 additions & 2 deletions crates/emmylua_code_analysis/resources/std/debug.lua
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ function debug.getregistry() end
---
--- Variable names starting with '(' (open parenthesis) represent variables with
--- no known names (variables from chunks saved without debug information).
---@param f integer
---@param f async fun(...):any...
---@param up integer
---@return table
---@return string name
---@return any value
---@nodiscard
function debug.getupvalue(f, up) end

---
Expand Down Expand Up @@ -236,6 +238,7 @@ function debug.setlocal(thread, level, var, value) end
---@param value T
---@param meta? table
---@return T value
---@overload fun(value: table, meta: T): T
function debug.setmetatable(value, meta) end

---
Expand Down
7 changes: 3 additions & 4 deletions crates/emmylua_code_analysis/resources/std/global.lua
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ function dofile(filename) end
--- `error` function was called. Level 2 points the error to where the function
--- that called `error` was called; and so on. Passing a level 0 avoids the
--- addition of error position information to the message.
---@overload fun(message:string)
---@param message string
---@param message any
---@param level? integer
function error(message, level) end

Expand All @@ -114,7 +113,7 @@ function getmetatable(object) end
--- will iterate over the key–value pairs (1,`t[1]`), (2,`t[2]`), ..., up to
--- the first absent index.
---@generic V
---@param t V[] | table<any, V>
---@param t V[] | table<any, V> | {[any]: V}
---@return fun(tbl: any):int, std.NotNull<V>
function ipairs(t) end

Expand Down Expand Up @@ -232,7 +231,7 @@ function next(table, index) end
--- See function `next` for the caveats of modifying the table during its
--- traversal.
---@generic K, V
---@param t table<K, V>
---@param t table<K, V> | V[] | {[K]: V}
---@return fun(tbl: any):K, std.NotNull<V>
function pairs(t) end
---
Expand Down
32 changes: 22 additions & 10 deletions crates/emmylua_code_analysis/resources/std/io.lua
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,25 @@ function io.output(file) end
---@return file
function io.popen(prog, mode) end

---@alias std.readmode
---| integer
---| string
---| "n" # Reads a number, returning a float or integer based on Lua's conversion grammar.
---| "a" # Reads the entire file starting from the current position.
---| "l" # Reads a line and ignores the end-of-line marker.
---| "L" # Reads a line and preserves the end-of-line marker.
---| "*n" # Reads a number, returning a float or integer based on Lua's conversion grammar.
---| "*a" # Reads the entire file starting from the current position.
---| "*l" # Reads a line and ignores the end-of-line marker.
---| "*L" # Reads a line and preserves the end-of-line marker.

---
--- Equivalent to `io.input():read(···)`.
--- @param format '*n' | '*a' | '*l' | integer
--- @return string | integer | nil
--- @overload fun(format:'*n'): integer
--- @overload fun(format:'*a' | '*l' | integer): string | nil
function io.read(format) end
---@param ... std.readmode
---@return any
---@return any ...
---@nodiscard
function io.read(...) end

---
--- In case of success, returns a handle for a temporary file. This file is
Expand Down Expand Up @@ -195,11 +207,11 @@ function file:lines(...) end
--- *number*: reads a string with up to this number of bytes, returning **nil**
--- on end of file. If `number` is zero, it reads nothing and returns an
--- empty string, or **nil** on end of file.
--- @param format '*n' | '*a' | '*l' | integer
--- @return string | integer | nil
--- @overload fun(format:'*n'): integer
--- @overload fun(format:'*a' | '*l' | integer): string | nil
function file:read(format) end
---@param ... std.readmode
---@return any
---@return any ...
---@nodiscard
function file:read(...) end

---
--- Sets and gets the file position, measured from the beginning of the
Expand Down
22 changes: 12 additions & 10 deletions crates/emmylua_code_analysis/resources/std/string.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ function string.dump(func, strip) end
---
--- If the pattern has captures, then in a successful match the captured values
--- are also returned, after the two indices.
---@overload fun(s:string, pattern:string):integer, integer, string...
---@param s string
---@param pattern string
---@param init? integer
---@param plain? boolean
---@return integer, integer, string...
---@param s string|number
---@param pattern string|number
---@param init? integer
---@param plain? boolean
---@return integer|nil start
---@return integer|nil end
---@return string ... captured
---@nodiscard
function string.find(s, pattern, init, plain) end

---
Expand Down Expand Up @@ -276,11 +278,11 @@ function string.reverse(s) end
--- corrected to 1. If `j` is greater than the string length, it is corrected to
--- that length. If, after these corrections, `i` is greater than `j`, the
--- function returns the empty string.
---@overload fun(s:string, i:integer):string
---@param s string
---@param i integer
---@param j integer
---@param s string|number
---@param i integer
---@param j? integer
---@return string
---@nodiscard
function string.sub(s, i, j) end

---@version >5.3
Expand Down
2 changes: 1 addition & 1 deletion crates/emmylua_code_analysis/resources/std/table.lua
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ function table.sort(list, comp) end
--- return `list[i]`, `list[i+1]`, `···`, `list[j]`
--- By default, i is 1 and j is #list.
---@generic T
---@param list [T...] | table<any, T>
---@param i? integer
---@param j? integer
---@param list [T...]
---@return T...
function table.unpack(list, i, j) end

Expand Down
14 changes: 8 additions & 6 deletions crates/emmylua_code_analysis/resources/std/utf8.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ function utf8.codepoint(s, i, j) end
--- positions `i` and `j` (both inclusive). The default for `i` is 1 and for
--- `j` is -1. If it finds any invalid byte sequence, returns a false value
--- plus the position of the first invalid byte.
---@overload fun(s:string):number
---@param s string
---@param i? number
---@param j? number
---@return number
function utf8.len(s, i, j) end
---@param s string
---@param i? integer
---@param j? integer
---@param lax? boolean
---@return integer?
---@return integer? errpos
---@nodiscard
function utf8.len(s, i, j, lax) end

---
--- Returns the position (in bytes) where the encoding of the `n`-th character
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ pub fn analyze_table_expr(analyzer: &mut DeclAnalyzer, expr: LuaTableExpr) -> Op
);

let decl_feature = if analyzer.is_meta {
LuaMemberFeature::MetaFieldDecl
LuaMemberFeature::MetaDefine
} else {
LuaMemberFeature::FileFieldDecl
LuaMemberFeature::FileDefine
};

let member_id = LuaMemberId::new(field.get_syntax_id(), file_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,35 @@ fn broadcast_up(
if let Some(ne_type_assert) = type_assert.get_negation() {
if let Some(else_stat) = if_stat.get_else_clause() {
let block_range = else_stat.get_range();
flow_chain.add_type_assert(path, ne_type_assert, block_range, actual_range);
} else if is_block_has_return(if_stat.get_block()?).unwrap_or(false) {
flow_chain.add_type_assert(
path,
ne_type_assert.clone(),
block_range,
actual_range,
);
} else if is_block_has_return(if_stat.get_block()).unwrap_or(false) {
let parent_block = if_stat.get_parent::<LuaBlock>()?;
let parent_range = parent_block.get_range();
let if_range = if_stat.get_range();
if if_range.end() < parent_range.end() {
let range = TextRange::new(if_range.end(), parent_range.end());
flow_chain.add_type_assert(path, ne_type_assert, range, actual_range);
flow_chain.add_type_assert(
path,
ne_type_assert.clone(),
range,
actual_range,
);
}
}
for else_if_clause in if_stat.get_else_if_clause_list() {
let block_range = else_if_clause.get_range();
flow_chain.add_type_assert(
path,
ne_type_assert.clone(),
block_range,
actual_range,
);
}
}
}
LuaAst::LuaWhileStat(while_stat) => {
Expand Down Expand Up @@ -445,10 +464,12 @@ fn infer_lua_type_assert(
Some(())
}

fn is_block_has_return(block: LuaBlock) -> Option<bool> {
for stat in block.get_stats() {
if is_stat_change_flow(stat.clone()).unwrap_or(false) {
return Some(true);
fn is_block_has_return(block: Option<LuaBlock>) -> Option<bool> {
if let Some(block) = block {
for stat in block.get_stats() {
if is_stat_change_flow(stat.clone()).unwrap_or(false) {
return Some(true);
}
}
}

Expand All @@ -469,9 +490,7 @@ fn is_stat_change_flow(stat: LuaStat) -> Option<bool> {
Some(false)
}
LuaStat::ReturnStat(_) => Some(true),
LuaStat::DoStat(do_stat) => {
Some(is_block_has_return(do_stat.get_block()?).unwrap_or(false))
}
LuaStat::DoStat(do_stat) => Some(is_block_has_return(do_stat.get_block()).unwrap_or(false)),
_ => Some(false),
}
}
Expand Down
22 changes: 22 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,26 @@ print(a.field)
"#
));
}

#[test]
fn test_elseif() {
let mut ws = VirtualWorkspace::new();

assert!(ws.check_code_for(
DiagnosticCode::NeedCheckNil,
r#"
---@class D11
---@field public a string

---@type D11|nil
local a

if not a then
elseif a.a then
print(a.a)
end

"#
));
}
}
8 changes: 8 additions & 0 deletions crates/emmylua_code_analysis/src/db_index/type/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ impl LuaType {
}
}

pub fn is_optional(&self) -> bool {
match self {
LuaType::Nil | LuaType::Any | LuaType::Unknown => true,
LuaType::Union(u) => u.types.iter().any(|t| t.is_optional()),
_ => false,
}
}

pub fn is_always_truthy(&self) -> bool {
match self {
LuaType::Nil | LuaType::Boolean | LuaType::Any | LuaType::Unknown => false,
Expand Down
24 changes: 24 additions & 0 deletions crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ fn check_index_expr(
let prefix_typ = semantic_model
.infer_expr(index_expr.get_prefix_expr()?)
.unwrap_or(LuaType::Unknown);

if !is_valid_prefix_type(&prefix_typ) {
return Some(());
}

let index_name = index_key.get_path_part();
match code {
DiagnosticCode::InjectField => {
Expand Down Expand Up @@ -92,3 +97,22 @@ fn check_index_expr(

Some(())
}

#[allow(dead_code)]
fn is_valid_prefix_type(typ: &LuaType) -> bool {
let mut current_typ = typ;
loop {
match current_typ {
LuaType::Any
| LuaType::Unknown
| LuaType::Table
| LuaType::TplRef(_)
| LuaType::StrTplRef(_)
| LuaType::TableConst(_) => return false,
LuaType::Instance(instance_typ) => {
current_typ = instance_typ.get_base();
}
_ => return true,
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ fn check_closure_expr(
semantic_model.get_file_id(),
&closure_expr,
))?;
let source_params_len = match source_typ {
LuaType::DocFunction(func_type) => func_type.get_params().len(),
let source_params_len = match &source_typ {
LuaType::DocFunction(func_type) => {
let params = func_type.get_params();
get_params_len(params)
}
LuaType::Signature(signature_id) => {
let signature = context.db.get_signature_index().get(&signature_id)?;
signature.get_type_params().len()
let params = signature.get_type_params();
get_params_len(&params)
}
_ => return Some(()),
};
}?;

// 只检查右值参数多于左值参数的情况, 右值参数少于左值参数的情况是能够接受的
if source_params_len > right_value.params.len() {
Expand Down Expand Up @@ -166,7 +170,14 @@ fn check_call_expr(
// Check for redundant parameters
else if call_args_count > params.len() {
// 参数定义中最后一个参数是 `...`
if params.last().map_or(false, |(name, _)| name == "...") {
if params.last().map_or(false, |(name, typ)| {
name == "..."
|| if let Some(typ) = typ {
typ.is_variadic()
} else {
false
}
}) {
return Some(());
}

Expand Down Expand Up @@ -198,3 +209,18 @@ fn check_call_expr(

Some(())
}

fn get_params_len(params: &[(String, Option<LuaType>)]) -> Option<usize> {
if let Some((name, typ)) = params.last() {
// 如果最后一个参数是可变参数, 则直接返回, 不需要检查
if name == "..." {
return None;
}
if let Some(typ) = typ {
if typ.is_variadic() {
return None;
}
}
}
Some(params.len())
}
Loading