@@ -21,7 +21,7 @@ use arrow::array::{
2121} ;
2222use arrow:: datatypes:: { Float32Type , Float64Type } ;
2323use arrow:: util:: display:: { ArrayFormatter , DisplayIndex , FormatOptions , FormatResult } ;
24- use arrow_schema:: extension:: ExtensionType ;
24+ use arrow_schema:: extension:: { EXTENSION_TYPE_METADATA_KEY , EXTENSION_TYPE_NAME_KEY } ;
2525use arrow_schema:: { ArrowError , DataType , Field , Schema , SchemaRef } ;
2626use datafusion:: dataframe:: DataFrame ;
2727use datafusion:: error:: Result ;
@@ -32,6 +32,7 @@ use datafusion_common::types::DFExtensionType;
3232use datafusion_expr:: registry:: {
3333 DefaultExtensionTypeRegistration , ExtensionTypeRegistry , MemoryExtensionTypeRegistry ,
3434} ;
35+ use std:: collections:: HashMap ;
3536use std:: fmt:: { Display , Write } ;
3637use std:: sync:: Arc ;
3738
@@ -50,15 +51,15 @@ fn create_session_context() -> Result<SessionContext> {
5051 let registry = MemoryExtensionTypeRegistry :: new_empty ( ) ;
5152
5253 // The registration creates a new instance of the extension type with the deserialized metadata.
53- let temp_registration =
54- DefaultExtensionTypeRegistration :: < TemperatureExtensionType > :: new_arc (
55- |storage_type, metadata| {
56- Ok ( Arc :: new ( TemperatureExtensionType :: try_new (
57- storage_type,
58- metadata,
59- ) ?) )
60- } ,
61- ) ;
54+ let temp_registration = DefaultExtensionTypeRegistration :: new_arc (
55+ TemperatureExtensionType :: NAME ,
56+ |storage_type, metadata| {
57+ Ok ( Arc :: new ( TemperatureExtensionType :: try_new (
58+ storage_type,
59+ TemperatureUnit :: try_from ( metadata) ? ,
60+ ) ?) )
61+ } ,
62+ ) ;
6263 registry. add_extension_type_registration ( temp_registration) ?;
6364
6465 let state = SessionStateBuilder :: default ( )
@@ -98,38 +99,15 @@ async fn register_temperature_table(ctx: &SessionContext) -> Result<DataFrame> {
9899fn example_schema ( ) -> SchemaRef {
99100 Arc :: new ( Schema :: new ( vec ! [
100101 Field :: new( "city" , DataType :: Utf8 , false ) ,
101- Field :: new( "celsius" , DataType :: Float64 , false ) . with_extension_type(
102- TemperatureExtensionType :: try_new(
103- & DataType :: Float64 ,
104- TemperatureUnit :: Celsius ,
105- )
106- . expect( "Valid Type" ) ,
107- ) ,
108- Field :: new( "fahrenheit" , DataType :: Float64 , false ) . with_extension_type(
109- TemperatureExtensionType :: try_new(
110- & DataType :: Float64 ,
111- TemperatureUnit :: Fahrenheit ,
112- )
113- . expect( "Valid Type" ) ,
114- ) ,
115- Field :: new( "kelvin" , DataType :: Float32 , false ) . with_extension_type(
116- TemperatureExtensionType :: try_new(
117- & DataType :: Float32 ,
118- TemperatureUnit :: Kelvin ,
119- )
120- . expect( "Valid Type" ) ,
121- ) ,
102+ Field :: new( "celsius" , DataType :: Float64 , false )
103+ . with_metadata( create_metadata( TemperatureUnit :: Celsius ) ) ,
104+ Field :: new( "fahrenheit" , DataType :: Float64 , false )
105+ . with_metadata( create_metadata( TemperatureUnit :: Fahrenheit ) ) ,
106+ Field :: new( "kelvin" , DataType :: Float32 , false )
107+ . with_metadata( create_metadata( TemperatureUnit :: Kelvin ) ) ,
122108 ] ) )
123109}
124110
125- /// Represents the unit of a temperature reading.
126- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
127- pub enum TemperatureUnit {
128- Celsius ,
129- Fahrenheit ,
130- Kelvin ,
131- }
132-
133111/// Represents a float that semantically represents a temperature. The temperature can be one of
134112/// the supported [`TemperatureUnit`]s.
135113///
@@ -157,51 +135,61 @@ pub struct TemperatureExtensionType {
157135}
158136
159137impl TemperatureExtensionType {
138+ /// The name of the extension type.
139+ pub const NAME : & ' static str = "custom.temperature" ;
140+
160141 /// Creates a new [`TemperatureExtensionType`].
161142 pub fn try_new (
162143 storage_type : & DataType ,
163144 temperature_unit : TemperatureUnit ,
164- ) -> Result < Self > {
145+ ) -> Result < Self , ArrowError > {
146+ match storage_type {
147+ DataType :: Float32 | DataType :: Float64 => { }
148+ _ => {
149+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
150+ "Invalid data type: {storage_type} for temperature type, expected Float32 or Float64" ,
151+ ) ) ) ;
152+ }
153+ }
154+
165155 let result = Self {
166156 storage_type : storage_type. clone ( ) ,
167157 temperature_unit,
168158 } ;
169- result. supports_data_type ( storage_type) ?; // Validate the storage type
170159 Ok ( result)
171160 }
172161}
173162
174- /// Implementation of [`ExtensionType`] for [`TemperatureExtensionType`].
175- ///
176- /// This implements the arrow-rs trait for reading, writing, and validating extension types.
177- impl ExtensionType for TemperatureExtensionType {
178- /// Arrow extension type name that is stored in the `ARROW:extension:name` field.
179- const NAME : & ' static str = "custom.temperature" ;
180- type Metadata = TemperatureUnit ;
181-
182- fn metadata ( & self ) -> & Self :: Metadata {
183- & self . temperature_unit
184- }
163+ /// Represents the unit of a temperature reading.
164+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
165+ pub enum TemperatureUnit {
166+ Celsius ,
167+ Fahrenheit ,
168+ Kelvin ,
169+ }
185170
171+ impl TemperatureUnit {
186172 /// Arrow extension type metadata is encoded as a string and stored using the
187173 /// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string
188174 /// suffices. Extension types can store more complex metadata using serialization formats like
189175 /// JSON.
190- fn serialize_metadata ( & self ) -> Option < String > {
191- let s = match self . temperature_unit {
176+ pub fn serialize ( self ) -> String {
177+ let result = match self {
192178 TemperatureUnit :: Celsius => "celsius" ,
193179 TemperatureUnit :: Fahrenheit => "fahrenheit" ,
194180 TemperatureUnit :: Kelvin => "kelvin" ,
195181 } ;
196- Some ( s . to_string ( ) )
182+ result . to_owned ( )
197183 }
184+ }
198185
199- /// Inverse operation of [`Self::serialize_metadata`]. This creates the [`TemperatureUnit`]
200- /// value from the serialized string.
201- fn deserialize_metadata (
202- metadata : Option < & str > ,
203- ) -> std:: result:: Result < Self :: Metadata , ArrowError > {
204- match metadata {
186+ /// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`]
187+ /// value from the serialized string.
188+ impl TryFrom < Option < & str > > for TemperatureUnit {
189+ type Error = ArrowError ;
190+
191+ fn try_from ( value : Option < & str > ) -> std:: result:: Result < Self , Self :: Error > {
192+ match value {
205193 Some ( "celsius" ) => Ok ( TemperatureUnit :: Celsius ) ,
206194 Some ( "fahrenheit" ) => Ok ( TemperatureUnit :: Fahrenheit ) ,
207195 Some ( "kelvin" ) => Ok ( TemperatureUnit :: Kelvin ) ,
@@ -213,28 +201,18 @@ impl ExtensionType for TemperatureExtensionType {
213201 ) ) ,
214202 }
215203 }
204+ }
216205
217- /// Checks that the extension type supports a given [`DataType`].
218- fn supports_data_type (
219- & self ,
220- data_type : & DataType ,
221- ) -> std:: result:: Result < ( ) , ArrowError > {
222- match data_type {
223- DataType :: Float32 | DataType :: Float64 => Ok ( ( ) ) ,
224- _ => Err ( ArrowError :: InvalidArgumentError ( format ! (
225- "Invalid data type: {data_type} for temperature type, expected Float32 or Float64" ,
226- ) ) ) ,
227- }
228- }
229-
230- fn try_new (
231- data_type : & DataType ,
232- metadata : Self :: Metadata ,
233- ) -> std:: result:: Result < Self , ArrowError > {
234- let instance = Self :: try_new ( data_type, metadata) ?;
235- instance. supports_data_type ( data_type) ?;
236- Ok ( instance)
237- }
206+ /// This creates a metadata map for the temperature type. Another way of writing the metadata can be
207+ /// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait.
208+ fn create_metadata ( unit : TemperatureUnit ) -> HashMap < String , String > {
209+ HashMap :: from ( [
210+ (
211+ EXTENSION_TYPE_NAME_KEY . to_owned ( ) ,
212+ TemperatureExtensionType :: NAME . to_owned ( ) ,
213+ ) ,
214+ ( EXTENSION_TYPE_METADATA_KEY . to_owned ( ) , unit. serialize ( ) ) ,
215+ ] )
238216}
239217
240218/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`].
@@ -246,7 +224,7 @@ impl DFExtensionType for TemperatureExtensionType {
246224 }
247225
248226 fn serialize_metadata ( & self ) -> Option < String > {
249- ExtensionType :: serialize_metadata ( self )
227+ Some ( self . temperature_unit . serialize ( ) )
250228 }
251229
252230 fn create_array_formatter < ' fmt > (
0 commit comments