Skip to content

Commit 259f43d

Browse files
committed
feat: 重构条件类型推断
1 parent 9626d71 commit 259f43d

19 files changed

Lines changed: 1083 additions & 683 deletions
Lines changed: 110 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use hashbrown::HashMap;
22

33
use rowan::{TextRange, TextSize};
4+
use smol_str::SmolStr;
5+
use std::sync::Arc;
46

5-
use crate::{GenericParam, GenericTplId, LuaType};
7+
use crate::{GenericParam, GenericTpl, GenericTplId, LuaType};
68

79
pub trait GenericIndex: std::fmt::Debug {
810
fn add_generic_scope(&mut self, ranges: Vec<TextRange>, is_func: bool) -> GenericScopeId;
@@ -15,10 +17,6 @@ pub trait GenericIndex: std::fmt::Debug {
1517
}
1618
}
1719

18-
fn append_pending_type_param(&mut self, _param: GenericParam) {}
19-
20-
fn clear_pending_type_params(&mut self) {}
21-
2220
fn find_generic(
2321
&self,
2422
position: TextSize,
@@ -29,42 +27,11 @@ pub trait GenericIndex: std::fmt::Debug {
2927
#[derive(Debug, Clone)]
3028
pub struct FileGenericIndex {
3129
scopes: Vec<FileGenericScope>,
32-
pending_type_params: Vec<GenericParam>,
3330
}
3431

3532
impl FileGenericIndex {
3633
pub fn new() -> Self {
37-
Self {
38-
scopes: Vec::new(),
39-
pending_type_params: Vec::new(),
40-
}
41-
}
42-
43-
pub fn add_generic_scope(&mut self, ranges: Vec<TextRange>, is_func: bool) -> GenericScopeId {
44-
let scope_id = GenericScopeId::new(self.scopes.len());
45-
let next_tpl_id = self.next_tpl_id(&ranges, is_func);
46-
self.scopes.push(FileGenericScope::new(ranges, next_tpl_id));
47-
scope_id
48-
}
49-
50-
pub fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) {
51-
if let Some(scope) = self.scopes.get_mut(scope_id.id) {
52-
scope.insert_param(param);
53-
}
54-
}
55-
56-
pub fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec<GenericParam>) {
57-
for param in params {
58-
self.append_generic_param(scope_id, param);
59-
}
60-
}
61-
62-
pub fn append_pending_type_param(&mut self, param: GenericParam) {
63-
self.pending_type_params.push(param);
64-
}
65-
66-
pub fn clear_pending_type_params(&mut self) {
67-
self.pending_type_params.clear();
34+
Self { scopes: Vec::new() }
6835
}
6936

7037
fn next_tpl_id(&self, ranges: &[TextRange], is_func: bool) -> GenericTplId {
@@ -86,10 +53,31 @@ impl FileGenericIndex {
8653
.sum(),
8754
)
8855
}
56+
}
57+
58+
impl GenericIndex for FileGenericIndex {
59+
fn add_generic_scope(&mut self, ranges: Vec<TextRange>, is_func: bool) -> GenericScopeId {
60+
let scope_id = GenericScopeId::new(self.scopes.len());
61+
let next_tpl_id = self.next_tpl_id(&ranges, is_func);
62+
self.scopes.push(FileGenericScope::new(ranges, next_tpl_id));
63+
scope_id
64+
}
65+
66+
fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) {
67+
if let Some(scope) = self.scopes.get_mut(scope_id.id) {
68+
scope.insert_param(param);
69+
}
70+
}
71+
72+
fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec<GenericParam>) {
73+
for param in params {
74+
self.append_generic_param(scope_id, param);
75+
}
76+
}
8977

9078
/// Find generic parameter by position and name.
9179
/// return (GenericTplId, constraint, default)
92-
pub fn find_generic(
80+
fn find_generic(
9381
&self,
9482
position: TextSize,
9583
name: &str,
@@ -108,49 +96,7 @@ impl FileGenericIndex {
10896
}
10997
}
11098

111-
// 搜索前置类型参数, 例如 ---@alias Pick<T, K extends keyof T>
112-
self.pending_type_params
113-
.iter()
114-
.enumerate()
115-
.rev()
116-
.find(|(_, param)| param.name == name)
117-
.map(|(idx, param)| {
118-
(
119-
GenericTplId::Type(idx as u32),
120-
param.type_constraint.clone(),
121-
param.default_type.clone(),
122-
)
123-
})
124-
}
125-
}
126-
127-
impl GenericIndex for FileGenericIndex {
128-
fn add_generic_scope(&mut self, ranges: Vec<TextRange>, is_func: bool) -> GenericScopeId {
129-
FileGenericIndex::add_generic_scope(self, ranges, is_func)
130-
}
131-
132-
fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) {
133-
FileGenericIndex::append_generic_param(self, scope_id, param);
134-
}
135-
136-
fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec<GenericParam>) {
137-
FileGenericIndex::append_generic_params(self, scope_id, params);
138-
}
139-
140-
fn append_pending_type_param(&mut self, param: GenericParam) {
141-
FileGenericIndex::append_pending_type_param(self, param);
142-
}
143-
144-
fn clear_pending_type_params(&mut self) {
145-
FileGenericIndex::clear_pending_type_params(self);
146-
}
147-
148-
fn find_generic(
149-
&self,
150-
position: TextSize,
151-
name: &str,
152-
) -> Option<(GenericTplId, Option<LuaType>, Option<LuaType>)> {
153-
FileGenericIndex::find_generic(self, position, name)
99+
None
154100
}
155101
}
156102

