Skip to content

Commit 05b37c0

Browse files
committed
Add draft for basic extension type support
1 parent aa9520e commit 05b37c0

18 files changed

Lines changed: 640 additions & 9 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ paste = { workspace = true }
8181
recursive = { workspace = true, optional = true }
8282
sqlparser = { workspace = true, optional = true }
8383
tokio = { workspace = true }
84+
uuid = { workspace = true, features = ["v4"] }
8485

8586
[target.'cfg(target_family = "wasm")'.dependencies]
8687
web-time = "1.1.0"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod uuid;
2+
3+
pub use uuid::*;
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use crate::error::_internal_err;
2+
use crate::types::extension::DFExtensionType;
3+
use arrow::array::{Array, FixedSizeBinaryArray};
4+
use arrow::datatypes::DataType;
5+
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
6+
use std::fmt::Write;
7+
use uuid::{Bytes, Uuid};
8+
9+
/// Defines the extension type logic for the canonical `arrow.uuid` extension type.
10+
///
11+
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism.
12+
#[derive(Debug)]
13+
pub struct UuidDFExtensionType();
14+
15+
impl UuidDFExtensionType {
16+
/// Create a new instance of [`UuidDFExtensionType`].
17+
pub fn new() -> Self {
18+
Self {}
19+
}
20+
}
21+
22+
impl Default for UuidDFExtensionType {
23+
fn default() -> Self {
24+
Self::new()
25+
}
26+
}
27+
28+
impl DFExtensionType for UuidDFExtensionType {
29+
fn create_array_formatter<'fmt>(
30+
&self,
31+
array: &'fmt dyn Array,
32+
options: &FormatOptions<'fmt>,
33+
) -> crate::Result<Option<ArrayFormatter<'fmt>>> {
34+
if array.data_type() != &DataType::FixedSizeBinary(16) {
35+
return _internal_err!("Wrong array type for Uuid");
36+
}
37+
38+
let display_index = UuidValueDisplayIndex {
39+
array: array.as_any().downcast_ref().unwrap(),
40+
null_str: options.null(),
41+
};
42+
Ok(Some(ArrayFormatter::new(
43+
Box::new(display_index),
44+
options.safe(),
45+
)))
46+
}
47+
}
48+
49+
/// Pretty printer for binary UUID values.
50+
#[derive(Debug, Clone, Copy)]
51+
struct UuidValueDisplayIndex<'a> {
52+
array: &'a FixedSizeBinaryArray,
53+
null_str: &'a str,
54+
}
55+
56+
impl DisplayIndex for UuidValueDisplayIndex<'_> {
57+
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
58+
if self.array.is_null(idx) {
59+
write!(f, "arrow.uuid({})", self.null_str)?;
60+
return Ok(());
61+
}
62+
63+
let bytes = Bytes::try_from(self.array.value(idx))
64+
.expect("FixedSizeBinaryArray length checked in create_array_formatter");
65+
let uuid = Uuid::from_bytes(bytes);
66+
write!(f, "arrow.uuid({uuid})")?;
67+
Ok(())
68+
}
69+
}
70+
71+
#[cfg(test)]
72+
mod tests {
73+
use super::*;
74+
use crate::ScalarValue;
75+
76+
#[test]
77+
pub fn test_pretty_print_uuid() {
78+
let my_uuid = Uuid::nil();
79+
let uuid = ScalarValue::FixedSizeBinary(16, Some(my_uuid.as_bytes().to_vec()))
80+
.to_array_of_size(1)
81+
.unwrap();
82+
83+
let extension_type = UuidDFExtensionType::new();
84+
let formatter = extension_type
85+
.create_array_formatter(uuid.as_ref(), &FormatOptions::default())
86+
.unwrap()
87+
.unwrap();
88+
89+
assert_eq!(
90+
formatter.value(0).to_string(),
91+
"arrow.uuid(00000000-0000-0000-0000-000000000000)"
92+
);
93+
}
94+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use crate::error::Result;
2+
use arrow::array::Array;
3+
use arrow::datatypes::DataType;
4+
use arrow::util::display::{ArrayFormatter, FormatOptions};
5+
use std::fmt::Debug;
6+
use std::sync::Arc;
7+
8+
/// A cheaply cloneable pointer to a [`DFExtensionType`].
9+
pub type DFExtensionTypeRef = Arc<dyn DFExtensionType>;
10+
11+
/// Represents an implementation of a DataFusion extension type, allowing users to customize the
12+
/// behavior of DataFusion for custom extension types.
13+
///
14+
/// Extension types may change the semantics of a column. For example, adding two values of
15+
/// [`DataType::Int64`] is a sensible thing to do. However, if the same data type is annotated with
16+
/// an extension type like `custom.id`, the correct interpretation of a column changes. For example,
17+
/// adding together two `custom.id` values (represented as a 64-bit integer) may no longer make
18+
/// sense.
19+
///
20+
/// Note that while helping users to navigate the semantic gap between the data type and extension
21+
/// types is a goal of this trait, DataFusion's extension type support is still evolving and does
22+
/// not cover all use cases. Currently, the following capabilities can be customized:
23+
/// - Pretty-printing values in record batches
24+
///
25+
/// # Relation to Arrow's `ExtensionType`
26+
///
27+
/// The purpose of Arrow's `ExtensionType` trait, for the time being, is to provide a way to handle
28+
/// metadata of an extension type in a type-safe manner. The trait does not provide any
29+
/// customization options such that users can customize the behavior of any kernels (e.g.,
30+
/// [`DFExtensionType::create_array_formatter`] for formatting record batches). Therefore,
31+
/// downstream users (such as DataFusion) have the flexibility to implement the extension type
32+
/// mechanism according to their needs. [`DFExtensionType`] is DataFusion's implementation of this
33+
/// extension type mechanism.
34+
///
35+
/// Furthermore, Arrow's current trait is not dyn-compatible which we need for implementing
36+
/// extension type registries. In the future, the two implementations may increasingly converge.
37+
///
38+
/// # Example
39+
///
40+
///
41+
pub trait DFExtensionType: Debug + Send + Sync {
42+
/// Returns an [`ArrayFormatter`] that can format values of this type.
43+
///
44+
/// If `Ok(None)` is returned, the default implementation will be used.
45+
/// If an error is returned, there was an error creating the formatter.
46+
fn create_array_formatter<'fmt>(
47+
&self,
48+
_array: &'fmt dyn Array,
49+
_options: &FormatOptions<'fmt>,
50+
) -> Result<Option<ArrayFormatter<'fmt>>> {
51+
Ok(None)
52+
}
53+
}

