Skip to content

Commit 902cab3

Browse files
authored
ide: semantic token highlighting for func params & types (#1057)
1 parent 1be2baf commit 902cab3

6 files changed

Lines changed: 278 additions & 29 deletions

File tree

crates/squawk_ide/src/semantic_tokens.rs

Lines changed: 271 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,97 @@
11
use rowan::{NodeOrToken, TextRange};
22
use salsa::Database as Db;
33
use squawk_syntax::{
4-
SyntaxKind,
4+
SyntaxElement, SyntaxKind,
55
ast::{self, AstNode},
66
};
77

88
use crate::db::{File, parse};
99

10+
fn highlight_param_mode(out: &mut SemanticTokenBuilder, mode: ast::ParamMode) {
11+
match mode {
12+
ast::ParamMode::ParamIn(param_in) => {
13+
if let Some(token) = param_in.in_token() {
14+
out.push_keyword(token.into());
15+
}
16+
}
17+
ast::ParamMode::ParamInOut(param_in_out) => {
18+
if let Some(token) = param_in_out.in_token() {
19+
out.push_keyword(token.into());
20+
}
21+
if let Some(token) = param_in_out.inout_token() {
22+
out.push_keyword(token.into());
23+
}
24+
if let Some(token) = param_in_out.out_token() {
25+
out.push_keyword(token.into());
26+
}
27+
}
28+
ast::ParamMode::ParamOut(param_out) => {
29+
if let Some(token) = param_out.out_token() {
30+
out.push_keyword(token.into());
31+
}
32+
}
33+
ast::ParamMode::ParamVariadic(param_variadic) => {
34+
if let Some(token) = param_variadic.variadic_token() {
35+
out.push_keyword(token.into());
36+
}
37+
}
38+
}
39+
}
40+
41+
fn highlight_type(out: &mut SemanticTokenBuilder, ty: ast::Type) {
42+
match ty {
43+
ast::Type::ArrayType(array_type) => {
44+
if let Some(ty) = array_type.ty() {
45+
highlight_type(out, ty);
46+
}
47+
}
48+
ast::Type::BitType(bit_type) => {
49+
if let Some(token) = bit_type.bit_token() {
50+
out.push_type(token.into());
51+
}
52+
}
53+
ast::Type::CharType(char_type) => {
54+
if let Some(token) = char_type
55+
.varchar_token()
56+
.or_else(|| char_type.nchar_token())
57+
.or_else(|| char_type.character_token())
58+
.or_else(|| char_type.char_token())
59+
{
60+
out.push_type(token.into());
61+
};
62+
}
63+
ast::Type::DoubleType(double_type) => {
64+
if let Some(token) = double_type.double_token() {
65+
out.push_type(token.into());
66+
}
67+
}
68+
ast::Type::ExprType(_) => (),
69+
ast::Type::IntervalType(interval_type) => {
70+
if let Some(token) = interval_type.interval_token() {
71+
out.push_type(token.into());
72+
}
73+
}
74+
ast::Type::PathType(path_type) => {
75+
if let Some(name_ref) = path_type
76+
.path()
77+
.and_then(|path| path.segment())
78+
.and_then(|ps| ps.name_ref())
79+
{
80+
out.push_type(name_ref.syntax().clone().into());
81+
}
82+
}
83+
ast::Type::PercentType(_) => (),
84+
ast::Type::TimeType(time_type) => {
85+
if let Some(token) = time_type
86+
.timestamp_token()
87+
.or_else(|| time_type.time_token())
88+
{
89+
out.push_type(token.into());
90+
}
91+
}
92+
}
93+
}
94+
1095
/// A semantic token with its position and classification.
1196
#[derive(Debug, Clone, PartialEq, Eq)]
1297
pub struct SemanticToken {
@@ -41,6 +126,35 @@ pub enum SemanticTokenType {
41126
PositionalParam,
42127
}
43128

129+
#[derive(Default)]
130+
struct SemanticTokenBuilder {
131+
tokens: Vec<SemanticToken>,
132+
}
133+
134+
impl SemanticTokenBuilder {
135+
fn build(mut self) -> Vec<SemanticToken> {
136+
self.tokens
137+
.sort_by_key(|token| (token.range.start(), token.range.end()));
138+
self.tokens
139+
}
140+
141+
fn push_keyword(&mut self, syntax_element: SyntaxElement) {
142+
self.push_token(syntax_element, SemanticTokenType::Keyword);
143+
}
144+
145+
fn push_type(&mut self, syntax_element: SyntaxElement) {
146+
self.push_token(syntax_element, SemanticTokenType::Type);
147+
}
148+
149+
fn push_token(&mut self, syntax_element: SyntaxElement, token_type: SemanticTokenType) {
150+
self.tokens.push(SemanticToken {
151+
range: syntax_element.text_range(),
152+
token_type,
153+
modifiers: None,
154+
});
155+
}
156+
}
157+
44158
#[salsa::tracked]
45159
pub fn semantic_tokens(
46160
db: &dyn Db,
@@ -66,7 +180,7 @@ pub fn semantic_tokens(
66180
}
67181
};
68182

69-
let mut out = vec![];
183+
let mut out = SemanticTokenBuilder::default();
70184

71185
// Taken from: https://github.com/rust-lang/rust-analyzer/blob/2efc80078029894eec0699f62ec8d5c1a56af763/crates/ide/src/syntax_highlighting.rs#L267C21-L267C21
72186
let preorder = root.preorder_with_tokens();
@@ -84,34 +198,76 @@ pub fn semantic_tokens(
84198

85199
match event {
86200
Enter(NodeOrToken::Node(node)) => {
87-
if let Some(target) = ast::Target::cast(node)
201+
if let Some(target) = ast::Target::cast(node.clone())
88202
&& let Some(as_name) = target.as_name()
89203
&& let Some(name) = as_name.name()
90204
{
91-
let range = name.syntax().text_range();
92-
out.push(SemanticToken {
93-
range,
94-
token_type: SemanticTokenType::Name,
95-
modifiers: None,
96-
});
205+
out.push_token(name.syntax().clone().into(), SemanticTokenType::Name);
97206
};
207+
208+
if let Some(alias) = ast::Alias::cast(node.clone())
209+
&& let Some(column_list) = alias.column_list()
210+
{
211+
for column in column_list.columns() {
212+
if let Some(ty) = column.ty() {
213+
highlight_type(&mut out, ty);
214+
}
215+
}
216+
}
217+
218+
if let Some(cast_expr) = ast::CastExpr::cast(node.clone())
219+
&& let Some(ty) = cast_expr.ty()
220+
{
221+
highlight_type(&mut out, ty);
222+
}
223+
224+
if let Some(create_function) = ast::CreateFunction::cast(node) {
225+
if let Some(param_list) = create_function.param_list() {
226+
for param in param_list.params() {
227+
if let Some(mode) = param.mode() {
228+
highlight_param_mode(&mut out, mode);
229+
}
230+
if let Some(name) = param.name() {
231+
out.push_token(
232+
name.syntax().clone().into(),
233+
SemanticTokenType::Parameter,
234+
);
235+
}
236+
if let Some(ty) = param.ty() {
237+
highlight_type(&mut out, ty);
238+
}
239+
}
240+
}
241+
242+
if let Some(ret_type) = create_function.ret_type() {
243+
if let Some(ty) = ret_type.ty() {
244+
highlight_type(&mut out, ty);
245+
}
246+
if let Some(table_arg_list) = ret_type.table_arg_list() {
247+
for arg in table_arg_list.args() {
248+
if let ast::TableArg::Column(column) = arg
249+
&& let Some(ty) = column.ty()
250+
{
251+
highlight_type(&mut out, ty);
252+
}
253+
}
254+
}
255+
}
256+
}
98257
}
99258
Enter(NodeOrToken::Token(token)) => {
100259
if token.kind() == SyntaxKind::WHITESPACE {
101260
continue;
102261
}
103262
if token.kind() == SyntaxKind::POSITIONAL_PARAM {
104-
out.push(SemanticToken {
105-
range: token.text_range(),
106-
token_type: SemanticTokenType::PositionalParam,
107-
modifiers: None,
108-
})
263+
out.push_token(token.into(), SemanticTokenType::PositionalParam);
109264
}
110265
}
111266
Leave(_) => {}
112267
}
113268
}
114-
out
269+
270+
out.build()
115271
}
116272

117273
#[cfg(test)]
@@ -143,12 +299,61 @@ mod test {
143299
}
144300

145301
#[test]
146-
fn create_function() {
147-
assert_snapshot!(semantic_tokens("
148-
create function add(a int, b int) returns int
149-
as 'select $1 + $2'
150-
language sql;
151-
"), @"");
302+
fn create_function_misc_params() {
303+
assert_snapshot!(semantic_tokens(
304+
"
305+
create function add(
306+
in a int = 1,
307+
inout b text default 'x',
308+
in out c varchar(10)[],
309+
variadic d int[]
310+
) returns int
311+
as 'select $1 + $2'
312+
language sql;
313+
",
314+
), @r#"
315+
"in" @ 24..26: Keyword
316+
"a" @ 27..28: Parameter
317+
"int" @ 29..32: Type
318+
"inout" @ 40..45: Keyword
319+
"b" @ 46..47: Parameter
320+
"text" @ 48..52: Type
321+
"in" @ 68..70: Keyword
322+
"out" @ 71..74: Keyword
323+
"c" @ 75..76: Parameter
324+
"varchar" @ 77..84: Type
325+
"variadic" @ 94..102: Keyword
326+
"d" @ 103..104: Parameter
327+
"int" @ 105..108: Type
328+
"int" @ 121..124: Type
329+
"#);
330+
}
331+
332+
#[test]
333+
fn create_function_param_mode_type() {
334+
assert_snapshot!(semantic_tokens(
335+
"
336+
create function f(int8 in int8)
337+
returns void
338+
as '' language sql;
339+
",
340+
), @r#"
341+
"int8" @ 19..23: Parameter
342+
"in" @ 24..26: Keyword
343+
"int8" @ 27..31: Type
344+
"void" @ 41..45: Type
345+
"#);
346+
}
347+
348+
#[test]
349+
fn create_function_percent_type() {
350+
assert_snapshot!(semantic_tokens(
351+
"
352+
create function f(a t.c%type)
353+
returns t.b%type
354+
as '' language plpgsql;
355+
",
356+
), @r#""a" @ 19..20: Parameter"#);
152357
}
153358

154359
#[test]
@@ -170,4 +375,49 @@ select $1, $2;
170375
"$2" @ 12..14: PositionalParam
171376
"#)
172377
}
378+
379+
#[test]
380+
fn from_alias_column_types() {
381+
assert_snapshot!(semantic_tokens(
382+
"
383+
select *
384+
from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop);
385+
",
386+
), @r#"
387+
"int" @ 24..27: Type
388+
"jsonb" @ 31..36: Type
389+
"text" @ 40..44: Type
390+
"int" @ 48..51: Type
391+
"char" @ 56..60: Type
392+
"int" @ 70..73: Type
393+
"jbpop" @ 81..86: Type
394+
"#);
395+
}
396+
397+
#[test]
398+
fn cast_types() {
399+
assert_snapshot!(semantic_tokens(
400+
"
401+
select '1'::jsonb, '2'::json, cast(1 as integer), cast(1 as int4[][]), cast(1 as varchar(10));
402+
",
403+
), @r#"
404+
"jsonb" @ 13..18: Type
405+
"json" @ 25..29: Type
406+
"integer" @ 41..48: Type
407+
"int4" @ 61..65: Type
408+
"varchar" @ 82..89: Type
409+
"#);
410+
}
411+
412+
#[test]
413+
fn positional_param_and_cast_type() {
414+
assert_snapshot!(semantic_tokens(
415+
"
416+
select $2::jsonb;
417+
",
418+
), @r#"
419+
"$2" @ 8..10: PositionalParam
420+
"jsonb" @ 12..17: Type
421+
"#);
422+
}
173423
}

crates/squawk_parser/src/grammar.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3664,12 +3664,12 @@ fn opt_column_list_with(p: &mut Parser<'_>, kind: ColumnDefKind) -> bool {
36643664
fn column(p: &mut Parser<'_>, kind: &ColumnDefKind) -> CompletedMarker {
36653665
assert!(p.at_ts(COLUMN_FIRST));
36663666
let m = p.start();
3667-
// https://www.depesz.com/2024/10/03/waiting-for-postgresql-18-add-temporal-foreign-key-contraints/
3668-
// TODO: add validation to ensure this is in the right position
3669-
p.eat(PERIOD_KW);
36703667
match kind {
36713668
ColumnDefKind::Name => name(p),
36723669
ColumnDefKind::NameRef => {
3670+
// https://www.depesz.com/2024/10/03/waiting-for-postgresql-18-add-temporal-foreign-key-contraints/
3671+
// TODO: add validation to ensure this is in the right position
3672+
p.eat(PERIOD_KW);
36733673
// supports parsing things like:
36743674
// INSERT INTO tictactoe (game, board[1:3][1:3])
36753675
name_ref(p).map(|lhs| postfix_expr(p, lhs, true));

crates/squawk_syntax/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub use squawk_parser::SyntaxKind;
4242
use ast::AstNode;
4343
use rowan::GreenNode;
4444
use syntax_error::SyntaxError;
45-
pub use syntax_node::{SyntaxNode, SyntaxNodePtr, SyntaxToken};
45+
pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxNodePtr, SyntaxToken};
4646
pub use token_text::TokenText;
4747

4848
/// `Parse` is the result of the parsing: a syntax tree and a collection of

crates/squawk_syntax/src/postgresql.ungram

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,6 @@ CompressionMethod =
664664
'compression' ('#ident' | 'default')
665665

666666
Column =
667-
'period'?
668667
(
669668
Name WithOptions? constraints:ColumnConstraint*
670669
DeferrableConstraintOption? NotDeferrableConstraintOption?
@@ -674,7 +673,7 @@ Column =
674673
DeferrableConstraintOption? NotDeferrableConstraintOption?
675674
InitiallyDeferredConstraintOption? InitiallyImmediateConstraintOption?
676675
NotEnforced? Enforced?
677-
| NameRef
676+
| 'period'? NameRef
678677
| IndexExpr
679678
)
680679
ColumnConstraint =

0 commit comments

Comments
 (0)