@@ -30,6 +30,7 @@ use crate::datasource::provider_as_source;
3030use crate :: execution:: SessionStateDefaults ;
3131use crate :: execution:: context:: { EmptySerializerRegistry , FunctionFactory , QueryPlanner } ;
3232use crate :: physical_planner:: { DefaultPhysicalPlanner , PhysicalPlanner } ;
33+ use arrow_schema:: extension:: ExtensionType ;
3334use arrow_schema:: { DataType , FieldRef } ;
3435use datafusion_catalog:: MemoryCatalogProviderList ;
3536use datafusion_catalog:: information_schema:: {
@@ -56,7 +57,7 @@ use datafusion_expr::expr_rewriter::FunctionRewrite;
5657use datafusion_expr:: planner:: ExprPlanner ;
5758#[ cfg( feature = "sql" ) ]
5859use datafusion_expr:: planner:: { RelationPlanner , TypePlanner } ;
59- use datafusion_expr:: registry:: { FunctionRegistry , SerializerRegistry } ;
60+ use datafusion_expr:: registry:: { ExtensionTypeRegistration , ExtensionTypeRegistrationRef , ExtensionTypeRegistry , ExtensionTypeRegistryRef , FunctionRegistry , MemoryExtensionTypeRegistry , SerializerRegistry , SimpleExtensionTypeRegistration } ;
6061use datafusion_expr:: simplify:: SimplifyContext ;
6162use datafusion_expr:: { AggregateUDF , Explain , Expr , LogicalPlan , ScalarUDF , WindowUDF } ;
6263use datafusion_optimizer:: simplify_expressions:: ExprSimplifier ;
@@ -77,6 +78,7 @@ use datafusion_sql::{
7778
7879use async_trait:: async_trait;
7980use chrono:: { DateTime , Utc } ;
81+ use datafusion_common:: types:: UuidDFExtensionType ;
8082use itertools:: Itertools ;
8183use log:: { debug, info} ;
8284use object_store:: ObjectStore ;
@@ -158,6 +160,8 @@ pub struct SessionState {
158160 aggregate_functions : HashMap < String , Arc < AggregateUDF > > ,
159161 /// Window functions registered in the context
160162 window_functions : HashMap < String , Arc < WindowUDF > > ,
163+ /// Extension types registry for extensions.
164+ extension_types : ExtensionTypeRegistryRef ,
161165 /// Deserializer registry for extensions.
162166 serializer_registry : Arc < dyn SerializerRegistry > ,
163167 /// Holds registered external FileFormat implementations
@@ -266,6 +270,10 @@ impl Session for SessionState {
266270 & self . window_functions
267271 }
268272
273+ fn extension_type_registry ( & self ) -> & ExtensionTypeRegistryRef {
274+ & self . extension_types
275+ }
276+
269277 fn runtime_env ( & self ) -> & Arc < RuntimeEnv > {
270278 self . runtime_env ( )
271279 }
@@ -986,6 +994,7 @@ pub struct SessionStateBuilder {
986994 scalar_functions : Option < Vec < Arc < ScalarUDF > > > ,
987995 aggregate_functions : Option < Vec < Arc < AggregateUDF > > > ,
988996 window_functions : Option < Vec < Arc < WindowUDF > > > ,
997+ extension_types : Option < ExtensionTypeRegistryRef > ,
989998 serializer_registry : Option < Arc < dyn SerializerRegistry > > ,
990999 file_formats : Option < Vec < Arc < dyn FileFormatFactory > > > ,
9911000 config : Option < SessionConfig > ,
@@ -1026,6 +1035,7 @@ impl SessionStateBuilder {
10261035 scalar_functions : None ,
10271036 aggregate_functions : None ,
10281037 window_functions : None ,
1038+ extension_types : None ,
10291039 serializer_registry : None ,
10301040 file_formats : None ,
10311041 table_options : None ,
@@ -1081,6 +1091,7 @@ impl SessionStateBuilder {
10811091 existing. aggregate_functions . into_values ( ) . collect_vec ( ) ,
10821092 ) ,
10831093 window_functions : Some ( existing. window_functions . into_values ( ) . collect_vec ( ) ) ,
1094+ extension_types : Some ( existing. extension_types ) ,
10841095 serializer_registry : Some ( existing. serializer_registry ) ,
10851096 file_formats : Some ( existing. file_formats . into_values ( ) . collect_vec ( ) ) ,
10861097 config : Some ( new_config) ,
@@ -1126,6 +1137,11 @@ impl SessionStateBuilder {
11261137 . get_or_insert_with ( Vec :: new)
11271138 . extend ( SessionStateDefaults :: default_window_functions ( ) ) ;
11281139
1140+ self . extension_types
1141+ . get_or_insert_with ( || Arc :: new ( MemoryExtensionTypeRegistry :: new ( ) ) )
1142+ . extend ( & SessionStateDefaults :: default_extension_types ( ) )
1143+ . expect ( "MemoryExtensionTypeRegistry is not read-only." ) ;
1144+
11291145 self . table_functions
11301146 . get_or_insert_with ( HashMap :: new)
11311147 . extend (
@@ -1316,6 +1332,44 @@ impl SessionStateBuilder {
13161332 self
13171333 }
13181334
1335+ /// Set the map of [`ExtensionTypeRegistration`]s
1336+ pub fn with_extension_type (
1337+ mut self ,
1338+ registry : ExtensionTypeRegistryRef ,
1339+ ) -> Self {
1340+ self . extension_types = Some ( registry) ;
1341+ self
1342+ }
1343+
1344+ /// Registers [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html)
1345+ /// in DataFusion's extension type registry. For more information see [`ExtensionTypeRegistry`].
1346+ ///
1347+ /// # Errors
1348+ ///
1349+ /// May fail if an already registered [`ExtensionTypeRegistry`] raises an error while
1350+ /// registering the canonical extension types.
1351+ pub fn with_canonical_extension_types ( mut self ) -> datafusion_common:: Result < Self > {
1352+ let canonical_extension_types = vec ! [ SimpleExtensionTypeRegistration :: new_arc(
1353+ arrow_schema:: extension:: Uuid :: NAME ,
1354+ Arc :: new( UuidDFExtensionType :: new( ) ) ,
1355+ ) ] ;
1356+
1357+ match & self . extension_types {
1358+ None => {
1359+ let registry = Arc :: new ( MemoryExtensionTypeRegistry :: new ( ) ) ;
1360+ registry
1361+ . extend ( & canonical_extension_types)
1362+ . expect ( "Adding valid extension types to MemoryExtensionTypeRegistry always succeeds." ) ;
1363+ self . extension_types = Some ( registry) ;
1364+ }
1365+ Some ( registry) => {
1366+ registry. extend ( & canonical_extension_types) ?;
1367+ }
1368+ }
1369+
1370+ Ok ( self )
1371+ }
1372+
13191373 /// Set the [`SerializerRegistry`]
13201374 pub fn with_serializer_registry (
13211375 mut self ,
@@ -1454,6 +1508,7 @@ impl SessionStateBuilder {
14541508 scalar_functions,
14551509 aggregate_functions,
14561510 window_functions,
1511+ extension_types,
14571512 serializer_registry,
14581513 file_formats,
14591514 table_options,
@@ -1490,6 +1545,7 @@ impl SessionStateBuilder {
14901545 scalar_functions : HashMap :: new ( ) ,
14911546 aggregate_functions : HashMap :: new ( ) ,
14921547 window_functions : HashMap :: new ( ) ,
1548+ extension_types : Arc :: new ( MemoryExtensionTypeRegistry :: default ( ) ) ,
14931549 serializer_registry : serializer_registry
14941550 . unwrap_or_else ( || Arc :: new ( EmptySerializerRegistry ) ) ,
14951551 file_formats : HashMap :: new ( ) ,
@@ -1559,6 +1615,10 @@ impl SessionStateBuilder {
15591615 } ) ;
15601616 }
15611617
1618+ if let Some ( extension_types) = extension_types {
1619+ state. extension_types = extension_types;
1620+ }
1621+
15621622 if state. config . create_default_catalog_and_schema ( ) {
15631623 let default_catalog = SessionStateDefaults :: default_catalog (
15641624 & state. config ,
@@ -2071,6 +2131,35 @@ impl datafusion_execution::TaskContextProvider for SessionState {
20712131 }
20722132}
20732133
2134+ impl ExtensionTypeRegistry for SessionState {
2135+ fn extension_type_registration (
2136+ & self ,
2137+ name : & str ,
2138+ ) -> datafusion_common:: Result < ExtensionTypeRegistrationRef > {
2139+ self . extension_types . extension_type_registration ( name)
2140+ }
2141+
2142+ fn extension_type_registrations ( & self ) -> Vec < Arc < dyn ExtensionTypeRegistration > > {
2143+ self . extension_types . extension_type_registrations ( )
2144+ }
2145+
2146+ fn add_extension_type_registration (
2147+ & self ,
2148+ extension_type : ExtensionTypeRegistrationRef ,
2149+ ) -> datafusion_common:: Result < Option < ExtensionTypeRegistrationRef > > {
2150+ self . extension_types
2151+ . add_extension_type_registration ( extension_type)
2152+ }
2153+
2154+ fn remove_extension_type_registration (
2155+ & self ,
2156+ name : & str ,
2157+ ) -> datafusion_common:: Result < Option < ExtensionTypeRegistrationRef > > {
2158+ self . extension_types
2159+ . remove_extension_type_registration ( name)
2160+ }
2161+ }
2162+
20742163impl OptimizerConfig for SessionState {
20752164 fn query_execution_start_time ( & self ) -> Option < DateTime < Utc > > {
20762165 self . execution_props . query_execution_start_time
0 commit comments