diff --git a/.gitignore b/.gitignore index abab0fa64..af3907006 100644 --- a/.gitignore +++ b/.gitignore @@ -52,4 +52,3 @@ site/ biome-main/ .review/ pglinter_repo/ -.review/ diff --git a/crates/pgls_lsp/src/capabilities.rs b/crates/pgls_lsp/src/capabilities.rs index 38d66d876..87aaf96f4 100644 --- a/crates/pgls_lsp/src/capabilities.rs +++ b/crates/pgls_lsp/src/capabilities.rs @@ -1,10 +1,13 @@ use crate::adapters::{PositionEncoding, WideEncoding, negotiated_encoding}; use crate::handlers::code_actions::command_id; use pgls_workspace::features::code_actions::CommandActionCategory; +use pgls_workspace::features::semantic_tokens::{TokenModifier, TokenType}; use strum::IntoEnumIterator; use tower_lsp::lsp_types::{ ClientCapabilities, CompletionOptions, ExecuteCommandOptions, HoverProviderCapability, OneOf, - PositionEncodingKind, SaveOptions, ServerCapabilities, TextDocumentSyncCapability, + PositionEncodingKind, SaveOptions, SemanticTokenModifier, SemanticTokenType, + SemanticTokensFullOptions, SemanticTokensLegend, SemanticTokensOptions, + SemanticTokensServerCapabilities, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions, WorkDoneProgressOptions, }; @@ -69,6 +72,25 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa )), rename_provider: None, hover_provider: Some(HoverProviderCapability::Simple(true)), + semantic_tokens_provider: Some(SemanticTokensServerCapabilities::SemanticTokensOptions( + SemanticTokensOptions { + work_done_progress_options: WorkDoneProgressOptions { + work_done_progress: None, + }, + legend: SemanticTokensLegend { + token_types: TokenType::legend() + .into_iter() + .map(SemanticTokenType::new) + .collect(), + token_modifiers: TokenModifier::legend() + .into_iter() + .map(SemanticTokenModifier::new) + .collect(), + }, + range: Some(true), + full: Some(SemanticTokensFullOptions::Bool(true)), + }, + )), ..Default::default() } } diff --git a/crates/pgls_lsp/src/handlers.rs b/crates/pgls_lsp/src/handlers.rs index 42ac16963..afb1c2627 100644 --- a/crates/pgls_lsp/src/handlers.rs +++ b/crates/pgls_lsp/src/handlers.rs @@ -2,4 +2,5 @@ pub(crate) mod code_actions; pub(crate) mod completions; pub(crate) mod formatting; pub(crate) mod hover; +pub(crate) mod semantic_tokens; pub(crate) mod text_document; diff --git a/crates/pgls_lsp/src/handlers/semantic_tokens.rs b/crates/pgls_lsp/src/handlers/semantic_tokens.rs new file mode 100644 index 000000000..4022d54ee --- /dev/null +++ b/crates/pgls_lsp/src/handlers/semantic_tokens.rs @@ -0,0 +1,189 @@ +use crate::{adapters, diagnostics::LspError, session::Session}; +use pgls_text_size::TextRange; +use pgls_workspace::features::semantic_tokens::SemanticTokensParams; +use tower_lsp::lsp_types::{self, SemanticToken, SemanticTokens, SemanticTokensRangeResult}; + +/// Handles a full semantic tokens request. +#[tracing::instrument(level = "debug", skip(session), err)] +pub fn semantic_tokens_full( + session: &Session, + params: lsp_types::SemanticTokensParams, +) -> Result, LspError> { + let url = ¶ms.text_document.uri; + let tokens = get_semantic_tokens(session, url, None)?; + Ok(Some(lsp_types::SemanticTokensResult::Tokens(tokens))) +} + +/// Handles a range semantic tokens request. +#[tracing::instrument(level = "debug", skip(session), err)] +pub fn semantic_tokens_range( + session: &Session, + params: lsp_types::SemanticTokensRangeParams, +) -> Result, LspError> { + let url = ¶ms.text_document.uri; + let tokens = get_semantic_tokens(session, url, Some(params.range))?; + Ok(Some(SemanticTokensRangeResult::Tokens(tokens))) +} + +/// Common implementation for semantic tokens requests. +fn get_semantic_tokens( + session: &Session, + url: &lsp_types::Url, + range: Option, +) -> Result { + let path = session.file_path(url)?; + let doc = session.document(url)?; + let encoding = adapters::negotiated_encoding(session.client_capabilities().unwrap()); + + let text_range = range + .map(|r| -> Result { + let start = adapters::from_lsp::offset(&doc.line_index, r.start, encoding)?; + let end = adapters::from_lsp::offset(&doc.line_index, r.end, encoding)?; + Ok(TextRange::new(start, end)) + }) + .transpose()?; + + let workspace_result = session + .workspace + .get_semantic_tokens(SemanticTokensParams { + path, + range: text_range, + })?; + + let lsp_tokens = encode_tokens(&workspace_result.tokens, &doc.line_index, encoding)?; + + Ok(SemanticTokens { + result_id: None, + data: lsp_tokens, + }) +} + +/// Encodes workspace semantic tokens into the LSP delta-encoded format. +/// +/// LSP semantic tokens are encoded as a flat array of integers with 5 values per token: +/// - deltaLine: line difference from previous token +/// - deltaStart: character offset from start of line (or from previous token if same line) +/// - length: the length of the token in characters +/// - tokenType: the token type index +/// - tokenModifiers: bit flags for token modifiers +/// +/// Multi-line tokens (like block comments) are split into multiple LSP tokens, +/// one per line, since LSP semantic tokens cannot span lines. +fn encode_tokens( + tokens: &[pgls_workspace::features::semantic_tokens::SemanticToken], + line_index: &adapters::line_index::LineIndex, + encoding: adapters::PositionEncoding, +) -> Result, LspError> { + let mut result = Vec::with_capacity(tokens.len()); + let mut prev_line = 0u32; + let mut prev_start = 0u32; + + for token in tokens { + // Convert token range to LSP positions + let start_pos = adapters::to_lsp::position(line_index, token.range.start(), encoding) + .map_err(|e| LspError::from(anyhow::anyhow!("Failed to convert position: {}", e)))?; + + let end_pos = adapters::to_lsp::position(line_index, token.range.end(), encoding) + .map_err(|e| LspError::from(anyhow::anyhow!("Failed to convert position: {}", e)))?; + + if start_pos.line == end_pos.line { + // Single-line token - emit one LSP token + let length = end_pos.character - start_pos.character; + let delta_line = start_pos.line - prev_line; + let delta_start = if delta_line == 0 { + start_pos.character - prev_start + } else { + start_pos.character + }; + + result.push(SemanticToken { + delta_line, + delta_start, + length, + token_type: token.token_type, + token_modifiers_bitset: token.token_modifiers, + }); + + prev_line = start_pos.line; + prev_start = start_pos.character; + } else { + // Multi-line token - emit one LSP token per line + for line in start_pos.line..=end_pos.line { + let (line_start, line_length) = if line == start_pos.line { + // First line: from token start to end of line + let line_len = get_line_length(line_index, line, encoding); + (start_pos.character, line_len.saturating_sub(start_pos.character)) + } else if line == end_pos.line { + // Last line: from start of line to token end + (0, end_pos.character) + } else { + // Middle lines: entire line + (0, get_line_length(line_index, line, encoding)) + }; + + // Skip empty segments + if line_length == 0 { + continue; + } + + let delta_line = line - prev_line; + let delta_start = if delta_line == 0 { + line_start - prev_start + } else { + line_start + }; + + result.push(SemanticToken { + delta_line, + delta_start, + length: line_length, + token_type: token.token_type, + token_modifiers_bitset: token.token_modifiers, + }); + + prev_line = line; + prev_start = line_start; + } + } + } + + Ok(result) +} + +/// Gets the length of a line in the appropriate encoding (excluding the newline character). +fn get_line_length( + line_index: &adapters::line_index::LineIndex, + line: u32, + encoding: adapters::PositionEncoding, +) -> u32 { + let line_usize = line as usize; + + // Get the start offset of this line and the next line + let line_start = line_index.newlines.get(line_usize).copied(); + let next_line_start = line_index.newlines.get(line_usize + 1).copied(); + + let (Some(start), Some(end)) = (line_start, next_line_start) else { + // Last line or invalid line - estimate from offset + // For the last line, we don't have a next newline offset + // Return 0 as a safe fallback (the token end position should handle this) + return 0; + }; + + // Line length in bytes (excluding newline) + let byte_length = end - start - pgls_text_size::TextSize::from(1u32); + + // Convert to the appropriate encoding + match encoding { + adapters::PositionEncoding::Utf8 => byte_length.into(), + adapters::PositionEncoding::Wide(enc) => { + let line_col = adapters::LineCol { + line, + col: byte_length.into(), + }; + line_index + .to_wide(enc, line_col) + .map(|wlc| wlc.col) + .unwrap_or(byte_length.into()) + } + } +} diff --git a/crates/pgls_lsp/src/server.rs b/crates/pgls_lsp/src/server.rs index 1d1024369..0aa44acfe 100644 --- a/crates/pgls_lsp/src/server.rs +++ b/crates/pgls_lsp/src/server.rs @@ -329,6 +329,28 @@ impl LanguageServer for LSPServer { Err(e) => LspResult::Err(into_lsp_error(e)), } } + + #[tracing::instrument(level = "trace", skip_all)] + async fn semantic_tokens_full( + &self, + params: SemanticTokensParams, + ) -> LspResult> { + match handlers::semantic_tokens::semantic_tokens_full(&self.session, params) { + Ok(result) => LspResult::Ok(result), + Err(e) => LspResult::Err(into_lsp_error(e)), + } + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn semantic_tokens_range( + &self, + params: SemanticTokensRangeParams, + ) -> LspResult> { + match handlers::semantic_tokens::semantic_tokens_range(&self.session, params) { + Ok(result) => LspResult::Ok(result), + Err(e) => LspResult::Err(into_lsp_error(e)), + } + } } impl Drop for LSPServer { @@ -483,6 +505,7 @@ impl ServerFactory { workspace_method!(builder, register_project_folder); workspace_method!(builder, unregister_project_folder); workspace_method!(builder, invalidate_schema_cache); + workspace_method!(builder, get_semantic_tokens); let (service, socket) = builder.finish(); ServerConnection { socket, service } diff --git a/crates/pgls_workspace/src/features/mod.rs b/crates/pgls_workspace/src/features/mod.rs index e8a089338..b69ddc480 100644 --- a/crates/pgls_workspace/src/features/mod.rs +++ b/crates/pgls_workspace/src/features/mod.rs @@ -3,3 +3,4 @@ pub mod completions; pub mod diagnostics; pub mod format; pub mod on_hover; +pub mod semantic_tokens; diff --git a/crates/pgls_workspace/src/features/semantic_tokens.rs b/crates/pgls_workspace/src/features/semantic_tokens.rs new file mode 100644 index 000000000..b76449713 --- /dev/null +++ b/crates/pgls_workspace/src/features/semantic_tokens.rs @@ -0,0 +1,29 @@ +use pgls_fs::PgLSPath; +use pgls_text_size::TextRange; + +// Re-export from the workspace implementation +pub use crate::workspace::server::semantic_tokens::{ + SemanticToken, TokenModifier, TokenType, NO_MODIFIERS, +}; + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct SemanticTokensParams { + /// The file path for which semantic tokens are requested. + pub path: PgLSPath, + /// Optional range to limit the tokens. If None, returns tokens for the entire file. + pub range: Option, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize, Default)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct SemanticTokensResult { + /// The semantic tokens for the requested file/range + pub tokens: Vec, +} + +impl SemanticTokensResult { + pub fn new(tokens: Vec) -> Self { + Self { tokens } + } +} diff --git a/crates/pgls_workspace/src/workspace.rs b/crates/pgls_workspace/src/workspace.rs index 63d70fa6c..192f0390e 100644 --- a/crates/pgls_workspace/src/workspace.rs +++ b/crates/pgls_workspace/src/workspace.rs @@ -21,6 +21,7 @@ use crate::{ }, format::{PullFileFormattingParams, PullFormattingResult}, on_hover::{OnHoverParams, OnHoverResult}, + semantic_tokens::{SemanticTokensParams, SemanticTokensResult}, }, }; @@ -131,6 +132,12 @@ pub trait Workspace: Send + Sync + RefUnwindSafe { fn on_hover(&self, params: OnHoverParams) -> Result; + /// Retrieves semantic tokens for syntax highlighting + fn get_semantic_tokens( + &self, + params: SemanticTokensParams, + ) -> Result; + /// Register a possible workspace project folder. Returns the key of said project. Use this key when you want to switch to different projects. fn register_project_folder( &self, diff --git a/crates/pgls_workspace/src/workspace/client.rs b/crates/pgls_workspace/src/workspace/client.rs index ed87761d5..66ec4d39b 100644 --- a/crates/pgls_workspace/src/workspace/client.rs +++ b/crates/pgls_workspace/src/workspace/client.rs @@ -183,6 +183,13 @@ where self.request("pgls/on_hover", params) } + fn get_semantic_tokens( + &self, + params: crate::features::semantic_tokens::SemanticTokensParams, + ) -> Result { + self.request("pgls/get_semantic_tokens", params) + } + fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> { self.request("pgls/invalidate_schema_cache", all) } diff --git a/crates/pgls_workspace/src/workspace/server.rs b/crates/pgls_workspace/src/workspace/server.rs index ae941914a..c4720e46a 100644 --- a/crates/pgls_workspace/src/workspace/server.rs +++ b/crates/pgls_workspace/src/workspace/server.rs @@ -46,6 +46,7 @@ use crate::{ diagnostics::{PullDiagnosticsResult, PullFileDiagnosticsParams}, format::{PullFileFormattingParams, PullFormattingResult, StatementFormatResult}, on_hover::{OnHoverParams, OnHoverResult}, + semantic_tokens::{SemanticTokensParams, SemanticTokensResult}, }, settings::{WorkspaceSettings, WorkspaceSettingsHandle, WorkspaceSettingsHandleMut}, workspace::{AnalyserDiagnosticsMapper, WithCSTandASTMapper}, @@ -71,6 +72,7 @@ pub(crate) mod document; mod migration; mod pg_query; mod schema_cache_manager; +pub mod semantic_tokens; mod sql_function; mod statement_identifier; mod tree_sitter; @@ -1058,6 +1060,29 @@ impl Workspace for WorkspaceServer { None => Ok(OnHoverResult::default()), } } + + #[ignored_path(path=¶ms.path)] + #[tracing::instrument(level = "debug", skip_all, fields( + path = params.path.as_os_str().to_str(), + ), err)] + fn get_semantic_tokens( + &self, + params: SemanticTokensParams, + ) -> Result { + use document::{RangeOverlapFilter, SemanticTokenMapper}; + + let documents = self.documents.read().unwrap(); + let doc = documents + .get(¶ms.path) + .ok_or(WorkspaceError::not_found())?; + + let tokens: Vec<_> = doc + .iter_with_filter(SemanticTokenMapper, RangeOverlapFilter::new(params.range)) + .flat_map(|(_, tokens)| tokens) + .collect(); + + Ok(SemanticTokensResult::new(tokens)) + } } /// Returns `true` if `path` is a directory or diff --git a/crates/pgls_workspace/src/workspace/server/document.rs b/crates/pgls_workspace/src/workspace/server/document.rs index 4a1f06b5c..da8f9ae41 100644 --- a/crates/pgls_workspace/src/workspace/server/document.rs +++ b/crates/pgls_workspace/src/workspace/server/document.rs @@ -9,6 +9,7 @@ use pgls_text_size::{TextRange, TextSize}; use super::{ annotation::AnnotationStore, pg_query::PgQueryStore, + semantic_tokens::{SemanticToken, SemanticTokenStore}, sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, statement_identifier::StatementId, tree_sitter::TreeSitterStore, @@ -21,6 +22,7 @@ pub struct Document { diagnostics: Vec, ast_db: PgQueryStore, cst_db: TreeSitterStore, + semantic_token_db: SemanticTokenStore, #[allow(dead_code)] annotation_db: AnnotationStore, suppressions: Suppressions, @@ -30,6 +32,7 @@ impl Document { pub fn new(content: String, version: i32) -> Document { let cst_db = TreeSitterStore::new(); let ast_db = PgQueryStore::new(); + let semantic_token_db = SemanticTokenStore::default(); let annotation_db = AnnotationStore::new(); let suppressions = Suppressions::from(content.as_str()); @@ -42,6 +45,7 @@ impl Document { version, ast_db, cst_db, + semantic_token_db, annotation_db, suppressions, } @@ -380,6 +384,53 @@ impl StatementFilter<'_> for IdFilter { } } +/// Filter that matches statements overlapping with a given range. +/// If no range is provided, all statements pass. +pub struct RangeOverlapFilter { + range: Option, +} + +impl RangeOverlapFilter { + pub fn new(range: Option) -> Self { + Self { range } + } +} + +impl StatementFilter<'_> for RangeOverlapFilter { + fn predicate(&self, _id: &StatementId, range: &TextRange, _content: &str) -> bool { + match self.range { + Some(filter_range) => { + // Check if the statement range overlaps with the filter range + range.start() < filter_range.end() && range.end() > filter_range.start() + } + None => true, // No range filter, all statements pass + } + } +} + +/// Mapper that returns semantic tokens for each statement, adjusted to document coordinates +pub struct SemanticTokenMapper; +impl<'a> StatementMapper<'a> for SemanticTokenMapper { + /// Returns (statement_range, tokens_with_document_coordinates) + type Output = (TextRange, Vec); + + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let tokens = parser.semantic_token_db.get_or_cache_tokens(&id); + + // Adjust token positions from statement-relative to document-absolute + let adjusted_tokens: Vec = tokens + .into_iter() + .map(|t| SemanticToken { + range: TextRange::new(t.range.start() + range.start(), t.range.end() + range.start()), + token_type: t.token_type, + token_modifiers: t.token_modifiers, + }) + .collect(); + + (range, adjusted_tokens) + } +} + /// Helper function that wraps the statement splitter and returns the ranges with unified /// diagnostics pub(crate) fn split_with_diagnostics( diff --git a/crates/pgls_workspace/src/workspace/server/semantic_tokens.rs b/crates/pgls_workspace/src/workspace/server/semantic_tokens.rs new file mode 100644 index 000000000..8a89328a1 --- /dev/null +++ b/crates/pgls_workspace/src/workspace/server/semantic_tokens.rs @@ -0,0 +1,417 @@ +//! Semantic token computation using pgls_query::scan() +//! +//! This module provides semantic token types and modifiers for SQL syntax highlighting. +//! The LSP layer converts these to protocol-specific types. + +use std::num::NonZeroUsize; +use std::sync::Mutex; + +use lru::LruCache; +use pgls_query::protobuf::{KeywordKind, Token}; +use pgls_text_size::{TextRange, TextSize}; +use strum::{EnumIter, IntoEnumIterator, IntoStaticStr}; + +use super::statement_identifier::StatementId; + +const DEFAULT_CACHE_SIZE: usize = 1000; + +/// Semantic token types for SQL syntax highlighting. +/// The discriminant values define the indices sent to the client. +#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumIter, IntoStaticStr)] +#[strum(serialize_all = "camelCase")] +#[repr(u32)] +pub enum TokenType { + Keyword = 0, + Type = 1, + Function = 2, + Parameter = 3, + String = 4, + Number = 5, + Operator = 6, + Comment = 7, + Property = 8, +} + +impl TokenType { + /// Returns the legend as a list of token type names in order. + pub fn legend() -> Vec<&'static str> { + Self::iter().map(|t| t.into()).collect() + } +} + +/// Semantic token modifier bit flags. +/// Modifiers can be combined using bitwise OR. +#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumIter, IntoStaticStr)] +#[strum(serialize_all = "camelCase")] +#[repr(u32)] +pub enum TokenModifier { + Declaration = 0, + Definition = 1, + Readonly = 2, + DefaultLibrary = 3, +} + +impl TokenModifier { + /// Returns the bit flag value for this modifier. + pub const fn bit(self) -> u32 { + 1 << (self as u32) + } + + /// Returns the legend as a list of modifier names in order. + pub fn legend() -> Vec<&'static str> { + Self::iter().map(|m| m.into()).collect() + } +} + +/// No modifiers applied. +pub const NO_MODIFIERS: u32 = 0; + +/// A semantic token with absolute position +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct SemanticToken { + /// The text range of this token + pub range: TextRange, + /// The semantic token type index + pub token_type: u32, + /// The semantic token modifiers as a bit flag + pub token_modifiers: u32, +} + +impl SemanticToken { + pub fn new(start: TextSize, end: TextSize, token_type: TokenType, token_modifiers: u32) -> Self { + Self { + range: TextRange::new(start, end), + token_type: token_type as u32, + token_modifiers, + } + } +} + +/// Cache for semantic tokens per statement +pub struct SemanticTokenStore { + db: Mutex>>, +} + +impl Default for SemanticTokenStore { + fn default() -> Self { + Self { + db: Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(), + )), + } + } +} + +impl SemanticTokenStore { + /// Get cached tokens or compute and cache them for the given statement. + /// Returns tokens with positions relative to the statement start. + pub fn get_or_cache_tokens(&self, statement: &StatementId) -> Vec { + let mut cache = self.db.lock().unwrap(); + + if let Some(existing) = cache.get(statement) { + return existing.clone(); + } + + let tokens = compute_tokens(statement.content()); + cache.put(statement.clone(), tokens.clone()); + tokens + } +} + +/// Maps a pgls_query token to semantic token type and modifiers +fn map_token(token: Token, keyword_kind: KeywordKind) -> Option<(TokenType, u32)> { + match token { + // Comments + Token::SqlComment | Token::CComment => Some((TokenType::Comment, NO_MODIFIERS)), + + // String literals + Token::Sconst | Token::Usconst => Some((TokenType::String, TokenModifier::Readonly.bit())), + + // Numeric literals + Token::Iconst | Token::Fconst | Token::Bconst | Token::Xconst => { + Some((TokenType::Number, TokenModifier::Readonly.bit())) + } + + // Parameters ($1, $2, etc.) + Token::Param => Some((TokenType::Parameter, NO_MODIFIERS)), + + // Operators + Token::Op + | Token::Typecast + | Token::DotDot + | Token::ColonEquals + | Token::EqualsGreater + | Token::LessEquals + | Token::GreaterEquals + | Token::NotEquals => Some((TokenType::Operator, NO_MODIFIERS)), + + // Single-character operators + Token::Ascii37 // % + | Token::Ascii42 // * + | Token::Ascii43 // + + | Token::Ascii45 // - + | Token::Ascii47 // / + | Token::Ascii60 // < + | Token::Ascii61 // = + | Token::Ascii62 // > + | Token::Ascii94 // ^ + => Some((TokenType::Operator, NO_MODIFIERS)), + + // Identifiers + Token::Ident | Token::Uident => Some((TokenType::Property, NO_MODIFIERS)), + + // Type keywords - these are known SQL type names + Token::Bigint + | Token::Bit + | Token::BooleanP + | Token::CharP + | Token::Character + | Token::Dec + | Token::DecimalP + | Token::DoubleP + | Token::FloatP + | Token::IntP + | Token::Integer + | Token::Interval + | Token::National + | Token::Nchar + | Token::None + | Token::Numeric + | Token::Real + | Token::Setof + | Token::Smallint + | Token::Time + | Token::Timestamp + | Token::Varchar + | Token::Varying + | Token::Xmlattributes + | Token::Xmlconcat + | Token::Xmlelement + | Token::Xmlexists + | Token::Xmlforest + | Token::Xmlnamespaces + | Token::Xmlparse + | Token::Xmlpi + | Token::Xmlroot + | Token::Xmlserialize + | Token::Xmltable => Some((TokenType::Type, NO_MODIFIERS)), + + // Keywords (other tokens with keyword_kind) + _ => match keyword_kind { + KeywordKind::TypeFuncNameKeyword => Some((TokenType::Type, NO_MODIFIERS)), + KeywordKind::ReservedKeyword => { + Some((TokenType::Keyword, TokenModifier::DefaultLibrary.bit())) + } + KeywordKind::UnreservedKeyword | KeywordKind::ColNameKeyword => { + Some((TokenType::Keyword, NO_MODIFIERS)) + } + KeywordKind::NoKeyword => None, + }, + } +} + +/// Computes semantic tokens for the given SQL string +pub fn compute_tokens(sql: &str) -> Vec { + let scan_result = match pgls_query::scan(sql) { + Ok(result) => result, + Err(_) => return Vec::new(), + }; + + let mut tokens = Vec::with_capacity(scan_result.tokens.len()); + + for scan_token in scan_result.tokens { + let token = Token::try_from(scan_token.token).unwrap_or(Token::Nul); + let keyword_kind = + KeywordKind::try_from(scan_token.keyword_kind).unwrap_or(KeywordKind::NoKeyword); + + if let Some((token_type, token_modifiers)) = map_token(token, keyword_kind) { + let start = TextSize::from(scan_token.start as u32); + let end = TextSize::from(scan_token.end as u32); + + tokens.push(SemanticToken::new(start, end, token_type, token_modifiers)); + } + } + + tokens +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_select() { + let sql = "SELECT * FROM users"; + let tokens = compute_tokens(sql); + + assert!(!tokens.is_empty()); + + // First token should be SELECT (keyword) + assert_eq!(tokens[0].token_type, TokenType::Keyword as u32); + + // Check that we have an identifier for 'users' + let users_token = tokens + .iter() + .find(|t| { + let start: usize = t.range.start().into(); + let end: usize = t.range.end().into(); + &sql[start..end] == "users" + }) + .expect("Should find users token"); + assert_eq!(users_token.token_type, TokenType::Property as u32); + } + + #[test] + fn test_string_literal() { + let sql = "SELECT 'hello'"; + let tokens = compute_tokens(sql); + + let string_token = tokens + .iter() + .find(|t| t.token_type == TokenType::String as u32) + .expect("Should find string token"); + + assert_eq!(string_token.token_modifiers, TokenModifier::Readonly.bit()); + } + + #[test] + fn test_numeric_literal() { + let sql = "SELECT 42, 3.14"; + let tokens = compute_tokens(sql); + + let number_tokens: Vec<_> = tokens + .iter() + .filter(|t| t.token_type == TokenType::Number as u32) + .collect(); + + assert_eq!(number_tokens.len(), 2); + } + + #[test] + fn test_comment() { + let sql = "SELECT 1 -- this is a comment"; + let tokens = compute_tokens(sql); + + let comment_token = tokens + .iter() + .find(|t| t.token_type == TokenType::Comment as u32) + .expect("Should find comment token"); + + assert!(comment_token.range.start() > TextSize::from(8u32)); + } + + #[test] + fn test_parameter() { + let sql = "SELECT $1, $2"; + let tokens = compute_tokens(sql); + + let param_tokens: Vec<_> = tokens + .iter() + .filter(|t| t.token_type == TokenType::Parameter as u32) + .collect(); + + assert_eq!(param_tokens.len(), 2); + } + + #[test] + fn test_operators() { + let sql = "SELECT 1 + 2 * 3"; + let tokens = compute_tokens(sql); + + let operator_tokens: Vec<_> = tokens + .iter() + .filter(|t| t.token_type == TokenType::Operator as u32) + .collect(); + + assert_eq!(operator_tokens.len(), 2); + } + + #[test] + fn test_type_keyword() { + let sql = "SELECT 1::integer"; + let tokens = compute_tokens(sql); + + // 'integer' should be a type token + let type_token = tokens + .iter() + .find(|t| { + let start: usize = t.range.start().into(); + let end: usize = t.range.end().into(); + &sql[start..end] == "integer" + }) + .expect("Should find integer token"); + + assert_eq!(type_token.token_type, TokenType::Type as u32); + } + + #[test] + fn test_multiline_comment() { + let sql = "SELECT /* this is\na multiline\ncomment */ 1"; + let tokens = compute_tokens(sql); + + // Should find a comment token spanning multiple lines + let comment_token = tokens + .iter() + .find(|t| t.token_type == TokenType::Comment as u32) + .expect("Should find comment token"); + + // The token should span from "/*" to "*/" + let start: usize = comment_token.range.start().into(); + let end: usize = comment_token.range.end().into(); + let comment_text = &sql[start..end]; + + assert!(comment_text.starts_with("/*")); + assert!(comment_text.ends_with("*/")); + assert!(comment_text.contains('\n'), "Comment should be multi-line"); + } + + #[test] + fn test_multiline_string() { + let sql = "SELECT 'hello\nworld'"; + let tokens = compute_tokens(sql); + + // Should find a string token spanning multiple lines + let string_token = tokens + .iter() + .find(|t| t.token_type == TokenType::String as u32) + .expect("Should find string token"); + + let start: usize = string_token.range.start().into(); + let end: usize = string_token.range.end().into(); + let string_text = &sql[start..end]; + + assert!(string_text.contains('\n'), "String should be multi-line"); + } + + #[test] + fn test_token_type_legend() { + let legend = TokenType::legend(); + assert_eq!(legend[0], "keyword"); + assert_eq!(legend[1], "type"); + assert_eq!(legend[2], "function"); + assert_eq!(legend[3], "parameter"); + assert_eq!(legend[4], "string"); + assert_eq!(legend[5], "number"); + assert_eq!(legend[6], "operator"); + assert_eq!(legend[7], "comment"); + assert_eq!(legend[8], "property"); + } + + #[test] + fn test_token_modifier_legend() { + let legend = TokenModifier::legend(); + assert_eq!(legend[0], "declaration"); + assert_eq!(legend[1], "definition"); + assert_eq!(legend[2], "readonly"); + assert_eq!(legend[3], "defaultLibrary"); + } + + #[test] + fn test_token_modifier_bits() { + assert_eq!(TokenModifier::Declaration.bit(), 1 << 0); + assert_eq!(TokenModifier::Definition.bit(), 1 << 1); + assert_eq!(TokenModifier::Readonly.bit(), 1 << 2); + assert_eq!(TokenModifier::DefaultLibrary.bit(), 1 << 3); + } +}