@@ -195,3 +141,86 @@ impl FileGenericScope {
195141
self.ranges.iter().any(|range| range.contains(position))
196142
}
197143
}
144+
145+
#[derive(Debug, Clone, Default)]
146+
pub struct ConditionalInferIndex {
147+
scopes: Vec<ConditionalInferScope>,
148+
next_infer_id: u32,
149+
}
150+
151+
impl ConditionalInferIndex {
152+
pub fn new() -> Self {
153+
Self::default()
154+
}
155+
156+
pub fn enter_scope(&mut self) {
157+
self.scopes.push(ConditionalInferScope::new());
158+
}
159+
160+
pub fn leave_scope(&mut self) -> Option<ConditionalInferScope> {
161+
self.scopes.pop()
162+
}
163+
164+
pub fn set_current_refs_visible(&mut self, visible: bool) {
165+
if let Some(scope) = self.scopes.last_mut() {
166+
scope.refs_visible = visible;
167+
}
168+
}
169+
170+
pub fn declare(&mut self, name: &str) -> Option<Arc<GenericTpl>> {
171+
let scope_idx = self.scopes.len().checked_sub(1)?;
172+
if let Some(tpl) = self.scopes[scope_idx].bindings.get(name) {
173+
return Some(tpl.clone());
174+
}
175+
176+
let tpl_id = GenericTplId::ConditionalInfer(self.next_infer_id);
177+
self.next_infer_id += 1;
178+
let tpl = Arc::new(GenericTpl::new(
179+
tpl_id,
180+
SmolStr::new(name).into(),
181+
None,
182+
None,
183+
));
184+
185+
let scope = &mut self.scopes[scope_idx];
186+
scope.bindings.insert(name.to_string(), tpl.clone());
187+
scope
188+
.params
189+
.push(GenericParam::new(SmolStr::new(name), None, None, None));
190+
Some(tpl)
191+
}
192+
193+
pub fn find_ref(&self, name: &str) -> Option<Arc<GenericTpl>> {
194+
self.scopes
195+
.iter()
196+
.rev()
197+
.filter(|scope| scope.refs_visible)
198+
.find_map(|scope| scope.bindings.get(name).cloned())
199+
}
200+
}
201+
202+
#[derive(Debug, Clone)]
203+
pub struct ConditionalInferScope {
204+
/// 是否允许在当前阶段把普通名字解析为 conditional infer 引用.
205+
/// condition 阶段只声明 `infer P`, true 分支阶段才允许引用 `P`.
206+
refs_visible: bool,
207+
/// 当前 conditional 作用域内的 `infer` 名字到实际模板的绑定.
208+
/// 同名 `infer P` 会复用同一个 `GenericTplId::ConditionalInfer`.
209+
bindings: HashMap<String, Arc<GenericTpl>>,
210+
/// 当前 conditional 声明过的 infer 参数元数据, 保留给 `LuaConditionalType`.
211+
params: Vec<GenericParam>,
212+
}
213+
214+
impl ConditionalInferScope {
215+
fn new() -> Self {
216+
Self {
217+
refs_visible: false,
218+
bindings: HashMap::new(),
219+
params: Vec::new(),
220+
}
221+
}
222+
223+
pub fn into_params(self) -> Vec<GenericParam> {
224+
self.params
225+
}
226+
}

crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ use std::sync::Arc;
33
use emmylua_parser::{
44
LuaAst, LuaAstNode, LuaComment, LuaDocAttributeType, LuaDocBinaryType, LuaDocConditionalType,
55
LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericDeclList,
6-
LuaDocGenericType, LuaDocIndexAccessType, LuaDocInferType, LuaDocMappedType,
7-
LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType,
8-
LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator,
6+
LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, LuaDocMultiLineUnionType,
7+
LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType,
8+
LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator,
99
LuaTypeUnaryOperator, LuaVarExpr, NumberResult,
1010
};
11-
use internment::ArcIntern;
1211
use rowan::TextRange;
1312
use smol_str::SmolStr;
1413

@@ -23,7 +22,10 @@ use crate::{
2322
},
2423
};
2524

