|
| 1 | +use std::collections::HashSet; |
| 2 | + |
1 | 3 | use emmylua_parser::{BinaryOperator, LuaBinaryExpr}; |
2 | 4 | use smol_str::SmolStr; |
3 | 5 |
|
@@ -61,17 +63,35 @@ fn infer_binary_expr_type( |
61 | 63 | } |
62 | 64 |
|
63 | 65 | fn infer_union(db: &DbIndex, u: &LuaUnionType, right: &LuaType, op: BinaryOperator) -> InferResult { |
64 | | - let mut union_types = vec![]; |
| 66 | + let mut unique_union_types = HashSet::new(); |
| 67 | + |
65 | 68 | for ty in u.get_types() { |
66 | | - let ty = infer_binary_expr_type(db, ty.clone(), right.clone(), op)?; |
67 | | - union_types.push(ty); |
| 69 | + let inferred_ty = infer_binary_expr_type(db, ty.clone(), right.clone(), op)?; |
| 70 | + flatten_and_insert(inferred_ty, &mut unique_union_types); |
68 | 71 | } |
69 | | - union_types.dedup(); |
70 | 72 |
|
71 | | - match union_types.len() { |
| 73 | + match unique_union_types.len() { |
72 | 74 | 0 => Some(LuaType::Unknown), |
73 | | - 1 => Some(union_types[0].clone()), |
74 | | - _ => Some(LuaType::Union(LuaUnionType::new(union_types).into())), |
| 75 | + 1 => Some(unique_union_types.into_iter().next().unwrap()), |
| 76 | + _ => Some(LuaType::Union( |
| 77 | + LuaUnionType::new(unique_union_types.into_iter().collect()).into(), |
| 78 | + )), |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +fn flatten_and_insert(ty: LuaType, unique_union_types: &mut HashSet<LuaType>) { |
| 83 | + let mut stack = vec![ty]; |
| 84 | + while let Some(current_ty) = stack.pop() { |
| 85 | + match current_ty { |
| 86 | + LuaType::Union(u) => { |
| 87 | + for inner_ty in u.get_types() { |
| 88 | + stack.push(inner_ty.clone()); |
| 89 | + } |
| 90 | + } |
| 91 | + _ => { |
| 92 | + unique_union_types.insert(current_ty); |
| 93 | + } |
| 94 | + } |
75 | 95 | } |
76 | 96 | } |
77 | 97 |
|
|
0 commit comments