11use rowan:: { NodeOrToken , TextRange } ;
22use salsa:: Database as Db ;
33use squawk_syntax:: {
4- SyntaxKind ,
4+ SyntaxElement , SyntaxKind ,
55 ast:: { self , AstNode } ,
66} ;
77
88use crate :: db:: { File , parse} ;
99
10+ fn highlight_param_mode ( out : & mut SemanticTokenBuilder , mode : ast:: ParamMode ) {
11+ match mode {
12+ ast:: ParamMode :: ParamIn ( param_in) => {
13+ if let Some ( token) = param_in. in_token ( ) {
14+ out. push_keyword ( token. into ( ) ) ;
15+ }
16+ }
17+ ast:: ParamMode :: ParamInOut ( param_in_out) => {
18+ if let Some ( token) = param_in_out. in_token ( ) {
19+ out. push_keyword ( token. into ( ) ) ;
20+ }
21+ if let Some ( token) = param_in_out. inout_token ( ) {
22+ out. push_keyword ( token. into ( ) ) ;
23+ }
24+ if let Some ( token) = param_in_out. out_token ( ) {
25+ out. push_keyword ( token. into ( ) ) ;
26+ }
27+ }
28+ ast:: ParamMode :: ParamOut ( param_out) => {
29+ if let Some ( token) = param_out. out_token ( ) {
30+ out. push_keyword ( token. into ( ) ) ;
31+ }
32+ }
33+ ast:: ParamMode :: ParamVariadic ( param_variadic) => {
34+ if let Some ( token) = param_variadic. variadic_token ( ) {
35+ out. push_keyword ( token. into ( ) ) ;
36+ }
37+ }
38+ }
39+ }
40+
41+ fn highlight_type ( out : & mut SemanticTokenBuilder , ty : ast:: Type ) {
42+ match ty {
43+ ast:: Type :: ArrayType ( array_type) => {
44+ if let Some ( ty) = array_type. ty ( ) {
45+ highlight_type ( out, ty) ;
46+ }
47+ }
48+ ast:: Type :: BitType ( bit_type) => {
49+ if let Some ( token) = bit_type. bit_token ( ) {
50+ out. push_type ( token. into ( ) ) ;
51+ }
52+ }
53+ ast:: Type :: CharType ( char_type) => {
54+ if let Some ( token) = char_type
55+ . varchar_token ( )
56+ . or_else ( || char_type. nchar_token ( ) )
57+ . or_else ( || char_type. character_token ( ) )
58+ . or_else ( || char_type. char_token ( ) )
59+ {
60+ out. push_type ( token. into ( ) ) ;
61+ } ;
62+ }
63+ ast:: Type :: DoubleType ( double_type) => {
64+ if let Some ( token) = double_type. double_token ( ) {
65+ out. push_type ( token. into ( ) ) ;
66+ }
67+ }
68+ ast:: Type :: ExprType ( _) => ( ) ,
69+ ast:: Type :: IntervalType ( interval_type) => {
70+ if let Some ( token) = interval_type. interval_token ( ) {
71+ out. push_type ( token. into ( ) ) ;
72+ }
73+ }
74+ ast:: Type :: PathType ( path_type) => {
75+ if let Some ( name_ref) = path_type
76+ . path ( )
77+ . and_then ( |path| path. segment ( ) )
78+ . and_then ( |ps| ps. name_ref ( ) )
79+ {
80+ out. push_type ( name_ref. syntax ( ) . clone ( ) . into ( ) ) ;
81+ }
82+ }
83+ ast:: Type :: PercentType ( _) => ( ) ,
84+ ast:: Type :: TimeType ( time_type) => {
85+ if let Some ( token) = time_type
86+ . timestamp_token ( )
87+ . or_else ( || time_type. time_token ( ) )
88+ {
89+ out. push_type ( token. into ( ) ) ;
90+ }
91+ }
92+ }
93+ }
94+
1095/// A semantic token with its position and classification.
1196#[ derive( Debug , Clone , PartialEq , Eq ) ]
1297pub struct SemanticToken {
@@ -41,6 +126,35 @@ pub enum SemanticTokenType {
41126 PositionalParam ,
42127}
43128
129+ #[ derive( Default ) ]
130+ struct SemanticTokenBuilder {
131+ tokens : Vec < SemanticToken > ,
132+ }
133+
134+ impl SemanticTokenBuilder {
135+ fn build ( mut self ) -> Vec < SemanticToken > {
136+ self . tokens
137+ . sort_by_key ( |token| ( token. range . start ( ) , token. range . end ( ) ) ) ;
138+ self . tokens
139+ }
140+
141+ fn push_keyword ( & mut self , syntax_element : SyntaxElement ) {
142+ self . push_token ( syntax_element, SemanticTokenType :: Keyword ) ;
143+ }
144+
145+ fn push_type ( & mut self , syntax_element : SyntaxElement ) {
146+ self . push_token ( syntax_element, SemanticTokenType :: Type ) ;
147+ }
148+
149+ fn push_token ( & mut self , syntax_element : SyntaxElement , token_type : SemanticTokenType ) {
150+ self . tokens . push ( SemanticToken {
151+ range : syntax_element. text_range ( ) ,
152+ token_type,
153+ modifiers : None ,
154+ } ) ;
155+ }
156+ }
157+
44158#[ salsa:: tracked]
45159pub fn semantic_tokens (
46160 db : & dyn Db ,
@@ -66,7 +180,7 @@ pub fn semantic_tokens(
66180 }
67181 } ;
68182
69- let mut out = vec ! [ ] ;
183+ let mut out = SemanticTokenBuilder :: default ( ) ;
70184
71185 // Taken from: https://github.com/rust-lang/rust-analyzer/blob/2efc80078029894eec0699f62ec8d5c1a56af763/crates/ide/src/syntax_highlighting.rs#L267C21-L267C21
72186 let preorder = root. preorder_with_tokens ( ) ;
@@ -84,34 +198,76 @@ pub fn semantic_tokens(
84198
85199 match event {
86200 Enter ( NodeOrToken :: Node ( node) ) => {
87- if let Some ( target) = ast:: Target :: cast ( node)
201+ if let Some ( target) = ast:: Target :: cast ( node. clone ( ) )
88202 && let Some ( as_name) = target. as_name ( )
89203 && let Some ( name) = as_name. name ( )
90204 {
91- let range = name. syntax ( ) . text_range ( ) ;
92- out. push ( SemanticToken {
93- range,
94- token_type : SemanticTokenType :: Name ,
95- modifiers : None ,
96- } ) ;
205+ out. push_token ( name. syntax ( ) . clone ( ) . into ( ) , SemanticTokenType :: Name ) ;
97206 } ;
207+
208+ if let Some ( alias) = ast:: Alias :: cast ( node. clone ( ) )
209+ && let Some ( column_list) = alias. column_list ( )
210+ {
211+ for column in column_list. columns ( ) {
212+ if let Some ( ty) = column. ty ( ) {
213+ highlight_type ( & mut out, ty) ;
214+ }
215+ }
216+ }
217+
218+ if let Some ( cast_expr) = ast:: CastExpr :: cast ( node. clone ( ) )
219+ && let Some ( ty) = cast_expr. ty ( )
220+ {
221+ highlight_type ( & mut out, ty) ;
222+ }
223+
224+ if let Some ( create_function) = ast:: CreateFunction :: cast ( node) {
225+ if let Some ( param_list) = create_function. param_list ( ) {
226+ for param in param_list. params ( ) {
227+ if let Some ( mode) = param. mode ( ) {
228+ highlight_param_mode ( & mut out, mode) ;
229+ }
230+ if let Some ( name) = param. name ( ) {
231+ out. push_token (
232+ name. syntax ( ) . clone ( ) . into ( ) ,
233+ SemanticTokenType :: Parameter ,
234+ ) ;
235+ }
236+ if let Some ( ty) = param. ty ( ) {
237+ highlight_type ( & mut out, ty) ;
238+ }
239+ }
240+ }
241+
242+ if let Some ( ret_type) = create_function. ret_type ( ) {
243+ if let Some ( ty) = ret_type. ty ( ) {
244+ highlight_type ( & mut out, ty) ;
245+ }
246+ if let Some ( table_arg_list) = ret_type. table_arg_list ( ) {
247+ for arg in table_arg_list. args ( ) {
248+ if let ast:: TableArg :: Column ( column) = arg
249+ && let Some ( ty) = column. ty ( )
250+ {
251+ highlight_type ( & mut out, ty) ;
252+ }
253+ }
254+ }
255+ }
256+ }
98257 }
99258 Enter ( NodeOrToken :: Token ( token) ) => {
100259 if token. kind ( ) == SyntaxKind :: WHITESPACE {
101260 continue ;
102261 }
103262 if token. kind ( ) == SyntaxKind :: POSITIONAL_PARAM {
104- out. push ( SemanticToken {
105- range : token. text_range ( ) ,
106- token_type : SemanticTokenType :: PositionalParam ,
107- modifiers : None ,
108- } )
263+ out. push_token ( token. into ( ) , SemanticTokenType :: PositionalParam ) ;
109264 }
110265 }
111266 Leave ( _) => { }
112267 }
113268 }
114- out
269+
270+ out. build ( )
115271}
116272
117273#[ cfg( test) ]
@@ -143,12 +299,61 @@ mod test {
143299 }
144300
145301 #[ test]
146- fn create_function ( ) {
147- assert_snapshot ! ( semantic_tokens( "
148- create function add(a int, b int) returns int
149- as 'select $1 + $2'
150- language sql;
151- " ) , @"" ) ;
302+ fn create_function_misc_params ( ) {
303+ assert_snapshot ! ( semantic_tokens(
304+ "
305+ create function add(
306+ in a int = 1,
307+ inout b text default 'x',
308+ in out c varchar(10)[],
309+ variadic d int[]
310+ ) returns int
311+ as 'select $1 + $2'
312+ language sql;
313+ " ,
314+ ) , @r#"
315+ "in" @ 24..26: Keyword
316+ "a" @ 27..28: Parameter
317+ "int" @ 29..32: Type
318+ "inout" @ 40..45: Keyword
319+ "b" @ 46..47: Parameter
320+ "text" @ 48..52: Type
321+ "in" @ 68..70: Keyword
322+ "out" @ 71..74: Keyword
323+ "c" @ 75..76: Parameter
324+ "varchar" @ 77..84: Type
325+ "variadic" @ 94..102: Keyword
326+ "d" @ 103..104: Parameter
327+ "int" @ 105..108: Type
328+ "int" @ 121..124: Type
329+ "# ) ;
330+ }
331+
332+ #[ test]
333+ fn create_function_param_mode_type ( ) {
334+ assert_snapshot ! ( semantic_tokens(
335+ "
336+ create function f(int8 in int8)
337+ returns void
338+ as '' language sql;
339+ " ,
340+ ) , @r#"
341+ "int8" @ 19..23: Parameter
342+ "in" @ 24..26: Keyword
343+ "int8" @ 27..31: Type
344+ "void" @ 41..45: Type
345+ "# ) ;
346+ }
347+
348+ #[ test]
349+ fn create_function_percent_type ( ) {
350+ assert_snapshot ! ( semantic_tokens(
351+ "
352+ create function f(a t.c%type)
353+ returns t.b%type
354+ as '' language plpgsql;
355+ " ,
356+ ) , @r#""a" @ 19..20: Parameter"# ) ;
152357 }
153358
154359 #[ test]
@@ -170,4 +375,49 @@ select $1, $2;
170375 "$2" @ 12..14: PositionalParam
171376 "# )
172377 }
378+
379+ #[ test]
380+ fn from_alias_column_types ( ) {
381+ assert_snapshot ! ( semantic_tokens(
382+ "
383+ select *
384+ from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop);
385+ " ,
386+ ) , @r#"
387+ "int" @ 24..27: Type
388+ "jsonb" @ 31..36: Type
389+ "text" @ 40..44: Type
390+ "int" @ 48..51: Type
391+ "char" @ 56..60: Type
392+ "int" @ 70..73: Type
393+ "jbpop" @ 81..86: Type
394+ "# ) ;
395+ }
396+
397+ #[ test]
398+ fn cast_types ( ) {
399+ assert_snapshot ! ( semantic_tokens(
400+ "
401+ select '1'::jsonb, '2'::json, cast(1 as integer), cast(1 as int4[][]), cast(1 as varchar(10));
402+ " ,
403+ ) , @r#"
404+ "jsonb" @ 13..18: Type
405+ "json" @ 25..29: Type
406+ "integer" @ 41..48: Type
407+ "int4" @ 61..65: Type
408+ "varchar" @ 82..89: Type
409+ "# ) ;
410+ }
411+
412+ #[ test]
413+ fn positional_param_and_cast_type ( ) {
414+ assert_snapshot ! ( semantic_tokens(
415+ "
416+ select $2::jsonb;
417+ " ,
418+ ) , @r#"
419+ "$2" @ 8..10: PositionalParam
420+ "jsonb" @ 12..17: Type
421+ "# ) ;
422+ }
173423}
0 commit comments