Skip to content

Commit 2a48e73

Browse files
committed
Add an example for custom extension types
1 parent 05b37c0 commit 2a48e73

8 files changed

Lines changed: 442 additions & 17 deletions

File tree

datafusion-examples/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ cargo run --example dataframe -- dataframe
125125
| mem_pool_tracking | [`execution_monitoring/memory_pool_tracking.rs`](examples/execution_monitoring/memory_pool_tracking.rs) | Demonstrates memory tracking |
126126
| tracing | [`execution_monitoring/tracing.rs`](examples/execution_monitoring/tracing.rs) | Demonstrates tracing integration |
127127

128+
## Extension Types Examples
129+
130+
### Group: `extension_types`
131+
132+
#### Category: Single Process
133+
134+
| Subcommand | File Path | Description |
135+
| --- | --- | --- |
136+
| my_id | [`extension_types/event_id.rs`](examples/extension_types/event_id.rs) | A custom wrapper around integers that represent event ids |
137+
128138
## External Dependency Examples
129139

130140
### Group: `external_dependency`
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, RecordBatch, StringArray, UInt32Array};
19+
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
20+
use arrow_schema::extension::ExtensionType;
21+
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
22+
use datafusion::dataframe::DataFrame;
23+
use datafusion::error::Result;
24+
use datafusion::execution::SessionStateBuilder;
25+
use datafusion::prelude::SessionContext;
26+
use datafusion_common::internal_err;
27+
use datafusion_common::types::{DFExtensionType, DFExtensionTypeRef};
28+
use datafusion_expr::registry::{
29+
ExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
30+
};
31+
use std::fmt::Write;
32+
use std::sync::Arc;
33+
34+
/// This example demonstrates using DataFusion's extension type API to create a custom identifier
35+
/// type [`EventIdExtensionType`].
36+
///
37+
/// The following use cases are demonstrated:
38+
/// - Use a custom implementation for pretty-printing data frames.
39+
pub async fn event_id_example() -> Result<()> {
40+
let ctx = create_session_context()?;
41+
register_events_table(&ctx).await?;
42+
43+
// Print the example table with the custom pretty-printer.
44+
ctx.table("example").await?.show().await
45+
}
46+
47+
/// Creates the DataFusion session context with the custom extension type implementation.
48+
fn create_session_context() -> Result<SessionContext> {
49+
// Create a registry with a reference to the custom extension type implementation.
50+
let registry = MemoryExtensionTypeRegistry::new();
51+
let event_id_registration = Arc::new(EventIdExtensionTypeRegistration {});
52+
registry.add_extension_type_registration(event_id_registration)?;
53+
54+
// Set the extension type registry in the session state so that DataFusion can use it.
55+
let state = SessionStateBuilder::default()
56+
.with_extension_type_registry(Arc::new(registry))
57+
.build();
58+
Ok(SessionContext::new_with_state(state))
59+
}
60+
61+
/// Registers the example table and returns the data frame.
62+
async fn register_events_table(ctx: &SessionContext) -> Result<DataFrame> {
63+
let schema = example_schema();
64+
let batch = RecordBatch::try_new(
65+
schema,
66+
vec![
67+
Arc::new(UInt32Array::from(vec![
68+
20_01_000000,
69+
20_01_000001,
70+
21_03_000000,
71+
21_03_000001,
72+
21_03_000002,
73+
])),
74+
Arc::new(UInt32Array::from(vec![
75+
2020_01_0000,
76+
2020_01_0001,
77+
2021_03_0000,
78+
2021_03_0001,
79+
2021_03_0002,
80+
])),
81+
Arc::new(StringArray::from(vec![
82+
"First Event Jan 2020",
83+
"Second Event Jan 2020",
84+
"First Event Mar 2021",
85+
"Second Event Mar 2021",
86+
"Third Event Mar 2021",
87+
])),
88+
],
89+
)?;
90+
91+
// Register the table and return the data frame.
92+
ctx.register_batch("example", batch)?;
93+
ctx.table("example").await
94+
}
95+
96+
/// The schema of the example table.
97+
fn example_schema() -> SchemaRef {
98+
Arc::new(Schema::new(vec![
99+
Field::new("event_id_short", DataType::UInt32, false)
100+
.with_extension_type(EventIdExtensionType(IdYearMode::Short)),
101+
Field::new("event_id_long", DataType::UInt32, false)
102+
.with_extension_type(EventIdExtensionType(IdYearMode::Long)),
103+
Field::new("name", DataType::Utf8, false),
104+
]))
105+
}
106+
107+
/// Represents a 32-bit custom identifier that represents a single event. Using this format is not
108+
/// a good idea in practice, but it is useful for demonstrating the API usage.
109+
///
110+
/// An event is constructed of three parts:
111+
/// - The year
112+
/// - The month
113+
/// - An auto-incrementing counter within the month
114+
///
115+
/// For example, the event id `2024-01-0000` represents the first event in 2024.
116+
///
117+
/// # Year Mode
118+
///
119+
/// In addition, each event id can be represented in two modes. A short year mode `24-01-000000` and
120+
/// a long year mode `2024-01-0000`. This showcases how extension types can be parameterized using
121+
/// metadata.
122+
#[derive(Debug)]
123+
pub struct EventIdExtensionType(IdYearMode);
124+
125+
/// Represents whether the id uses the short or long format.
126+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
127+
pub enum IdYearMode {
128+
/// The short year format (e.g., `24-01-000000`). Allows for more events per month.
129+
Short,
130+
/// The long year format (e.g., `2024-01-0000`). Allows distinguishing between centuries.
131+
Long,
132+
}
133+
134+
/// Implementation of [`ExtensionType`] for [`EventIdExtensionType`].
135+
///
136+
/// This is for the arrow-rs side of the API usage. The [`ExtensionType::Metadata`] type provides
137+
/// static guarantees on the deserialized metadata for the extension type. We will use this
138+
/// implementation to read and write the type metadata to arrow [`Field`]s.
139+
///
140+
/// This trait does allow users to customize the behavior of DataFusion for this extension type.
141+
/// This is done in [`DFExtensionType`].
142+
impl ExtensionType for EventIdExtensionType {
143+
const NAME: &'static str = "custom.event_id";
144+
type Metadata = IdYearMode;
145+
146+
fn metadata(&self) -> &Self::Metadata {
147+
&self.0
148+
}
149+
150+
fn serialize_metadata(&self) -> Option<String> {
151+
// Arrow extension type metadata is encoded as a string. We simply use the lowercase name.
152+
// In a more involved scenario, more complex serialization formats such as JSON are
153+
// appropriate.
154+
Some(format!("{:?}", self.0).to_lowercase())
155+
}
156+
157+
fn deserialize_metadata(
158+
metadata: Option<&str>,
159+
) -> std::result::Result<Self::Metadata, ArrowError> {
160+
match metadata {
161+
None => Err(ArrowError::InvalidArgumentError(
162+
"Event id type requires metadata".to_owned(),
163+
)),
164+
Some(metadata) => match metadata {
165+
"short" => Ok(IdYearMode::Short),
166+
"long" => Ok(IdYearMode::Long),
167+
_ => Err(ArrowError::InvalidArgumentError(format!(
168+
"Invalid metadata for event id type: {}",
169+
metadata
170+
))),
171+
},
172+
}
173+
}
174+
175+
fn supports_data_type(
176+
&self,
177+
data_type: &DataType,
178+
) -> std::result::Result<(), ArrowError> {
179+
match data_type {
180+
DataType::UInt32 => Ok(()),
181+
_ => Err(ArrowError::InvalidArgumentError(format!(
182+
"Invalid data type: {data_type} for event id type",
183+
))),
184+
}
185+
}
186+
187+
fn try_new(
188+
data_type: &DataType,
189+
metadata: Self::Metadata,
190+
) -> std::result::Result<Self, ArrowError> {
191+
let instance = Self(metadata);
192+
instance.supports_data_type(data_type)?; // Check that the data type is supported.
193+
Ok(instance)
194+
}
195+
}
196+
197+
/// Implementation of [`ExtensionType`] for [`EventIdExtensionType`].
198+
///
199+
/// This is for the DataFusion side of the API usage. Here users can override the default behavior
200+
/// of DataFusion for supported extension points.
201+
impl DFExtensionType for EventIdExtensionType {
202+
fn create_array_formatter<'fmt>(
203+
&self,
204+
array: &'fmt dyn Array,
205+
options: &FormatOptions<'fmt>,
206+
) -> Result<Option<ArrayFormatter<'fmt>>> {
207+
if array.data_type() != &DataType::UInt32 {
208+
return internal_err!("Wrong array type for Event Id");
209+
}
210+
211+
// Create the formatter and pass in the year formatting mode of the type
212+
let display_index = EventIdDisplayIndex {
213+
array: array.as_any().downcast_ref().unwrap(),
214+
null_str: options.null(),
215+
mode: self.0,
216+
};
217+
Ok(Some(ArrayFormatter::new(
218+
Box::new(display_index),
219+
options.safe(),
220+
)))
221+
}
222+
}
223+
224+
/// Pretty printer for event ids.
225+
#[derive(Debug)]
226+
struct EventIdDisplayIndex<'a> {
227+
array: &'a UInt32Array,
228+
null_str: &'a str,
229+
mode: IdYearMode,
230+
}
231+
232+
/// This implements the arrow-rs API for printing individual values of a column. DataFusion will
233+
/// automatically pass in the reference to this implementation if a column is annotated with the
234+
/// extension type metadata.
235+
impl DisplayIndex for EventIdDisplayIndex<'_> {
236+
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
237+
// Handle nulls first
238+
if self.array.is_null(idx) {
239+
write!(f, "{}", self.null_str)?;
240+
return Ok(());
241+
}
242+
243+
let value = self.array.value(idx);
244+
245+
match self.mode {
246+
IdYearMode::Short => {
247+
// Format: YY-MM-CCCCCC
248+
// Logic:
249+
// - The last 6 digits are the counter.
250+
// - The next 2 digits are the month.
251+
// - The remaining digits are the year.
252+
let counter = value % 1_000_000;
253+
let rest = value / 1_000_000;
254+
let month = rest % 100;
255+
let year = rest / 100;
256+
257+
write!(f, "{:02}-{:02}-{:06}", year, month, counter)?;
258+
}
259+
IdYearMode::Long => {
260+
// Format: YYYY-MM-CCCC
261+
// Logic:
262+
// - The last 4 digits are the counter.
263+
// - The next 2 digits are the month.
264+
// - The remaining digits are the year.
265+
let counter = value % 10_000;
266+
let rest = value / 10_000;
267+
let month = rest % 100;
268+
let year = rest / 100;
269+
270+
write!(f, "{:04}-{:02}-{:04}", year, month, counter)?;
271+
}
272+
}
273+
Ok(())
274+
}
275+
}
276+
277+
/// The registration is the last piece missing for the extension type implementation. It contains
278+
/// the logic for deserializing the metadata from the arrow [`Field`]s and creating the extension
279+
/// type instance. We cannot use the trait from arrow-rs as it's not dyn-compatible (the Metadata
280+
/// type must be known at compile time).
281+
///
282+
/// If an extension type does not have any parameters, the [`SimpleExtensionTypeRegistration`]
283+
/// provides an easier way of registering it.
284+
#[derive(Debug)]
285+
pub struct EventIdExtensionTypeRegistration();
286+
287+
impl ExtensionTypeRegistration for EventIdExtensionTypeRegistration {
288+
fn type_name(&self) -> &str {
289+
EventIdExtensionType::NAME
290+
}
291+
292+
fn create_df_extension_type(
293+
&self,
294+
metadata: Option<&str>,
295+
) -> Result<DFExtensionTypeRef> {
296+
let metadata = EventIdExtensionType::deserialize_metadata(metadata)?;
297+
Ok(Arc::new(EventIdExtensionType(metadata)))
298+
}
299+
}
300+
301+
#[cfg(test)]
302+
mod tests {
303+
use super::*;
304+
use insta::assert_snapshot;
305+
306+
#[tokio::test]
307+
async fn test_print_example_table() -> Result<()> {
308+
let ctx = create_session_context()?;
309+
let table = register_events_table(&ctx).await?;
310+
311+
assert_snapshot!(
312+
table.to_string().await?,
313+
@r"
314+
+----------------+---------------+-----------------------+
315+
| event_id_short | event_id_long | name |
316+
+----------------+---------------+-----------------------+
317+
| 20-01-000000 | 2020-01-0000 | First Event Jan 2020 |
318+
| 20-01-000001 | 2020-01-0001 | Second Event Jan 2020 |
319+
| 21-03-000000 | 2021-03-0000 | First Event Mar 2021 |
320+
| 21-03-000001 | 2021-03-0001 | Second Event Mar 2021 |
321+
| 21-03-000002 | 2021-03-0002 | Third Event Mar 2021 |
322+
+----------------+---------------+-----------------------+
323+
"
324+
);
325+
326+
Ok(())
327+
}
328+
}

0 commit comments

Comments
 (0)