Skip to content

Commit dad0be4

Browse files
authored
[Arrow] Add API to check if Field has a valid ExtensionType (#9677)
# Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. --> - Closes #8474. # Rationale for this change Check issue <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> # What changes are included in this PR? - Added a `has_valid_extension_type` API to `Field` - Added a unit test to save behavior <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> # Are these changes tested? - yes, unit test <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> # Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. If there are any breaking changes to public APIs, please call them out. -->
1 parent 88b7fca commit dad0be4

11 files changed

Lines changed: 159 additions & 53 deletions

File tree

arrow-schema/src/extension/canonical/bool8.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ impl ExtensionType for Bool8 {
6868
fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
6969
Self.supports_data_type(data_type).map(|_| Self)
7070
}
71+
72+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
73+
Self.supports_data_type(data_type)
74+
}
7175
}
7276

7377
#[cfg(test)]

arrow-schema/src/extension/canonical/json.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ impl ExtensionType for Json {
173173
json.supports_data_type(data_type)?;
174174
Ok(json)
175175
}
176+
177+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
178+
Self::default().supports_data_type(data_type)
179+
}
176180
}
177181

178182
#[cfg(test)]

arrow-schema/src/extension/canonical/opaque.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ impl ExtensionType for Opaque {
257257
fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
258258
Ok(Self::from(metadata))
259259
}
260+
261+
fn validate(_data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
262+
Ok(())
263+
}
260264
}
261265

262266
#[cfg(test)]

arrow-schema/src/extension/canonical/timestamp_with_offset.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ impl ExtensionType for TimestampWithOffset {
139139
fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
140140
Self.supports_data_type(data_type).map(|_| Self)
141141
}
142+
143+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
144+
Self.supports_data_type(data_type)
145+
}
142146
}
143147

144148
#[cfg(test)]

arrow-schema/src/extension/canonical/uuid.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ impl ExtensionType for Uuid {
7373
fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
7474
Self.supports_data_type(data_type).map(|_| Self)
7575
}
76+
77+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
78+
Self.supports_data_type(data_type)
79+
}
7680
}
7781

7882
#[cfg(test)]

