Skip to content

Commit 939aa34

Browse files
committed
optimize infer_union
1 parent a1de533 commit 939aa34

2 files changed

Lines changed: 45 additions & 7 deletions

File tree

crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,22 @@ mod tests {
205205
"#
206206
));
207207
}
208+
209+
#[test]
210+
fn test_1() {
211+
let mut ws = VirtualWorkspace::new();
212+
213+
assert!(ws.check_code_for(
214+
DiagnosticCode::ReturnTypeMismatch,
215+
r#"
216+
---@return string?
217+
local function a()
218+
---@type int?
219+
local ccc
220+
return ccc and a() or nil
221+
end
222+
223+
"#
224+
));
225+
}
208226
}

crates/emmylua_code_analysis/src/semantic/infer/infer_binary.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::HashSet;
2+
13
use emmylua_parser::{BinaryOperator, LuaBinaryExpr};
24
use smol_str::SmolStr;
35

@@ -61,17 +63,35 @@ fn infer_binary_expr_type(
6163
}
6264

6365
fn infer_union(db: &DbIndex, u: &LuaUnionType, right: &LuaType, op: BinaryOperator) -> InferResult {
64-
let mut union_types = vec![];
66+
let mut unique_union_types = HashSet::new();
67+
6568
for ty in u.get_types() {
66-
let ty = infer_binary_expr_type(db, ty.clone(), right.clone(), op)?;
67-
union_types.push(ty);
69+
let inferred_ty = infer_binary_expr_type(db, ty.clone(), right.clone(), op)?;
70+
flatten_and_insert(inferred_ty, &mut unique_union_types);
6871
}
69-
union_types.dedup();
7072

71-
match union_types.len() {
73+
match unique_union_types.len() {
7274
0 => Some(LuaType::Unknown),
73-
1 => Some(union_types[0].clone()),
74-
_ => Some(LuaType::Union(LuaUnionType::new(union_types).into())),
75+
1 => Some(unique_union_types.into_iter().next().unwrap()),
76+
_ => Some(LuaType::Union(
77+
LuaUnionType::new(unique_union_types.into_iter().collect()).into(),
78+
)),
79+
}
80+
}
81+
82+
fn flatten_and_insert(ty: LuaType, unique_union_types: &mut HashSet<LuaType>) {
83+
let mut stack = vec![ty];
84+
while let Some(current_ty) = stack.pop() {
85+
match current_ty {
86+
LuaType::Union(u) => {
87+
for inner_ty in u.get_types() {
88+
stack.push(inner_ty.clone());
89+
}
90+
}
91+
_ => {
92+
unique_union_types.insert(current_ty);
93+
}
94+
}
7595
}
7696
}
7797

0 commit comments

Comments
 (0)