26-
use super::{file_generic_index::GenericIndex, preprocess_description};
25+
use super::{
26+
file_generic_index::{ConditionalInferIndex, GenericIndex},
27+
preprocess_description,
28+
};
2729

2830
#[derive(Debug)]
2931
pub struct DocTypeAnalyzeContext<'a> {
@@ -33,6 +35,7 @@ pub struct DocTypeAnalyzeContext<'a> {
3335
pub workspace_id: WorkspaceId,
3436
comment: Option<LuaComment>,
3537
options: DocTypeAnalyzeOptions,
38+
conditional_infer_index: ConditionalInferIndex,
3639
}
3740

3841
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -77,6 +80,7 @@ impl<'a> DocTypeAnalyzeContext<'a> {
7780
workspace_id,
7881
comment: None,
7982
options: DocTypeAnalyzeOptions::default(),
83+
conditional_infer_index: ConditionalInferIndex::new(),
8084
}
8185
}
8286

@@ -203,7 +207,9 @@ pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) ->
203207
}
204208
LuaDocType::Infer(infer_type) => {
205209
if let Some(name) = infer_type.get_generic_decl_name_text() {
206-
return LuaType::ConditionalInfer(ArcIntern::new(SmolStr::new(&name)));
210+
if let Some(tpl) = analyzer.conditional_infer_index.declare(&name) {
211+
return LuaType::TplRef(tpl);
212+
}
207213
}
208214
}
209215
LuaDocType::Mapped(mapped_type) => {
@@ -246,6 +252,10 @@ fn infer_buildin_or_ref_type(
246252
LuaType::Table
247253
}
248254
_ => {
255+
if let Some(tpl) = analyzer.conditional_infer_index.find_ref(name) {
256+
return LuaType::TplRef(tpl);
257+
}
258+
249259
if let Some((tpl_id, constraint, default_type)) =
250260
analyzer.generic_index.find_generic(position, name)
251261
{
@@ -899,26 +909,37 @@ fn infer_conditional_type(
899909
cond_type: &LuaDocConditionalType,
900910
) -> LuaType {
901911
if let Some((condition, when_true, when_false)) = cond_type.get_types() {
902-
// 收集条件中的所有 infer 声明
903-
let infer_params = collect_cond_infer_params(&condition);
904-
if !infer_params.is_empty() {
905-
// 条件表达式中 infer 声明的类型参数只允许在`true`分支中使用
906-
let true_range = when_true.get_range();
907-
let scope_id = analyzer
908-
.generic_index
909-
.add_generic_scope(vec![true_range], false);
910-
analyzer
911-
.generic_index
912-
.append_generic_params(scope_id, infer_params.clone());
913-
}
914-
915-
// 处理条件和分支类型
912+
analyzer.conditional_infer_index.enter_scope();
913+
916914
let condition_type = infer_type(analyzer, condition);
915+
let LuaType::Call(alias_call) = condition_type else {
916+
analyzer.conditional_infer_index.leave_scope();
917+
return LuaType::Unknown;
918+
};
919+
if alias_call.get_call_kind() != LuaAliasCallKind::Extends
920+
|| alias_call.get_operands().len() != 2
921+
{
922+
analyzer.conditional_infer_index.leave_scope();
923+
return LuaType::Unknown;
924+
}
925+
let operands = alias_call.get_operands();
926+
let checked_type = operands[0].clone();
927+
let extends_type = operands[1].clone();
928+
929+
analyzer
930+
.conditional_infer_index
931+
.set_current_refs_visible(true);
917932
let true_type = infer_type(analyzer, when_true);
933+
let infer_params = analyzer
934+
.conditional_infer_index
935+
.leave_scope() // 退出当前作用域
936+
.map(|scope| scope.into_params())
937+
.unwrap_or_default();
918938
let false_type = infer_type(analyzer, when_false);
919939

920940
return LuaConditionalType::new(
921-
condition_type,
941+
checked_type,
942+
extends_type,
922943
true_type,
923944
false_type,
924945
infer_params,
@@ -930,18 +951,6 @@ fn infer_conditional_type(
930951
LuaType::Unknown
931952
}
932953

933-
/// 收集条件类型中的条件表达式中所有 infer 声明
934-
fn collect_cond_infer_params(doc_type: &LuaDocType) -> Vec<GenericParam> {
935-
let mut params = Vec::new();
936-
let doc_infer_types = doc_type.descendants::<LuaDocInferType>();
937-
for infer_type in doc_infer_types {
938-
if let Some(name) = infer_type.get_generic_decl_name_text() {
939-
params.push(GenericParam::new(SmolStr::new(&name), None, None, None));
940-
}
941-
}
942-
params
943-
}
944-
945954
fn infer_mapped_type(
946955
analyzer: &mut DocTypeAnalyzeContext<'_>,
947956
mapped_type: &LuaDocMappedType,

0 commit comments

Comments
 (0)