Skip to content

Commit 28578af

Browse files
committed
Remove requirement to arrow-rs ExtensionType
1 parent 8e63149 commit 28578af

2 files changed

Lines changed: 109 additions & 120 deletions

File tree

datafusion-examples/examples/extension_types/temperature.rs

Lines changed: 61 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow::array::{
2121
};
2222
use arrow::datatypes::{Float32Type, Float64Type};
2323
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
24-
use arrow_schema::extension::ExtensionType;
24+
use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY};
2525
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
2626
use datafusion::dataframe::DataFrame;
2727
use datafusion::error::Result;
@@ -32,6 +32,7 @@ use datafusion_common::types::DFExtensionType;
3232
use datafusion_expr::registry::{
3333
DefaultExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
3434
};
35+
use std::collections::HashMap;
3536
use std::fmt::{Display, Write};
3637
use std::sync::Arc;
3738

@@ -50,15 +51,15 @@ fn create_session_context() -> Result<SessionContext> {
5051
let registry = MemoryExtensionTypeRegistry::new_empty();
5152

5253
// The registration creates a new instance of the extension type with the deserialized metadata.
53-
let temp_registration =
54-
DefaultExtensionTypeRegistration::<TemperatureExtensionType>::new_arc(
55-
|storage_type, metadata| {
56-
Ok(Arc::new(TemperatureExtensionType::try_new(
57-
storage_type,
58-
metadata,
59-
)?))
60-
},
61-
);
54+
let temp_registration = DefaultExtensionTypeRegistration::new_arc(
55+
TemperatureExtensionType::NAME,
56+
|storage_type, metadata| {
57+
Ok(Arc::new(TemperatureExtensionType::try_new(
58+
storage_type,
59+
TemperatureUnit::try_from(metadata)?,
60+
)?))
61+
},
62+
);
6263
registry.add_extension_type_registration(temp_registration)?;
6364

6465
let state = SessionStateBuilder::default()
@@ -98,38 +99,15 @@ async fn register_temperature_table(ctx: &SessionContext) -> Result<DataFrame> {
9899
fn example_schema() -> SchemaRef {
99100
Arc::new(Schema::new(vec![
100101
Field::new("city", DataType::Utf8, false),
101-
Field::new("celsius", DataType::Float64, false).with_extension_type(
102-
TemperatureExtensionType::try_new(
103-
&DataType::Float64,
104-
TemperatureUnit::Celsius,
105-
)
106-
.expect("Valid Type"),
107-
),
108-
Field::new("fahrenheit", DataType::Float64, false).with_extension_type(
109-
TemperatureExtensionType::try_new(
110-
&DataType::Float64,
111-
TemperatureUnit::Fahrenheit,
112-
)
113-
.expect("Valid Type"),
114-
),
115-
Field::new("kelvin", DataType::Float32, false).with_extension_type(
116-
TemperatureExtensionType::try_new(
117-
&DataType::Float32,
118-
TemperatureUnit::Kelvin,
119-
)
120-
.expect("Valid Type"),
121-
),
102+
Field::new("celsius", DataType::Float64, false)
103+
.with_metadata(create_metadata(TemperatureUnit::Celsius)),
104+
Field::new("fahrenheit", DataType::Float64, false)
105+
.with_metadata(create_metadata(TemperatureUnit::Fahrenheit)),
106+
Field::new("kelvin", DataType::Float32, false)
107+
.with_metadata(create_metadata(TemperatureUnit::Kelvin)),
122108
]))
123109
}
124110

125-
/// Represents the unit of a temperature reading.
126-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127-
pub enum TemperatureUnit {
128-
Celsius,
129-
Fahrenheit,
130-
Kelvin,
131-
}
132-
133111
/// Represents a float that semantically represents a temperature. The temperature can be one of
134112
/// the supported [`TemperatureUnit`]s.
135113
///
@@ -157,51 +135,61 @@ pub struct TemperatureExtensionType {
157135
}
158136

159137
impl TemperatureExtensionType {
138+
/// The name of the extension type.
139+
pub const NAME: &'static str = "custom.temperature";
140+
160141
/// Creates a new [`TemperatureExtensionType`].
161142
pub fn try_new(
162143
storage_type: &DataType,
163144
temperature_unit: TemperatureUnit,
164-
) -> Result<Self> {
145+
) -> Result<Self, ArrowError> {
146+
match storage_type {
147+
DataType::Float32 | DataType::Float64 => {}
148+
_ => {
149+
return Err(ArrowError::InvalidArgumentError(format!(
150+
"Invalid data type: {storage_type} for temperature type, expected Float32 or Float64",
151+
)));
152+
}
153+
}
154+
165155
let result = Self {
166156
storage_type: storage_type.clone(),
167157
temperature_unit,
168158
};
169-
result.supports_data_type(storage_type)?; // Validate the storage type
170159
Ok(result)
171160
}
172161
}
173162

174-
/// Implementation of [`ExtensionType`] for [`TemperatureExtensionType`].
175-
///
176-
/// This implements the arrow-rs trait for reading, writing, and validating extension types.
177-
impl ExtensionType for TemperatureExtensionType {
178-
/// Arrow extension type name that is stored in the `ARROW:extension:name` field.
179-
const NAME: &'static str = "custom.temperature";
180-
type Metadata = TemperatureUnit;
181-
182-
fn metadata(&self) -> &Self::Metadata {
183-
&self.temperature_unit
184-
}
163+
/// Represents the unit of a temperature reading.
164+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165+
pub enum TemperatureUnit {
166+
Celsius,
167+
Fahrenheit,
168+
Kelvin,
169+
}
185170

171+
impl TemperatureUnit {
186172
/// Arrow extension type metadata is encoded as a string and stored using the
187173
/// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string
188174
/// suffices. Extension types can store more complex metadata using serialization formats like
189175
/// JSON.
190-
fn serialize_metadata(&self) -> Option<String> {
191-
let s = match self.temperature_unit {
176+
pub fn serialize(self) -> String {
177+
let result = match self {
192178
TemperatureUnit::Celsius => "celsius",
193179
TemperatureUnit::Fahrenheit => "fahrenheit",
194180
TemperatureUnit::Kelvin => "kelvin",
195181
};
196-
Some(s.to_string())
182+
result.to_owned()
197183
}
184+
}
198185

199-
/// Inverse operation of [`Self::serialize_metadata`]. This creates the [`TemperatureUnit`]
200-
/// value from the serialized string.
201-
fn deserialize_metadata(
202-
metadata: Option<&str>,
203-
) -> std::result::Result<Self::Metadata, ArrowError> {
204-
match metadata {
186+
/// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`]
187+
/// value from the serialized string.
188+
impl TryFrom<Option<&str>> for TemperatureUnit {
189+
type Error = ArrowError;
190+
191+
fn try_from(value: Option<&str>) -> std::result::Result<Self, Self::Error> {
192+
match value {
205193
Some("celsius") => Ok(TemperatureUnit::Celsius),
206194
Some("fahrenheit") => Ok(TemperatureUnit::Fahrenheit),
207195
Some("kelvin") => Ok(TemperatureUnit::Kelvin),
@@ -213,28 +201,18 @@ impl ExtensionType for TemperatureExtensionType {
213201
)),
214202
}
215203
}
204+
}
216205

217-
/// Checks that the extension type supports a given [`DataType`].
218-
fn supports_data_type(
219-
&self,
220-
data_type: &DataType,
221-
) -> std::result::Result<(), ArrowError> {
222-
match data_type {
223-
DataType::Float32 | DataType::Float64 => Ok(()),
224-
_ => Err(ArrowError::InvalidArgumentError(format!(
225-
"Invalid data type: {data_type} for temperature type, expected Float32 or Float64",
226-
))),
227-
}
228-
}
229-
230-
fn try_new(
231-
data_type: &DataType,
232-
metadata: Self::Metadata,
233-
) -> std::result::Result<Self, ArrowError> {
234-
let instance = Self::try_new(data_type, metadata)?;
235-
instance.supports_data_type(data_type)?;
236-
Ok(instance)
237-
}
206+
/// This creates a metadata map for the temperature type. Another way of writing the metadata can be
207+
/// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait.
208+
fn create_metadata(unit: TemperatureUnit) -> HashMap<String, String> {
209+
HashMap::from([
210+
(
211+
EXTENSION_TYPE_NAME_KEY.to_owned(),
212+
TemperatureExtensionType::NAME.to_owned(),
213+
),
214+
(EXTENSION_TYPE_METADATA_KEY.to_owned(), unit.serialize()),
215+
])
238216
}
239217

240218
/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`].
@@ -246,7 +224,7 @@ impl DFExtensionType for TemperatureExtensionType {
246224
}
247225

248226
fn serialize_metadata(&self) -> Option<String> {
249-
ExtensionType::serialize_metadata(self)
227+
Some(self.temperature_unit.serialize())
250228
}
251229

252230
fn create_array_formatter<'fmt>(

datafusion/expr/src/registry.rs

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -336,62 +336,54 @@ pub trait ExtensionTypeRegistry: Debug + Send + Sync {
336336

337337
/// A factory that creates instances of extension types from a storage [`DataType`] and the
338338
/// metadata.
339-
pub type ExtensionTypeFactory<TExtensionType> = dyn Fn(
340-
&DataType,
341-
<TExtensionType as ExtensionType>::Metadata,
342-
) -> Result<DFExtensionTypeRef>
343-
+ Send
344-
+ Sync;
339+
pub type ExtensionTypeFactory =
340+
dyn Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef> + Send + Sync;
345341

346342
/// A default implementation of [ExtensionTypeRegistration] that parses the metadata from the
347-
/// given extension type and passes it to a constructor function. The [`ExtensionType::NAME`] is
348-
/// used for registering the extension type.
349-
pub struct DefaultExtensionTypeRegistration<TExtensionType: ExtensionType + 'static> {
343+
/// given extension type and passes it to a constructor function.
344+
pub struct DefaultExtensionTypeRegistration {
345+
/// The name of the extension type.
346+
name: String,
350347
/// A function that creates an instance of [`DFExtensionTypeRef`] from the storage type and the
351348
/// metadata.
352-
factory: Box<ExtensionTypeFactory<TExtensionType>>,
349+
factory: Box<ExtensionTypeFactory>,
353350
}
354351

355-
impl<TExtensionType: ExtensionType + 'static>
356-
DefaultExtensionTypeRegistration<TExtensionType>
357-
{
352+
impl DefaultExtensionTypeRegistration {
358353
/// Creates a new registration for an extension type. The factory is required to validate that
359354
/// the storage [`DataType`] is compatible with the extension type.
360355
pub fn new_arc(
361-
factory: impl Fn(&DataType, TExtensionType::Metadata) -> Result<DFExtensionTypeRef>
356+
name: impl Into<String>,
357+
factory: impl Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef>
362358
+ Send
363359
+ Sync
364360
+ 'static,
365361
) -> ExtensionTypeRegistrationRef {
366362
Arc::new(Self {
363+
name: name.into(),
367364
factory: Box::new(factory),
368365
})
369366
}
370367
}
371368

372-
impl<TExtensionType: ExtensionType> ExtensionTypeRegistration
373-
for DefaultExtensionTypeRegistration<TExtensionType>
374-
{
369+
impl ExtensionTypeRegistration for DefaultExtensionTypeRegistration {
375370
fn type_name(&self) -> &str {
376-
TExtensionType::NAME
371+
&self.name
377372
}
378373

379374
fn create_df_extension_type(
380375
&self,
381376
storage_type: &DataType,
382377
metadata: Option<&str>,
383378
) -> Result<DFExtensionTypeRef> {
384-
let metadata = TExtensionType::deserialize_metadata(metadata)?;
385379
self.factory.as_ref()(storage_type, metadata)
386380
}
387381
}
388382

389-
impl<TExtensionType: ExtensionType> Debug
390-
for DefaultExtensionTypeRegistration<TExtensionType>
391-
{
383+
impl Debug for DefaultExtensionTypeRegistration {
392384
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393385
f.debug_struct("DefaultExtensionTypeRegistration")
394-
.field("type_name", &TExtensionType::NAME)
386+
.field("type_name", &self.name)
395387
.finish()
396388
}
397389
}
@@ -421,47 +413,66 @@ impl MemoryExtensionTypeRegistry {
421413
/// in the extension type registry.
422414
pub fn new_with_canonical_extension_types() -> Self {
423415
let mapping = [
424-
DefaultExtensionTypeRegistration::<FixedShapeTensor>::new_arc(
416+
DefaultExtensionTypeRegistration::new_arc(
417+
FixedShapeTensor::NAME,
425418
|storage_type, metadata| {
426419
Ok(Arc::new(DFFixedShapeTensor::try_new(
427420
storage_type,
428-
metadata,
421+
FixedShapeTensor::deserialize_metadata(metadata)?,
429422
)?))
430423
},
431424
),
432-
DefaultExtensionTypeRegistration::<VariableShapeTensor>::new_arc(
425+
DefaultExtensionTypeRegistration::new_arc(
426+
VariableShapeTensor::NAME,
433427
|storage_type, metadata| {
434428
Ok(Arc::new(DFVariableShapeTensor::try_new(
435429
storage_type,
436-
metadata,
430+
VariableShapeTensor::deserialize_metadata(metadata)?,
437431
)?))
438432
},
439433
),
440-
DefaultExtensionTypeRegistration::<Json>::new_arc(
434+
DefaultExtensionTypeRegistration::new_arc(
435+
Json::NAME,
441436
|storage_type, metadata| {
442-
Ok(Arc::new(DFJson::try_new(storage_type, metadata)?))
437+
Ok(Arc::new(DFJson::try_new(
438+
storage_type,
439+
Json::deserialize_metadata(metadata)?,
440+
)?))
443441
},
444442
),
445-
DefaultExtensionTypeRegistration::<Uuid>::new_arc(
443+
DefaultExtensionTypeRegistration::new_arc(
444+
Uuid::NAME,
446445
|storage_type, metadata| {
447-
Ok(Arc::new(DFUuid::try_new(storage_type, metadata)?))
446+
Ok(Arc::new(DFUuid::try_new(
447+
storage_type,
448+
Uuid::deserialize_metadata(metadata)?,
449+
)?))
448450
},
449451
),
450-
DefaultExtensionTypeRegistration::<Opaque>::new_arc(
452+
DefaultExtensionTypeRegistration::new_arc(
453+
Opaque::NAME,
451454
|storage_type, metadata| {
452-
Ok(Arc::new(DFOpaque::try_new(storage_type, metadata)?))
455+
Ok(Arc::new(DFOpaque::try_new(
456+
storage_type,
457+
Opaque::deserialize_metadata(metadata)?,
458+
)?))
453459
},
454460
),
455-
DefaultExtensionTypeRegistration::<Bool8>::new_arc(
461+
DefaultExtensionTypeRegistration::new_arc(
462+
Bool8::NAME,
456463
|storage_type, metadata| {
457-
Ok(Arc::new(DFBool8::try_new(storage_type, metadata)?))
464+
Ok(Arc::new(DFBool8::try_new(
465+
storage_type,
466+
Bool8::deserialize_metadata(metadata)?,
467+
)?))
458468
},
459469
),
460-
DefaultExtensionTypeRegistration::<TimestampWithOffset>::new_arc(
470+
DefaultExtensionTypeRegistration::new_arc(
471+
TimestampWithOffset::NAME,
461472
|storage_type, metadata| {
462473
Ok(Arc::new(DFTimestampWithOffset::try_new(
463474
storage_type,
464-
metadata,
475+
TimestampWithOffset::deserialize_metadata(metadata)?,
465476
)?))
466477
},
467478
),

0 commit comments

Comments
 (0)