1111 DateTime ,
1212)
1313from collections .abc import Sequence
14- from databricks .sqlalchemy import TIMESTAMP , TINYINT , DatabricksArray , DatabricksMap , DatabricksVariant
14+ from databricks .sqlalchemy import (
15+ TIMESTAMP ,
16+ TINYINT ,
17+ DatabricksArray ,
18+ DatabricksMap ,
19+ DatabricksVariant ,
20+ )
1521from sqlalchemy .orm import DeclarativeBase , Session
1622from sqlalchemy import select
1723from datetime import date , datetime , time , timedelta , timezone
2026import decimal
2127import json
2228
29+
2330class TestComplexTypes (TestSetup ):
2431 def _parse_to_common_type (self , value ):
2532 """
@@ -175,8 +182,8 @@ class VariantTable(Base):
175182 "number" : 123 ,
176183 "boolean" : True ,
177184 "array" : [1 , 2 , 3 ],
178- "object" : {"nested" : "value" }
179- }
185+ "object" : {"nested" : "value" },
186+ },
180187 }
181188
182189 return VariantTable , sample_data
@@ -239,6 +246,44 @@ def test_map_table_creation_pandas(self):
239246 df_result = pd .read_sql (stmt , engine )
240247 assert self ._recursive_compare (df_result .iloc [0 ].to_dict (), sample_data )
241248
249+ def test_array_table_creation_pandas_multi (self ):
250+ table , sample_data = self .sample_array_table ()
251+
252+ with self .table_context (table ) as engine :
253+ df = pd .DataFrame ([sample_data , sample_data | {"int_col" : 2 }])
254+ df .to_sql (
255+ table .__tablename__ ,
256+ engine ,
257+ if_exists = "append" ,
258+ index = False ,
259+ method = "multi" ,
260+ )
261+
262+ stmt = select (table ).order_by (table .int_col )
263+ df_result = pd .read_sql (stmt , engine )
264+ assert self ._recursive_compare (df_result .iloc [0 ].to_dict (), sample_data )
265+ expected_second = sample_data | {"int_col" : 2 }
266+ assert self ._recursive_compare (df_result .iloc [1 ].to_dict (), expected_second )
267+
268+ def test_map_table_creation_pandas_multi (self ):
269+ table , sample_data = self .sample_map_table ()
270+
271+ with self .table_context (table ) as engine :
272+ df = pd .DataFrame ([sample_data , sample_data | {"int_col" : 2 }])
273+ df .to_sql (
274+ table .__tablename__ ,
275+ engine ,
276+ if_exists = "append" ,
277+ index = False ,
278+ method = "multi" ,
279+ )
280+
281+ stmt = select (table ).order_by (table .int_col )
282+ df_result = pd .read_sql (stmt , engine )
283+ assert self ._recursive_compare (df_result .iloc [0 ].to_dict (), sample_data )
284+ expected_second = sample_data | {"int_col" : 2 }
285+ assert self ._recursive_compare (df_result .iloc [1 ].to_dict (), expected_second )
286+
242287 def test_insert_variant_table_sqlalchemy (self ):
243288 table , sample_data = self .sample_variant_table ()
244289
@@ -253,7 +298,12 @@ def test_insert_variant_table_sqlalchemy(self):
253298 result = session .scalar (stmt )
254299 compare = {key : getattr (result , key ) for key in sample_data .keys ()}
255300 # Parse JSON values back to original format for comparison
256- for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
301+ for key in [
302+ "variant_simple_col" ,
303+ "variant_nested_col" ,
304+ "variant_array_col" ,
305+ "variant_mixed_col" ,
306+ ]:
257307 if compare [key ] is not None :
258308 compare [key ] = json .loads (compare [key ])
259309
@@ -263,26 +313,76 @@ def test_variant_table_creation_pandas(self):
263313 table , sample_data = self .sample_variant_table ()
264314
265315 with self .table_context (table ) as engine :
266-
316+
267317 df = pd .DataFrame ([sample_data ])
268318 dtype_mapping = {
269319 "variant_simple_col" : DatabricksVariant ,
270320 "variant_nested_col" : DatabricksVariant ,
271321 "variant_array_col" : DatabricksVariant ,
272- "variant_mixed_col" : DatabricksVariant
322+ "variant_mixed_col" : DatabricksVariant ,
273323 }
274- df .to_sql (table .__tablename__ , engine , if_exists = "append" , index = False , dtype = dtype_mapping )
275-
324+ df .to_sql (
325+ table .__tablename__ ,
326+ engine ,
327+ if_exists = "append" ,
328+ index = False ,
329+ dtype = dtype_mapping ,
330+ )
331+
276332 stmt = select (table )
277333 df_result = pd .read_sql (stmt , engine )
278334 result_dict = df_result .iloc [0 ].to_dict ()
279335 # Parse JSON values back to original format for comparison
280- for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
336+ for key in [
337+ "variant_simple_col" ,
338+ "variant_nested_col" ,
339+ "variant_array_col" ,
340+ "variant_mixed_col" ,
341+ ]:
281342 if result_dict [key ] is not None :
282343 result_dict [key ] = json .loads (result_dict [key ])
283344
284345 assert result_dict == sample_data
285346
347+ def test_variant_table_creation_pandas_multi (self ):
348+ table , sample_data = self .sample_variant_table ()
349+
350+ with self .table_context (table ) as engine :
351+ second = sample_data | {"int_col" : 2 }
352+ df = pd .DataFrame ([sample_data , second ])
353+ dtype_mapping = {
354+ "variant_simple_col" : DatabricksVariant ,
355+ "variant_nested_col" : DatabricksVariant ,
356+ "variant_array_col" : DatabricksVariant ,
357+ "variant_mixed_col" : DatabricksVariant ,
358+ }
359+ df .to_sql (
360+ table .__tablename__ ,
361+ engine ,
362+ if_exists = "append" ,
363+ index = False ,
364+ dtype = dtype_mapping ,
365+ method = "multi" ,
366+ )
367+
368+ stmt = select (table ).order_by (table .int_col )
369+ df_result = pd .read_sql (stmt , engine )
370+ first_row = df_result .iloc [0 ].to_dict ()
371+ second_row = df_result .iloc [1 ].to_dict ()
372+ for key in [
373+ "variant_simple_col" ,
374+ "variant_nested_col" ,
375+ "variant_array_col" ,
376+ "variant_mixed_col" ,
377+ ]:
378+ if first_row [key ] is not None :
379+ first_row [key ] = json .loads (first_row [key ])
380+ if second_row [key ] is not None :
381+ second_row [key ] = json .loads (second_row [key ])
382+
383+ assert first_row == sample_data
384+ assert second_row == second
385+
286386 def test_variant_literal_processor (self ):
287387 table , sample_data = self .sample_variant_table ()
288388
@@ -291,8 +391,7 @@ def test_variant_literal_processor(self):
291391
292392 try :
293393 compiled = stmt .compile (
294- dialect = engine .dialect ,
295- compile_kwargs = {"literal_binds" : True }
394+ dialect = engine .dialect , compile_kwargs = {"literal_binds" : True }
296395 )
297396 sql_str = str (compiled )
298397
@@ -311,7 +410,12 @@ def test_variant_literal_processor(self):
311410 compare = {key : getattr (result , key ) for key in sample_data .keys ()}
312411
313412 # Parse JSON values back to original Python objects
314- for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
413+ for key in [
414+ "variant_simple_col" ,
415+ "variant_nested_col" ,
416+ "variant_array_col" ,
417+ "variant_mixed_col" ,
418+ ]:
315419 if compare [key ] is not None :
316420 compare [key ] = json .loads (compare [key ])
317421
0 commit comments