arrow-schema/src/extension/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,15 @@ pub trait ExtensionType: Sized {
257257
/// this extension type.
258258
fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError>;
259259

260+
/// Validate this extension type for a field with the given data type and
261+
/// metadata.
262+
///
263+
/// The default implementation delegates to [`Self::try_new`]. Extension
264+
/// types may override this to validate without constructing `Self`.
265+
fn validate(data_type: &DataType, metadata: Self::Metadata) -> Result<(), ArrowError> {
266+
Self::try_new(data_type, metadata).map(|_| ())
267+
}
268+
260269
/// Construct this extension type from field metadata and data type.
261270
///
262271
/// This is a provided method that extracts extension type information from

arrow-schema/src/field.rs

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -504,13 +504,39 @@ impl Field {
504504
.map(String::as_ref)
505505
}
506506

507+
/// Returns `true` if this [`Field`] has the given [`ExtensionType`] name
508+
/// and can be successfully validated as that extension type.
509+
///
510+
/// This first checks the extension type name and only calls
511+
/// [`ExtensionType::validate`] when the name matches.
512+
///
513+
/// This is useful when you only need a boolean validity check and do not
514+
/// need to retrieve the extension type instance.
515+
#[inline]
516+
pub fn has_valid_extension_type<E: ExtensionType>(&self) -> bool {
517+
if self.extension_type_name() != Some(E::NAME) {
518+
return false;
519+
}
520+
521+
let ext_metadata = self
522+
.metadata()
523+
.get(EXTENSION_TYPE_METADATA_KEY)
524+
.map(|s| s.as_str());
525+
526+
E::deserialize_metadata(ext_metadata)
527+
.and_then(|metadata| E::validate(self.data_type(), metadata))
528+
.is_ok()
529+
}
530+
507531
/// Returns an instance of the given [`ExtensionType`] of this [`Field`],
508532
/// if set in the [`Field::metadata`].
509533
///
510534
/// Note that using `try_extension_type` with an extension type that does
511535
/// not match the name in the metadata will return an `ArrowError` which can
512536
/// be slow due to string allocations. If you only want to check if a
513-
/// [`Field`] has a specific [`ExtensionType`], see the example below.
537+
/// [`Field`] has a specific [`ExtensionType`], first check
538+
/// [`Field::extension_type_name`], or use [`Field::has_valid_extension_type`]
539+
/// to also validate metadata and data type.
514540
///
515541
/// # Errors
516542
///
@@ -524,7 +550,7 @@ impl Field {
524550
/// fail (for example when the [`Field::data_type`] is not supported by
525551
/// the extension type ([`ExtensionType::supports_data_type`]))
526552
///
527-
/// # Examples: Check and retrieve an extension type
553+
/// # Example: Check and retrieve an extension type
528554
/// You can use this to check if a [`Field`] has a specific
529555
/// [`ExtensionType`] and retrieve it:
530556
/// ```
@@ -546,34 +572,6 @@ impl Field {
546572
/// // do something with extension_type
547573
/// }
548574
/// ```
549-
///
550-
/// # Example: Checking if a field has a specific extension type first
551-
///
552-
/// Since `try_extension_type` returns an error, it is more
553-
/// efficient to first check if the name matches before calling
554-
/// `try_extension_type`:
555-
/// ```
556-
/// # use arrow_schema::{DataType, Field, ArrowError};
557-
/// # use arrow_schema::extension::ExtensionType;
558-
/// # struct MyExtensionType;
559-
/// # impl ExtensionType for MyExtensionType {
560-
/// # const NAME: &'static str = "my_extension";
561-
/// # type Metadata = String;
562-
/// # fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { Ok(()) }
563-
/// # fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> { Ok(Self) }
564-
/// # fn serialize_metadata(&self) -> Option<String> { unimplemented!() }
565-
/// # fn deserialize_metadata(s: Option<&str>) -> Result<Self::Metadata, ArrowError> { unimplemented!() }
566-
/// # fn metadata(&self) -> &<Self as ExtensionType>::Metadata { todo!() }
567-
/// # }
568-
/// # fn get_field() -> Field { Field::new("field", DataType::Null, false) }
569-
/// let field = get_field();
570-
/// // First check if the name matches before calling the potentially expensive `try_extension_type`
571-
/// if field.extension_type_name() == Some(MyExtensionType::NAME) {
572-
/// if let Ok(extension_type) = field.try_extension_type::<MyExtensionType>() {
573-
/// // do something with extension_type
574-
/// }
575-
/// }
576-
/// ```
577575
pub fn try_extension_type<E: ExtensionType>(&self) -> Result<E, ArrowError> {
578576
E::try_new_from_field_metadata(self.data_type(), self.metadata())
579577
}
@@ -1013,6 +1011,80 @@ mod test {
10131011
use super::*;
10141012
use std::collections::hash_map::DefaultHasher;
10151013

1014+
#[derive(Debug, Clone, Copy)]
1015+
struct TestExtensionType;
1016+
1017+
impl ExtensionType for TestExtensionType {
1018+
const NAME: &'static str = "test.extension";
1019+
type Metadata = ();
1020+
1021+
fn metadata(&self) -> &Self::Metadata {
1022+
&()
1023+
}
1024+
1025+
fn serialize_metadata(&self) -> Option<String> {
1026+
None
1027+
}
1028+
1029+
fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
1030+
metadata.map_or(Ok(()), |_| {
1031+
Err(ArrowError::InvalidArgumentError(
1032+
"TestExtensionType expects no metadata".to_owned(),
1033+
))
1034+
})
1035+
}
1036+
1037+
fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> {
1038+
Ok(())
1039+
}
1040+
1041+
fn try_new(_data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
1042+
Ok(Self)
1043+
}
1044+
}
1045+
1046+
#[test]
1047+
fn test_has_valid_extension_type() {
1048+
let no_extension = Field::new("f", DataType::Null, false);
1049+
assert!(!no_extension.has_valid_extension_type::<TestExtensionType>());
1050+
1051+
let matching_name = Field::new("f", DataType::Null, false).with_metadata(
1052+
[(
1053+
EXTENSION_TYPE_NAME_KEY.to_owned(),
1054+
TestExtensionType::NAME.to_owned(),
1055+
)]
1056+
.into_iter()
1057+
.collect(),
1058+
);
1059+
assert!(matching_name.has_valid_extension_type::<TestExtensionType>());
1060+
1061+
let matching_name_with_invalid_metadata = Field::new("f", DataType::Null, false)
1062+
.with_metadata(
1063+
[
1064+
(
1065+
EXTENSION_TYPE_NAME_KEY.to_owned(),
1066+
TestExtensionType::NAME.to_owned(),
1067+
),
1068+
(EXTENSION_TYPE_METADATA_KEY.to_owned(), "invalid".to_owned()),
1069+
]
1070+
.into_iter()
1071+
.collect(),
1072+
);
1073+
assert!(
1074+
!matching_name_with_invalid_metadata.has_valid_extension_type::<TestExtensionType>()
1075+
);
1076+
1077+
let different_name = Field::new("f", DataType::Null, false).with_metadata(
1078+
[(
1079+
EXTENSION_TYPE_NAME_KEY.to_owned(),
1080+
"some.other_extension".to_owned(),
1081+
)]
1082+
.into_iter()
1083+
.collect(),
1084+
);
1085+
assert!(!different_name.has_valid_extension_type::<TestExtensionType>());
1086+
}
1087+
10161088
#[test]
10171089
fn test_new_with_string() {
10181090
// Fields should allow owned Strings to support reuse

parquet-variant-compute/src/variant_array.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ impl ExtensionType for VariantType {
8080
Self.supports_data_type(data_type)?;
8181
Ok(Self)
8282
}
83+
84+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<()> {
85+
Self.supports_data_type(data_type)
86+
}
8387
}
8488

8589
/// An array of Parquet [`Variant`] values
@@ -131,9 +135,9 @@ impl ExtensionType for VariantType {
131135
/// let schema = get_schema();
132136
/// assert_eq!(schema.fields().len(), 2);
133137
/// // first field is not a Variant
134-
/// assert!(schema.field(0).try_extension_type::<VariantType>().is_err());
138+
/// assert!(!schema.field(0).has_valid_extension_type::<VariantType>());
135139
/// // second field is a Variant
136-
/// assert!(schema.field(1).try_extension_type::<VariantType>().is_ok());
140+
/// assert!(schema.field(1).has_valid_extension_type::<VariantType>());
137141
/// ```
138142
///
139143
/// # Example: Constructing the correct [`Field`] for a [`VariantArray`]

parquet/src/arrow/schema/extension.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,13 @@ pub(crate) fn has_extension_type(parquet_type: &Type) -> bool {
110110
/// Return the Parquet logical type to use for the specified Arrow Struct field, if any.
111111
#[cfg(feature = "variant_experimental")]
112112
pub(crate) fn logical_type_for_struct(field: &Field) -> Option<LogicalType> {
113-
use arrow_schema::extension::ExtensionType;
114113
use parquet_variant_compute::VariantType;
115-
// Check the name (= quick and cheap) and only try_extension_type if the name matches
116-
// to avoid unnecessary String allocations in ArrowError
117-
if field.extension_type_name()? != VariantType::NAME {
118-
return None;
119-
}
120-
match field.try_extension_type::<VariantType>() {
121-
Ok(VariantType) => Some(LogicalType::Variant {
114+
if field.has_valid_extension_type::<VariantType>() {
115+
Some(LogicalType::Variant {
122116
specification_version: None,
123-
}),
124-
// Given check above, this should not error, but if it does ignore
125-
Err(_e) => None,
117+
})
118+
} else {
119+
None
126120
}
127121
}
128122

@@ -137,9 +131,8 @@ pub(crate) fn logical_type_for_fixed_size_binary(field: &Field) -> Option<Logica
137131
use arrow_schema::extension::Uuid;
138132
// If set, map arrow uuid extension type to parquet uuid logical type.
139133
field
140-
.try_extension_type::<Uuid>()
141-
.ok()
142-
.map(|_| LogicalType::Uuid)
134+
.has_valid_extension_type::<Uuid>()
135+
.then_some(LogicalType::Uuid)
143136
}
144137

145138
#[cfg(not(feature = "arrow_canonical_extension_types"))]
@@ -153,9 +146,11 @@ pub(crate) fn logical_type_for_string(field: &Field) -> Option<LogicalType> {
153146
use arrow_schema::extension::Json;
154147
// Use the Json logical type if the canonical Json
155148
// extension type is set on this field.
156-
field
157-
.try_extension_type::<Json>()
158-
.map_or(Some(LogicalType::String), |_| Some(LogicalType::Json))
149+
Some(if field.has_valid_extension_type::<Json>() {
150+
LogicalType::Json
151+
} else {
152+
LogicalType::String
153+
})
159154
}
160155

161156
#[cfg(not(feature = "arrow_canonical_extension_types"))]

parquet/src/arrow/schema/virtual_type.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ impl ExtensionType for RowGroupIndex {
6969
fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
7070
Self.supports_data_type(data_type).map(|_| Self)
7171
}
72+
73+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
74+
Self.supports_data_type(data_type)
75+
}
7276
}
7377

7478
/// The extension type for row numbers.
@@ -113,6 +117,10 @@ impl ExtensionType for RowNumber {
113117
fn try_new(data_type: &DataType, _metadata: Self::Metadata) -> Result<Self, ArrowError> {
114118
Self.supports_data_type(data_type).map(|_| Self)
115119
}
120+
121+
fn validate(data_type: &DataType, _metadata: Self::Metadata) -> Result<(), ArrowError> {
122+
Self.supports_data_type(data_type)
123+
}
116124
}
117125

118126
/// Returns `true` if the field is a virtual column.

0 commit comments

Comments
 (0)