datafusion/common/src/types/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
// under the License.
1717

1818
mod builtin;
19+
mod canonical_extensions;
20+
mod extension;
1921
mod field;
2022
mod logical;
2123
mod native;
2224

2325
pub use builtin::*;
26+
pub use canonical_extensions::*;
27+
pub use extension::*;
2428
pub use field::*;
2529
pub use logical::*;
2630
pub use native::*;

datafusion/core/src/dataframe/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ use datafusion_functions_aggregate::expr_fn::{
6969
avg, count, max, median, min, stddev, sum,
7070
};
7171

72+
use crate::extension_types::DFArrayFormatterFactory;
7273
use async_trait::async_trait;
7374
use datafusion_catalog::Session;
7475

@@ -1516,6 +1517,11 @@ impl DataFrame {
15161517
let options = self.session_state.config().options().format.clone();
15171518
let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?;
15181519

1520+
let registry = self.session_state.extension_type_registry();
1521+
let formatter_factory = DFArrayFormatterFactory::new(Arc::clone(registry));
1522+
let arrow_options =
1523+
arrow_options.with_formatter_factory(Some(&formatter_factory));
1524+
15191525
let results = self.collect().await?;
15201526
Ok(
15211527
pretty::pretty_format_batches_with_options(&results, &arrow_options)?

datafusion/core/src/execution/session_state.rs

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::datasource::provider_as_source;
3030
use crate::execution::SessionStateDefaults;
3131
use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner};
3232
use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
33+
use arrow_schema::extension::ExtensionType;
3334
use arrow_schema::{DataType, FieldRef};
3435
use datafusion_catalog::MemoryCatalogProviderList;
3536
use datafusion_catalog::information_schema::{
@@ -56,7 +57,7 @@ use datafusion_expr::expr_rewriter::FunctionRewrite;
5657
use datafusion_expr::planner::ExprPlanner;
5758
#[cfg(feature = "sql")]
5859
use datafusion_expr::planner::{RelationPlanner, TypePlanner};
59-
use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry};
60+
use datafusion_expr::registry::{ExtensionTypeRegistration, ExtensionTypeRegistrationRef, ExtensionTypeRegistry, ExtensionTypeRegistryRef, FunctionRegistry, MemoryExtensionTypeRegistry, SerializerRegistry, SimpleExtensionTypeRegistration};
6061
use datafusion_expr::simplify::SimplifyContext;
6162
use datafusion_expr::{AggregateUDF, Explain, Expr, LogicalPlan, ScalarUDF, WindowUDF};
6263
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
@@ -77,6 +78,7 @@ use datafusion_sql::{
7778

7879
use async_trait::async_trait;
7980
use chrono::{DateTime, Utc};
81+
use datafusion_common::types::UuidDFExtensionType;
8082
use itertools::Itertools;
8183
use log::{debug, info};
8284
use object_store::ObjectStore;
@@ -158,6 +160,8 @@ pub struct SessionState {
158160
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
159161
/// Window functions registered in the context
160162
window_functions: HashMap<String, Arc<WindowUDF>>,
163+
/// Extension types registry for extensions.
164+
extension_types: ExtensionTypeRegistryRef,
161165
/// Deserializer registry for extensions.
162166
serializer_registry: Arc<dyn SerializerRegistry>,
163167
/// Holds registered external FileFormat implementations
@@ -266,6 +270,10 @@ impl Session for SessionState {
266270
&self.window_functions
267271
}
268272

273+
fn extension_type_registry(&self) -> &ExtensionTypeRegistryRef {
274+
&self.extension_types
275+
}
276+
269277
fn runtime_env(&self) -> &Arc<RuntimeEnv> {
270278
self.runtime_env()
271279
}
@@ -986,6 +994,7 @@ pub struct SessionStateBuilder {
986994
scalar_functions: Option<Vec<Arc<ScalarUDF>>>,
987995
aggregate_functions: Option<Vec<Arc<AggregateUDF>>>,
988996
window_functions: Option<Vec<Arc<WindowUDF>>>,
997+
extension_types: Option<ExtensionTypeRegistryRef>,
989998
serializer_registry: Option<Arc<dyn SerializerRegistry>>,
990999
file_formats: Option<Vec<Arc<dyn FileFormatFactory>>>,
9911000
config: Option<SessionConfig>,
@@ -1026,6 +1035,7 @@ impl SessionStateBuilder {
10261035
scalar_functions: None,
10271036
aggregate_functions: None,
10281037
window_functions: None,
1038+
extension_types: None,
10291039
serializer_registry: None,
10301040
file_formats: None,
10311041
table_options: None,
@@ -1081,6 +1091,7 @@ impl SessionStateBuilder {
10811091
existing.aggregate_functions.into_values().collect_vec(),
10821092
),
10831093
window_functions: Some(existing.window_functions.into_values().collect_vec()),
1094+
extension_types: Some(existing.extension_types),
10841095
serializer_registry: Some(existing.serializer_registry),
10851096
file_formats: Some(existing.file_formats.into_values().collect_vec()),
10861097
config: Some(new_config),
@@ -1126,6 +1137,11 @@ impl SessionStateBuilder {
11261137
.get_or_insert_with(Vec::new)
11271138
.extend(SessionStateDefaults::default_window_functions());
11281139

1140+
self.extension_types
1141+
.get_or_insert_with(|| Arc::new(MemoryExtensionTypeRegistry::new()))
1142+
.extend(&SessionStateDefaults::default_extension_types())
1143+
.expect("MemoryExtensionTypeRegistry is not read-only.");
1144+
11291145
self.table_functions
11301146
.get_or_insert_with(HashMap::new)
11311147
.extend(
@@ -1316,6 +1332,44 @@ impl SessionStateBuilder {
13161332
self
13171333
}
13181334

1335+
/// Set the map of [`ExtensionTypeRegistration`]s
1336+
pub fn with_extension_type(
1337+
mut self,
1338+
registry: ExtensionTypeRegistryRef,
1339+
) -> Self {
1340+
self.extension_types = Some(registry);
1341+
self
1342+
}
1343+
1344+
/// Registers [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html)
1345+
/// in DataFusion's extension type registry. For more information see [`ExtensionTypeRegistry`].
1346+
///
1347+
/// # Errors
1348+
///
1349+
/// May fail if an already registered [`ExtensionTypeRegistry`] raises an error while
1350+
/// registering the canonical extension types.
1351+
pub fn with_canonical_extension_types(mut self) -> datafusion_common::Result<Self> {
1352+
let canonical_extension_types = vec![SimpleExtensionTypeRegistration::new_arc(
1353+
arrow_schema::extension::Uuid::NAME,
1354+
Arc::new(UuidDFExtensionType::new()),
1355+
)];
1356+
1357+
match &self.extension_types {
1358+
None => {
1359+
let registry = Arc::new(MemoryExtensionTypeRegistry::new());
1360+
registry
1361+
.extend(&canonical_extension_types)
1362+
.expect("Adding valid extension types to MemoryExtensionTypeRegistry always succeeds.");
1363+
self.extension_types = Some(registry);
1364+
}
1365+
Some(registry) => {
1366+
registry.extend(&canonical_extension_types)?;
1367+
}
1368+
}
1369+
1370+
Ok(self)
1371+
}
1372+
13191373
/// Set the [`SerializerRegistry`]
13201374
pub fn with_serializer_registry(
13211375
mut self,
@@ -1454,6 +1508,7 @@ impl SessionStateBuilder {
14541508
scalar_functions,
14551509
aggregate_functions,
14561510
window_functions,
1511+
extension_types,
14571512
serializer_registry,
14581513
file_formats,
14591514
table_options,
@@ -1490,6 +1545,7 @@ impl SessionStateBuilder {
14901545
scalar_functions: HashMap::new(),
14911546
aggregate_functions: HashMap::new(),
14921547
window_functions: HashMap::new(),
1548+
extension_types: Arc::new(MemoryExtensionTypeRegistry::default()),
14931549
serializer_registry: serializer_registry
14941550
.unwrap_or_else(|| Arc::new(EmptySerializerRegistry)),
14951551
file_formats: HashMap::new(),
@@ -1559,6 +1615,10 @@ impl SessionStateBuilder {
15591615
});
15601616
}
15611617

1618+
if let Some(extension_types) = extension_types {
1619+
state.extension_types = extension_types;
1620+
}
1621+
15621622
if state.config.create_default_catalog_and_schema() {
15631623
let default_catalog = SessionStateDefaults::default_catalog(
15641624
&state.config,
@@ -2071,6 +2131,35 @@ impl datafusion_execution::TaskContextProvider for SessionState {
20712131
}
20722132
}
20732133

2134+
impl ExtensionTypeRegistry for SessionState {
2135+
fn extension_type_registration(
2136+
&self,
2137+
name: &str,
2138+
) -> datafusion_common::Result<ExtensionTypeRegistrationRef> {
2139+
self.extension_types.extension_type_registration(name)
2140+
}
2141+
2142+
fn extension_type_registrations(&self) -> Vec<Arc<dyn ExtensionTypeRegistration>> {
2143+
self.extension_types.extension_type_registrations()
2144+
}
2145+
2146+
fn add_extension_type_registration(
2147+
&self,
2148+
extension_type: ExtensionTypeRegistrationRef,
2149+
) -> datafusion_common::Result<Option<ExtensionTypeRegistrationRef>> {
2150+
self.extension_types
2151+
.add_extension_type_registration(extension_type)
2152+
}
2153+
2154+
fn remove_extension_type_registration(
2155+
&self,
2156+
name: &str,
2157+
) -> datafusion_common::Result<Option<ExtensionTypeRegistrationRef>> {
2158+
self.extension_types
2159+
.remove_extension_type_registration(name)
2160+
}
2161+
}
2162+
20742163
impl OptimizerConfig for SessionState {
20752164
fn query_execution_start_time(&self) -> Option<DateTime<Utc>> {
20762165
self.execution_props.query_execution_start_time

0 commit comments

Comments
 (0)