@@ -13,6 +13,7 @@ pub enum ClauseType {
1313 Select ,
1414 Where ,
1515 From ,
16+ Join ,
1617 Update ,
1718 Delete ,
1819}
@@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType {
3334 "from" => Ok ( Self :: From ) ,
3435 "update" => Ok ( Self :: Update ) ,
3536 "delete" => Ok ( Self :: Delete ) ,
37+ "join" => Ok ( Self :: Join ) ,
3638 _ => {
3739 let message = format ! ( "Unimplemented ClauseType: {}" , value) ;
3840
@@ -106,14 +108,35 @@ pub(crate) struct CompletionContext<'a> {
106108 pub schema_cache : & ' a SchemaCache ,
107109 pub position : usize ,
108110
109- pub schema_name : Option < String > ,
111+ /// If the cursor is on a node that uses dot notation
112+ /// to specify an alias or schema, this will hold the schema's or
113+ /// alias's name.
114+ ///
115+ /// Here, `auth` is a schema name:
116+ /// ```sql
117+ /// select * from auth.users;
118+ /// ```
119+ ///
120+ /// Here, `u` is an alias name:
121+ /// ```sql
122+ /// select
123+ /// *
124+ /// from
125+ /// auth.users u
126+ /// left join identities i
127+ /// on u.id = i.user_id;
128+ /// ```
129+ pub schema_or_alias_name : Option < String > ,
110130 pub wrapping_clause_type : Option < ClauseType > ,
111131
112132 pub wrapping_node_kind : Option < WrappingNode > ,
113133
114134 pub is_invocation : bool ,
115135 pub wrapping_statement_range : Option < tree_sitter:: Range > ,
116136
137+ /// Some incomplete statements can't be correctly parsed by TreeSitter.
138+ pub is_in_error_node : bool ,
139+
117140 pub mentioned_relations : HashMap < Option < String > , HashSet < String > > ,
118141
119142 pub mentioned_table_aliases : HashMap < String , String > ,
@@ -127,13 +150,14 @@ impl<'a> CompletionContext<'a> {
127150 schema_cache : params. schema ,
128151 position : usize:: from ( params. position ) ,
129152 node_under_cursor : None ,
130- schema_name : None ,
153+ schema_or_alias_name : None ,
131154 wrapping_clause_type : None ,
132155 wrapping_node_kind : None ,
133156 wrapping_statement_range : None ,
134157 is_invocation : false ,
135158 mentioned_relations : HashMap :: new ( ) ,
136159 mentioned_table_aliases : HashMap :: new ( ) ,
160+ is_in_error_node : false ,
137161 } ;
138162
139163 ctx. gather_tree_context ( ) ;
@@ -246,34 +270,77 @@ impl<'a> CompletionContext<'a> {
246270 self . wrapping_statement_range = Some ( parent_node. range ( ) ) ;
247271 }
248272 "invocation" => self . is_invocation = true ,
249-
250273 _ => { }
251274 }
252275
276+ // try to gather context from the siblings if we're within an error node.
277+ if self . is_in_error_node {
278+ let mut next_sibling = current_node. next_named_sibling ( ) ;
279+ while let Some ( n) = next_sibling {
280+ if n. kind ( ) . starts_with ( "keyword_" ) {
281+ if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
282+ NodeText :: Original ( txt) => Some ( txt) ,
283+ NodeText :: Replaced => None ,
284+ } ) {
285+ match txt {
286+ "where" | "update" | "select" | "delete" | "from" | "join" => {
287+ self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
288+ break ;
289+ }
290+ _ => { }
291+ }
292+ } ;
293+ }
294+ next_sibling = n. next_named_sibling ( ) ;
295+ }
296+ let mut prev_sibling = current_node. prev_named_sibling ( ) ;
297+ while let Some ( n) = prev_sibling {
298+ if n. kind ( ) . starts_with ( "keyword_" ) {
299+ if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
300+ NodeText :: Original ( txt) => Some ( txt) ,
301+ NodeText :: Replaced => None ,
302+ } ) {
303+ match txt {
304+ "where" | "update" | "select" | "delete" | "from" | "join" => {
305+ self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
306+ break ;
307+ }
308+ _ => { }
309+ }
310+ } ;
311+ }
312+ prev_sibling = n. prev_named_sibling ( ) ;
313+ }
314+ }
315+
253316 match current_node_kind {
254- "object_reference" => {
317+ "object_reference" | "field" => {
255318 let content = self . get_ts_node_content ( current_node) ;
256319 if let Some ( node_txt) = content {
257320 match node_txt {
258321 NodeText :: Original ( txt) => {
259322 let parts: Vec < & str > = txt. split ( '.' ) . collect ( ) ;
260323 if parts. len ( ) == 2 {
261- self . schema_name = Some ( parts[ 0 ] . to_string ( ) ) ;
324+ self . schema_or_alias_name = Some ( parts[ 0 ] . to_string ( ) ) ;
262325 }
263326 }
264327 NodeText :: Replaced => { }
265328 }
266329 }
267330 }
268331
269- "where" | "update" | "select" | "delete" | "from" => {
332+ "where" | "update" | "select" | "delete" | "from" | "join" => {
270333 self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
271334 }
272335
273336 "relation" | "binary_expression" | "assignment" => {
274337 self . wrapping_node_kind = current_node_kind. try_into ( ) . ok ( ) ;
275338 }
276339
340+ "ERROR" => {
341+ self . is_in_error_node = true ;
342+ }
343+
277344 _ => { }
278345 }
279346
@@ -380,7 +447,10 @@ mod tests {
380447
381448 let ctx = CompletionContext :: new ( & params) ;
382449
383- assert_eq ! ( ctx. schema_name, expected_schema. map( |f| f. to_string( ) ) ) ;
450+ assert_eq ! (
451+ ctx. schema_or_alias_name,
452+ expected_schema. map( |f| f. to_string( ) )
453+ ) ;
384454 }
385455 }
386456
0 commit comments