Skip to content

Commit fb53269

Browse files
committed
Add first draft of using type registry for uuids
1 parent 460f484 commit fb53269

7 files changed

Lines changed: 108 additions & 12 deletions

File tree

datafusion/catalog-listing/src/helpers.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ mod tests {
541541
use std::ops::Not;
542542

543543
use super::*;
544+
use datafusion_expr::registry::MemoryExtensionTypeRegistry;
544545
use datafusion_expr::{
545546
case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF,
546547
};
@@ -1059,6 +1060,10 @@ mod tests {
10591060
unimplemented!()
10601061
}
10611062

1063+
fn extension_types(&self) -> &MemoryExtensionTypeRegistry {
1064+
unimplemented!()
1065+
}
1066+
10621067
fn runtime_env(&self) -> &Arc<RuntimeEnv> {
10631068
unimplemented!()
10641069
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::types::{LogicalType, NativeType, TypeSignature};
19+
20+
/// Represents the canonical [UUID extension type](https://arrow.apache.org/docs/format/CanonicalExtensions.html#uuid).
21+
pub struct UuidType;
22+
23+
impl UuidType {
24+
/// Creates a new [UuidType].
25+
pub fn new() -> Self {
26+
Self {}
27+
}
28+
}
29+
30+
impl LogicalType for UuidType {
31+
fn native(&self) -> &NativeType {
32+
&NativeType::FixedSizeBinary(16)
33+
}
34+
35+
fn signature(&self) -> TypeSignature<'_> {
36+
TypeSignature::Extension {
37+
name: "arrow.uuid",
38+
parameters: &[],
39+
}
40+
}
41+
}

datafusion/common/src/types/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
// under the License.
1717

1818
mod builtin;
19+
mod canonical;
1920
mod field;
2021
mod logical;
2122
mod native;
2223

2324
pub use builtin::*;
25+
pub use canonical::*;
2426
pub use field::*;
2527
pub use logical::*;
2628
pub use native::*;

datafusion/core/src/execution/session_state.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ pub struct SessionStateBuilder {
932932
scalar_functions: Option<Vec<Arc<ScalarUDF>>>,
933933
aggregate_functions: Option<Vec<Arc<AggregateUDF>>>,
934934
window_functions: Option<Vec<Arc<WindowUDF>>>,
935+
extension_types: Option<Vec<LogicalTypeRef>>,
935936
serializer_registry: Option<Arc<dyn SerializerRegistry>>,
936937
file_formats: Option<Vec<Arc<dyn FileFormatFactory>>>,
937938
config: Option<SessionConfig>,
@@ -969,6 +970,7 @@ impl SessionStateBuilder {
969970
scalar_functions: None,
970971
aggregate_functions: None,
971972
window_functions: None,
973+
extension_types: None,
972974
serializer_registry: None,
973975
file_formats: None,
974976
table_options: None,
@@ -1021,6 +1023,7 @@ impl SessionStateBuilder {
10211023
existing.aggregate_functions.into_values().collect_vec(),
10221024
),
10231025
window_functions: Some(existing.window_functions.into_values().collect_vec()),
1026+
extension_types: Some(existing.extension_types.all_types()),
10241027
serializer_registry: Some(existing.serializer_registry),
10251028
file_formats: Some(existing.file_formats.into_values().collect_vec()),
10261029
config: Some(new_config),
@@ -1407,6 +1410,7 @@ impl SessionStateBuilder {
14071410
scalar_functions: HashMap::new(),
14081411
aggregate_functions: HashMap::new(),
14091412
window_functions: HashMap::new(),
1413+
extension_types: MemoryExtensionTypeRegistry::new(),
14101414
serializer_registry: serializer_registry
14111415
.unwrap_or_else(|| Arc::new(EmptySerializerRegistry)),
14121416
file_formats: HashMap::new(),
@@ -1456,6 +1460,15 @@ impl SessionStateBuilder {
14561460
});
14571461
}
14581462

1463+
if let Some(extension_types) = extension_types {
1464+
extension_types.into_iter().for_each(|ext_type| {
1465+
let existing_type = state.register_extension_type(ext_type);
1466+
if let Ok(Some(existing_type)) = existing_type {
1467+
debug!("Overwrote an existing UDF: {}", existing_type);
1468+
}
1469+
});
1470+
}
1471+
14591472
if state.config.create_default_catalog_and_schema() {
14601473
let default_catalog = SessionStateDefaults::default_catalog(
14611474
&state.config,

datafusion/datasource/src/url.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ mod tests {
411411
use datafusion_execution::runtime_env::RuntimeEnv;
412412
use datafusion_execution::TaskContext;
413413
use datafusion_expr::execution_props::ExecutionProps;
414+
use datafusion_expr::registry::MemoryExtensionTypeRegistry;
414415
use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF};
415416
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
416417
use datafusion_physical_plan::ExecutionPlan;
@@ -767,6 +768,10 @@ mod tests {
767768
unimplemented!()
768769
}
769770

771+
fn extension_types(&self) -> &MemoryExtensionTypeRegistry {
772+
unimplemented!()
773+
}
774+
770775
fn runtime_env(&self) -> &Arc<RuntimeEnv> {
771776
&self.runtime_env
772777
}

datafusion/expr/src/registry.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,16 @@ pub trait ExtensionTypeRegistry {
245245
#[derive(Clone, Debug)]
246246
pub struct MemoryExtensionTypeRegistry {
247247
/// Holds a mapping between the name of an extension type and its logical type.
248+
///
249+
/// TODO: Use TypeSignature to support arguments
248250
extension_types: HashMap<String, LogicalTypeRef>,
249251
}
250252

251253
impl Default for MemoryExtensionTypeRegistry {
252254
fn default() -> Self {
253-
let mut registry = MemoryExtensionTypeRegistry::new();
254-
// TODO add canonical types
255-
registry
255+
MemoryExtensionTypeRegistry {
256+
extension_types: HashMap::new(),
257+
}
256258
}
257259
}
258260

@@ -263,6 +265,29 @@ impl MemoryExtensionTypeRegistry {
263265
extension_types: HashMap::new(),
264266
}
265267
}
268+
269+
/// Creates a new [MemoryExtensionTypeRegistry] with the provided `types`.
270+
///
271+
/// # Errors
272+
///
273+
/// Returns an error if one of the `types` is a native type.
274+
pub fn new_with_types(
275+
types: impl IntoIterator<Item = LogicalTypeRef>,
276+
) -> Result<Self> {
277+
let extension_types = types
278+
.into_iter()
279+
.map(|t| match t.signature() {
280+
TypeSignature::Native(_) => todo!("TODO"),
281+
TypeSignature::Extension { name, .. } => (name.to_owned(), t),
282+
})
283+
.collect();
284+
Ok(Self { extension_types })
285+
}
286+
287+
/// Returns a list of all registered types.
288+
pub fn all_types(&self) -> Vec<LogicalTypeRef> {
289+
self.extension_types.values().cloned().collect()
290+
}
266291
}
267292

268293
impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry {

datafusion/functions/src/string/uuid.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::GenericStringBuilder;
22-
use arrow::datatypes::DataType;
21+
use arrow::array::FixedSizeBinaryBuilder;
2322
use arrow::datatypes::DataType::Utf8;
23+
use arrow::datatypes::{DataType, Field, FieldRef};
2424
use rand::Rng;
2525
use uuid::Uuid;
2626

2727
use datafusion_common::{internal_err, Result};
28-
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
28+
use datafusion_common::types::UuidType;
29+
use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, Volatility};
2930
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
3031
use datafusion_macros::user_doc;
3132

@@ -75,7 +76,12 @@ impl ScalarUDFImpl for UuidFunc {
7576
}
7677

7778
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78-
Ok(Utf8)
79+
unreachable!("return_field_from_args is overwritten")
80+
}
81+
82+
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
83+
// TODO: pass-in the registry
84+
Ok(Arc::new(Field::new("output", Utf8, false).with_extension_type(UuidType::new())))
7985
}
8086

8187
/// Prints random (v4) uuid values per row
@@ -90,18 +96,17 @@ impl ScalarUDFImpl for UuidFunc {
9096
let mut randoms = vec![0u128; args.number_rows];
9197
rng.fill(&mut randoms[..]);
9298

93-
let mut builder = GenericStringBuilder::<i32>::with_capacity(
94-
args.number_rows,
95-
args.number_rows * 36,
96-
);
99+
let mut builder = FixedSizeBinaryBuilder::with_capacity(args.number_rows, 16);
97100

98101
let mut buffer = [0u8; 36];
99102
for x in &mut randoms {
100103
// From Uuid::new_v4(): Mask out the version and variant bits
101104
*x = *x & 0xFFFFFFFFFFFF4FFFBFFFFFFFFFFFFFFF | 0x40008000000000000000;
102105
let uuid = Uuid::from_u128(*x);
103106
let fmt = uuid::fmt::Hyphenated::from_uuid(uuid);
104-
builder.append_value(fmt.encode_lower(&mut buffer));
107+
builder
108+
.append_value(fmt.encode_lower(&mut buffer))
109+
.expect("Value always has 16 bytes");
105110
}
106111

107112
Ok(ColumnarValue::Array(Arc::new(builder.finish())))

0 commit comments

Comments
 (0)