Skip to content

Commit b337baa

Browse files
committed
update diagnostic: GenericConstraintMismatch support check union
1 parent 5312a09 commit b337baa

3 files changed

Lines changed: 113 additions & 31 deletions

File tree

crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -110,37 +110,86 @@ fn check_call_expr(
110110
continue;
111111
};
112112

113-
match param_type {
114-
LuaType::StrTplRef(str_tpl_ref) => {
115-
let extend_type = get_extend_type(
116-
semantic_model,
117-
&call_expr,
118-
str_tpl_ref.get_tpl_id(),
119-
signature,
120-
);
121-
check_str_tpl_ref(
113+
check_param_type(
114+
context,
115+
semantic_model,
116+
&call_expr,
117+
i,
118+
param_type,
119+
signature,
120+
&arg_infos,
121+
false,
122+
);
123+
}
124+
}
125+
126+
Some(())
127+
}
128+
129+
fn check_param_type(
130+
context: &mut DiagnosticContext,
131+
semantic_model: &SemanticModel,
132+
call_expr: &LuaCallExpr,
133+
param_index: usize,
134+
param_type: &LuaType,
135+
signature: &LuaSignature,
136+
arg_infos: &[(LuaType, TextRange)],
137+
from_union: bool,
138+
) -> Option<()> {
139+
// 应该先通过泛型体操约束到唯一类型再进行检查
140+
match param_type {
141+
LuaType::StrTplRef(str_tpl_ref) => {
142+
let extend_type = get_extend_type(
143+
semantic_model,
144+
&call_expr,
145+
str_tpl_ref.get_tpl_id(),
146+
signature,
147+
);
148+
let arg_expr = call_expr.get_args_list()?.get_args().nth(param_index)?;
149+
let arg_type = semantic_model.infer_expr(arg_expr.clone()).ok()?;
150+
151+
if from_union && !arg_type.is_string() {
152+
return None;
153+
}
154+
155+
check_str_tpl_ref(
156+
context,
157+
semantic_model,
158+
str_tpl_ref,
159+
&arg_type,
160+
arg_expr.get_range(),
161+
extend_type,
162+
);
163+
}
164+
LuaType::TplRef(tpl_ref) => {
165+
let extend_type =
166+
get_extend_type(semantic_model, &call_expr, tpl_ref.get_tpl_id(), signature);
167+
check_tpl_ref(
168+
context,
169+
semantic_model,
170+
&extend_type,
171+
arg_infos.get(param_index),
172+
);
173+
}
174+
LuaType::Union(union_type) => {
175+
// 如果不是来自 union, 才展开 union 中的每个类型进行检查
176+
if !from_union {
177+
for union_member_type in union_type.get_types() {
178+
check_param_type(
122179
context,
123180
semantic_model,
124-
&call_expr,
125-
i,
126-
str_tpl_ref,
127-
extend_type,
128-
);
129-
}
130-
LuaType::TplRef(tpl_ref) => {
131-
let extend_type = get_extend_type(
132-
semantic_model,
133-
&call_expr,
134-
tpl_ref.get_tpl_id(),
181+
call_expr,
182+
param_index,
183+
union_member_type,
135184
signature,
185+
arg_infos,
186+
true,
136187
);
137-
check_tpl_ref(context, semantic_model, &extend_type, arg_infos.get(i));
138188
}
139-
_ => {}
140189
}
141190
}
191+
_ => {}
142192
}
143-
144193
Some(())
145194
}
146195

@@ -182,15 +231,11 @@ fn get_extend_type(
182231
fn check_str_tpl_ref(
183232
context: &mut DiagnosticContext,
184233
semantic_model: &SemanticModel,
185-
call_expr: &LuaCallExpr,
186-
param_index: usize,
187234
str_tpl_ref: &LuaStringTplType,
235+
arg_type: &LuaType,
236+
range: TextRange,
188237
extend_type: Option<LuaType>,
189238
) -> Option<()> {
190-
let arg_expr = call_expr.get_args_list()?.get_args().nth(param_index)?;
191-
let arg_type = semantic_model.infer_expr(arg_expr.clone()).ok()?;
192-
let range = arg_expr.get_range();
193-
194239
match arg_type {
195240
LuaType::StringConst(str) | LuaType::DocStringConst(str) => {
196241
let full_type_name = format!(

crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ fn check_return_stat(
127127
let return_expr_type = &return_expr_types[0];
128128
let return_expr_range = return_expr_ranges[0];
129129
let result = semantic_model.type_check(check_type, &return_expr_type);
130-
dbg!(&check_type);
131-
dbg!(&return_expr_type);
132130
if !result.is_ok() {
133131
if return_expr_type.is_table() {
134132
if let Some(return_expr) = return_exprs.get(0) {

crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,43 @@ mod test {
200200
"#
201201
));
202202
}
203+
204+
#[test]
205+
fn test_union() {
206+
let mut ws = VirtualWorkspace::new();
207+
ws.def(
208+
r#"
209+
---@class ab
210+
211+
---@generic T
212+
---@param a `T`|T
213+
---@return T
214+
function name(a)
215+
return a
216+
end
217+
"#,
218+
);
219+
assert!(ws.check_code_for(
220+
DiagnosticCode::GenericConstraintMismatch,
221+
r#"
222+
---@type ab
223+
local a
224+
225+
name(a)
226+
"#
227+
));
228+
assert!(ws.check_code_for(
229+
DiagnosticCode::GenericConstraintMismatch,
230+
r#"
231+
name("ab")
232+
"#
233+
));
234+
235+
assert!(!ws.check_code_for(
236+
DiagnosticCode::GenericConstraintMismatch,
237+
r#"
238+
name("a")
239+
"#
240+
));
241+
}
203242
}

0 commit comments

Comments
 (0)