diff --git a/crates/squawk_ide/src/find_references.rs b/crates/squawk_ide/src/find_references.rs index d4772cb5..680a3c28 100644 --- a/crates/squawk_ide/src/find_references.rs +++ b/crates/squawk_ide/src/find_references.rs @@ -1,7 +1,8 @@ use crate::binder::Binder; use crate::builtins::{builtins_binder, parse_builtins}; +use crate::classify::classify_def_node; use crate::db::{File, bind, parse}; -use crate::goto_definition::{FileId, Location}; +use crate::goto_definition::{FileId, Location, LocationKind}; use crate::offsets::token_from_offset; use crate::resolve; use rowan::TextSize; @@ -23,7 +24,7 @@ pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec Vec Vec Vec Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> { +) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>, LocationKind)> { let token = token_from_offset(file, offset)?; let parent = token.parent()?; - if let Some(name) = ast::Name::cast(parent.clone()) { + if let Some(name) = ast::Name::cast(parent.clone()) + && let Some(kind) = classify_def_node(name.syntax()).map(LocationKind::from) + { return Some(( FileId::Current, smallvec![SyntaxNodePtr::new(name.syntax())], + kind, )); } @@ -105,8 +112,8 @@ fn find_target_defs( FileId::Current => (current_binder, file.syntax()), FileId::Builtins => (builtins_binder, builtins_tree.syntax()), }; - if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) { - return Some((file_id, ptrs)); + if let Some((ptrs, kind)) = resolve::resolve_name_ref(binder, root, &name_ref) { + return Some((file_id, ptrs, kind)); } } } diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 4d3e47f8..3d4fc059 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -1,5 +1,6 @@ use crate::binder; use crate::builtins::parse_builtins; +use crate::classify::{NameRefClass, classify_def_node}; use crate::db::{File, parse}; use crate::offsets::token_from_offset; use crate::resolve; @@ -31,7 +32,10 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L if let Some(case_expr) = ast::CaseExpr::cast(parent) && let Some(case_token) = case_expr.case_token() { - return smallvec![Location::range(case_token.text_range())]; + return smallvec![Location::current( + case_token.text_range(), + LocationKind::CaseExpr + )]; } } } @@ -40,14 +44,14 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L if ast::Commit::can_cast(parent.kind()) && let Some(begin_range) = find_preceding_begin(source_file, token.text_range().start()) { - return smallvec![Location::range(begin_range)]; + return smallvec![Location::current(begin_range, LocationKind::CommitBegin)]; } // goto def on ROLLBACK -> BEGIN/START TRANSACTION if ast::Rollback::can_cast(parent.kind()) && let Some(begin_range) = find_preceding_begin(source_file, token.text_range().start()) { - return smallvec![Location::range(begin_range)]; + return smallvec![Location::current(begin_range, LocationKind::CommitBegin)]; } // goto def on BEGIN/START TRANSACTION -> COMMIT or ROLLBACK @@ -55,11 +59,13 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L && let Some(end_range) = find_following_commit_or_rollback(source_file, token.text_range().end()) { - return smallvec![Location::range(end_range)]; + return smallvec![Location::current(end_range, LocationKind::CommitEnd)]; } - if let Some(name) = ast::Name::cast(parent.clone()) { - return smallvec![Location::range(name.syntax().text_range())]; + if let Some(name) = ast::Name::cast(parent.clone()) + && let Some(kind) = classify_def_node(name.syntax()).map(LocationKind::from) + { + return smallvec![Location::current(name.syntax().text_range(), kind)]; } if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { @@ -71,13 +77,14 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L // TODO: we should salsa this let binder_output = binder::bind(file); let root = file.syntax(); - if let Some(ptrs) = resolve::resolve_name_ref_ptrs(&binder_output, root, &name_ref) { + if let Some((ptrs, kind)) = resolve::resolve_name_ref(&binder_output, root, &name_ref) { let ranges = ptrs .iter() .map(|ptr| ptr.to_node(file.syntax()).text_range()) .map(|range| Location { file: file_id, range, + kind, }) .collect(); return ranges; @@ -106,6 +113,7 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L return smallvec![Location { file: file_id, range: ptr.to_node(file.syntax()).text_range(), + kind: LocationKind::Type, }]; } } @@ -120,17 +128,111 @@ pub enum FileId { Builtins, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LocationKind { + Aggregate, + CaseExpr, + Channel, + Column, + CommitBegin, + CommitEnd, + Cursor, + Database, + EventTrigger, + Extension, + Function, + Index, + NamedArgParameter, + Policy, + PreparedStatement, + Procedure, + Role, + Schema, + Sequence, + Server, + Table, + Tablespace, + Trigger, + Type, + View, + Window, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Location { pub file: FileId, pub range: TextRange, + pub kind: LocationKind, } impl Location { - fn range(range: TextRange) -> Location { + fn current(range: TextRange, kind: LocationKind) -> Location { Location { file: FileId::Current, range, + kind, + } + } +} + +impl From for LocationKind { + fn from(class: NameRefClass) -> Self { + match class { + NameRefClass::Aggregate => LocationKind::Aggregate, + NameRefClass::Channel => LocationKind::Channel, + NameRefClass::Cursor => LocationKind::Cursor, + NameRefClass::Database => LocationKind::Database, + NameRefClass::EventTrigger => LocationKind::EventTrigger, + NameRefClass::Extension => LocationKind::Extension, + NameRefClass::Index => LocationKind::Index, + NameRefClass::NamedArgParameter => LocationKind::NamedArgParameter, + NameRefClass::Policy => LocationKind::Policy, + NameRefClass::PreparedStatement => LocationKind::PreparedStatement, + NameRefClass::Role => LocationKind::Role, + NameRefClass::Schema => LocationKind::Schema, + NameRefClass::Sequence => LocationKind::Sequence, + NameRefClass::Server => LocationKind::Server, + NameRefClass::Tablespace => LocationKind::Tablespace, + NameRefClass::Trigger => LocationKind::Trigger, + NameRefClass::Type => LocationKind::Type, + NameRefClass::View => LocationKind::View, + NameRefClass::Window => LocationKind::Window, + + NameRefClass::CallProcedure | NameRefClass::Procedure | NameRefClass::ProcedureCall => { + LocationKind::Procedure + } + + NameRefClass::Function + | NameRefClass::FunctionCall + | NameRefClass::FunctionName + | NameRefClass::Routine + | NameRefClass::SelectFunctionCall => LocationKind::Function, + + NameRefClass::AlterColumn + | NameRefClass::CompositeTypeField + | NameRefClass::ConstraintColumn + | NameRefClass::CreateIndexColumn + | NameRefClass::DeleteColumn + | NameRefClass::ForeignKeyColumn + | NameRefClass::InsertColumn + | NameRefClass::JoinUsingColumn + | NameRefClass::MergeColumn + | NameRefClass::PolicyColumn + | NameRefClass::QualifiedColumn + | NameRefClass::SelectColumn + | NameRefClass::SelectQualifiedColumn + | NameRefClass::UpdateColumn => LocationKind::Column, + + NameRefClass::DeleteQualifiedColumnTable + | NameRefClass::ForeignKeyTable + | NameRefClass::FromTable + | NameRefClass::InsertQualifiedColumnTable + | NameRefClass::LikeTable + | NameRefClass::MergeQualifiedColumnTable + | NameRefClass::PolicyQualifiedColumnTable + | NameRefClass::SelectQualifiedColumnTable + | NameRefClass::Table + | NameRefClass::UpdateQualifiedColumnTable => LocationKind::Table, } } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 439e08b0..15f33169 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -6,8 +6,9 @@ use squawk_syntax::{ }; use crate::binder::Binder; -use crate::classify::{NameRefClass, classify_name_ref}; +use crate::classify::{NameRefClass, classify_def_node, classify_name_ref}; use crate::column_name::ColumnName; +use crate::goto_definition::LocationKind; use crate::infer::{Type, infer_type_from_expr, infer_type_from_ty}; pub(crate) use crate::symbols::Schema; use crate::symbols::{Name, SymbolKind}; @@ -27,6 +28,14 @@ pub(crate) fn resolve_name_ref_ptrs( root: &SyntaxNode, name_ref: &ast::NameRef, ) -> Option> { + resolve_name_ref(binder, root, name_ref).map(|(ptrs, _)| ptrs) +} + +pub(crate) fn resolve_name_ref( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option<(SmallVec<[SyntaxNodePtr; 1]>, LocationKind)> { let context = classify_name_ref(name_ref.syntax())?; match context { @@ -37,10 +46,11 @@ pub(crate) fn resolve_name_ref_ptrs( if schema.is_none() && let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) { - return Some(smallvec![cte_ptr]); + return Some((smallvec![cte_ptr], LocationKind::Table)); } - resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::Table)) } NameRefClass::NamedArgParameter => { let (function_name, schema) = find_func_call_from_named_arg(name_ref)?; @@ -58,21 +68,23 @@ pub(crate) fn resolve_name_ref_ptrs( })?; let param_ptr = find_param_in_func_def(root, function_ptr, ¶m_name)?; - Some(smallvec![param_ptr]) + Some((smallvec![param_ptr], LocationKind::NamedArgParameter)) } NameRefClass::Cursor => { let cursor_name = Name::from_node(name_ref); binder .lookup(&cursor_name, SymbolKind::Cursor) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Cursor)) } NameRefClass::PreparedStatement => { let statement_name = Name::from_node(name_ref); - resolve_prepared_statement_name_ptr(binder, &statement_name).map(|ptr| smallvec![ptr]) + resolve_prepared_statement_name_ptr(binder, &statement_name) + .map(|ptr| (smallvec![ptr], LocationKind::PreparedStatement)) } NameRefClass::Channel => { let channel_name = Name::from_node(name_ref); - resolve_channel_name_ptr(binder, &channel_name).map(|ptr| smallvec![ptr]) + resolve_channel_name_ptr(binder, &channel_name) + .map(|ptr| (smallvec![ptr], LocationKind::Channel)) } NameRefClass::FromTable => { let table_name = Name::from_node(name_ref); @@ -89,7 +101,7 @@ pub(crate) fn resolve_name_ref_ptrs( if schema.is_none() && let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) { - return Some(smallvec![cte_ptr]); + return Some((smallvec![cte_ptr], LocationKind::Table)); } let position = name_ref.syntax().text_range().start(); @@ -97,15 +109,17 @@ pub(crate) fn resolve_name_ref_ptrs( if let Some(table_name_ptr) = resolve_table_name_ptr(binder, &table_name, &schema, position) { - return Some(smallvec![table_name_ptr]); + return Some((smallvec![table_name_ptr], LocationKind::Table)); } - resolve_view_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_view_name_ptr(binder, &table_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::View)) } NameRefClass::Index => { let position = name_ref.syntax().text_range().start(); let index_name = Name::from_node(name_ref); - resolve_index_name_ptr(binder, &index_name, &None, position).map(|ptr| smallvec![ptr]) + resolve_index_name_ptr(binder, &index_name, &None, position) + .map(|ptr| (smallvec![ptr], LocationKind::Index)) } NameRefClass::Type => { let (type_name, schema) = if let Some(parent) = name_ref.syntax().parent() @@ -128,19 +142,23 @@ pub(crate) fn resolve_name_ref_ptrs( }; let type_name = resolve_float_precision(name_ref, type_name); let position = name_ref.syntax().text_range().start(); - resolve_type_name_ptr(binder, &type_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_type_name_ptr(binder, &type_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::Type)) } NameRefClass::View => { let (view_name, schema) = extract_table_schema_from_name_ref(name_ref)?; let position = name_ref.syntax().text_range().start(); - resolve_view_name_ptr(binder, &view_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_view_name_ptr(binder, &view_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::View)) + } + NameRefClass::Window => { + resolve_window_name_ptr(name_ref).map(|ptr| (smallvec![ptr], LocationKind::Window)) } - NameRefClass::Window => resolve_window_name_ptr(name_ref).map(|ptr| smallvec![ptr]), NameRefClass::Sequence => { let (sequence_name, schema) = extract_table_schema_from_name_ref(name_ref)?; let position = name_ref.syntax().text_range().start(); resolve_sequence_name_ptr(binder, &sequence_name, &schema, position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Sequence)) } NameRefClass::Trigger => { let drop_trigger = name_ref @@ -158,7 +176,7 @@ pub(crate) fn resolve_name_ref_ptrs( let table_name = extract_table_name(&on_table_path)?; let position = name_ref.syntax().text_range().start(); resolve_trigger_name_ptr(binder, &trigger_name, &schema, position, Some(table_name)) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Trigger)) } NameRefClass::Policy => { let (policy_name, on_table) = name_ref.syntax().ancestors().find_map(|a| { @@ -173,44 +191,53 @@ pub(crate) fn resolve_name_ref_ptrs( let (table_name, schema) = extract_table_schema_from_path(&on_table_path)?; let position = name_ref.syntax().text_range().start(); resolve_policy_name_ptr(binder, &policy_name, &schema, position, table_name) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Policy)) } NameRefClass::EventTrigger => { let event_trigger_name = Name::from_node(name_ref); - resolve_event_trigger_name_ptr(binder, &event_trigger_name).map(|ptr| smallvec![ptr]) + resolve_event_trigger_name_ptr(binder, &event_trigger_name) + .map(|ptr| (smallvec![ptr], LocationKind::EventTrigger)) } NameRefClass::Database => { let database_name = Name::from_node(name_ref); - resolve_database_name_ptr(binder, &database_name).map(|ptr| smallvec![ptr]) + resolve_database_name_ptr(binder, &database_name) + .map(|ptr| (smallvec![ptr], LocationKind::Database)) } NameRefClass::Server => { let server_name = Name::from_node(name_ref); - resolve_server_name_ptr(binder, &server_name).map(|ptr| smallvec![ptr]) + resolve_server_name_ptr(binder, &server_name) + .map(|ptr| (smallvec![ptr], LocationKind::Server)) } NameRefClass::Extension => { let extension_name = Name::from_node(name_ref); - resolve_extension_name_ptr(binder, &extension_name).map(|ptr| smallvec![ptr]) + resolve_extension_name_ptr(binder, &extension_name) + .map(|ptr| (smallvec![ptr], LocationKind::Extension)) } NameRefClass::Role => { let role_name = Name::from_node(name_ref); - resolve_role_name_ptr(binder, &role_name).map(|ptr| smallvec![ptr]) + resolve_role_name_ptr(binder, &role_name) + .map(|ptr| (smallvec![ptr], LocationKind::Role)) } NameRefClass::QualifiedColumn => { let path = name_ref.syntax().ancestors().find_map(ast::Path::cast)?; let column_name = Name::from_node(name_ref); let table_path = path.qualifier()?; - resolve_column_for_path(binder, root, &table_path, column_name) - .map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &table_path, column_name).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } NameRefClass::Tablespace => { let tablespace_name = Name::from_node(name_ref); - resolve_tablespace_name_ptr(binder, &tablespace_name).map(|ptr| smallvec![ptr]) + resolve_tablespace_name_ptr(binder, &tablespace_name) + .map(|ptr| (smallvec![ptr], LocationKind::Tablespace)) } NameRefClass::ForeignKeyTable => { let path = name_ref.syntax().ancestors().find_map(ast::Path::cast)?; let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::Table)) } NameRefClass::ForeignKeyColumn => { // TODO: the ast is too flat here @@ -230,19 +257,29 @@ pub(crate) fn resolve_name_ref_ptrs( return None; }; let column_name = Name::from_node(name_ref); - resolve_column_for_path(binder, root, &path, column_name).map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &path, column_name).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } NameRefClass::ConstraintColumn => { let column_name = Name::from_node(name_ref); for ancestor in name_ref.syntax().ancestors() { if let Some(create_table) = ast::CreateTableLike::cast(ancestor.clone()) { return find_column_in_create_table(binder, root, &create_table, &column_name) - .map(|ptr| smallvec![ptr]); + .map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }); } if let Some(alter_table) = ast::AlterTable::cast(ancestor) { let table_path = alter_table.relation_name()?.path()?; - return resolve_column_for_path(binder, root, &table_path, column_name) - .map(|ptr| smallvec![ptr]); + return resolve_column_for_path(binder, root, &table_path, column_name).map( + |ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }, + ); } } None @@ -258,8 +295,10 @@ pub(crate) fn resolve_name_ref_ptrs( } })?; let column_name = Name::from_node(name_ref); - resolve_column_for_path(binder, root, &on_table_path, column_name) - .map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &on_table_path, column_name).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } NameRefClass::PolicyQualifiedColumnTable => { let on_table_path = name_ref.syntax().ancestors().find_map(|n| { @@ -273,7 +312,8 @@ pub(crate) fn resolve_name_ref_ptrs( })?; let (table_name, schema) = extract_table_schema_from_path(&on_table_path)?; let position = name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::Table)) } NameRefClass::LikeTable => { let like_clause = name_ref @@ -283,7 +323,8 @@ pub(crate) fn resolve_name_ref_ptrs( let path = like_clause.path()?; let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position) + .map(|ptr| (smallvec![ptr], LocationKind::Table)) } NameRefClass::Function => { let function_sig = name_ref @@ -295,7 +336,7 @@ pub(crate) fn resolve_name_ref_ptrs( let params = extract_param_signature(&function_sig); let position = name_ref.syntax().text_range().start(); resolve_function(binder, &function_name, &schema, params.as_deref(), position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Function)) } NameRefClass::Aggregate => { let aggregate = name_ref @@ -313,7 +354,7 @@ pub(crate) fn resolve_name_ref_ptrs( params.as_deref(), position, ) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Aggregate)) } NameRefClass::Procedure => { let function_sig = name_ref @@ -331,7 +372,7 @@ pub(crate) fn resolve_name_ref_ptrs( params.as_deref(), position, ) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Procedure)) } NameRefClass::Routine => { let function_sig = name_ref @@ -346,17 +387,17 @@ pub(crate) fn resolve_name_ref_ptrs( if let Some(ptr) = resolve_function(binder, &routine_name, &schema, params.as_deref(), position) { - return Some(smallvec![ptr]); + return Some((smallvec![ptr], LocationKind::Function)); } if let Some(ptr) = resolve_aggregate(binder, &routine_name, &schema, params.as_deref(), position) { - return Some(smallvec![ptr]); + return Some((smallvec![ptr], LocationKind::Aggregate)); } resolve_procedure(binder, &routine_name, &schema, params.as_deref(), position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Procedure)) } NameRefClass::CallProcedure => { let call = name_ref.syntax().ancestors().find_map(ast::Call::cast)?; @@ -364,11 +405,11 @@ pub(crate) fn resolve_name_ref_ptrs( let (procedure_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_procedure(binder, &procedure_name, &schema, None, position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Procedure)) } NameRefClass::Schema => { let schema_name = Name::from_node(name_ref); - resolve_schema(binder, &schema_name).map(|ptr| smallvec![ptr]) + resolve_schema(binder, &schema_name).map(|ptr| (smallvec![ptr], LocationKind::Schema)) } NameRefClass::FunctionCall => { let schema = if let Some(parent_node) = name_ref.syntax().parent() @@ -383,7 +424,7 @@ pub(crate) fn resolve_name_ref_ptrs( let function_name = Name::from_node(name_ref); let position = name_ref.syntax().text_range().start(); resolve_function(binder, &function_name, &schema, None, position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Function)) } NameRefClass::ProcedureCall => { let schema = if let Some(parent_node) = name_ref.syntax().parent() @@ -399,7 +440,7 @@ pub(crate) fn resolve_name_ref_ptrs( let position = name_ref.syntax().text_range().start(); resolve_procedure(binder, &procedure_name, &schema, None, position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Procedure)) } NameRefClass::FunctionName => { let path_type = name_ref @@ -410,7 +451,7 @@ pub(crate) fn resolve_name_ref_ptrs( let (function_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_function(binder, &function_name, &schema, None, position) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Function)) } NameRefClass::SelectFunctionCall => { let schema = if let Some(parent_node) = name_ref.syntax().parent() @@ -427,12 +468,12 @@ pub(crate) fn resolve_name_ref_ptrs( // functions take precedence if let Some(ptr) = resolve_function(binder, &function_name, &schema, None, position) { - return Some(smallvec![ptr]); + return Some((smallvec![ptr], LocationKind::Function)); } // aggregates take precedence over function-call-style column access if let Some(ptr) = resolve_aggregate(binder, &function_name, &schema, None, position) { - return Some(smallvec![ptr]); + return Some((smallvec![ptr], LocationKind::Aggregate)); } // if no function found, check if this is function-call-style column access @@ -443,52 +484,68 @@ pub(crate) fn resolve_name_ref_ptrs( if schema.is_none() && let Some(ptr) = resolve_fn_call_column(binder, root, name_ref) { - return Some(smallvec![ptr]); + return Some((smallvec![ptr], LocationKind::Column)); } None } - NameRefClass::CreateIndexColumn => { - resolve_create_index_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } + NameRefClass::CreateIndexColumn => resolve_create_index_column_ptr(binder, root, name_ref) + .map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }), NameRefClass::SelectColumn => { - resolve_select_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } NameRefClass::SelectQualifiedColumnTable => { resolve_select_qualified_column_table_name_ptr(binder, name_ref) - .map(|ptr| smallvec![ptr]) + .map(|ptr| (smallvec![ptr], LocationKind::Table)) } NameRefClass::SelectQualifiedColumn => { - resolve_select_qualified_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_qualified_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } NameRefClass::CompositeTypeField => { - resolve_composite_type_field_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_composite_type_field_ptr(binder, root, name_ref) + .map(|ptr| (smallvec![ptr], LocationKind::Column)) } NameRefClass::InsertColumn => { - resolve_insert_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } - NameRefClass::InsertQualifiedColumnTable => { - resolve_insert_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_insert_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } + NameRefClass::InsertQualifiedColumnTable => resolve_insert_table_name_ptr(binder, name_ref) + .map(|ptr| (smallvec![ptr], LocationKind::Table)), NameRefClass::DeleteColumn => { - resolve_delete_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } - NameRefClass::DeleteQualifiedColumnTable => { - resolve_delete_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_delete_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } + NameRefClass::DeleteQualifiedColumnTable => resolve_delete_table_name_ptr(binder, name_ref) + .map(|ptr| (smallvec![ptr], LocationKind::Table)), NameRefClass::UpdateColumn => { - resolve_update_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } - NameRefClass::UpdateQualifiedColumnTable => { - resolve_update_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) - } - NameRefClass::MergeColumn => { - resolve_merge_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } - NameRefClass::MergeQualifiedColumnTable => { - resolve_merge_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_update_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } - NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, root, name_ref), + NameRefClass::UpdateQualifiedColumnTable => resolve_update_table_name_ptr(binder, name_ref) + .map(|ptr| (smallvec![ptr], LocationKind::Table)), + NameRefClass::MergeColumn => resolve_merge_column_ptr(binder, root, name_ref).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }), + NameRefClass::MergeQualifiedColumnTable => resolve_merge_table_name_ptr(binder, name_ref) + .map(|ptr| (smallvec![ptr], LocationKind::Table)), + NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, root, name_ref) + .map(|ptrs| (ptrs, LocationKind::Column)), NameRefClass::AlterColumn => { let column_name = Name::from_node(name_ref); let alter_table = name_ref @@ -496,13 +553,25 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::AlterTable::cast)?; let table_path = alter_table.relation_name()?.path()?; - resolve_column_for_path(binder, root, &table_path, column_name) - .map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &table_path, column_name).map(|ptr| { + let kind = resolved_location_kind(root, &ptr, LocationKind::Column); + (smallvec![ptr], kind) + }) } } .or_else(|| resolve_special_keyword_as_function(binder, name_ref)) } +fn resolved_location_kind( + root: &SyntaxNode, + ptr: &SyntaxNodePtr, + fallback: LocationKind, +) -> LocationKind { + classify_def_node(&ptr.to_node(root)) + .map(LocationKind::from) + .unwrap_or(fallback) +} + fn resolve_table_name_ptr( binder: &Binder, table_name: &Name, @@ -729,7 +798,7 @@ fn resolve_for_kind_with_params( fn resolve_special_keyword_as_function( binder: &Binder, name_ref: &ast::NameRef, -) -> Option> { +) -> Option<(SmallVec<[SyntaxNodePtr; 1]>, LocationKind)> { let function_name = name_ref .syntax() .first_child_or_token() @@ -742,7 +811,8 @@ fn resolve_special_keyword_as_function( })?; let function_name = Name::from_string(function_name); let position = name_ref.syntax().text_range().start(); - resolve_function(binder, &function_name, &None, None, position).map(|ptr| smallvec![ptr]) + resolve_function(binder, &function_name, &None, None, position) + .map(|ptr| (smallvec![ptr], LocationKind::Function)) } fn resolve_function( diff --git a/crates/squawk_ide/src/semantic_tokens.rs b/crates/squawk_ide/src/semantic_tokens.rs index bfc4511d..4e092dcf 100644 --- a/crates/squawk_ide/src/semantic_tokens.rs +++ b/crates/squawk_ide/src/semantic_tokens.rs @@ -6,6 +6,7 @@ use squawk_syntax::{ }; use crate::db::{File, parse}; +use crate::goto_definition::{LocationKind, goto_definition}; fn highlight_param_mode(out: &mut SemanticTokenBuilder, mode: ast::ParamMode) { match mode { @@ -71,15 +72,7 @@ fn highlight_type(out: &mut SemanticTokenBuilder, ty: ast::Type) { out.push_type(token.into()); } } - ast::Type::PathType(path_type) => { - if let Some(name_ref) = path_type - .path() - .and_then(|path| path.segment()) - .and_then(|ps| ps.name_ref()) - { - out.push_type(name_ref.syntax().clone().into()); - } - } + ast::Type::PathType(_) => (), ast::Type::PercentType(_) => (), ast::Type::TimeType(time_type) => { if let Some(token) = time_type @@ -124,6 +117,59 @@ pub enum SemanticTokenType { Type, Parameter, PositionalParam, + Column, + Table, + Schema, +} + +impl TryFrom for SemanticTokenType { + type Error = LocationKind; + + fn try_from(kind: LocationKind) -> Result { + match kind { + LocationKind::Aggregate | LocationKind::Function | LocationKind::Procedure => { + Ok(SemanticTokenType::Function) + } + LocationKind::Column => Ok(SemanticTokenType::Column), + LocationKind::NamedArgParameter => Ok(SemanticTokenType::Parameter), + LocationKind::Schema => Ok(SemanticTokenType::Schema), + LocationKind::Sequence | LocationKind::Table | LocationKind::View => { + Ok(SemanticTokenType::Table) + } + LocationKind::Type => Ok(SemanticTokenType::Type), + LocationKind::CaseExpr + | LocationKind::Channel + | LocationKind::CommitBegin + | LocationKind::CommitEnd + | LocationKind::Cursor + | LocationKind::Database + | LocationKind::EventTrigger + | LocationKind::Extension + | LocationKind::Index + | LocationKind::Policy + | LocationKind::PreparedStatement + | LocationKind::Role + | LocationKind::Server + | LocationKind::Tablespace + | LocationKind::Trigger + | LocationKind::Window => Err(kind), + } + } +} + +fn highlight_name_ref( + db: &dyn Db, + file: File, + out: &mut SemanticTokenBuilder, + name_ref: &ast::NameRef, +) { + let offset = name_ref.syntax().text_range().start(); + let Some(location) = goto_definition(db, file, offset).into_iter().next() else { + return; + }; + if let Ok(token_type) = SemanticTokenType::try_from(location.kind) { + out.push_token(name_ref.syntax().clone().into(), token_type); + } } #[derive(Default)] @@ -202,8 +248,12 @@ pub fn semantic_tokens( && let Some(as_name) = target.as_name() && let Some(name) = as_name.name() { - out.push_token(name.syntax().clone().into(), SemanticTokenType::Name); - }; + out.push_token(name.syntax().clone().into(), SemanticTokenType::Column); + } + + if let Some(name_ref) = ast::NameRef::cast(node.clone()) { + highlight_name_ref(db, file, &mut out, &name_ref); + } if let Some(alias) = ast::Alias::cast(node.clone()) && let Some(column_list) = alias.column_list() @@ -221,6 +271,12 @@ pub fn semantic_tokens( highlight_type(&mut out, ty); } + if let Some(with_table) = ast::WithTable::cast(node.clone()) + && let Some(name) = with_table.name() + { + out.push_token(name.syntax().clone().into(), SemanticTokenType::Table); + } + if let Some(create_function) = ast::CreateFunction::cast(node) { if let Some(param_list) = create_function.param_list() { for param in param_list.params() { @@ -361,8 +417,8 @@ as '' language plpgsql; assert_snapshot!(semantic_tokens(" select 1 and, 2 select; "), @r#" - "and" @ 10..13: Name - "select" @ 17..23: Name + "and" @ 10..13: Column + "select" @ 17..23: Column "#) } @@ -381,7 +437,7 @@ select $1, $2; assert_snapshot!(semantic_tokens( " select * -from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop); +from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r text); ", ), @r#" "int" @ 24..27: Type @@ -390,7 +446,7 @@ from f as t(a int, b jsonb, c text, x int, ca char(5)[], ia int[][], r jbpop); "int" @ 48..51: Type "char" @ 56..60: Type "int" @ 70..73: Type - "jbpop" @ 81..86: Type + "text" @ 81..85: Type "#); } @@ -420,4 +476,123 @@ select $2::jsonb; "jsonb" @ 12..17: Type "#); } + + #[test] + fn select_target_column() { + assert_snapshot!(semantic_tokens( + " +create table t(a int, b text); +select a, b from t; +", + ), @r#" + "int" @ 18..21: Type + "text" @ 25..29: Type + "a" @ 39..40: Column + "b" @ 42..43: Column + "t" @ 49..50: Table + "#); + } + + #[test] + fn select_target_qualified_column() { + assert_snapshot!(semantic_tokens( + " +create table t(a int); +select t.a from t; +", + ), @r#" + "int" @ 18..21: Type + "t" @ 31..32: Table + "a" @ 33..34: Column + "t" @ 40..41: Table + "#); + } + + #[test] + fn select_target_function_call() { + assert_snapshot!(semantic_tokens( + " +create function f() returns int as 'select 1' language sql; +select f(); +", + ), @r#" + "int" @ 29..32: Type + "f" @ 68..69: Function + "#); + } + + #[test] + fn select_function_arg_and_qualified_column() { + assert_snapshot!(semantic_tokens( + " +create table t(a int); +create function b(t) returns int as 'select 1' language sql; +select b(t), t.b from t; +", + ), @r#" + "int" @ 18..21: Type + "t" @ 42..43: Type + "int" @ 53..56: Type + "b" @ 92..93: Function + "t" @ 94..95: Table + "t" @ 98..99: Table + "b" @ 100..101: Function + "t" @ 107..108: Table + "#); + } + + #[test] + fn policy_field_style_function_call() { + assert_snapshot!(semantic_tokens( + " +create table t(c int); +create function x(t) returns int as 'select 1' language sql; +create policy p on t + with check (t.x > 0 and t.c > 0); +", + ), @r#" + "int" @ 18..21: Type + "t" @ 42..43: Type + "int" @ 53..56: Type + "t" @ 104..105: Table + "t" @ 120..121: Table + "x" @ 122..123: Function + "t" @ 132..133: Table + "c" @ 134..135: Column + "#); + } + + #[test] + fn with_cte_name() { + assert_snapshot!(semantic_tokens( + " +with t as ( + select 1 +) +select * from t; +", + ), @r#" + "t" @ 6..7: Table + "t" @ 40..41: Table + "#); + } + + #[test] + fn select_target_schema_qualified() { + assert_snapshot!(semantic_tokens( + " +create schema s; +create table s.t(a int); +select s.t.a from s.t; +", + ), @r#" + "s" @ 31..32: Schema + "int" @ 37..40: Type + "s" @ 50..51: Schema + "t" @ 52..53: Table + "a" @ 54..55: Column + "s" @ 61..62: Schema + "t" @ 63..64: Table + "#); + } } diff --git a/crates/squawk_server/src/lsp_utils.rs b/crates/squawk_server/src/lsp_utils.rs index e681cc35..e03a6ba2 100644 --- a/crates/squawk_server/src/lsp_utils.rs +++ b/crates/squawk_server/src/lsp_utils.rs @@ -319,6 +319,9 @@ fn to_token_type(ty: SemanticTokenType) -> lsp_types::SemanticTokenType { SemanticTokenType::PositionalParam | SemanticTokenType::Parameter => { lsp_types::SemanticTokenType::PARAMETER } + SemanticTokenType::Column => lsp_types::SemanticTokenType::VARIABLE, + SemanticTokenType::Table => lsp_types::SemanticTokenType::STRUCT, + SemanticTokenType::Schema => lsp_types::SemanticTokenType::NAMESPACE, } } diff --git a/crates/squawk_wasm/src/lib.rs b/crates/squawk_wasm/src/lib.rs index 05b1a614..e200a49e 100644 --- a/crates/squawk_wasm/src/lib.rs +++ b/crates/squawk_wasm/src/lib.rs @@ -32,13 +32,16 @@ const SEMANTIC_TOKEN_MODIFIERS: &[&str] = &["declaration", "definition", "readon fn semantic_token_type_name(ty: SemanticTokenType) -> &'static str { match ty { SemanticTokenType::Bool | SemanticTokenType::Keyword => "keyword", + SemanticTokenType::Column => "variable", SemanticTokenType::Comment => "comment", SemanticTokenType::Function => "function", SemanticTokenType::Name | SemanticTokenType::NameRef => "variable", SemanticTokenType::Number => "number", SemanticTokenType::Operator | SemanticTokenType::Punctuation => "operator", SemanticTokenType::Parameter | SemanticTokenType::PositionalParam => "parameter", + SemanticTokenType::Schema => "namespace", SemanticTokenType::String => "string", + SemanticTokenType::Table => "struct", SemanticTokenType::Type => "type", } } diff --git a/squawk-vscode/syntaxes/pgsql.tmLanguage.json b/squawk-vscode/syntaxes/pgsql.tmLanguage.json index 09453257..fb58fa11 100644 --- a/squawk-vscode/syntaxes/pgsql.tmLanguage.json +++ b/squawk-vscode/syntaxes/pgsql.tmLanguage.json @@ -90,9 +90,21 @@ }, "4": { "name": "entity.name.function.pgsql" + }, + "5": { + "name": "keyword.operator.pgsql" + }, + "6": { + "name": "entity.name.function.pgsql" + }, + "7": { + "name": "keyword.operator.pgsql" + }, + "8": { + "name": "entity.name.function.pgsql" } }, - "match": "(?i)\\b(create)\\s+(or\\s+replace\\s+)?(function|view)\\s+((?:(?:[\\w]+|\".+\")\\.){0,2}(?:[\\w]+|\".+\"))" + "match": "(?i)\\b(create)\\s+(or\\s+replace\\s+)?(function|view)\\s+((?:[\\w]+|\"[^\"]+\"))(?:(\\.)(?:([\\w]+|\"[^\"]+\"))(?:(\\.)(?:([\\w]+|\"[^\"]+\")))?)?" }, { "captures": { @@ -107,9 +119,21 @@ }, "4": { "name": "entity.name.function.pgsql" + }, + "5": { + "name": "keyword.operator.pgsql" + }, + "6": { + "name": "entity.name.function.pgsql" + }, + "7": { + "name": "keyword.operator.pgsql" + }, + "8": { + "name": "entity.name.function.pgsql" } }, - "match": "(?i)\\b(create)\\s+(aggregate|collation|conversion|database|domain|event\\s+trigger|group|(?:unique\\s+)?index|language|operator\\s+class|operator|rule|schema|sequence|(?:(?:global|local)\\s+)?(?:(?:temp|temporary|unlogged)\\s+)?table|tablespace|trigger|type|user|(?:materialized\\s+)?view)\\s+(if\\s+not\\s+exists\\s+)?((?:(?:[\\w]+|\".+\")\\.){0,2}(?:[\\w]+|\".+\"))" + "match": "(?i)\\b(create)\\s+(aggregate|collation|conversion|database|domain|event\\s+trigger|group|(?:unique\\s+)?index|language|operator\\s+class|operator|rule|schema|sequence|(?:(?:global|local)\\s+)?(?:(?:temp|temporary|unlogged)\\s+)?table|tablespace|trigger|type|user|(?:materialized\\s+)?view)\\s+(if\\s+not\\s+exists\\s+)?((?:[\\w]+|\"[^\"]+\"))(?:(\\.)(?:([\\w]+|\"[^\"]+\"))(?:(\\.)(?:([\\w]+|\"[^\"]+\")))?)?" } ] }, @@ -174,7 +198,7 @@ "name": "keyword.operator.cast.pgsql" }, { - "match": "[+\\-*/<>=~!@#%\\^&|`?]", + "match": "[+\\-*/<>=~!@#%\\^&|`?\\.]", "name": "keyword.operator.pgsql" } ]