@@ -32,43 +32,61 @@ class TableAlterColumn(PydanticModel):
3232 is_struct : bool
3333 is_array_of_struct : bool
3434 is_array_of_primitive : bool
35+ quoted : bool = False
3536
3637 @classmethod
37- def primitive (self , name : str ) -> TableAlterColumn :
38+ def primitive (self , name : str , quoted : bool = False ) -> TableAlterColumn :
3839 return self (
39- name = name , is_struct = False , is_array_of_struct = False , is_array_of_primitive = False
40+ name = name ,
41+ is_struct = False ,
42+ is_array_of_struct = False ,
43+ is_array_of_primitive = False ,
44+ quoted = quoted ,
4045 )
4146
4247 @classmethod
43- def struct (self , name : str ) -> TableAlterColumn :
48+ def struct (self , name : str , quoted : bool = False ) -> TableAlterColumn :
4449 return self (
45- name = name , is_struct = True , is_array_of_struct = False , is_array_of_primitive = False
50+ name = name ,
51+ is_struct = True ,
52+ is_array_of_struct = False ,
53+ is_array_of_primitive = False ,
54+ quoted = quoted ,
4655 )
4756
4857 @classmethod
49- def array_of_struct (self , name : str ) -> TableAlterColumn :
58+ def array_of_struct (self , name : str , quoted : bool = False ) -> TableAlterColumn :
5059 return self (
51- name = name , is_struct = False , is_array_of_struct = True , is_array_of_primitive = False
60+ name = name ,
61+ is_struct = False ,
62+ is_array_of_struct = True ,
63+ is_array_of_primitive = False ,
64+ quoted = quoted ,
5265 )
5366
5467 @classmethod
55- def array_of_primitive (self , name : str ) -> TableAlterColumn :
68+ def array_of_primitive (self , name : str , quoted : bool = False ) -> TableAlterColumn :
5669 return self (
57- name = name , is_struct = False , is_array_of_struct = False , is_array_of_primitive = True
70+ name = name ,
71+ is_struct = False ,
72+ is_array_of_struct = False ,
73+ is_array_of_primitive = True ,
74+ quoted = quoted ,
5875 )
5976
6077 @classmethod
6178 def from_struct_kwarg (self , struct : exp .StructKwarg ) -> TableAlterColumn :
6279 name = struct .alias_or_name
80+ quoted = struct .this .quoted
6381 if struct .expression .is_type (exp .DataType .Type .STRUCT ):
64- return self .struct (name )
82+ return self .struct (name , quoted = quoted )
6583 elif struct .expression .is_type (exp .DataType .Type .ARRAY ):
6684 if struct .expression .expressions [0 ].is_type (exp .DataType .Type .STRUCT ):
67- return self .array_of_struct (name )
85+ return self .array_of_struct (name , quoted = quoted )
6886 else :
69- return self .array_of_primitive (name )
87+ return self .array_of_primitive (name , quoted = quoted )
7088 else :
71- return self .primitive (name )
89+ return self .primitive (name , quoted = quoted )
7290
7391 @property
7492 def is_array (self ) -> bool :
@@ -82,23 +100,29 @@ def is_primitive(self) -> bool:
82100 def is_nested (self ) -> bool :
83101 return not self .is_primitive
84102
103+ @property
104+ def identifier (self ) -> exp .Identifier :
105+ return exp .to_identifier (self .name , quoted = self .quoted )
106+
85107
86108class TableAlterColumnPosition (PydanticModel ):
87109 is_first : bool
88110 is_last : bool
89- after : t .Optional [str ] = None
111+ after : t .Optional [exp . Identifier ] = None
90112
91113 @classmethod
92114 def first (self ) -> TableAlterColumnPosition :
93115 return self (is_first = True , is_last = False , after = None )
94116
95117 @classmethod
96- def last (self , after : t .Optional [str ] = None ) -> TableAlterColumnPosition :
97- return self (is_first = False , is_last = True , after = after )
118+ def last (
119+ self , after : t .Optional [t .Union [str , exp .Identifier ]] = None
120+ ) -> TableAlterColumnPosition :
121+ return self (is_first = False , is_last = True , after = exp .to_identifier (after ) if after else None )
98122
99123 @classmethod
100- def middle (self , after : str ) -> TableAlterColumnPosition :
101- return self (is_first = False , is_last = False , after = after )
124+ def middle (self , after : t . Union [ str , exp . Identifier ] ) -> TableAlterColumnPosition :
125+ return self (is_first = False , is_last = False , after = exp . to_identifier ( after ) )
102126
103127 @classmethod
104128 def create (
@@ -117,7 +141,7 @@ def create(
117141
118142 @property
119143 def column_position_node (self ) -> t .Optional [exp .ColumnPosition ]:
120- column = exp . column ( self .after ) if self . after and not self .is_last else None
144+ column = self .after if not self .is_last else None
121145 position = None
122146 if self .is_first :
123147 position = "FIRST"
@@ -195,45 +219,49 @@ def is_drop(self) -> bool:
195219 def is_alter_type (self ) -> bool :
196220 return self .op .is_alter_type
197221
198- def full_column_path (self , array_suffix : str ) -> str :
222+ def column_identifiers (self , array_element_selector : str ) -> t . List [ exp . Identifier ] :
199223 results = []
200224 for column in self .columns :
201- if column .is_array_of_struct and len (self .columns ) > 1 :
202- results .append (column .name + array_suffix )
203- else :
204- results .append (column .name )
205- return "." .join (results )
206-
207- def column (self , array_suffix : str ) -> exp .Column :
208- return exp .column (self .full_column_path (array_suffix ))
209-
210- def column_def (self , array_suffix : str ) -> exp .ColumnDef :
225+ results .append (column .identifier )
226+ if column .is_array_of_struct and len (self .columns ) > 1 and array_element_selector :
227+ results .append (exp .to_identifier (array_element_selector ))
228+ return results
229+
230+ def column (self , array_element_selector : str ) -> t .Union [exp .Dot , exp .Identifier ]:
231+ columns = self .column_identifiers (array_element_selector )
232+ if len (columns ) == 1 :
233+ return columns [0 ]
234+ return exp .Dot .build (columns )
235+
236+ def column_def (self , array_element_selector : str ) -> exp .ColumnDef :
211237 return exp .ColumnDef (
212- this = exp . to_identifier ( self .full_column_path ( array_suffix ) ),
238+ this = self .column ( array_element_selector ),
213239 kind = self .column_type ,
214240 )
215241
216- def expression (self , table_name : t .Union [str , exp .Table ], array_suffix : str ) -> exp .AlterTable :
242+ def expression (
243+ self , table_name : t .Union [str , exp .Table ], array_element_selector : str
244+ ) -> exp .AlterTable :
217245 if self .is_alter_type :
218246 return exp .AlterTable (
219247 this = exp .to_table (table_name ),
220248 actions = [
221249 exp .AlterColumn (
222- this = self .column (array_suffix ),
250+ this = self .column (array_element_selector ),
223251 dtype = self .column_type ,
224252 )
225253 ],
226254 )
227255 elif self .is_add :
228256 alter_table = exp .AlterTable (this = exp .to_table (table_name ))
229- column = self .column_def (array_suffix )
257+ column = self .column_def (array_element_selector )
230258 alter_table .set ("actions" , [column ])
231259 if self .add_position :
232260 column .set ("position" , self .add_position .column_position_node )
233261 return alter_table
234262 elif self .is_drop :
235263 alter_table = exp .AlterTable (this = exp .to_table (table_name ))
236- drop_column = exp .Drop (this = self .column (array_suffix ), kind = "COLUMN" )
264+ drop_column = exp .Drop (this = self .column (array_element_selector ), kind = "COLUMN" )
237265 alter_table .set ("actions" , [drop_column ])
238266 return alter_table
239267 else :
@@ -260,7 +288,7 @@ class SchemaDiffer(PydanticModel):
260288
261289 support_positional_add : bool = False
262290 support_nested_operations : bool = False
263- array_suffix : str = ""
291+ array_element_selector : str = ""
264292 compatible_types : t .Dict [exp .DataType , t .Set [exp .DataType ]] = {}
265293
266294 @classmethod
@@ -287,20 +315,20 @@ def _get_matching_kwarg(
287315 current_pos : int ,
288316 ) -> t .Tuple [t .Optional [int ], t .Optional [exp .StructKwarg ]]:
289317 current_name = (
290- current_kwarg
318+ exp . to_identifier ( current_kwarg )
291319 if isinstance (current_kwarg , str )
292320 else _get_name_and_type (current_kwarg )[0 ]
293321 )
294322 # First check if we have the same column in the same position to get O(1) complexity
295323 new_kwarg = seq_get (new_struct .expressions , current_pos )
296324 if new_kwarg :
297325 new_name , new_type = _get_name_and_type (new_kwarg )
298- if current_name == new_name :
326+ if current_name . this == new_name . this :
299327 return current_pos , new_kwarg
300328 # If not, check if we have the same column in all positions with O(n) complexity
301329 for i , new_kwarg in enumerate (new_struct .expressions ):
302330 new_name , new_type = _get_name_and_type (new_kwarg )
303- if current_name == new_name :
331+ if current_name . this == new_name . this :
304332 return i , new_kwarg
305333 return None , None
306334
@@ -499,7 +527,8 @@ def compare_structs(
499527 The list of table alter operations.
500528 """
501529 return [
502- op .expression (table_name , self .array_suffix ) for op in self ._from_structs (current , new )
530+ op .expression (table_name , self .array_element_selector )
531+ for op in self ._from_structs (current , new )
503532 ]
504533
505534 def compare_columns (
@@ -523,5 +552,5 @@ def compare_columns(
523552 )
524553
525554
526- def _get_name_and_type (struct : exp .StructKwarg ) -> t .Tuple [str , exp .DataType ]:
527- return struct .alias_or_name , struct .expression
555+ def _get_name_and_type (struct : exp .StructKwarg ) -> t .Tuple [exp . Identifier , exp .DataType ]:
556+ return struct .this , struct .expression
0 commit comments