Skip to content

Commit 294f3c0

Browse files
committed
feat(ls): improves completion in with and for statements
Enhances the language server's code completion to correctly infer the types of variables declared within `with` and `for` expressions. This allows for accurate field and method suggestions when accessing properties of these local variables.
1 parent bd38dc6 commit 294f3c0

8 files changed

Lines changed: 313 additions & 96 deletions

File tree

ls/src/features/completion.rs

Lines changed: 222 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ use yara_x_parser::cst::{Immutable, Node, SyntaxKind, Token, CST};
1515

1616
use crate::documents::storage::DocumentStorage;
1717
use crate::utils::cst_traversal::{
18-
idents_declared_by_expr, non_error_parent, prev_non_trivia_token,
19-
rule_containing_token, token_at_position,
18+
find_declaration, idents_declared_by_expr, non_error_parent,
19+
prev_non_trivia_token, rule_containing_token, token_at_position,
2020
};
2121

2222
const PATTERN_MODS: &[(SyntaxKind, &[&str])] = &[
@@ -98,7 +98,7 @@ pub fn completion(
9898
if !is_trigger_character
9999
&& non_error_parent(&token)?.kind() == SyntaxKind::SOURCE_FILE
100100
{
101-
return Some(source_file_suggestions());
101+
return Some(top_level_suggestions());
102102
}
103103

104104
let prev_token = prev_non_trivia_token(&token)?;
@@ -142,7 +142,7 @@ fn condition_suggestions(
142142
let mut result = Vec::new();
143143

144144
#[cfg(feature = "full-compiler")]
145-
if let Some(suggestions) = module_suggestions(&token) {
145+
if let Some(suggestions) = field_suggestions(&token) {
146146
return Some(suggestions);
147147
}
148148

@@ -271,7 +271,7 @@ fn import_suggestions() -> Vec<CompletionItem> {
271271
}
272272

273273
/// Collects completion suggestions outside any block.
274-
fn source_file_suggestions() -> Vec<CompletionItem> {
274+
fn top_level_suggestions() -> Vec<CompletionItem> {
275275
// Propose import or rule definition with snippet
276276
SRC_SUGGESTIONS
277277
.map(|(label, insert_text)| CompletionItem {
@@ -292,6 +292,7 @@ fn source_file_suggestions() -> Vec<CompletionItem> {
292292
.collect()
293293
}
294294

295+
/// Collects completion suggestions for pattern modifiers.
295296
fn pattern_modifier_suggestions(node: Node<Immutable>) -> Vec<CompletionItem> {
296297
for (kind, valid_modifiers) in PATTERN_MODS {
297298
if node.children_with_tokens().any(|child| child.kind() == *kind) {
@@ -326,43 +327,137 @@ fn rule_suggestions() -> Vec<CompletionItem> {
326327
.collect()
327328
}
328329

329-
#[cfg(feature = "full-compiler")]
330-
fn module_suggestions(
331-
token: &Token<Immutable>,
332-
) -> Option<Vec<CompletionItem>> {
333-
let mut curr;
330+
#[derive(Debug)]
331+
enum Segment {
332+
Field(String),
333+
Index,
334+
}
334335

336+
/// Collects completion suggestions for structure fields.
337+
#[cfg(feature = "full-compiler")]
338+
fn field_suggestions(token: &Token<Immutable>) -> Option<Vec<CompletionItem>> {
335339
// Check if we are at a position that triggers completion.
336-
match token.kind() {
340+
let token = match token.kind() {
337341
SyntaxKind::DOT => {
338342
// structure. <cursor>
339-
curr = prev_non_trivia_token(token);
343+
prev_non_trivia_token(token)
340344
}
341345
SyntaxKind::IDENT => {
342346
// structure.field <cursor>
343347
// We need to check if previous is DOT
344-
let prev = prev_non_trivia_token(token)?;
345-
if prev.kind() == SyntaxKind::DOT {
346-
// It is a field
347-
curr = prev_non_trivia_token(&prev);
348-
} else {
349-
return None;
350-
}
348+
prev_non_trivia_token(token)
349+
.filter(|t| t.kind() == SyntaxKind::DOT)
350+
.and_then(|t| prev_non_trivia_token(&t))
351351
}
352+
_ => None,
353+
}?;
354+
355+
let current_struct = match get_struct(&token)? {
356+
Type::Struct(s) => s,
352357
_ => return None,
353-
}
358+
};
354359

355-
#[derive(Debug)]
356-
enum Segment {
357-
Field(String),
358-
Index,
359-
}
360+
// Now `current_struct` is the structure before the cursor.
361+
// We want to suggest fields for this structure.
362+
let suggestions = current_struct
363+
.fields()
364+
.flat_map(|f| {
365+
let name = f.name();
366+
let ty = f.ty();
367+
368+
if let Type::Func(ref func_def) = ty {
369+
func_def
370+
.signatures
371+
.iter()
372+
.map(|sig| {
373+
let arg_types = sig
374+
.args
375+
.iter()
376+
.map(ty_to_string)
377+
.collect::<Vec<_>>();
378+
379+
let args_template = arg_types
380+
.iter()
381+
.enumerate()
382+
.map(|(n, arg_type)| {
383+
format!("${{{}:{arg_type}}}", n + 1)
384+
})
385+
.join(",");
360386

387+
CompletionItem {
388+
label: format!(
389+
"{}({})",
390+
name,
391+
arg_types.join(", ")
392+
),
393+
kind: Some(CompletionItemKind::METHOD),
394+
insert_text: Some(format!(
395+
"{name}({args_template})",
396+
)),
397+
insert_text_format: Some(
398+
InsertTextFormat::SNIPPET,
399+
),
400+
label_details: Some(CompletionItemLabelDetails {
401+
description: Some(ty_to_string(&ty)),
402+
..Default::default()
403+
}),
404+
..Default::default()
405+
}
406+
})
407+
.collect()
408+
} else {
409+
let insert_text = match &ty {
410+
Type::Array(_) => format!("{name}[${{1}}]${{2}}"),
411+
_ => name.to_string(),
412+
};
413+
414+
vec![CompletionItem {
415+
label: name.to_string(),
416+
kind: Some(CompletionItemKind::FIELD),
417+
insert_text: Some(insert_text),
418+
insert_text_format: Some(InsertTextFormat::SNIPPET),
419+
label_details: Some(CompletionItemLabelDetails {
420+
description: Some(ty_to_string(&ty)),
421+
..Default::default()
422+
}),
423+
..Default::default()
424+
}]
425+
}
426+
})
427+
.collect();
428+
429+
Some(suggestions)
430+
}
431+
432+
#[cfg(feature = "full-compiler")]
433+
/// Given a token, returns the type of the structure that the token is part of.
434+
///
435+
/// This function traverses the CST backwards from the given token to determine
436+
/// the full path to a field within a structure (e.g., `module.field.subfield`).
437+
/// It then uses this path to look up the corresponding `Type` definition.
438+
///
439+
/// If the token is part of a `for` or `with` statement, it will try to resolve
440+
/// the type from the declared variables in those statements.
441+
///
442+
/// Returns an `Option<Type>` representing the type of the structure or field
443+
/// identified by the token. Returns `None` if the type cannot be determined.
444+
fn get_struct(token: &Token<Immutable>) -> Option<Type> {
361445
let mut path = Vec::new();
362446

447+
let mut curr = Some(token.clone());
363448
while let Some(token) = curr {
364449
match token.kind() {
365450
SyntaxKind::IDENT => {
451+
// If the identifier is a variable declared in a `for` or `with`
452+
// statement, we need to find the type of that variable.
453+
if let Some((_, declaration)) = find_declaration(&token) {
454+
return get_type_from_declaration(
455+
&declaration,
456+
&token,
457+
path.into_iter().rev(),
458+
);
459+
}
460+
366461
path.push(Segment::Field(token.text().to_string()));
367462
// Look for previous DOT
368463
if let Some(prev) = prev_non_trivia_token(&token) {
@@ -428,82 +523,113 @@ fn module_suggestions(
428523
}
429524
}
430525
}
526+
Some(current_kind)
527+
}
431528

432-
let current_struct = match current_kind {
433-
Type::Struct(s) => s,
434-
_ => return None,
435-
};
436-
437-
// Now `current_struct` is the structure before the cursor.
438-
// We want to suggest fields for this structure.
439-
let suggestions = current_struct
440-
.fields()
441-
.flat_map(|f| {
442-
let name = f.name();
443-
let ty = f.ty();
444-
445-
if let Type::Func(ref func_def) = ty {
446-
func_def
447-
.signatures
448-
.iter()
449-
.map(|sig| {
450-
let arg_types = sig
451-
.args
452-
.iter()
453-
.map(ty_to_string)
454-
.collect::<Vec<_>>();
529+
#[cfg(feature = "full-compiler")]
530+
/// Resolves the `Type` of an identifier declared within `for` or `with` statements.
531+
///
532+
/// This function is called when `get_struct` identifies an identifier that is
533+
/// not a module name but rather a variable declared in a `for` or `with` expression.
534+
/// It then attempts to deduce the type of this variable based on its declaration.
535+
///
536+
/// # Arguments
537+
///
538+
/// * `declaration` - The `Node` representing the `for` or `with` declaration.
539+
/// * `ident` - The `Token` of the identifier whose type needs to be resolved.
540+
/// * `path` - An iterator over `Segment`s representing the access path (fields,
541+
/// array indices) applied to the declared variable.
542+
///
543+
/// # Returns
544+
///
545+
/// An `Option<Type>` representing the resolved type of the identifier. Returns `None`
546+
/// if the type cannot be determined or if the access path is invalid for the type.
547+
fn get_type_from_declaration(
548+
declaration: &Node<Immutable>,
549+
ident: &Token<Immutable>,
550+
path: impl Iterator<Item = Segment>,
551+
) -> Option<Type> {
552+
match declaration.kind() {
553+
SyntaxKind::WITH_EXPR => {
554+
let with_decls = declaration
555+
.children()
556+
.find(|n| n.kind() == SyntaxKind::WITH_DECLS)?;
455557

456-
let args_template = arg_types
457-
.iter()
458-
.enumerate()
459-
.map(|(n, arg_type)| {
460-
format!("${{{}:{arg_type}}}", n + 1)
461-
})
462-
.join(",");
558+
for with_decl in with_decls.children() {
559+
let declared_ident = with_decl.first_token()?;
560+
if declared_ident.text() != ident.text() {
561+
continue;
562+
}
463563

464-
CompletionItem {
465-
label: format!(
466-
"{}({})",
467-
name,
468-
arg_types.join(", ")
469-
),
470-
kind: Some(CompletionItemKind::METHOD),
471-
insert_text: Some(format!(
472-
"{name}({args_template})",
473-
)),
474-
insert_text_format: Some(
475-
InsertTextFormat::SNIPPET,
476-
),
477-
label_details: Some(CompletionItemLabelDetails {
478-
description: Some(ty_to_string(&ty)),
479-
..Default::default()
480-
}),
481-
..Default::default()
564+
let mut current_type = get_struct(&with_decl.last_token()?)?;
565+
566+
for segment in path {
567+
match segment {
568+
Segment::Field(name) => {
569+
if let Type::Struct(struct_def) = current_type {
570+
current_type = struct_def
571+
.fields()
572+
.find(|field| field.name() == name)?
573+
.ty()
574+
} else {
575+
return None;
576+
}
482577
}
483-
})
484-
.collect()
485-
} else {
486-
let insert_text = match &ty {
487-
Type::Array(_) => format!("{name}[${{1}}]${{2}}"),
488-
_ => name.to_string(),
489-
};
490-
491-
vec![CompletionItem {
492-
label: name.to_string(),
493-
kind: Some(CompletionItemKind::FIELD),
494-
insert_text: Some(insert_text),
495-
insert_text_format: Some(InsertTextFormat::SNIPPET),
496-
label_details: Some(CompletionItemLabelDetails {
497-
description: Some(ty_to_string(&ty)),
498-
..Default::default()
499-
}),
500-
..Default::default()
501-
}]
578+
Segment::Index => {
579+
if let Type::Array(inner) = current_type {
580+
current_type = *inner
581+
} else {
582+
return None;
583+
}
584+
}
585+
}
586+
}
587+
return Some(current_type);
502588
}
503-
})
504-
.collect();
505-
506-
Some(suggestions)
589+
return None;
590+
}
591+
SyntaxKind::FOR_EXPR => {
592+
let colon = declaration
593+
.children_with_tokens()
594+
.find(|child| child.kind() == SyntaxKind::COLON)?
595+
.into_token()?;
596+
597+
let iterable_last_token = prev_non_trivia_token(&colon)?;
598+
599+
let iterable_type = get_struct(&iterable_last_token)?;
600+
601+
let mut current_type = match iterable_type {
602+
Type::Array(inner) => *inner,
603+
Type::Map(_, value) => *value,
604+
_ => return None,
605+
};
606+
607+
for segment in path {
608+
match segment {
609+
Segment::Field(name) => {
610+
if let Type::Struct(struct_def) = current_type {
611+
current_type = struct_def
612+
.fields()
613+
.find(|field| field.name() == name)?
614+
.ty()
615+
} else {
616+
return None;
617+
}
618+
}
619+
Segment::Index => {
620+
if let Type::Array(inner) = current_type {
621+
current_type = *inner
622+
} else {
623+
return None;
624+
}
625+
}
626+
}
627+
}
628+
return Some(current_type);
629+
}
630+
_ => {}
631+
}
632+
None
507633
}
508634

509635
/// Given a token that must be a closing (right) bracket, find the

0 commit comments

Comments
 (0)